]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

black/parser: partial support for pattern matching (#2586)
authorBatuhan Taskaya <isidentical@gmail.com>
Sun, 14 Nov 2021 03:15:31 +0000 (06:15 +0300)
committerGitHub <noreply@github.com>
Sun, 14 Nov 2021 03:15:31 +0000 (19:15 -0800)
Partial implementation for #2242. Only works when explicitly stated -t py310.

Co-authored-by: Richard Si <63936253+ichard26@users.noreply.github.com>
14 files changed:
CHANGES.md
src/black/linegen.py
src/black/mode.py
src/black/parsing.py
src/blib2to3/Grammar.txt
src/blib2to3/pgen2/driver.py
src/blib2to3/pgen2/grammar.py
src/blib2to3/pgen2/parse.py
src/blib2to3/pgen2/pgen.py
src/blib2to3/pygram.py
tests/data/parenthesized_context_managers.py [new file with mode: 0644]
tests/data/pattern_matching_complex.py [new file with mode: 0644]
tests/data/pattern_matching_simple.py [new file with mode: 0644]
tests/test_format.py

index 4b8dc57388c4990d3e62a990dd0fa8e8e652a18d..b2e8f7439b7eba99ed711128f319e34bfbd41f8c 100644 (file)
@@ -6,6 +6,9 @@
 
 - Warn about Python 2 deprecation in more cases by improving Python 2 only syntax
   detection (#2592)
 
 - Warn about Python 2 deprecation in more cases by improving Python 2 only syntax
   detection (#2592)
+- Add partial support for the match statement. As it's experimental, it's only enabled
+  when `--target-version py310` is explicitly specified (#2586)
+- Add support for parenthesized with (#2586)
 
 ## 21.10b0
 
 
 ## 21.10b0
 
index eb53fa0ac56e0aa2520d3b27248d3e4123c7c7ef..8cf32c973bb6a4880e543ec6a51ec06ff54dc492 100644 (file)
@@ -126,7 +126,7 @@ class LineGenerator(Visitor[Line]):
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
-        `def`, `with`, `class`, `assert` and assignments.
+        `def`, `with`, `class`, `assert`, `match`, `case` and assignments.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
@@ -292,6 +292,10 @@ class LineGenerator(Visitor[Line]):
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
+        # PEP 634
+        self.visit_match_stmt = partial(v, keywords={"match"}, parens=Ø)
+        self.visit_case_block = partial(v, keywords={"case"}, parens=Ø)
+
 
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
 
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
index 01ee336366c8e113d3e5cb0d2eb397c6db0c2760..b24c9c60dedc568758abdea95d086ba4c0ad0833 100644 (file)
@@ -20,6 +20,7 @@ class TargetVersion(Enum):
     PY37 = 7
     PY38 = 8
     PY39 = 9
     PY37 = 7
     PY38 = 8
     PY39 = 9
+    PY310 = 10
 
     def is_python2(self) -> bool:
         return self is TargetVersion.PY27
 
     def is_python2(self) -> bool:
         return self is TargetVersion.PY27
@@ -39,6 +40,7 @@ class Feature(Enum):
     ASSIGNMENT_EXPRESSIONS = 8
     POS_ONLY_ARGUMENTS = 9
     RELAXED_DECORATORS = 10
     ASSIGNMENT_EXPRESSIONS = 8
     POS_ONLY_ARGUMENTS = 9
     RELAXED_DECORATORS = 10
+    PATTERN_MATCHING = 11
     FORCE_OPTIONAL_PARENTHESES = 50
 
     # temporary for Python 2 deprecation
     FORCE_OPTIONAL_PARENTHESES = 50
 
     # temporary for Python 2 deprecation
@@ -108,6 +110,9 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.RELAXED_DECORATORS,
         Feature.POS_ONLY_ARGUMENTS,
     },
         Feature.RELAXED_DECORATORS,
         Feature.POS_ONLY_ARGUMENTS,
     },
+    TargetVersion.PY310: {
+        Feature.PATTERN_MATCHING,
+    },
 }
 
 
 }
 
 
index 0b8d984cedd2a43dcf6e80ab59a9041e481588f0..fc540ad021dc240547d332a70dd91ffee256a2cc 100644 (file)
@@ -59,6 +59,9 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
 
     # Python 3-compatible code, so only try Python 3 grammar.
     grammars = []
 
     # Python 3-compatible code, so only try Python 3 grammar.
     grammars = []
+    if supports_feature(target_versions, Feature.PATTERN_MATCHING):
+        # Python 3.10+
+        grammars.append(pygram.python_grammar_soft_keywords)
     # If we have to parse both, try to parse async as a keyword first
     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
         # Python 3.7+
     # If we have to parse both, try to parse async as a keyword first
     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
         # Python 3.7+
index ac8a067378d71ec0ba4f05a2a2af5b2a68fbfad7..49680323d8b8fe83c98bee1940ee4e1d1e1f94d1 100644 (file)
@@ -105,7 +105,7 @@ global_stmt: ('global' | 'nonlocal') NAME (',' NAME)*
 exec_stmt: 'exec' expr ['in' test [',' test]]
 assert_stmt: 'assert' test [',' test]
 
 exec_stmt: 'exec' expr ['in' test [',' test]]
 assert_stmt: 'assert' test [',' test]
 
-compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt
+compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef | decorated | async_stmt | match_stmt
 async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
 if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
 while_stmt: 'while' namedexpr_test ':' suite ['else' ':' suite]
 async_stmt: ASYNC (funcdef | with_stmt | for_stmt)
 if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite]
 while_stmt: 'while' namedexpr_test ':' suite ['else' ':' suite]
@@ -115,9 +115,8 @@ try_stmt: ('try' ':' suite
            ['else' ':' suite]
            ['finally' ':' suite] |
           'finally' ':' suite))
            ['else' ':' suite]
            ['finally' ':' suite] |
           'finally' ':' suite))
-with_stmt: 'with' with_item (',' with_item)*  ':' suite
-with_item: test ['as' expr]
-with_var: 'as' expr
+with_stmt: 'with' asexpr_test (',' asexpr_test)*  ':' suite
+
 # NB compile.c makes sure that the default except clause is last
 except_clause: 'except' [test [(',' | 'as') test]]
 suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT
 # NB compile.c makes sure that the default except clause is last
 except_clause: 'except' [test [(',' | 'as') test]]
 suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT
@@ -131,7 +130,15 @@ testlist_safe: old_test [(',' old_test)+ [',']]
 old_test: or_test | old_lambdef
 old_lambdef: 'lambda' [varargslist] ':' old_test
 
 old_test: or_test | old_lambdef
 old_lambdef: 'lambda' [varargslist] ':' old_test
 
-namedexpr_test: test [':=' test]
+namedexpr_test: asexpr_test [':=' asexpr_test]
+
+# This is actually not a real rule, though since the parser is very
+# limited in terms of the strategy about match/case rules, we are inserting
+# a virtual case (<expr> as <expr>) as a valid expression. Unless a better
+# approach is thought, the only side effect of this seem to be just allowing
+# more stuff to be parser (which would fail on the ast).
+asexpr_test: test ['as' test]
+
 test: or_test ['if' or_test 'else' test] | lambdef
 or_test: and_test ('or' and_test)*
 and_test: not_test ('and' not_test)*
 test: or_test ['if' or_test 'else' test] | lambdef
 or_test: and_test ('or' and_test)*
 and_test: not_test ('and' not_test)*
@@ -213,3 +220,27 @@ encoding_decl: NAME
 
 yield_expr: 'yield' [yield_arg]
 yield_arg: 'from' test | testlist_star_expr
 
 yield_expr: 'yield' [yield_arg]
 yield_arg: 'from' test | testlist_star_expr
+
+
+# 3.10 match statement definition
+
+# PS: normally the grammar is much much more restricted, but
+# at this moment for not trying to bother much with encoding the
+# exact same DSL in a LL(1) parser, we will just accept an expression
+# and let the ast.parse() step of the safe mode to reject invalid
+# grammar.
+
+# The reason why it is more restricted is that, patterns are some
+# sort of a DSL (more advanced than our LHS on assignments, but
+# still in a very limited python subset). They are not really
+# expressions, but who cares. If we can parse them, that is enough
+# to reformat them.
+
+match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT
+subject_expr: namedexpr_test
+
+# cases
+case_block: "case" patterns [guard] ':' suite
+guard: 'if' namedexpr_test
+patterns: pattern ['as' pattern]
+pattern: (expr|star_expr) (',' (expr|star_expr))* [',']
index af1dc6b8aebe032b701df42e9a9ff11257cf3dc4..5edd75b1333991d692a8f57ac84101b1a609e916 100644 (file)
@@ -28,19 +28,92 @@ from typing import (
     List,
     Optional,
     Text,
     List,
     Optional,
     Text,
+    Iterator,
     Tuple,
     Tuple,
+    TypeVar,
+    Generic,
     Union,
 )
     Union,
 )
+from dataclasses import dataclass, field
 
 # Pgen imports
 from . import grammar, parse, token, tokenize, pgen
 from logging import Logger
 from blib2to3.pytree import _Convert, NL
 from blib2to3.pgen2.grammar import Grammar
 
 # Pgen imports
 from . import grammar, parse, token, tokenize, pgen
 from logging import Logger
 from blib2to3.pytree import _Convert, NL
 from blib2to3.pgen2.grammar import Grammar
+from contextlib import contextmanager
 
 Path = Union[str, "os.PathLike[str]"]
 
 
 
 Path = Union[str, "os.PathLike[str]"]
 
 
+@dataclass
+class ReleaseRange:
+    start: int
+    end: Optional[int] = None
+    tokens: List[Any] = field(default_factory=list)
+
+    def lock(self) -> None:
+        total_eaten = len(self.tokens)
+        self.end = self.start + total_eaten
+
+
+class TokenProxy:
+    def __init__(self, generator: Any) -> None:
+        self._tokens = generator
+        self._counter = 0
+        self._release_ranges: List[ReleaseRange] = []
+
+    @contextmanager
+    def release(self) -> Iterator["TokenProxy"]:
+        release_range = ReleaseRange(self._counter)
+        self._release_ranges.append(release_range)
+        try:
+            yield self
+        finally:
+            # Lock the last release range to the final position that
+            # has been eaten.
+            release_range.lock()
+
+    def eat(self, point: int) -> Any:
+        eaten_tokens = self._release_ranges[-1].tokens
+        if point < len(eaten_tokens):
+            return eaten_tokens[point]
+        else:
+            while point >= len(eaten_tokens):
+                token = next(self._tokens)
+                eaten_tokens.append(token)
+            return token
+
+    def __iter__(self) -> "TokenProxy":
+        return self
+
+    def __next__(self) -> Any:
+        # If the current position is already compromised (looked up)
+        # return the eaten token, if not just go further on the given
+        # token producer.
+        for release_range in self._release_ranges:
+            assert release_range.end is not None
+
+            start, end = release_range.start, release_range.end
+            if start <= self._counter < end:
+                token = release_range.tokens[self._counter - start]
+                break
+        else:
+            token = next(self._tokens)
+        self._counter += 1
+        return token
+
+    def can_advance(self, to: int) -> bool:
+        # Try to eat, fail if it can't. The eat operation is cached
+        # so there wont be any additional cost of eating here
+        try:
+            self.eat(to)
+        except StopIteration:
+            return False
+        else:
+            return True
+
+
 class Driver(object):
     def __init__(
         self,
 class Driver(object):
     def __init__(
         self,
@@ -57,14 +130,18 @@ class Driver(object):
     def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
         """Parse a series of tokens and return the syntax tree."""
         # XXX Move the prefix computation into a wrapper around tokenize.
     def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
         """Parse a series of tokens and return the syntax tree."""
         # XXX Move the prefix computation into a wrapper around tokenize.
+        proxy = TokenProxy(tokens)
+
         p = parse.Parser(self.grammar, self.convert)
         p = parse.Parser(self.grammar, self.convert)
-        p.setup()
+        p.setup(proxy=proxy)
+
         lineno = 1
         column = 0
         indent_columns = []
         type = value = start = end = line_text = None
         prefix = ""
         lineno = 1
         column = 0
         indent_columns = []
         type = value = start = end = line_text = None
         prefix = ""
-        for quintuple in tokens:
+
+        for quintuple in proxy:
             type, value, start, end, line_text = quintuple
             if start != (lineno, column):
                 assert (lineno, column) <= start, ((lineno, column), start)
             type, value, start, end, line_text = quintuple
             if start != (lineno, column):
                 assert (lineno, column) <= start, ((lineno, column), start)
index 2882cdac89b2ac138151edd8d1791ea55db799f0..56851070933a708d7a0fadce36bdb6c6ffd6e4c5 100644 (file)
@@ -89,6 +89,7 @@ class Grammar(object):
         self.dfas: Dict[int, DFAS] = {}
         self.labels: List[Label] = [(0, "EMPTY")]
         self.keywords: Dict[str, int] = {}
         self.dfas: Dict[int, DFAS] = {}
         self.labels: List[Label] = [(0, "EMPTY")]
         self.keywords: Dict[str, int] = {}
+        self.soft_keywords: Dict[str, int] = {}
         self.tokens: Dict[int, int] = {}
         self.symbol2label: Dict[str, int] = {}
         self.start = 256
         self.tokens: Dict[int, int] = {}
         self.symbol2label: Dict[str, int] = {}
         self.start = 256
@@ -136,6 +137,7 @@ class Grammar(object):
             "number2symbol",
             "dfas",
             "keywords",
             "number2symbol",
             "dfas",
             "keywords",
+            "soft_keywords",
             "tokens",
             "symbol2label",
         ):
             "tokens",
             "symbol2label",
         ):
index 47c8f02b4f5c8ee081ed70f4ba26f351483095a0..dc405264bad454a92469cb613b6dd962c394e837 100644 (file)
@@ -9,22 +9,31 @@ See Parser/parser.c in the Python distribution for additional info on
 how this parsing engine works.
 
 """
 how this parsing engine works.
 
 """
+import copy
+from contextlib import contextmanager
 
 # Local imports
 
 # Local imports
-from . import token
+from . import grammar, token, tokenize
 from typing import (
 from typing import (
+    cast,
+    Any,
     Optional,
     Text,
     Union,
     Tuple,
     Dict,
     List,
     Optional,
     Text,
     Union,
     Tuple,
     Dict,
     List,
+    Iterator,
     Callable,
     Set,
     Callable,
     Set,
+    TYPE_CHECKING,
 )
 from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pytree import NL, Context, RawNode, Leaf, Node
 
 )
 from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pytree import NL, Context, RawNode, Leaf, Node
 
+if TYPE_CHECKING:
+    from blib2to3.driver import TokenProxy
+
 
 Results = Dict[Text, NL]
 Convert = Callable[[Grammar, RawNode], Union[Node, Leaf]]
 
 Results = Dict[Text, NL]
 Convert = Callable[[Grammar, RawNode], Union[Node, Leaf]]
@@ -37,6 +46,61 @@ def lam_sub(grammar: Grammar, node: RawNode) -> NL:
     return Node(type=node[0], children=node[3], context=node[2])
 
 
     return Node(type=node[0], children=node[3], context=node[2])
 
 
+class Recorder:
+    def __init__(self, parser: "Parser", ilabels: List[int], context: Context) -> None:
+        self.parser = parser
+        self._ilabels = ilabels
+        self.context = context  # not really matter
+
+        self._dead_ilabels: Set[int] = set()
+        self._start_point = copy.deepcopy(self.parser.stack)
+        self._points = {ilabel: copy.deepcopy(self._start_point) for ilabel in ilabels}
+
+    @property
+    def ilabels(self) -> Set[int]:
+        return self._dead_ilabels.symmetric_difference(self._ilabels)
+
+    @contextmanager
+    def switch_to(self, ilabel: int) -> Iterator[None]:
+        self.parser.stack = self._points[ilabel]
+        try:
+            yield
+        except ParseError:
+            self._dead_ilabels.add(ilabel)
+        finally:
+            self.parser.stack = self._start_point
+
+    def add_token(
+        self, tok_type: int, tok_val: Optional[Text], raw: bool = False
+    ) -> None:
+        func: Callable[..., Any]
+        if raw:
+            func = self.parser._addtoken
+        else:
+            func = self.parser.addtoken
+
+        for ilabel in self.ilabels:
+            with self.switch_to(ilabel):
+                args = [tok_type, tok_val, self.context]
+                if raw:
+                    args.insert(0, ilabel)
+                func(*args)
+
+    def determine_route(
+        self, value: Optional[Text] = None, force: bool = False
+    ) -> Optional[int]:
+        alive_ilabels = self.ilabels
+        if len(alive_ilabels) == 0:
+            *_, most_successful_ilabel = self._dead_ilabels
+            raise ParseError("bad input", most_successful_ilabel, value, self.context)
+
+        ilabel, *rest = alive_ilabels
+        if force or not rest:
+            return ilabel
+        else:
+            return None
+
+
 class ParseError(Exception):
     """Exception to signal the parser is stuck."""
 
 class ParseError(Exception):
     """Exception to signal the parser is stuck."""
 
@@ -114,7 +178,7 @@ class Parser(object):
         self.grammar = grammar
         self.convert = convert or lam_sub
 
         self.grammar = grammar
         self.convert = convert or lam_sub
 
-    def setup(self, start: Optional[int] = None) -> None:
+    def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
         """Prepare for parsing.
 
         This *must* be called before starting to parse.
         """Prepare for parsing.
 
         This *must* be called before starting to parse.
@@ -137,11 +201,55 @@ class Parser(object):
         self.stack: List[Tuple[DFAS, int, RawNode]] = [stackentry]
         self.rootnode: Optional[NL] = None
         self.used_names: Set[str] = set()
         self.stack: List[Tuple[DFAS, int, RawNode]] = [stackentry]
         self.rootnode: Optional[NL] = None
         self.used_names: Set[str] = set()
+        self.proxy = proxy
 
     def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
         """Add a token; return True iff this is the end of the program."""
         # Map from token to label
 
     def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
         """Add a token; return True iff this is the end of the program."""
         # Map from token to label
-        ilabel = self.classify(type, value, context)
+        ilabels = self.classify(type, value, context)
+        assert len(ilabels) >= 1
+
+        # If we have only one state to advance, we'll directly
+        # take it as is.
+        if len(ilabels) == 1:
+            [ilabel] = ilabels
+            return self._addtoken(ilabel, type, value, context)
+
+        # If there are multiple states which we can advance (only
+        # happen under soft-keywords), then we will try all of them
+        # in parallel and as soon as one state can reach further than
+        # the rest, we'll choose that one. This is a pretty hacky
+        # and hopefully temporary algorithm.
+        #
+        # For a more detailed explanation, check out this post:
+        # https://tree.science/what-the-backtracking.html
+
+        with self.proxy.release() as proxy:
+            counter, force = 0, False
+            recorder = Recorder(self, ilabels, context)
+            recorder.add_token(type, value, raw=True)
+
+            next_token_value = value
+            while recorder.determine_route(next_token_value) is None:
+                if not proxy.can_advance(counter):
+                    force = True
+                    break
+
+                next_token_type, next_token_value, *_ = proxy.eat(counter)
+                if next_token_type == tokenize.OP:
+                    next_token_type = grammar.opmap[cast(str, next_token_value)]
+
+                recorder.add_token(next_token_type, next_token_value)
+                counter += 1
+
+            ilabel = cast(int, recorder.determine_route(next_token_value, force=force))
+            assert ilabel is not None
+
+        return self._addtoken(ilabel, type, value, context)
+
+    def _addtoken(
+        self, ilabel: int, type: int, value: Optional[Text], context: Context
+    ) -> bool:
         # Loop until the token is shifted; may raise exceptions
         while True:
             dfa, state, node = self.stack[-1]
         # Loop until the token is shifted; may raise exceptions
         while True:
             dfa, state, node = self.stack[-1]
@@ -185,20 +293,29 @@ class Parser(object):
                     # No success finding a transition
                     raise ParseError("bad input", type, value, context)
 
                     # No success finding a transition
                     raise ParseError("bad input", type, value, context)
 
-    def classify(self, type: int, value: Optional[Text], context: Context) -> int:
-        """Turn a token into a label.  (Internal)"""
+    def classify(self, type: int, value: Optional[Text], context: Context) -> List[int]:
+        """Turn a token into a label.  (Internal)
+
+        Depending on whether the value is a soft-keyword or not,
+        this function may return multiple labels to choose from."""
         if type == token.NAME:
             # Keep a listing of all used names
             assert value is not None
             self.used_names.add(value)
             # Check for reserved words
         if type == token.NAME:
             # Keep a listing of all used names
             assert value is not None
             self.used_names.add(value)
             # Check for reserved words
-            ilabel = self.grammar.keywords.get(value)
-            if ilabel is not None:
-                return ilabel
+            if value in self.grammar.keywords:
+                return [self.grammar.keywords[value]]
+            elif value in self.grammar.soft_keywords:
+                assert type in self.grammar.tokens
+                return [
+                    self.grammar.soft_keywords[value],
+                    self.grammar.tokens[type],
+                ]
+
         ilabel = self.grammar.tokens.get(type)
         if ilabel is None:
             raise ParseError("bad token", type, value, context)
         ilabel = self.grammar.tokens.get(type)
         if ilabel is None:
             raise ParseError("bad token", type, value, context)
-        return ilabel
+        return [ilabel]
 
     def shift(
         self, type: int, value: Optional[Text], newstate: int, context: Context
 
     def shift(
         self, type: int, value: Optional[Text], newstate: int, context: Context
index 564ebbd1184c9738a10ce5752757a1262cee4573..631682a77c9b65654ddf8cd90579fa4337895a99 100644 (file)
@@ -115,12 +115,17 @@ class ParserGenerator(object):
             assert label[0] in ('"', "'"), label
             value = eval(label)
             if value[0].isalpha():
             assert label[0] in ('"', "'"), label
             value = eval(label)
             if value[0].isalpha():
+                if label[0] == '"':
+                    keywords = c.soft_keywords
+                else:
+                    keywords = c.keywords
+
                 # A keyword
                 # A keyword
-                if value in c.keywords:
-                    return c.keywords[value]
+                if value in keywords:
+                    return keywords[value]
                 else:
                     c.labels.append((token.NAME, value))
                 else:
                     c.labels.append((token.NAME, value))
-                    c.keywords[value] = ilabel
+                    keywords[value] = ilabel
                     return ilabel
             else:
                 # An operator (any non-numeric token)
                     return ilabel
             else:
                 # An operator (any non-numeric token)
index b8362b814735afffb1d7a2804dc0361f0fa88a76..aa20b8104aea914d4751c6ed40a93a18cb57f76b 100644 (file)
@@ -39,12 +39,14 @@ class _python_symbols(Symbols):
     arglist: int
     argument: int
     arith_expr: int
     arglist: int
     argument: int
     arith_expr: int
+    asexpr_test: int
     assert_stmt: int
     async_funcdef: int
     async_stmt: int
     atom: int
     augassign: int
     break_stmt: int
     assert_stmt: int
     async_funcdef: int
     async_stmt: int
     atom: int
     augassign: int
     break_stmt: int
+    case_block: int
     classdef: int
     comp_for: int
     comp_if: int
     classdef: int
     comp_for: int
     comp_if: int
@@ -74,6 +76,7 @@ class _python_symbols(Symbols):
     for_stmt: int
     funcdef: int
     global_stmt: int
     for_stmt: int
     funcdef: int
     global_stmt: int
+    guard: int
     if_stmt: int
     import_as_name: int
     import_as_names: int
     if_stmt: int
     import_as_name: int
     import_as_names: int
@@ -82,6 +85,7 @@ class _python_symbols(Symbols):
     import_stmt: int
     lambdef: int
     listmaker: int
     import_stmt: int
     lambdef: int
     listmaker: int
+    match_stmt: int
     namedexpr_test: int
     not_test: int
     old_comp_for: int
     namedexpr_test: int
     not_test: int
     old_comp_for: int
@@ -92,6 +96,8 @@ class _python_symbols(Symbols):
     or_test: int
     parameters: int
     pass_stmt: int
     or_test: int
     parameters: int
     pass_stmt: int
+    pattern: int
+    patterns: int
     power: int
     print_stmt: int
     raise_stmt: int
     power: int
     print_stmt: int
     raise_stmt: int
@@ -101,6 +107,7 @@ class _python_symbols(Symbols):
     single_input: int
     sliceop: int
     small_stmt: int
     single_input: int
     sliceop: int
     small_stmt: int
+    subject_expr: int
     star_expr: int
     stmt: int
     subscript: int
     star_expr: int
     stmt: int
     subscript: int
@@ -124,9 +131,7 @@ class _python_symbols(Symbols):
     vfplist: int
     vname: int
     while_stmt: int
     vfplist: int
     vname: int
     while_stmt: int
-    with_item: int
     with_stmt: int
     with_stmt: int
-    with_var: int
     xor_expr: int
     yield_arg: int
     yield_expr: int
     xor_expr: int
     yield_arg: int
     yield_expr: int
@@ -149,6 +154,7 @@ python_grammar_no_print_statement_no_exec_statement: Grammar
 python_grammar_no_print_statement_no_exec_statement_async_keywords: Grammar
 python_grammar_no_exec_statement: Grammar
 pattern_grammar: Grammar
 python_grammar_no_print_statement_no_exec_statement_async_keywords: Grammar
 python_grammar_no_exec_statement: Grammar
 pattern_grammar: Grammar
+python_grammar_soft_keywords: Grammar
 
 python_symbols: _python_symbols
 pattern_symbols: _pattern_symbols
 
 python_symbols: _python_symbols
 pattern_symbols: _pattern_symbols
@@ -159,6 +165,7 @@ def initialize(cache_dir: Union[str, "os.PathLike[str]", None] = None) -> None:
     global python_grammar_no_print_statement
     global python_grammar_no_print_statement_no_exec_statement
     global python_grammar_no_print_statement_no_exec_statement_async_keywords
     global python_grammar_no_print_statement
     global python_grammar_no_print_statement_no_exec_statement
     global python_grammar_no_print_statement_no_exec_statement_async_keywords
+    global python_grammar_soft_keywords
     global python_symbols
     global pattern_grammar
     global pattern_symbols
     global python_symbols
     global pattern_grammar
     global pattern_symbols
@@ -171,6 +178,8 @@ def initialize(cache_dir: Union[str, "os.PathLike[str]", None] = None) -> None:
 
     # Python 2
     python_grammar = driver.load_packaged_grammar("blib2to3", _GRAMMAR_FILE, cache_dir)
 
     # Python 2
     python_grammar = driver.load_packaged_grammar("blib2to3", _GRAMMAR_FILE, cache_dir)
+    soft_keywords = python_grammar.soft_keywords.copy()
+    python_grammar.soft_keywords.clear()
 
     python_symbols = _python_symbols(python_grammar)
 
 
     python_symbols = _python_symbols(python_grammar)
 
@@ -191,6 +200,12 @@ def initialize(cache_dir: Union[str, "os.PathLike[str]", None] = None) -> None:
         True
     )
 
         True
     )
 
+    # Python 3.10+
+    python_grammar_soft_keywords = (
+        python_grammar_no_print_statement_no_exec_statement_async_keywords.copy()
+    )
+    python_grammar_soft_keywords.soft_keywords = soft_keywords
+
     pattern_grammar = driver.load_packaged_grammar(
         "blib2to3", _PATTERN_GRAMMAR_FILE, cache_dir
     )
     pattern_grammar = driver.load_packaged_grammar(
         "blib2to3", _PATTERN_GRAMMAR_FILE, cache_dir
     )
diff --git a/tests/data/parenthesized_context_managers.py b/tests/data/parenthesized_context_managers.py
new file mode 100644 (file)
index 0000000..ccf1f94
--- /dev/null
@@ -0,0 +1,21 @@
+with (CtxManager() as example):
+    ...
+
+with (CtxManager1(), CtxManager2()):
+    ...
+
+with (CtxManager1() as example, CtxManager2()):
+    ...
+
+with (CtxManager1(), CtxManager2() as example):
+    ...
+
+with (CtxManager1() as example1, CtxManager2() as example2):
+    ...
+
+with (
+    CtxManager1() as example1,
+    CtxManager2() as example2,
+    CtxManager3() as example3,
+):
+    ...
diff --git a/tests/data/pattern_matching_complex.py b/tests/data/pattern_matching_complex.py
new file mode 100644 (file)
index 0000000..97ee194
--- /dev/null
@@ -0,0 +1,144 @@
+# Cases sampled from Lib/test/test_patma.py
+
+# case black_test_patma_098
+match x:
+    case -0j:
+        y = 0
+# case black_test_patma_142
+match x:
+    case bytes(z):
+        y = 0
+# case black_test_patma_073
+match x:
+    case 0 if 0:
+        y = 0
+    case 0 if 1:
+        y = 1
+# case black_test_patma_006
+match 3:
+    case 0 | 1 | 2 | 3:
+        x = True
+# case black_test_patma_049
+match x:
+    case [0, 1] | [1, 0]:
+        y = 0
+# case black_check_sequence_then_mapping
+match x:
+    case [*_]:
+        return "seq"
+    case {}:
+        return "map"
+# case black_test_patma_035
+match x:
+    case {0: [1, 2, {}]}:
+        y = 0
+    case {0: [1, 2, {}] | True} | {1: [[]]} | {0: [1, 2, {}]} | [] | "X" | {}:
+        y = 1
+    case []:
+        y = 2
+# case black_test_patma_107
+match x:
+    case 0.25 + 1.75j:
+        y = 0
+# case black_test_patma_097
+match x:
+    case -0j:
+        y = 0
+# case black_test_patma_007
+match 4:
+    case 0 | 1 | 2 | 3:
+        x = True
+# case black_test_patma_154
+match x:
+    case 0 if x:
+        y = 0
+# case black_test_patma_134
+match x:
+    case {1: 0}:
+        y = 0
+    case {0: 0}:
+        y = 1
+    case {**z}:
+        y = 2
+# case black_test_patma_185
+match Seq():
+    case [*_]:
+        y = 0
+# case black_test_patma_063
+match x:
+    case 1:
+        y = 0
+    case 1:
+        y = 1
+# case black_test_patma_248
+match x:
+    case {"foo": bar}:
+        y = bar
+# case black_test_patma_019
+match (0, 1, 2):
+    case [0, 1, *x, 2]:
+        y = 0
+# case black_test_patma_052
+match x:
+    case [0]:
+        y = 0
+    case [1, 0] if (x := x[:0]):
+        y = 1
+    case [1, 0]:
+        y = 2
+# case black_test_patma_191
+match w:
+    case [x, y, *_]:
+        z = 0
+# case black_test_patma_110
+match x:
+    case -0.25 - 1.75j:
+        y = 0
+# case black_test_patma_151
+match (x,):
+    case [y]:
+        z = 0
+# case black_test_patma_114
+match x:
+    case A.B.C.D:
+        y = 0
+# case black_test_patma_232
+match x:
+    case None:
+        y = 0
+# case black_test_patma_058
+match x:
+    case 0:
+        y = 0
+# case black_test_patma_233
+match x:
+    case False:
+        y = 0
+# case black_test_patma_078
+match x:
+    case []:
+        y = 0
+    case [""]:
+        y = 1
+    case "":
+        y = 2
+# case black_test_patma_156
+match x:
+    case z:
+        y = 0
+# case black_test_patma_189
+match w:
+    case [x, y, *rest]:
+        z = 0
+# case black_test_patma_042
+match x:
+    case (0 as z) | (1 as z) | (2 as z) if z == x % 2:
+        y = 0
+# case black_test_patma_034
+match x:
+    case {0: [1, 2, {}]}:
+        y = 0
+    case {0: [1, 2, {}] | False} | {1: [[]]} | {0: [1, 2, {}]} | [] | "X" | {}:
+        y = 1
+    case []:
+        y = 2
diff --git a/tests/data/pattern_matching_simple.py b/tests/data/pattern_matching_simple.py
new file mode 100644 (file)
index 0000000..5ed6241
--- /dev/null
@@ -0,0 +1,92 @@
+# Cases sampled from PEP 636 examples
+
+match command.split():
+    case [action, obj]:
+        ...  # interpret action, obj
+
+match command.split():
+    case [action]:
+        ...  # interpret single-verb action
+    case [action, obj]:
+        ...  # interpret action, obj
+
+match command.split():
+    case ["quit"]:
+        print("Goodbye!")
+        quit_game()
+    case ["look"]:
+        current_room.describe()
+    case ["get", obj]:
+        character.get(obj, current_room)
+    case ["go", direction]:
+        current_room = current_room.neighbor(direction)
+    # The rest of your commands go here
+
+match command.split():
+    case ["drop", *objects]:
+        for obj in objects:
+            character.drop(obj, current_room)
+    # The rest of your commands go here
+
+match command.split():
+    case ["quit"]:
+        pass
+    case ["go", direction]:
+        print("Going:", direction)
+    case ["drop", *objects]:
+        print("Dropping: ", *objects)
+    case _:
+        print(f"Sorry, I couldn't understand {command!r}")
+
+match command.split():
+    case ["north"] | ["go", "north"]:
+        current_room = current_room.neighbor("north")
+    case ["get", obj] | ["pick", "up", obj] | ["pick", obj, "up"]:
+        ...  # Code for picking up the given object
+
+match command.split():
+    case ["go", ("north" | "south" | "east" | "west")]:
+        current_room = current_room.neighbor(...)
+        # how do I know which direction to go?
+
+match command.split():
+    case ["go", ("north" | "south" | "east" | "west") as direction]:
+        current_room = current_room.neighbor(direction)
+
+match command.split():
+    case ["go", direction] if direction in current_room.exits:
+        current_room = current_room.neighbor(direction)
+    case ["go", _]:
+        print("Sorry, you can't go that way")
+
+match event.get():
+    case Click(position=(x, y)):
+        handle_click_at(x, y)
+    case KeyPress(key_name="Q") | Quit():
+        game.quit()
+    case KeyPress(key_name="up arrow"):
+        game.go_north()
+    case KeyPress():
+        pass  # Ignore other keystrokes
+    case other_event:
+        raise ValueError(f"Unrecognized event: {other_event}")
+
+match event.get():
+    case Click((x, y), button=Button.LEFT):  # This is a left click
+        handle_click_at(x, y)
+    case Click():
+        pass  # ignore other clicks
+
+
+def where_is(point):
+    match point:
+        case Point(x=0, y=0):
+            print("Origin")
+        case Point(x=0, y=y):
+            print(f"Y={y}")
+        case Point(x=x, y=0):
+            print(f"X={x}")
+        case Point():
+            print("Somewhere else")
+        case _:
+            print("Not a point")
index 649c1572bee44a2651beef931d474ba6cdaa0674..4359deea92ba0b290fee3ac32fdf0fb5b0eefbdd 100644 (file)
@@ -70,6 +70,11 @@ EXPERIMENTAL_STRING_PROCESSING_CASES = [
     "percent_precedence",
 ]
 
     "percent_precedence",
 ]
 
+PY310_CASES = [
+    "pattern_matching_simple",
+    "pattern_matching_complex",
+    "parenthesized_context_managers",
+]
 
 SOURCES = [
     "src/black/__init__.py",
 
 SOURCES = [
     "src/black/__init__.py",
@@ -187,6 +192,13 @@ def test_pep_570() -> None:
     assert_format(source, expected, minimum_version=(3, 8))
 
 
     assert_format(source, expected, minimum_version=(3, 8))
 
 
+@pytest.mark.parametrize("filename", PY310_CASES)
+def test_python_310(filename: str) -> None:
+    source, expected = read_data(filename)
+    mode = black.Mode(target_versions={black.TargetVersion.PY310})
+    assert_format(source, expected, mode, minimum_version=(3, 10))
+
+
 def test_docstring_no_string_normalization() -> None:
     """Like test_docstring but with string normalization off."""
     source, expected = read_data("docstring_no_string_normalization")
 def test_docstring_no_string_normalization() -> None:
     """Like test_docstring but with string normalization off."""
     source, expected = read_data("docstring_no_string_normalization")