]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add support for walrus operator (#935)
authorZsolt Dollenstein <zsol.zsol@gmail.com>
Sun, 28 Jul 2019 15:03:23 +0000 (16:03 +0100)
committerGitHub <noreply@github.com>
Sun, 28 Jul 2019 15:03:23 +0000 (16:03 +0100)
* Parse `:=` properly
* never unwrap parenthesis around `:=`
* When checking for AST-equivalence, use `ast` instead of `typed-ast` when running on python >=3.8
* Assume code that uses `:=` is at least 3.8

black.py
blib2to3/Grammar.txt
blib2to3/pgen2/grammar.py
blib2to3/pgen2/token.py
blib2to3/pgen2/tokenize.py
blib2to3/pygram.pyi
tests/data/pep_572.py [new file with mode: 0644]
tests/test_black.py

index 180163c74c138090214b3da6e32399864a3fc7be..9938b37f6dfe61b219442056c17b10ce1183ecf9 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,3 +1,4 @@
+import ast
 import asyncio
 from concurrent.futures import Executor, ProcessPoolExecutor
 from contextlib import contextmanager
 import asyncio
 from concurrent.futures import Executor, ProcessPoolExecutor
 from contextlib import contextmanager
@@ -141,6 +142,7 @@ class Feature(Enum):
     # set for every version of python.
     ASYNC_IDENTIFIERS = 6
     ASYNC_KEYWORDS = 7
     # set for every version of python.
     ASYNC_IDENTIFIERS = 6
     ASYNC_KEYWORDS = 7
+    ASSIGNMENT_EXPRESSIONS = 8
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
 
 
 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
@@ -175,6 +177,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
         Feature.ASYNC_KEYWORDS,
         Feature.TRAILING_COMMA_IN_CALL,
         Feature.TRAILING_COMMA_IN_DEF,
         Feature.ASYNC_KEYWORDS,
+        Feature.ASSIGNMENT_EXPRESSIONS,
     },
 }
 
     },
 }
 
@@ -2863,6 +2866,8 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
             check_lpar = True
 
         if check_lpar:
             check_lpar = True
 
         if check_lpar:
+            if is_walrus_assignment(child):
+                continue
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     lpar = Leaf(token.LPAR, "")
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     lpar = Leaf(token.LPAR, "")
@@ -3017,18 +3022,24 @@ def is_empty_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
+    """Returns `wrapped` if `node` is of the shape ( wrapped ).
+
+    Parenthesis can be optional. Returns None otherwise"""
+    if len(node.children) != 3:
+        return None
+    lpar, wrapped, rpar = node.children
+    if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
+        return None
+
+    return wrapped
+
+
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
-        if len(node.children) != 3:
-            return False
-
-        lpar, gexp, rpar = node.children
-        if not (
-            lpar.type == token.LPAR
-            and gexp.type == syms.testlist_gexp
-            and rpar.type == token.RPAR
-        ):
+        gexp = unwrap_singleton_parenthesis(node)
+        if gexp is None or gexp.type != syms.testlist_gexp:
             return False
 
         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
             return False
 
         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
@@ -3040,6 +3051,12 @@ def is_one_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def is_walrus_assignment(node: LN) -> bool:
+    """Return True iff `node` is of the shape ( test := test )"""
+    inner = unwrap_singleton_parenthesis(node)
+    return inner is not None and inner.type == syms.namedexpr_test
+
+
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
@@ -3198,6 +3215,9 @@ def get_features_used(node: Node) -> Set[Feature]:
             if "_" in n.value:  # type: ignore
                 features.add(Feature.NUMERIC_UNDERSCORES)
 
             if "_" in n.value:  # type: ignore
                 features.add(Feature.NUMERIC_UNDERSCORES)
 
+        elif n.type == token.COLONEQUAL:
+            features.add(Feature.ASSIGNMENT_EXPRESSIONS)
+
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
@@ -3479,32 +3499,58 @@ class Report:
         return ", ".join(report) + "."
 
 
         return ", ".join(report) + "."
 
 
-def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
-    for feature_version in (7, 6):
-        try:
-            return ast3.parse(src, feature_version=feature_version)
-        except SyntaxError:
-            continue
+def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
+    filename = "<unknown>"
+    if sys.version_info >= (3, 8):
+        # TODO: support Python 4+ ;)
+        for minor_version in range(sys.version_info[1], 4, -1):
+            try:
+                return ast.parse(src, filename, feature_version=(3, minor_version))
+            except SyntaxError:
+                continue
+    else:
+        for feature_version in (7, 6):
+            try:
+                return ast3.parse(src, filename, feature_version=feature_version)
+            except SyntaxError:
+                continue
 
     return ast27.parse(src)
 
 
 
     return ast27.parse(src)
 
 
+def _fixup_ast_constants(
+    node: Union[ast.AST, ast3.AST, ast27.AST]
+) -> Union[ast.AST, ast3.AST, ast27.AST]:
+    """Map ast nodes deprecated in 3.8 to Constant."""
+    # casts are required until this is released:
+    # https://github.com/python/typeshed/pull/3142
+    if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
+        return cast(ast.AST, ast.Constant(value=node.s))
+    elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
+        return cast(ast.AST, ast.Constant(value=node.n))
+    elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
+        return cast(ast.AST, ast.Constant(value=node.value))
+    return node
+
+
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
-    def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
+    def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
         """Simple visitor generating strings to compare ASTs by content."""
         """Simple visitor generating strings to compare ASTs by content."""
+
+        node = _fixup_ast_constants(node)
+
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
             # TypeIgnore has only one field 'lineno' which breaks this comparison
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
             # TypeIgnore has only one field 'lineno' which breaks this comparison
-            if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+            type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
+            if sys.version_info >= (3, 8):
+                type_ignore_classes += (ast.TypeIgnore,)
+            if isinstance(node, type_ignore_classes):
                 break
 
                 break
 
-            # Ignore str kind which is case sensitive / and ignores unicode_literals
-            if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
-                continue
-
             try:
                 value = getattr(node, field)
             except AttributeError:
             try:
                 value = getattr(node, field)
             except AttributeError:
@@ -3518,15 +3564,15 @@ def assert_equivalent(src: str, dst: str) -> None:
                     # parentheses and they change the AST.
                     if (
                         field == "targets"
                     # parentheses and they change the AST.
                     if (
                         field == "targets"
-                        and isinstance(node, (ast3.Delete, ast27.Delete))
-                        and isinstance(item, (ast3.Tuple, ast27.Tuple))
+                        and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
+                        and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
-                    elif isinstance(item, (ast3.AST, ast27.AST)):
+                    elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
                         yield from _v(item, depth + 2)
 
-            elif isinstance(value, (ast3.AST, ast27.AST)):
+            elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
                 yield from _v(value, depth + 2)
 
             else:
                 yield from _v(value, depth + 2)
 
             else:
index c9cb3a7e8c9bfe159684f3e23afaad8a2cb9a2d0..1061ac81e24645618f5abb23fc2ba02e3b914200 100644 (file)
@@ -67,7 +67,7 @@ assert_stmt: 'assert' test [',' test]
 
 compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
 async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
 
 compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
 async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
-if_stmt: 'if' test ':' suite ('elif' test ':' suite)* ['else' ':' suite]
+if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
 while_stmt: 'while' test ':' suite ['else' ':' suite]
 for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
 try_stmt: ('try' ':' suite
 while_stmt: 'while' test ':' suite ['else' ':' suite]
 for_stmt: 'for' exprlist 'in' testlist ':' suite ['else' ':' suite]
 try_stmt: ('try' ':' suite
@@ -91,6 +91,7 @@ testlist_safe: old_test [(',' old_test)+ [',']]
 old_test: or_test | old_lambdef
 old_lambdef: 'lambda' [varargslist] ':' old_test
 
 old_test: or_test | old_lambdef
 old_lambdef: 'lambda' [varargslist] ':' old_test
 
+namedexpr_test: test [':=' test]
 test: or_test ['if' or_test 'else' test] | lambdef
 or_test: and_test ('or' and_test)*
 and_test: not_test ('and' not_test)*
 test: or_test ['if' or_test 'else' test] | lambdef
 or_test: and_test ('or' and_test)*
 and_test: not_test ('and' not_test)*
@@ -111,8 +112,8 @@ atom: ('(' [yield_expr|testlist_gexp] ')' |
        '{' [dictsetmaker] '}' |
        '`' testlist1 '`' |
        NAME | NUMBER | STRING+ | '.' '.' '.')
        '{' [dictsetmaker] '}' |
        '`' testlist1 '`' |
        NAME | NUMBER | STRING+ | '.' '.' '.')
-listmaker: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
-testlist_gexp: (test|star_expr) ( old_comp_for | (',' (test|star_expr))* [','] )
+listmaker: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
+testlist_gexp: (namedexpr_test|star_expr) ( old_comp_for | (',' (namedexpr_test|star_expr))* [','] )
 lambdef: 'lambda' [varargslist] ':' test
 trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
 subscriptlist: subscript (',' subscript)* [',']
 lambdef: 'lambda' [varargslist] ':' test
 trailer: '(' [arglist] ')' | '[' subscriptlist ']' | '.' NAME
 subscriptlist: subscript (',' subscript)* [',']
@@ -137,6 +138,7 @@ arglist: argument (',' argument)* [',']
 # multiple (test comp_for) arguments are blocked; keyword unpackings
 # that precede iterable unpackings are blocked; etc.
 argument: ( test [comp_for] |
 # multiple (test comp_for) arguments are blocked; keyword unpackings
 # that precede iterable unpackings are blocked; etc.
 argument: ( test [comp_for] |
+            test ':=' test |
             test '=' test |
            '**' test |
             '*' test )
             test '=' test |
            '**' test |
             '*' test )
index 32d1d8b7226e3b4b8f27068d70628c12cdf42fa6..aa025cfd2543de34fe3eea6ec51be0a6138c9c9a 100644 (file)
@@ -184,6 +184,7 @@ opmap_raw = """
 // DOUBLESLASH
 //= DOUBLESLASHEQUAL
 -> RARROW
 // DOUBLESLASH
 //= DOUBLESLASHEQUAL
 -> RARROW
+:= COLONEQUAL
 """
 
 opmap = {}
 """
 
 opmap = {}
index c37b0d58481b16450746f66210f88921969d2d7d..40aa89d3a100182d5254dc928b98ea5bb1d6fa23 100644 (file)
@@ -63,7 +63,8 @@ RARROW = 55
 AWAIT = 56
 ASYNC = 57
 ERRORTOKEN = 58
 AWAIT = 56
 ASYNC = 57
 ERRORTOKEN = 58
-N_TOKENS = 59
+COLONEQUAL = 59
+N_TOKENS = 60
 NT_OFFSET = 256
 #--end constants--
 
 NT_OFFSET = 256
 #--end constants--
 
index 0912f43b867719f69e987bf1111725f81ed3693a..a5c6462e7cfaeeaf47f5b685c723d1f079d5bbc2 100644 (file)
@@ -89,7 +89,7 @@ String = group(_litprefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*'",
 # recognized as two instances of =).
 Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
                  r"//=?", r"->",
 # recognized as two instances of =).
 Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"<>", r"!=",
                  r"//=?", r"->",
-                 r"[+\-*/%&@|^=<>]=?",
+                 r"[+\-*/%&@|^=<>:]=?",
                  r"~")
 
 Bracket = '[][(){}]'
                  r"~")
 
 Bracket = '[][(){}]'
index 1660900097ac377d43c98853f40a753f5c60c7b2..11bf295a319b0f4709dd4eec749fd2be344070fa 100644 (file)
@@ -57,6 +57,7 @@ class python_symbols(Symbols):
     import_stmt: int
     lambdef: int
     listmaker: int
     import_stmt: int
     lambdef: int
     listmaker: int
+    namedexpr_test: int
     not_test: int
     old_comp_for: int
     old_comp_if: int
     not_test: int
     old_comp_for: int
     old_comp_if: int
diff --git a/tests/data/pep_572.py b/tests/data/pep_572.py
new file mode 100644 (file)
index 0000000..2b240be
--- /dev/null
@@ -0,0 +1,40 @@
+(a := 1)
+(a := a)
+if (match := pattern.search(data)) is None:
+    pass
+[y := f(x), y ** 2, y ** 3]
+filtered_data = [y for x in data if (y := f(x)) is None]
+(y := f(x))
+y0 = (y1 := f(x))
+foo(x=(y := f(x)))
+
+
+def foo(answer=(p := 42)):
+    pass
+
+
+def foo(answer: (p := 42) = 5):
+    pass
+
+
+lambda: (x := 1)
+(x := lambda: 1)
+(x := lambda: (y := 1))
+lambda line: (m := re.match(pattern, line)) and m.group(1)
+x = (y := 0)
+(z := (y := (x := 0)))
+(info := (name, phone, *rest))
+(x := 1, 2)
+(total := total + tax)
+len(lines := f.readlines())
+foo(x := 3, cat="vector")
+foo(cat=(category := "vector"))
+if any(len(longline := l) >= 100 for l in lines):
+    print(longline)
+if env_base := os.environ.get("PYTHONUSERBASE", None):
+    return env_base
+if self._is_special and (ans := self._check_nans(context=context)):
+    return ans
+foo(b := 2, a=1)
+foo((b := 2), a=1)
+foo(c=(b := 2), a=1)
index 828b3e45bf9a4b73af5ce065f51cf158b79be6a2..be6d98ad5da847e9ece74fefa895b7f3814bdfee 100644 (file)
@@ -280,6 +280,23 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, black.FileMode())
 
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, black.FileMode())
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_pep_572(self) -> None:
+        source, expected = read_data("pep_572")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_stable(source, actual, black.FileMode())
+        if sys.version_info >= (3, 8):
+            black.assert_equivalent(source, actual)
+
+    def test_pep_572_version_detection(self) -> None:
+        source, _ = read_data("pep_572")
+        root = black.lib2to3_parse(source)
+        features = black.get_features_used(root)
+        self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features)
+        versions = black.detect_target_versions(root)
+        self.assertIn(black.TargetVersion.PY38, versions)
+
     def test_expression_ff(self) -> None:
         source, expected = read_data("expression")
         tmp_file = Path(black.dump_to_file(source))
     def test_expression_ff(self) -> None:
         source, expected = read_data("expression")
         tmp_file = Path(black.dump_to_file(source))