From 3e731527e4418b0b6d9791d6e32caee9227ba69d Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Mon, 10 Jan 2022 21:22:00 +0300 Subject: [PATCH] Speed up new backtracking parser (#2728) --- CHANGES.md | 2 + src/blib2to3/pgen2/parse.py | 89 ++++++++++++++------ tests/data/pattern_matching_generic.py | 107 +++++++++++++++++++++++++ tests/test_format.py | 1 + 4 files changed, 176 insertions(+), 23 deletions(-) create mode 100644 tests/data/pattern_matching_generic.py diff --git a/CHANGES.md b/CHANGES.md index f6e8343..a1c8ccb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -24,6 +24,8 @@ at least one pre-existing blank line (#2736) - Verbose mode also now describes how a project root was discovered and which paths will be formatted. (#2526) +- Speed-up the new backtracking parser about 4X in general (enabled when + `--target-version` is set to 3.10 and higher). (#2728) ### Packaging diff --git a/src/blib2to3/pgen2/parse.py b/src/blib2to3/pgen2/parse.py index e5dad3a..8fe9667 100644 --- a/src/blib2to3/pgen2/parse.py +++ b/src/blib2to3/pgen2/parse.py @@ -46,6 +46,17 @@ def lam_sub(grammar: Grammar, node: RawNode) -> NL: return Node(type=node[0], children=node[3], context=node[2]) +# A placeholder node, used when parser is backtracking. +DUMMY_NODE = (-1, None, None, None) + + +def stack_copy( + stack: List[Tuple[DFAS, int, RawNode]] +) -> List[Tuple[DFAS, int, RawNode]]: + """Nodeless stack copy.""" + return [(copy.deepcopy(dfa), label, DUMMY_NODE) for dfa, label, _ in stack] + + class Recorder: def __init__(self, parser: "Parser", ilabels: List[int], context: Context) -> None: self.parser = parser @@ -54,7 +65,7 @@ class Recorder: self._dead_ilabels: Set[int] = set() self._start_point = self.parser.stack - self._points = {ilabel: copy.deepcopy(self._start_point) for ilabel in ilabels} + self._points = {ilabel: stack_copy(self._start_point) for ilabel in ilabels} @property def ilabels(self) -> Set[int]: @@ -62,13 +73,32 @@ class Recorder: @contextmanager def switch_to(self, ilabel: int) -> Iterator[None]: - self.parser.stack = self._points[ilabel] + with self.backtrack(): + self.parser.stack = self._points[ilabel] + try: + yield + except ParseError: + self._dead_ilabels.add(ilabel) + finally: + self.parser.stack = self._start_point + + @contextmanager + def backtrack(self) -> Iterator[None]: + """ + Use the node-level invariant ones for basic parsing operations (push/pop/shift). + These still will operate on the stack; but they won't create any new nodes, or + modify the contents of any other existing nodes. + + This saves us a ton of time when we are backtracking, since we + want to restore to the initial state as quick as possible, which + can only be done by having as little mutatations as possible. + """ + is_backtracking = self.parser.is_backtracking try: + self.parser.is_backtracking = True yield - except ParseError: - self._dead_ilabels.add(ilabel) finally: - self.parser.stack = self._start_point + self.parser.is_backtracking = is_backtracking def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None: func: Callable[..., Any] @@ -179,6 +209,7 @@ class Parser(object): self.grammar = grammar # See note in docstring above. TL;DR this is ignored. self.convert = convert or lam_sub + self.is_backtracking = False def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None: """Prepare for parsing. @@ -319,28 +350,40 @@ class Parser(object): def shift(self, type: int, value: Text, newstate: int, context: Context) -> None: """Shift a token. (Internal)""" - dfa, state, node = self.stack[-1] - rawnode: RawNode = (type, value, context, None) - newnode = convert(self.grammar, rawnode) - assert node[-1] is not None - node[-1].append(newnode) - self.stack[-1] = (dfa, newstate, node) + if self.is_backtracking: + dfa, state, _ = self.stack[-1] + self.stack[-1] = (dfa, newstate, DUMMY_NODE) + else: + dfa, state, node = self.stack[-1] + rawnode: RawNode = (type, value, context, None) + newnode = convert(self.grammar, rawnode) + assert node[-1] is not None + node[-1].append(newnode) + self.stack[-1] = (dfa, newstate, node) def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None: """Push a nonterminal. (Internal)""" - dfa, state, node = self.stack[-1] - newnode: RawNode = (type, None, context, []) - self.stack[-1] = (dfa, newstate, node) - self.stack.append((newdfa, 0, newnode)) + if self.is_backtracking: + dfa, state, _ = self.stack[-1] + self.stack[-1] = (dfa, newstate, DUMMY_NODE) + self.stack.append((newdfa, 0, DUMMY_NODE)) + else: + dfa, state, node = self.stack[-1] + newnode: RawNode = (type, None, context, []) + self.stack[-1] = (dfa, newstate, node) + self.stack.append((newdfa, 0, newnode)) def pop(self) -> None: """Pop a nonterminal. (Internal)""" - popdfa, popstate, popnode = self.stack.pop() - newnode = convert(self.grammar, popnode) - if self.stack: - dfa, state, node = self.stack[-1] - assert node[-1] is not None - node[-1].append(newnode) + if self.is_backtracking: + self.stack.pop() else: - self.rootnode = newnode - self.rootnode.used_names = self.used_names + popdfa, popstate, popnode = self.stack.pop() + newnode = convert(self.grammar, popnode) + if self.stack: + dfa, state, node = self.stack[-1] + assert node[-1] is not None + node[-1].append(newnode) + else: + self.rootnode = newnode + self.rootnode.used_names = self.used_names diff --git a/tests/data/pattern_matching_generic.py b/tests/data/pattern_matching_generic.py new file mode 100644 index 0000000..00a0e4a --- /dev/null +++ b/tests/data/pattern_matching_generic.py @@ -0,0 +1,107 @@ +re.match() +match = a +with match() as match: + match = f"{match}" + +re.match() +match = a +with match() as match: + match = f"{match}" + + +def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: + if not target_versions: + # No target_version specified, so try all grammars. + return [ + # Python 3.7+ + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords, + # Python 3.0-3.6 + pygram.python_grammar_no_print_statement_no_exec_statement, + # Python 2.7 with future print_function import + pygram.python_grammar_no_print_statement, + # Python 2.7 + pygram.python_grammar, + ] + + match match: + case case: + match match: + case case: + pass + + if all(version.is_python2() for version in target_versions): + # Python 2-only code, so try Python 2 grammars. + return [ + # Python 2.7 with future print_function import + pygram.python_grammar_no_print_statement, + # Python 2.7 + pygram.python_grammar, + ] + + re.match() + match = a + with match() as match: + match = f"{match}" + + def test_patma_139(self): + x = False + match x: + case bool(z): + y = 0 + self.assertIs(x, False) + self.assertEqual(y, 0) + self.assertIs(z, x) + + # Python 3-compatible code, so only try Python 3 grammar. + grammars = [] + if supports_feature(target_versions, Feature.PATTERN_MATCHING): + # Python 3.10+ + grammars.append(pygram.python_grammar_soft_keywords) + # If we have to parse both, try to parse async as a keyword first + if not supports_feature( + target_versions, Feature.ASYNC_IDENTIFIERS + ) and not supports_feature(target_versions, Feature.PATTERN_MATCHING): + # Python 3.7-3.9 + grammars.append( + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords + ) + if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS): + # Python 3.0-3.6 + grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement) + + def test_patma_155(self): + x = 0 + y = None + match x: + case 1e1000: + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + x = range(3) + match x: + case [y, case as x, z]: + w = 0 + + # At least one of the above branches must have been taken, because every Python + # version has exactly one of the two 'ASYNC_*' flags + return grammars + + +def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: + """Given a string with source, return the lib2to3 Node.""" + if not src_txt.endswith("\n"): + src_txt += "\n" + + grammars = get_grammars(set(target_versions)) + + +re.match() +match = a +with match() as match: + match = f"{match}" + +re.match() +match = a +with match() as match: + match = f"{match}" diff --git a/tests/test_format.py b/tests/test_format.py index 6651272..db39678 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -69,6 +69,7 @@ PY310_CASES = [ "pattern_matching_complex", "pattern_matching_extras", "pattern_matching_style", + "pattern_matching_generic", "parenthesized_context_managers", ] -- 2.39.5