X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/448885b256ca1741fda7c4ef17f80f750ea614c0..9394de150ebf0adc426523f46dc08e8b2b2b0b63:/black.py diff --git a/black.py b/black.py index 17aea7a..8318674 100644 --- a/black.py +++ b/black.py @@ -16,6 +16,7 @@ import signal import sys import tempfile import tokenize +import traceback from typing import ( Any, Callable, @@ -336,7 +337,7 @@ def read_pyproject_toml( "--quiet", is_flag=True, help=( - "Don't emit non-error messages to stderr. Errors are still emitted, " + "Don't emit non-error messages to stderr. Errors are still emitted; " "silence those with 2>/dev/null." ), ) @@ -457,8 +458,7 @@ def main( ) if verbose or not quiet: - bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨" - out(f"All done! {bang}") + out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") click.secho(str(report), err=True) ctx.exit(report.return_code) @@ -468,8 +468,7 @@ def reformat_one( ) -> None: """Reformat a single file under `src` without spawning child processes. - If `quiet` is True, non-error messages are not output. `line_length`, - `write_back`, `fast` and `pyi` options are passed to + `fast`, `write_back`, and `mode` options are passed to :func:`format_file_in_place` or :func:`format_stdin_to_stdout`. """ try: @@ -609,7 +608,7 @@ def format_file_in_place( If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted code to the file. - `line_length` and `fast` options are passed to :func:`format_file_contents`. + `mode` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": mode = evolve(mode, is_pyi=True) @@ -687,7 +686,7 @@ def format_file_contents( If `fast` is False, additionally confirm that the reformatted code is valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. - `line_length` is passed to :func:`format_str`. + `mode` is passed to :func:`format_str`. """ if src_contents.strip() == "": raise NothingChanged @@ -705,10 +704,11 @@ def format_file_contents( def format_str(src_contents: str, *, mode: FileMode) -> FileContent: """Reformat a string and return new contents. - `line_length` determines how many characters per line are allowed. + `mode` determines formatting options, such as how many characters per line are + allowed. """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_contents = "" + dst_contents = [] future_imports = get_future_imports(src_node) if mode.target_versions: versions = mode.target_versions @@ -731,15 +731,15 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: } for current_line in lines.visit(src_node): for _ in range(after): - dst_contents += str(empty_line) + dst_contents.append(str(empty_line)) before, after = elt.maybe_empty_lines(current_line) for _ in range(before): - dst_contents += str(empty_line) + dst_contents.append(str(empty_line)) for line in split_line( current_line, line_length=mode.line_length, features=split_line_features ): - dst_contents += str(line) - return dst_contents + dst_contents.append(str(line)) + return "".join(dst_contents) def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: @@ -1061,7 +1061,7 @@ class BracketTracker: """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) - def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: + def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority: """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_split_*_delimiter()` return. @@ -1069,7 +1069,7 @@ class BracketTracker: """ return max(v for k, v in self.delimiters.items() if k not in exclude) - def delimiter_count_with_priority(self, priority: int = 0) -> int: + def delimiter_count_with_priority(self, priority: Priority = 0) -> int: """Return the number of delimiters with the given `priority`. If no `priority` is passed, defaults to max priority on the line. @@ -1352,7 +1352,10 @@ class Line: bracket_depth = leaf.bracket_depth if bracket_depth == depth and leaf.type == token.COMMA: commas += 1 - if leaf.parent and leaf.parent.type == syms.arglist: + if leaf.parent and leaf.parent.type in { + syms.arglist, + syms.typedargslist, + }: commas += 1 break @@ -1602,6 +1605,26 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) + def visit_atom(self, node: Node) -> Iterator[Line]: + # Always make parentheses invisible around a single node, because it should + # not be needed (except in the case of yield, where removing the parentheses + # produces a SyntaxError). + if ( + len(node.children) == 3 + and isinstance(node.children[0], Leaf) + and node.children[0].type == token.LPAR + and isinstance(node.children[2], Leaf) + and node.children[2].type == token.RPAR + and isinstance(node.children[1], Leaf) + and not ( + node.children[1].type == token.NAME + and node.children[1].value == "yield" + ) + ): + node.children[0].value = "" + node.children[2].value = "" + yield from super().visit_default(node) + def visit_INDENT(self, node: Node) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. @@ -2015,7 +2038,7 @@ def container_of(leaf: Leaf) -> LN: return container -def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int: +def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break after it. The delimiter priorities returned here are from those delimiters that would @@ -2029,7 +2052,7 @@ def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int return 0 -def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int: +def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break before it. The delimiter priorities returned here are from those delimiters that would @@ -2468,9 +2491,13 @@ def bracket_split_build_line( if leaves: # Since body is a new indent level, remove spurious leading whitespace. normalize_prefix(leaves[0], inside_brackets=True) - # Ensure a trailing comma for imports, but be careful not to add one after - # any comments. - if original.is_import: + # Ensure a trailing comma for imports and standalone function arguments, but + # be careful not to add one after any comments. + no_commas = original.is_def and not any( + l.type == token.COMMA for l in leaves + ) + + if original.is_import or no_commas: for i in range(len(leaves) - 1, -1, -1): if leaves[i].type == STANDALONE_COMMENT: continue @@ -2709,7 +2736,15 @@ def normalize_string_quotes(leaf: Leaf) -> None: new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) if "f" in prefix.casefold(): - matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body) + matches = re.findall( + r""" + (?:[^{]|^)\{ # start of the string or a non-{ followed by a single { + ([^{].*?) # contents of the brackets except if begins with {{ + \}(?:[^}]|$) # A } followed by end of the string or a non-} + """, + new_body, + re.VERBOSE, + ) for m in matches: if "\\" in str(m): # Do not introduce backslashes in interpolated expressions @@ -2776,7 +2811,7 @@ def format_float_or_int_string(text: str) -> str: def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: """Make existing optional parentheses invisible or create new ones. - `parens_after` is a set of string leaf values immeditely after which parens + `parens_after` is a set of string leaf values immediately after which parens should be put. Standardizes on visible parentheses for single-element tuples, and keeps @@ -3054,7 +3089,7 @@ def is_stub_body(node: LN) -> bool: ) -def max_delimiter_priority_in_atom(node: LN) -> int: +def max_delimiter_priority_in_atom(node: LN) -> Priority: """Return maximum delimiter priority inside `node`. This is specific to atoms with contents contained in a pair of parentheses. @@ -3427,8 +3462,6 @@ def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]: def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" - import traceback - def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: """Simple visitor generating strings to compare ASTs by content.""" yield f"{' ' * depth}{node.__class__.__name__}(" @@ -3519,8 +3552,6 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" - import tempfile - with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: @@ -3618,7 +3649,6 @@ def enumerate_with_length( if "\n" in leaf.value: return # Multiline strings, we can't continue. - comment: Optional[Leaf] for comment in line.comments_after(leaf): length += len(comment.value)