X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/6fdbdb4ee301d8f9466da4da8364ac05611a1b19..3bfb66971f03da39ae1f4c98c30d55e60f63d33b:/black.py?ds=inline diff --git a/black.py b/black.py index 1978fd5..0776983 100644 --- a/black.py +++ b/black.py @@ -1,6 +1,7 @@ +import ast import asyncio -from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from contextlib import contextmanager from datetime import datetime from enum import Enum from functools import lru_cache, partial, wraps @@ -11,11 +12,12 @@ from multiprocessing import Manager, freeze_support import os from pathlib import Path import pickle -import re +import regex as re import signal import sys import tempfile import tokenize +import traceback from typing import ( Any, Callable, @@ -40,6 +42,7 @@ from appdirs import user_cache_dir from attr import dataclass, evolve, Factory import click import toml +from typed_ast import ast3, ast27 # lib2to3 fork from blib2to3.pytree import Node, Leaf, type_repr @@ -48,8 +51,8 @@ from blib2to3.pgen2 import driver, token from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.parse import ParseError +from _version import version as __version__ -__version__ = "19.3b0" DEFAULT_LINE_LENGTH = 88 DEFAULT_EXCLUDES = ( r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/" @@ -135,19 +138,30 @@ class Feature(Enum): NUMERIC_UNDERSCORES = 3 TRAILING_COMMA_IN_CALL = 4 TRAILING_COMMA_IN_DEF = 5 + # The following two feature-flags are mutually exclusive, and exactly one should be + # set for every version of python. + ASYNC_IDENTIFIERS = 6 + ASYNC_KEYWORDS = 7 + ASSIGNMENT_EXPRESSIONS = 8 + POS_ONLY_ARGUMENTS = 9 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { - TargetVersion.PY27: set(), - TargetVersion.PY33: {Feature.UNICODE_LITERALS}, - TargetVersion.PY34: {Feature.UNICODE_LITERALS}, - TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL}, + TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS}, + TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS}, + TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS}, + TargetVersion.PY35: { + Feature.UNICODE_LITERALS, + Feature.TRAILING_COMMA_IN_CALL, + Feature.ASYNC_IDENTIFIERS, + }, TargetVersion.PY36: { Feature.UNICODE_LITERALS, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES, Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_IDENTIFIERS, }, TargetVersion.PY37: { Feature.UNICODE_LITERALS, @@ -155,6 +169,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { Feature.NUMERIC_UNDERSCORES, Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_KEYWORDS, }, TargetVersion.PY38: { Feature.UNICODE_LITERALS, @@ -162,6 +177,9 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { Feature.NUMERIC_UNDERSCORES, Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_KEYWORDS, + Feature.ASSIGNMENT_EXPRESSIONS, + Feature.POS_ONLY_ARGUMENTS, }, } @@ -324,7 +342,7 @@ def read_pyproject_toml( "--quiet", is_flag=True, help=( - "Don't emit non-error messages to stderr. Errors are still emitted, " + "Don't emit non-error messages to stderr. Errors are still emitted; " "silence those with 2>/dev/null." ), ) @@ -415,6 +433,7 @@ def main( report = Report(check=check, quiet=quiet, verbose=verbose) root = find_project_root(src) sources: Set[Path] = set() + path_empty(src, quiet, verbose, ctx) for s in src: p = Path(s) if p.is_dir(): @@ -428,7 +447,7 @@ def main( err(f"invalid path: {s}") if len(sources) == 0: if verbose or not quiet: - out("No paths given. Nothing to do 😴") + out("No Python files are present to be formatted. Nothing to do 😴") ctx.exit(0) if len(sources) == 1: @@ -445,19 +464,27 @@ def main( ) if verbose or not quiet: - bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨" - out(f"All done! {bang}") + out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") click.secho(str(report), err=True) ctx.exit(report.return_code) +def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None: + """ + Exit if there is no `src` provided for formatting + """ + if not src: + if verbose or not quiet: + out("No Path provided. Nothing to do 😴") + ctx.exit(0) + + def reformat_one( src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report" ) -> None: """Reformat a single file under `src` without spawning child processes. - If `quiet` is True, non-error messages are not output. `line_length`, - `write_back`, `fast` and `pyi` options are passed to + `fast`, `write_back`, and `mode` options are passed to :func:`format_file_in_place` or :func:`format_stdin_to_stdout`. """ try: @@ -513,6 +540,7 @@ def reformat_many( ) finally: shutdown(loop) + executor.shutdown() async def schedule_formatting( @@ -521,14 +549,14 @@ async def schedule_formatting( write_back: WriteBack, mode: FileMode, report: "Report", - loop: BaseEventLoop, + loop: asyncio.AbstractEventLoop, executor: Executor, ) -> None: """Run formatting of `sources` in parallel using the provided `executor`. (Use ProcessPoolExecutors for actual parallelism.) - `line_length`, `write_back`, `fast`, and `pyi` options are passed to + `write_back`, `fast`, and `mode` options are passed to :func:`format_file_in_place`. """ cache: Cache = {} @@ -597,7 +625,7 @@ def format_file_in_place( If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted code to the file. - `line_length` and `fast` options are passed to :func:`format_file_contents`. + `mode` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": mode = evolve(mode, is_pyi=True) @@ -618,9 +646,8 @@ def format_file_in_place( src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) - if lock: - lock.acquire() - try: + + with lock or nullcontext(): f = io.TextIOWrapper( sys.stdout.buffer, encoding=encoding, @@ -629,9 +656,7 @@ def format_file_in_place( ) f.write(diff_contents) f.detach() - finally: - if lock: - lock.release() + return True @@ -675,7 +700,7 @@ def format_file_contents( If `fast` is False, additionally confirm that the reformatted code is valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. - `line_length` is passed to :func:`format_str`. + `mode` is passed to :func:`format_str`. """ if src_contents.strip() == "": raise NothingChanged @@ -693,10 +718,11 @@ def format_file_contents( def format_str(src_contents: str, *, mode: FileMode) -> FileContent: """Reformat a string and return new contents. - `line_length` determines how many characters per line are allowed. + `mode` determines formatting options, such as how many characters per line are + allowed. """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_contents = "" + dst_contents = [] future_imports = get_future_imports(src_node) if mode.target_versions: versions = mode.target_versions @@ -719,15 +745,15 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: } for current_line in lines.visit(src_node): for _ in range(after): - dst_contents += str(empty_line) + dst_contents.append(str(empty_line)) before, after = elt.maybe_empty_lines(current_line) for _ in range(before): - dst_contents += str(empty_line) + dst_contents.append(str(empty_line)) for line in split_line( current_line, line_length=mode.line_length, features=split_line_features ): - dst_contents += str(line) - return dst_contents + dst_contents.append(str(line)) + return "".join(dst_contents) def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: @@ -751,16 +777,38 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: if not target_versions: # No target_version specified, so try all grammars. return [ + # Python 3.7+ + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords, + # Python 3.0-3.6 pygram.python_grammar_no_print_statement_no_exec_statement, + # Python 2.7 with future print_function import pygram.python_grammar_no_print_statement, + # Python 2.7 pygram.python_grammar, ] elif all(version.is_python2() for version in target_versions): # Python 2-only code, so try Python 2 grammars. - return [pygram.python_grammar_no_print_statement, pygram.python_grammar] + return [ + # Python 2.7 with future print_function import + pygram.python_grammar_no_print_statement, + # Python 2.7 + pygram.python_grammar, + ] else: # Python 3-compatible code, so only try Python 3 grammar. - return [pygram.python_grammar_no_print_statement_no_exec_statement] + grammars = [] + # If we have to parse both, try to parse async as a keyword first + if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS): + # Python 3.7+ + grammars.append( + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords # noqa: B950 + ) + if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS): + # Python 3.0-3.6 + grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement) + # At least one of the above branches must have been taken, because every Python + # version has exactly one of the two 'ASYNC_*' flags + return grammars def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: @@ -900,6 +948,7 @@ MATH_OPERATORS = { token.DOUBLESTAR, } STARS = {token.STAR, token.DOUBLESTAR} +VARARGS_SPECIALS = STARS | {token.SLASH} VARARGS_PARENTS = { syms.arglist, syms.argument, # double star in arglist @@ -1027,7 +1076,7 @@ class BracketTracker: """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) - def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: + def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority: """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_split_*_delimiter()` return. @@ -1035,7 +1084,7 @@ class BracketTracker: """ return max(v for k, v in self.delimiters.items() if k not in exclude) - def delimiter_count_with_priority(self, priority: int = 0) -> int: + def delimiter_count_with_priority(self, priority: Priority = 0) -> int: """Return the number of delimiters with the given `priority`. If no `priority` is passed, defaults to max priority on the line. @@ -1243,27 +1292,67 @@ class Line: return True return False - def contains_inner_type_comments(self) -> bool: + def contains_uncollapsable_type_comments(self) -> bool: ignored_ids = set() try: last_leaf = self.leaves[-1] ignored_ids.add(id(last_leaf)) - if last_leaf.type == token.COMMA: - # When trailing commas are inserted by Black for consistency, comments - # after the previous last element are not moved (they don't have to, - # rendering will still be correct). So we ignore trailing commas. + if last_leaf.type == token.COMMA or ( + last_leaf.type == token.RPAR and not last_leaf.value + ): + # When trailing commas or optional parens are inserted by Black for + # consistency, comments after the previous last element are not moved + # (they don't have to, rendering will still be correct). So we ignore + # trailing commas and invisible. last_leaf = self.leaves[-2] ignored_ids.add(id(last_leaf)) except IndexError: return False + # A type comment is uncollapsable if it is attached to a leaf + # that isn't at the end of the line (since that could cause it + # to get associated to a different argument) or if there are + # comments before it (since that could cause it to get hidden + # behind a comment. + comment_seen = False for leaf_id, comments in self.comments.items(): - if leaf_id in ignored_ids: - continue - for comment in comments: if is_type_comment(comment): - return True + if leaf_id not in ignored_ids or comment_seen: + return True + + comment_seen = True + + return False + + def contains_unsplittable_type_ignore(self) -> bool: + if not self.leaves: + return False + + # If a 'type: ignore' is attached to the end of a line, we + # can't split the line, because we can't know which of the + # subexpressions the ignore was meant to apply to. + # + # We only want this to apply to actual physical lines from the + # original source, though: we don't want the presence of a + # 'type: ignore' at the end of a multiline expression to + # justify pushing it all onto one line. Thus we + # (unfortunately) need to check the actual source lines and + # only report an unsplittable 'type: ignore' if this line was + # one line in the original code. + + # Grab the first and last line numbers, skipping generated leaves + first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0) + last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0) + + if first_line == last_line: + # We look at the last two leaves since a comma or an + # invisible paren could have been added at the end of the + # line. + for node in self.leaves[-2:]: + for comment in self.comments.get(id(node), []): + if is_type_comment(comment, " ignore"): + return True return False @@ -1318,7 +1407,10 @@ class Line: bracket_depth = leaf.bracket_depth if bracket_depth == depth and leaf.type == token.COMMA: commas += 1 - if leaf.parent and leaf.parent.type == syms.arglist: + if leaf.parent and leaf.parent.type in { + syms.arglist, + syms.typedargslist, + }: commas += 1 break @@ -1345,7 +1437,23 @@ class Line: comment.prefix = "" return False - self.comments.setdefault(id(self.leaves[-1]), []).append(comment) + last_leaf = self.leaves[-1] + if ( + last_leaf.type == token.RPAR + and not last_leaf.value + and last_leaf.parent + and len(list(last_leaf.parent.leaves())) <= 3 + and not is_type_comment(comment) + ): + # Comments on an optional parens wrapping a single leaf should belong to + # the wrapped node except if it's a type comment. Pinning the comment like + # this avoids unstable formatting caused by comment migration. + if len(self.leaves) < 2: + comment.type = STANDALONE_COMMENT + comment.prefix = "" + return False + last_leaf = self.leaves[-2] + self.comments.setdefault(id(last_leaf), []).append(comment) return True def comments_after(self, leaf: Leaf) -> List[Leaf]: @@ -1420,7 +1528,13 @@ class EmptyLineTracker: lines (two on module-level). """ before, after = self._maybe_empty_lines(current_line) - before -= self.previous_after + before = ( + # Black should not insert empty lines at the beginning + # of the file + 0 + if self.previous_line is None + else before - self.previous_after + ) self.previous_after = after self.previous_line = current_line return before, after @@ -1568,6 +1682,39 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) + def visit_atom(self, node: Node) -> Iterator[Line]: + # Always make parentheses invisible around a single node, because it should + # not be needed (except in the case of yield, where removing the parentheses + # produces a SyntaxError). + if ( + len(node.children) == 3 + and isinstance(node.children[0], Leaf) + and node.children[0].type == token.LPAR + and isinstance(node.children[2], Leaf) + and node.children[2].type == token.RPAR + and isinstance(node.children[1], Leaf) + and not ( + node.children[1].type == token.NAME + and node.children[1].value == "yield" + ) + ): + node.children[0].value = "" + node.children[2].value = "" + yield from super().visit_default(node) + + def visit_factor(self, node: Node) -> Iterator[Line]: + """Force parentheses between a unary op and a binary power: + + -2 ** 8 -> -(2 ** 8) + """ + child = node.children[1] + if child.type == syms.power and len(child.children) == 3: + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + yield from self.visit_default(node) + def visit_INDENT(self, node: Node) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. @@ -1757,7 +1904,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901 # that, too. return prevp.prefix - elif prevp.type in STARS: + elif prevp.type in VARARGS_SPECIALS: if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS): return NO @@ -1847,7 +1994,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901 if not prevp or prevp.type == token.LPAR: return NO - elif prev.type in {token.EQUAL} | STARS: + elif prev.type in {token.EQUAL} | VARARGS_SPECIALS: return NO elif p.type == syms.decorator: @@ -1981,7 +2128,7 @@ def container_of(leaf: Leaf) -> LN: return container -def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int: +def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break after it. The delimiter priorities returned here are from those delimiters that would @@ -1995,7 +2142,7 @@ def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int return 0 -def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int: +def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break before it. The delimiter priorities returned here are from those delimiters that would @@ -2214,9 +2361,12 @@ def split_line( line_str = str(line).strip("\n") if ( - not line.contains_inner_type_comments() + not line.contains_uncollapsable_type_comments() and not line.should_explode - and is_line_short_enough(line, line_length=line_length, line_str=line_str) + and ( + is_line_short_enough(line, line_length=line_length, line_str=line_str) + or line.contains_unsplittable_type_ignore() + ) ): yield line return @@ -2434,9 +2584,13 @@ def bracket_split_build_line( if leaves: # Since body is a new indent level, remove spurious leading whitespace. normalize_prefix(leaves[0], inside_brackets=True) - # Ensure a trailing comma for imports, but be careful not to add one after - # any comments. - if original.is_import: + # Ensure a trailing comma for imports and standalone function arguments, but + # be careful not to add one after any comments. + no_commas = original.is_def and not any( + l.type == token.COMMA for l in leaves + ) + + if original.is_import or no_commas: for i in range(len(leaves) - 1, -1, -1): if leaves[i].type == STANDALONE_COMMENT: continue @@ -2585,12 +2739,14 @@ def is_import(leaf: Leaf) -> bool: ) -def is_type_comment(leaf: Leaf) -> bool: +def is_type_comment(leaf: Leaf, suffix: str = "") -> bool: """Return True if the given leaf is a special comment. Only returns true for type comments for now.""" t = leaf.type v = leaf.value - return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:") + return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith( + "# type:" + suffix + ) def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: @@ -2675,7 +2831,15 @@ def normalize_string_quotes(leaf: Leaf) -> None: new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) if "f" in prefix.casefold(): - matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body) + matches = re.findall( + r""" + (?:[^{]|^)\{ # start of the string or a non-{ followed by a single { + ([^{].*?) # contents of the brackets except if begins with {{ + \}(?:[^}]|$) # A } followed by end of the string or a non-} + """, + new_body, + re.VERBOSE, + ) for m in matches: if "\\" in str(m): # Do not introduce backslashes in interpolated expressions @@ -2742,7 +2906,7 @@ def format_float_or_int_string(text: str) -> str: def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: """Make existing optional parentheses invisible or create new ones. - `parens_after` is a set of string leaf values immeditely after which parens + `parens_after` is a set of string leaf values immediately after which parens should be put. Standardizes on visible parentheses for single-element tuples, and keeps @@ -2764,8 +2928,18 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = True if check_lpar: + if is_walrus_assignment(child): + continue if child.type == syms.atom: - if maybe_make_parens_invisible_in_atom(child, parent=node): + # Determines if the underlying atom should be surrounded with + # invisible params - also makes parens invisible recursively + # within the atom and removes repeated invisible parens within + # the atom + should_surround_with_parens = maybe_make_parens_invisible_in_atom( + child, parent=node + ) + + if should_surround_with_parens: lpar = Leaf(token.LPAR, "") rpar = Leaf(token.RPAR, "") index = child.remove() or 0 @@ -2882,6 +3056,8 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: """If it's safe, make the parens in the atom `node` invisible, recursively. + Additionally, remove repeated, adjacent invisible parens from the atom `node` + as they are redundant. Returns whether the node should itself be wrapped in invisible parentheses. @@ -2898,16 +3074,40 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: first = node.children[0] last = node.children[-1] if first.type == token.LPAR and last.type == token.RPAR: + middle = node.children[1] # make parentheses invisible first.value = "" # type: ignore last.value = "" # type: ignore - if len(node.children) > 1: - maybe_make_parens_invisible_in_atom(node.children[1], parent=parent) + maybe_make_parens_invisible_in_atom(middle, parent=parent) + + if is_atom_with_invisible_parens(middle): + # Strip the invisible parens from `middle` by replacing + # it with the child in-between the invisible parens + middle.replace(middle.children[1]) + return False return True +def is_atom_with_invisible_parens(node: LN) -> bool: + """Given a `LN`, determines whether it's an atom `node` with invisible + parens. Useful in dedupe-ing and normalizing parens. + """ + if isinstance(node, Leaf) or node.type != syms.atom: + return False + + first, last = node.children[0], node.children[-1] + return ( + isinstance(first, Leaf) + and first.type == token.LPAR + and first.value == "" + and isinstance(last, Leaf) + and last.type == token.RPAR + and last.value == "" + ) + + def is_empty_tuple(node: LN) -> bool: """Return True if `node` holds an empty tuple.""" return ( @@ -2918,18 +3118,24 @@ def is_empty_tuple(node: LN) -> bool: ) +def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: + """Returns `wrapped` if `node` is of the shape ( wrapped ). + + Parenthesis can be optional. Returns None otherwise""" + if len(node.children) != 3: + return None + lpar, wrapped, rpar = node.children + if not (lpar.type == token.LPAR and rpar.type == token.RPAR): + return None + + return wrapped + + def is_one_tuple(node: LN) -> bool: """Return True if `node` holds a tuple with one element, with or without parens.""" if node.type == syms.atom: - if len(node.children) != 3: - return False - - lpar, gexp, rpar = node.children - if not ( - lpar.type == token.LPAR - and gexp.type == syms.testlist_gexp - and rpar.type == token.RPAR - ): + gexp = unwrap_singleton_parenthesis(node) + if gexp is None or gexp.type != syms.testlist_gexp: return False return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA @@ -2941,6 +3147,12 @@ def is_one_tuple(node: LN) -> bool: ) +def is_walrus_assignment(node: LN) -> bool: + """Return True iff `node` is of the shape ( test := test )""" + inner = unwrap_singleton_parenthesis(node) + return inner is not None and inner.type == syms.namedexpr_test + + def is_yield(node: LN) -> bool: """Return True if `node` holds a `yield` or `yield from` expression.""" if node.type == syms.yield_expr: @@ -2970,7 +3182,7 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: extended iterable unpacking (PEP 3132) and additional unpacking generalizations (PEP 448). """ - if leaf.type not in STARS or not leaf.parent: + if leaf.type not in VARARGS_SPECIALS or not leaf.parent: return False p = leaf.parent @@ -3020,7 +3232,7 @@ def is_stub_body(node: LN) -> bool: ) -def max_delimiter_priority_in_atom(node: LN) -> int: +def max_delimiter_priority_in_atom(node: LN) -> Priority: """Return maximum delimiter priority inside `node`. This is specific to atoms with contents contained in a pair of parentheses. @@ -3052,7 +3264,7 @@ def ensure_visible(leaf: Leaf) -> None: """Make sure parentheses are visible. They could be invisible as part of some statements (see - :func:`normalize_invible_parens` and :func:`visit_import_from`). + :func:`normalize_invisible_parens` and :func:`visit_import_from`). """ if leaf.type == token.LPAR: leaf.value = "(" @@ -3085,8 +3297,9 @@ def get_features_used(node: Node) -> Set[Feature]: Currently looking for: - f-strings; - - underscores in numeric literals; and - - trailing commas after * or ** in function signatures and calls. + - underscores in numeric literals; + - trailing commas after * or ** in function signatures and calls; + - positional only arguments in function signatures and lambdas; """ features: Set[Feature] = set() for n in node.pre_order(): @@ -3099,6 +3312,13 @@ def get_features_used(node: Node) -> Set[Feature]: if "_" in n.value: # type: ignore features.add(Feature.NUMERIC_UNDERSCORES) + elif n.type == token.SLASH: + if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}: + features.add(Feature.POS_ONLY_ARGUMENTS) + + elif n.type == token.COLONEQUAL: + features.add(Feature.ASSIGNMENT_EXPRESSIONS) + elif ( n.type in {syms.typedargslist, syms.arglist} and n.children @@ -3380,17 +3600,58 @@ class Report: return ", ".join(report) + "." +def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: + filename = "" + if sys.version_info >= (3, 8): + # TODO: support Python 4+ ;) + for minor_version in range(sys.version_info[1], 4, -1): + try: + return ast.parse(src, filename, feature_version=(3, minor_version)) + except SyntaxError: + continue + else: + for feature_version in (7, 6): + try: + return ast3.parse(src, filename, feature_version=feature_version) + except SyntaxError: + continue + + return ast27.parse(src) + + +def _fixup_ast_constants( + node: Union[ast.AST, ast3.AST, ast27.AST] +) -> Union[ast.AST, ast3.AST, ast27.AST]: + """Map ast nodes deprecated in 3.8 to Constant.""" + # casts are required until this is released: + # https://github.com/python/typeshed/pull/3142 + if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)): + return cast(ast.AST, ast.Constant(value=node.s)) + elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)): + return cast(ast.AST, ast.Constant(value=node.n)) + elif isinstance(node, (ast.NameConstant, ast3.NameConstant)): + return cast(ast.AST, ast.Constant(value=node.value)) + return node + + def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" - import ast - import traceback - - def _v(node: ast.AST, depth: int = 0) -> Iterator[str]: + def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: """Simple visitor generating strings to compare ASTs by content.""" + + node = _fixup_ast_constants(node) + yield f"{' ' * depth}{node.__class__.__name__}(" for field in sorted(node._fields): + # TypeIgnore has only one field 'lineno' which breaks this comparison + type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) + if sys.version_info >= (3, 8): + type_ignore_classes += (ast.TypeIgnore,) + if isinstance(node, type_ignore_classes): + break + try: value = getattr(node, field) except AttributeError: @@ -3404,15 +3665,15 @@ def assert_equivalent(src: str, dst: str) -> None: # parentheses and they change the AST. if ( field == "targets" - and isinstance(node, ast.Delete) - and isinstance(item, ast.Tuple) + and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) + and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) ): for item in item.elts: yield from _v(item, depth + 2) - elif isinstance(item, ast.AST): + elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): yield from _v(item, depth + 2) - elif isinstance(value, ast.AST): + elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): yield from _v(value, depth + 2) else: @@ -3421,22 +3682,20 @@ def assert_equivalent(src: str, dst: str) -> None: yield f"{' ' * depth}) # /{node.__class__.__name__}" try: - src_ast = ast.parse(src) + src_ast = parse_ast(src) except Exception as exc: - major, minor = sys.version_info[:2] raise AssertionError( - f"cannot use --safe with this file; failed to parse source file " - f"with Python {major}.{minor}'s builtin AST. Re-run with --fast " - f"or stop using deprecated Python 2 syntax. AST error message: {exc}" + f"cannot use --safe with this file; failed to parse source file. " + f"AST error message: {exc}" ) try: - dst_ast = ast.parse(dst) + dst_ast = parse_ast(dst) except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( f"INTERNAL ERROR: Black produced invalid code: {exc}. " - f"Please report a bug on https://github.com/python/black/issues. " + f"Please report a bug on https://github.com/psf/black/issues. " f"This invalid output might be helpful: {log}" ) from None @@ -3447,7 +3706,7 @@ def assert_equivalent(src: str, dst: str) -> None: raise AssertionError( f"INTERNAL ERROR: Black produced code that is not equivalent to " f"the source. " - f"Please report a bug on https://github.com/python/black/issues. " + f"Please report a bug on https://github.com/psf/black/issues. " f"This diff might be helpful: {log}" ) from None @@ -3463,15 +3722,13 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: raise AssertionError( f"INTERNAL ERROR: Black produced different code on the second pass " f"of the formatter. " - f"Please report a bug on https://github.com/python/black/issues. " + f"Please report a bug on https://github.com/psf/black/issues. " f"This diff might be helpful: {log}" ) from None def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" - import tempfile - with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: @@ -3482,6 +3739,13 @@ def dump_to_file(*output: str) -> str: return f.name +@contextmanager +def nullcontext() -> Iterator[None]: + """Return context manager that does nothing. + Similar to `nullcontext` from python 3.7""" + yield + + def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib @@ -3500,7 +3764,7 @@ def cancel(tasks: Iterable[asyncio.Task]) -> None: task.cancel() -def shutdown(loop: BaseEventLoop) -> None: +def shutdown(loop: asyncio.AbstractEventLoop) -> None: """Cancel all pending tasks on `loop`, wait for them, and close the loop.""" try: if sys.version_info[:2] >= (3, 7): @@ -3542,7 +3806,8 @@ def re_compile_maybe_verbose(regex: str) -> Pattern[str]: """ if "\n" in regex: regex = "(?x)" + regex - return re.compile(regex) + compiled: Pattern[str] = re.compile(regex) + return compiled def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: @@ -3569,7 +3834,6 @@ def enumerate_with_length( if "\n" in leaf.value: return # Multiline strings, we can't continue. - comment: Optional[Leaf] for comment in line.comments_after(leaf): length += len(comment.value)