X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/1d8b4d766d912c7b9e91fa885419730c334345ef..aebd3c37b28bbc0183a58d13b80e7595db3c09bb:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 7e13a5d..7c1a013 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -42,7 +42,6 @@ from typing import ( cast, TYPE_CHECKING, ) -from typing_extensions import Final from mypy_extensions import mypyc_attr from appdirs import user_cache_dir @@ -61,6 +60,11 @@ from blib2to3.pgen2.parse import ParseError from _black_version import version as __version__ +if sys.version_info < (3, 8): + from typing_extensions import Final +else: + from typing import Final + if TYPE_CHECKING: import colorama # noqa: F401 @@ -68,6 +72,7 @@ DEFAULT_LINE_LENGTH = 88 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) +STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__" STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters. @@ -88,7 +93,7 @@ Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]] Timestamp = float FileSize = int CacheInfo = Tuple[Timestamp, FileSize] -Cache = Dict[Path, CacheInfo] +Cache = Dict[str, CacheInfo] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg="red", err=True) @@ -255,6 +260,7 @@ class Mode: target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True + magic_trailing_comma: bool = True experimental_string_processing: bool = False is_pyi: bool = False @@ -392,6 +398,12 @@ def target_version_option_callback( is_flag=True, help="Don't normalize string quotes or prefixes.", ) +@click.option( + "-C", + "--skip-magic-trailing-comma", + is_flag=True, + help="Don't use trailing commas as a reason to split lines.", +) @click.option( "--experimental-string-processing", is_flag=True, @@ -457,6 +469,15 @@ def target_version_option_callback( "excluded even when they are passed explicitly as arguments." ), ) +@click.option( + "--stdin-filename", + type=str, + help=( + "The name of the file when passing it through stdin. Useful to make " + "sure Black will respect --force-exclude option on some " + "editors that rely on using stdin." + ), +) @click.option( "-q", "--quiet", @@ -510,12 +531,14 @@ def main( fast: bool, pyi: bool, skip_string_normalization: bool, + skip_magic_trailing_comma: bool, experimental_string_processing: bool, quiet: bool, verbose: bool, include: str, exclude: str, force_exclude: Optional[str], + stdin_filename: Optional[str], src: Tuple[str, ...], config: Optional[str], ) -> None: @@ -531,6 +554,7 @@ def main( line_length=line_length, is_pyi=pyi, string_normalization=not skip_string_normalization, + magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, ) if config and verbose: @@ -548,6 +572,7 @@ def main( exclude=exclude, force_exclude=force_exclude, report=report, + stdin_filename=stdin_filename, ) path_empty( @@ -587,6 +612,7 @@ def get_sources( exclude: str, force_exclude: Optional[str], report: "Report", + stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" try: @@ -613,22 +639,14 @@ def get_sources( gitignore = get_gitignore(root) for s in src: - p = Path(s) - if p.is_dir(): - sources.update( - gen_python_files( - p.iterdir(), - root, - include_regex, - exclude_regex, - force_exclude_regex, - report, - gitignore, - ) - ) - elif s == "-": - sources.add(p) - elif p.is_file(): + if s == "-" and stdin_filename: + p = Path(stdin_filename) + is_stdin = True + else: + p = Path(s) + is_stdin = False + + if is_stdin or p.is_file(): normalized_path = normalize_path_maybe_ignore(p, root, report) if normalized_path is None: continue @@ -643,6 +661,23 @@ def get_sources( report.path_ignored(p, "matches the --force-exclude regular expression") continue + if is_stdin: + p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") + + sources.add(p) + elif p.is_dir(): + sources.update( + gen_python_files( + p.iterdir(), + root, + include_regex, + exclude_regex, + force_exclude_regex, + report, + gitignore, + ) + ) + elif s == "-": sources.add(p) else: err(f"invalid path: {s}") @@ -670,7 +705,18 @@ def reformat_one( """ try: changed = Changed.NO - if not src.is_file() and str(src) == "-": + + if str(src) == "-": + is_stdin = True + elif str(src).startswith(STDIN_PLACEHOLDER): + is_stdin = True + # Use the original name again in case we want to print something + # to the user + src = Path(str(src)[len(STDIN_PLACEHOLDER) :]) + else: + is_stdin = False + + if is_stdin: if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): changed = Changed.YES else: @@ -678,7 +724,8 @@ def reformat_one( if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): cache = read_cache(mode) res_src = src.resolve() - if res_src in cache and cache[res_src] == get_cache_info(res_src): + res_src_s = str(res_src) + if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src): changed = Changed.CACHED if changed is not Changed.CACHED and format_file_in_place( src, fast=fast, write_back=write_back, mode=mode @@ -704,7 +751,7 @@ def reformat_many( worker_count = os.cpu_count() if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 - worker_count = min(worker_count, 61) + worker_count = min(worker_count, 60) try: executor = ProcessPoolExecutor(max_workers=worker_count) except (ImportError, OSError): @@ -953,7 +1000,7 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: allowed. Example: >>> import black - >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode())) def f(arg: str = "") -> None: ... @@ -985,13 +1032,12 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: versions = detect_target_versions(src_node) normalize_fmt_off(src_node) lines = LineGenerator( + mode=mode, remove_u_prefix="unicode_literals" in future_imports or supports_feature(versions, Feature.UNICODE_LITERALS), - is_pyi=mode.is_pyi, - normalize_strings=mode.string_normalization, ) elt = EmptyLineTracker(is_pyi=mode.is_pyi) - empty_line = Line() + empty_line = Line(mode=mode) after = 0 split_line_features = { feature @@ -1427,6 +1473,7 @@ class BracketTracker: class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" + mode: Mode depth: int = 0 leaves: List[Leaf] = field(default_factory=list) # keys ordered like `leaves` @@ -1459,8 +1506,11 @@ class Line: ) if self.inside_brackets or not preformatted: self.bracket_tracker.mark(leaf) - if self.maybe_should_explode(leaf): - self.should_explode = True + if self.mode.magic_trailing_comma: + if self.has_magic_trailing_comma(leaf): + self.should_explode = True + elif self.has_magic_trailing_comma(leaf, ensure_removable=True): + self.remove_trailing_comma() if not self.append_comment(leaf): self.leaves.append(leaf) @@ -1636,10 +1686,14 @@ class Line: def contains_multiline_strings(self) -> bool: return any(is_multiline_string(leaf) for leaf in self.leaves) - def maybe_should_explode(self, closing: Leaf) -> bool: - """Return True if this line should explode (always be split), that is when: - - there's a trailing comma here; and - - it's not a one-tuple. + def has_magic_trailing_comma( + self, closing: Leaf, ensure_removable: bool = False + ) -> bool: + """Return True if we have a magic trailing comma, that is when: + - there's a trailing comma here + - it's not a one-tuple + Additionally, if ensure_removable: + - it's not from square bracket indexing """ if not ( closing.type in CLOSING_BRACKETS @@ -1648,9 +1702,15 @@ class Line: ): return False - if closing.type in {token.RBRACE, token.RSQB}: + if closing.type == token.RBRACE: return True + if closing.type == token.RSQB: + if not ensure_removable: + return True + comma = self.leaves[-1] + return bool(comma.parent and comma.parent.type == syms.listmaker) + if self.is_import: return True @@ -1728,6 +1788,7 @@ class Line: def clone(self) -> "Line": return Line( + mode=self.mode, depth=self.depth, inside_brackets=self.inside_brackets, should_explode=self.should_explode, @@ -1886,10 +1947,9 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ - is_pyi: bool = False - normalize_strings: bool = True - current_line: Line = field(default_factory=Line) + mode: Mode remove_u_prefix: bool = False + current_line: Line = field(init=False) def line(self, indent: int = 0) -> Iterator[Line]: """Generate a line. @@ -1904,7 +1964,7 @@ class LineGenerator(Visitor[Line]): return # Line is empty, don't emit. Creating a new one unnecessary. complete_line = self.current_line - self.current_line = Line(depth=complete_line.depth + indent) + self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) yield complete_line def visit_default(self, node: LN) -> Iterator[Line]: @@ -1928,7 +1988,7 @@ class LineGenerator(Visitor[Line]): yield from self.line() normalize_prefix(node, inside_brackets=any_open_brackets) - if self.normalize_strings and node.type == token.STRING: + if self.mode.string_normalization and node.type == token.STRING: normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type == token.NUMBER: @@ -1980,7 +2040,7 @@ class LineGenerator(Visitor[Line]): def visit_suite(self, node: Node) -> Iterator[Line]: """Visit a suite.""" - if self.is_pyi and is_stub_suite(node): + if self.mode.is_pyi and is_stub_suite(node): yield from self.visit(node.children[2]) else: yield from self.visit_default(node) @@ -1989,7 +2049,7 @@ class LineGenerator(Visitor[Line]): """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - if self.is_pyi and is_stub_body(node): + if self.mode.is_pyi and is_stub_body(node): yield from self.visit_default(node) else: yield from self.line(+1) @@ -1997,7 +2057,11 @@ class LineGenerator(Visitor[Line]): yield from self.line(-1) else: - if not self.is_pyi or not node.parent or not is_stub_suite(node.parent): + if ( + not self.mode.is_pyi + or not node.parent + or not is_stub_suite(node.parent) + ): yield from self.line() yield from self.visit_default(node) @@ -2073,6 +2137,8 @@ class LineGenerator(Visitor[Line]): def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" + self.current_line = Line(mode=self.mode) + v = self.visit_stmt Ø: Set[str] = set() self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) @@ -3238,7 +3304,8 @@ class StringParenStripper(StringTransformer): Requirements: The line contains a string which is surrounded by parentheses and: - - The target string is NOT the only argument to a function call). + - The target string is NOT the only argument to a function call. + - The target string is NOT a "pointless" string. - If the target string contains a PERCENT, the brackets are not preceeded or followed by an operator with higher precedence than PERCENT. @@ -3262,6 +3329,14 @@ class StringParenStripper(StringTransformer): if leaf.type != token.STRING: continue + # If this is a "pointless" string... + if ( + leaf.parent + and leaf.parent.parent + and leaf.parent.parent.type == syms.simple_stmt + ): + continue + # Should be preceded by a non-empty LPAR... if ( not is_valid_index(idx - 1) @@ -4304,6 +4379,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): # `StringSplitter` will break it down further if necessary. string_value = LL[string_idx].value string_line = Line( + mode=line.mode, depth=line.depth + 1, inside_brackets=True, should_explode=line.should_explode, @@ -4897,7 +4973,7 @@ def bracket_split_build_line( If `is_body` is True, the result line is one-indented inside brackets and as such has its first leaf's prefix normalized and a trailing comma added when expected. """ - result = Line(depth=original.depth) + result = Line(mode=original.mode, depth=original.depth) if is_body: result.inside_brackets = True result.depth += 1 @@ -4969,7 +5045,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ if bt.delimiter_count_with_priority(delimiter_priority) == 1: raise CannotSplit("Splitting a single attribute from its owner looks wrong") - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) lowest_depth = sys.maxsize trailing_comma_safe = True @@ -4981,7 +5059,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ except ValueError: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) current_line.append(leaf) for leaf in line.leaves: @@ -5005,7 +5085,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ if leaf_priority == delimiter_priority: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) if current_line: if ( trailing_comma_safe @@ -5026,7 +5108,9 @@ def standalone_comment_split( if not line.contains_standalone_comments(0): raise CannotSplit("Line does not have any standalone comments") - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) def append_to_line(leaf: Leaf) -> Iterator[Line]: """Append `leaf` to current line or to new line if appending impossible.""" @@ -5036,7 +5120,9 @@ def standalone_comment_split( except ValueError: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) current_line.append(leaf) for leaf in line.leaves: @@ -5192,31 +5278,52 @@ def normalize_numeric_literal(leaf: Leaf) -> None: # Leave octal and binary literals alone. pass elif text.startswith("0x"): - # Change hex literals to upper case. - before, after = text[:2], text[2:] - text = f"{before}{after.upper()}" + text = format_hex(text) elif "e" in text: - before, after = text.split("e") - sign = "" - if after.startswith("-"): - after = after[1:] - sign = "-" - elif after.startswith("+"): - after = after[1:] - before = format_float_or_int_string(before) - text = f"{before}e{sign}{after}" + text = format_scientific_notation(text) elif text.endswith(("j", "l")): - number = text[:-1] - suffix = text[-1] - # Capitalize in "2L" because "l" looks too similar to "1". - if suffix == "l": - suffix = "L" - text = f"{format_float_or_int_string(number)}{suffix}" + text = format_long_or_complex_number(text) else: text = format_float_or_int_string(text) leaf.value = text +def format_hex(text: str) -> str: + """ + Formats a hexadecimal string like "0x12b3" + + Uses lowercase because of similarity between "B" and "8", which + can cause security issues. + see: https://github.com/psf/black/issues/1692 + """ + + before, after = text[:2], text[2:] + return f"{before}{after.lower()}" + + +def format_scientific_notation(text: str) -> str: + """Formats a numeric string utilizing scentific notation""" + before, after = text.split("e") + sign = "" + if after.startswith("-"): + after = after[1:] + sign = "-" + elif after.startswith("+"): + after = after[1:] + before = format_float_or_int_string(before) + return f"{before}e{sign}{after}" + + +def format_long_or_complex_number(text: str) -> str: + """Formats a long or complex string like `10L` or `10j`""" + number = text[:-1] + suffix = text[-1] + # Capitalize in "2L" because "l" looks too similar to "1". + if suffix == "l": + suffix = "L" + return f"{format_float_or_int_string(number)}{suffix}" + + def format_float_or_int_string(text: str) -> str: """Formats a float string like "1.0".""" if "." not in text: @@ -5700,7 +5807,7 @@ def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: return False return max_priority == COMMA_PRIORITY and ( - trailing_comma + (line.mode.magic_trailing_comma and trailing_comma) # always explode imports or opening_bracket.parent.type in {syms.atom, syms.import_from} ) @@ -6675,8 +6782,8 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set """ todo, done = set(), set() for src in sources: - src = src.resolve() - if cache.get(src) != get_cache_info(src): + res_src = src.resolve() + if cache.get(str(res_src)) != get_cache_info(res_src): todo.add(src) else: done.add(src) @@ -6688,7 +6795,10 @@ def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: cache_file = get_cache_file(mode) try: CACHE_DIR.mkdir(parents=True, exist_ok=True) - new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + new_cache = { + **cache, + **{str(src.resolve()): get_cache_info(src) for src in sources}, + } with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f: pickle.dump(new_cache, f, protocol=4) os.replace(f.name, cache_file)