From d8fa8df0526de9c0968e0a3568008f58eae45364 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Sun, 28 Jul 2019 16:03:23 +0100 Subject: [PATCH] Add support for walrus operator (#935) * Parse `:=` properly * never unwrap parenthesis around `:=` * When checking for AST-equivalence, use `ast` instead of `typed-ast` when running on python >=3.8 * Assume code that uses `:=` is at least 3.8 --- black.py | 96 ++++++++++++++++++++++++++++---------- blib2to3/Grammar.txt | 8 ++-- blib2to3/pgen2/grammar.py | 1 + blib2to3/pgen2/token.py | 3 +- blib2to3/pgen2/tokenize.py | 2 +- blib2to3/pygram.pyi | 1 + tests/data/pep_572.py | 40 ++++++++++++++++ tests/test_black.py | 17 +++++++ 8 files changed, 138 insertions(+), 30 deletions(-) create mode 100644 tests/data/pep_572.py diff --git a/black.py b/black.py index 180163c..9938b37 100644 --- a/black.py +++ b/black.py @@ -1,3 +1,4 @@ +import ast import asyncio from concurrent.futures import Executor, ProcessPoolExecutor from contextlib import contextmanager @@ -141,6 +142,7 @@ class Feature(Enum): # set for every version of python. ASYNC_IDENTIFIERS = 6 ASYNC_KEYWORDS = 7 + ASSIGNMENT_EXPRESSIONS = 8 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @@ -175,6 +177,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF, Feature.ASYNC_KEYWORDS, + Feature.ASSIGNMENT_EXPRESSIONS, }, } @@ -2863,6 +2866,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = True if check_lpar: + if is_walrus_assignment(child): + continue if child.type == syms.atom: if maybe_make_parens_invisible_in_atom(child, parent=node): lpar = Leaf(token.LPAR, "") @@ -3017,18 +3022,24 @@ def is_empty_tuple(node: LN) -> bool: ) +def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: + """Returns `wrapped` if `node` is of the shape ( wrapped ). + + Parenthesis can be optional. Returns None otherwise""" + if len(node.children) != 3: + return None + lpar, wrapped, rpar = node.children + if not (lpar.type == token.LPAR and rpar.type == token.RPAR): + return None + + return wrapped + + def is_one_tuple(node: LN) -> bool: """Return True if `node` holds a tuple with one element, with or without parens.""" if node.type == syms.atom: - if len(node.children) != 3: - return False - - lpar, gexp, rpar = node.children - if not ( - lpar.type == token.LPAR - and gexp.type == syms.testlist_gexp - and rpar.type == token.RPAR - ): + gexp = unwrap_singleton_parenthesis(node) + if gexp is None or gexp.type != syms.testlist_gexp: return False return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA @@ -3040,6 +3051,12 @@ def is_one_tuple(node: LN) -> bool: ) +def is_walrus_assignment(node: LN) -> bool: + """Return True iff `node` is of the shape ( test := test )""" + inner = unwrap_singleton_parenthesis(node) + return inner is not None and inner.type == syms.namedexpr_test + + def is_yield(node: LN) -> bool: """Return True if `node` holds a `yield` or `yield from` expression.""" if node.type == syms.yield_expr: @@ -3198,6 +3215,9 @@ def get_features_used(node: Node) -> Set[Feature]: if "_" in n.value: # type: ignore features.add(Feature.NUMERIC_UNDERSCORES) + elif n.type == token.COLONEQUAL: + features.add(Feature.ASSIGNMENT_EXPRESSIONS) + elif ( n.type in {syms.typedargslist, syms.arglist} and n.children @@ -3479,32 +3499,58 @@ class Report: return ", ".join(report) + "." -def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]: - for feature_version in (7, 6): - try: - return ast3.parse(src, feature_version=feature_version) - except SyntaxError: - continue +def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: + filename = "" + if sys.version_info >= (3, 8): + # TODO: support Python 4+ ;) + for minor_version in range(sys.version_info[1], 4, -1): + try: + return ast.parse(src, filename, feature_version=(3, minor_version)) + except SyntaxError: + continue + else: + for feature_version in (7, 6): + try: + return ast3.parse(src, filename, feature_version=feature_version) + except SyntaxError: + continue return ast27.parse(src) +def _fixup_ast_constants( + node: Union[ast.AST, ast3.AST, ast27.AST] +) -> Union[ast.AST, ast3.AST, ast27.AST]: + """Map ast nodes deprecated in 3.8 to Constant.""" + # casts are required until this is released: + # https://github.com/python/typeshed/pull/3142 + if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)): + return cast(ast.AST, ast.Constant(value=node.s)) + elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)): + return cast(ast.AST, ast.Constant(value=node.n)) + elif isinstance(node, (ast.NameConstant, ast3.NameConstant)): + return cast(ast.AST, ast.Constant(value=node.value)) + return node + + def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" - def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: + def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: """Simple visitor generating strings to compare ASTs by content.""" + + node = _fixup_ast_constants(node) + yield f"{' ' * depth}{node.__class__.__name__}(" for field in sorted(node._fields): # TypeIgnore has only one field 'lineno' which breaks this comparison - if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)): + type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) + if sys.version_info >= (3, 8): + type_ignore_classes += (ast.TypeIgnore,) + if isinstance(node, type_ignore_classes): break - # Ignore str kind which is case sensitive / and ignores unicode_literals - if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind": - continue - try: value = getattr(node, field) except AttributeError: @@ -3518,15 +3564,15 @@ def assert_equivalent(src: str, dst: str) -> None: # parentheses and they change the AST. if ( field == "targets" - and isinstance(node, (ast3.Delete, ast27.Delete)) - and isinstance(item, (ast3.Tuple, ast27.Tuple)) + and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) + and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) ): for item in item.elts: yield from _v(item, depth + 2) - elif isinstance(item, (ast3.AST, ast27.AST)): + elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): yield from _v(item, depth + 2) - elif isinstance(value, (ast3.AST, ast27.AST)): + elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): yield from _v(value, depth + 2) else: diff --git a/blib2to3/Grammar.txt b/blib2to3/Grammar.txt index c9cb3a7..1061ac8 100644 --- a/blib2to3/Grammar.txt +++ b/blib2to3/Grammar.txt @@ -67,7 +67,7 @@ assert_stmt: 'assert' test [',' test] compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt async_stmt: ASYNC (funcdef | with_stmt | for_stmt) -if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite] +if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite] while_stmt: 'while' test ':' suite ['else' ':' suite] for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite] try_stmt: ('try' ':' suite @@ -91,6 +91,7 @@ testlist_safe: old_test [(',' old_test)+ [',']] old_test: or_test | old_lambdef old_lambdef: 'lambda' [varargslist] ':' old_test +namedexpr_test: test [':=' test] test: or_test ['if' or_test 'else' test] | lambdef or_test: and_test ('or' and_test)* and_test: not_test ('and' not_test)* @@ -111,8 +112,8 @@ atom: ('(' [yield_expr|testlist_gexp] ')' | '{' [dictsetmaker] '}' | '`' testlist1 '`' | NAME | NUMBER | STRING+ | '.' '.' '.') -listmaker: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] ) -testlist_gexp: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] ) +listmaker: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] ) +testlist_gexp: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] ) lambdef: 'lambda' [varargslist] ':' test trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME subscriptlist: subscript (',' subscript)* [','] @@ -137,6 +138,7 @@ arglist: argument (',' argument)* [','] # multiple (test comp_for) arguments are blocked; keyword unpackings # that precede iterable unpackings are blocked; etc. argument: ( test [comp_for] | + test ':=' test | test '=' test | '**' test | '*' test ) diff --git a/blib2to3/pgen2/grammar.py b/blib2to3/pgen2/grammar.py index 32d1d8b..aa025cf 100644 --- a/blib2to3/pgen2/grammar.py +++ b/blib2to3/pgen2/grammar.py @@ -184,6 +184,7 @@ opmap_raw = """ // DOUBLESLASH //= DOUBLESLASHEQUAL -> RARROW +:= COLONEQUAL """ opmap = {} diff --git a/blib2to3/pgen2/token.py b/blib2to3/pgen2/token.py index c37b0d5..40aa89d 100644 --- a/blib2to3/pgen2/token.py +++ b/blib2to3/pgen2/token.py @@ -63,7 +63,8 @@ RARROW = 55 AWAIT = 56 ASYNC = 57 ERRORTOKEN = 58 -N_TOKENS = 59 +COLONEQUAL = 59 +N_TOKENS = 60 NT_OFFSET = 256 #--end constants-- diff --git a/blib2to3/pgen2/tokenize.py b/blib2to3/pgen2/tokenize.py index 0912f43..a5c6462 100644 --- a/blib2to3/pgen2/tokenize.py +++ b/blib2to3/pgen2/tokenize.py @@ -89,7 +89,7 @@ String = group(_litprefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*'", # recognized as two instances of =). Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=", r"//=?", r"->", - r"[+\-*/%&@|^=<>]=?", + r"[+\-*/%&@|^=<>:]=?", r"~") Bracket = '[][(){}]' diff --git a/blib2to3/pygram.pyi b/blib2to3/pygram.pyi index 1660900..11bf295 100644 --- a/blib2to3/pygram.pyi +++ b/blib2to3/pygram.pyi @@ -57,6 +57,7 @@ class python_symbols(Symbols): import_stmt: int lambdef: int listmaker: int + namedexpr_test: int not_test: int old_comp_for: int old_comp_if: int diff --git a/tests/data/pep_572.py b/tests/data/pep_572.py new file mode 100644 index 0000000..2b240be --- /dev/null +++ b/tests/data/pep_572.py @@ -0,0 +1,40 @@ +(a := 1) +(a := a) +if (match := pattern.search(data)) is None: + pass +[y := f(x), y ** 2, y ** 3] +filtered_data = [y for x in data if (y := f(x)) is None] +(y := f(x)) +y0 = (y1 := f(x)) +foo(x=(y := f(x))) + + +def foo(answer=(p := 42)): + pass + + +def foo(answer: (p := 42) = 5): + pass + + +lambda: (x := 1) +(x := lambda: 1) +(x := lambda: (y := 1)) +lambda line: (m := re.match(pattern, line)) and m.group(1) +x = (y := 0) +(z := (y := (x := 0))) +(info := (name, phone, *rest)) +(x := 1, 2) +(total := total + tax) +len(lines := f.readlines()) +foo(x := 3, cat="vector") +foo(cat=(category := "vector")) +if any(len(longline := l) >= 100 for l in lines): + print(longline) +if env_base := os.environ.get("PYTHONUSERBASE", None): + return env_base +if self._is_special and (ans := self._check_nans(context=context)): + return ans +foo(b := 2, a=1) +foo((b := 2), a=1) +foo(c=(b := 2), a=1) diff --git a/tests/test_black.py b/tests/test_black.py index 828b3e4..be6d98a 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -280,6 +280,23 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) + @patch("black.dump_to_file", dump_to_stderr) + def test_pep_572(self) -> None: + source, expected = read_data("pep_572") + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_stable(source, actual, black.FileMode()) + if sys.version_info >= (3, 8): + black.assert_equivalent(source, actual) + + def test_pep_572_version_detection(self) -> None: + source, _ = read_data("pep_572") + root = black.lib2to3_parse(source) + features = black.get_features_used(root) + self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features) + versions = black.detect_target_versions(root) + self.assertIn(black.TargetVersion.PY38, versions) + def test_expression_ff(self) -> None: source, expected = read_data("expression") tmp_file = Path(black.dump_to_file(source)) -- 2.39.5