]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Implementing mypyc support pt. 2 (#2431)
authorRichard Si <63936253+ichard26@users.noreply.github.com>
Tue, 16 Nov 2021 04:24:16 +0000 (23:24 -0500)
committerGitHub <noreply@github.com>
Tue, 16 Nov 2021 04:24:16 +0000 (20:24 -0800)
22 files changed:
mypy.ini
pyproject.toml
setup.py
src/black/__init__.py
src/black/brackets.py
src/black/comments.py
src/black/files.py
src/black/handle_ipynb_magics.py
src/black/linegen.py
src/black/mode.py
src/black/nodes.py
src/black/output.py
src/black/parsing.py
src/black/strings.py
src/black/trans.py
src/black_primer/cli.py
src/blib2to3/README
src/blib2to3/pgen2/driver.py
src/blib2to3/pgen2/parse.py
src/blib2to3/pgen2/tokenize.py
src/blib2to3/pytree.py
tests/test_black.py

index 62c1c7fefaa4eb098cf2caad0547254dea573a06..cfceaa3ee867b0ca9ca218c2147a2fbd911a4e91 100644 (file)
--- a/mypy.ini
+++ b/mypy.ini
@@ -3,7 +3,6 @@
 # free to run mypy on Windows, Linux, or macOS and get consistent
 # results.
 python_version=3.6
-platform=linux
 
 mypy_path=src
 
@@ -24,6 +23,10 @@ warn_redundant_casts=True
 warn_unused_ignores=True
 disallow_any_generics=True
 
+# Unreachable blocks have been an issue when compiling mypyc, let's try
+# to avoid 'em in the first place.
+warn_unreachable=True
+
 # The following are off by default.  Flip them on if you feel
 # adventurous.
 disallow_untyped_defs=True
@@ -32,6 +35,11 @@ check_untyped_defs=True
 # No incremental mode
 cache_dir=/dev/null
 
+[mypy-black]
+# The following is because of `patch_click()`. Remove when
+# we drop Python 3.6 support.
+warn_unused_ignores=False
+
 [mypy-black_primer.*]
 # Until we're not supporting 3.6 primer needs this
 disallow_any_generics=False
index 73e19608108d5541a4d1e49081b1ef2e9207b829..aebbc0da29c84b32a77676309f9324bbf13252b3 100644 (file)
@@ -33,3 +33,6 @@ optional-tests = [
   "no_blackd: run when `d` extra NOT installed",
   "no_jupyter: run when `jupyter` extra NOT installed",
 ]
+markers = [
+  "incompatible_with_mypyc: run when testing mypyc compiled black"
+]
index 1914ba745e779a02b5622bd1882a0498aee468c7..7022b24345c66a7fb162faa3421d8ee87363b522 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -5,6 +5,7 @@ import os
 
 assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+"
 from pathlib import Path  # noqa E402
+from typing import List  # noqa: E402
 
 CURRENT_DIR = Path(__file__).parent
 sys.path.insert(0, str(CURRENT_DIR))  # for setuptools.build_meta
@@ -18,6 +19,17 @@ def get_long_description() -> str:
     )
 
 
+def find_python_files(base: Path) -> List[Path]:
+    files = []
+    for entry in base.iterdir():
+        if entry.is_file() and entry.suffix == ".py":
+            files.append(entry)
+        elif entry.is_dir():
+            files.extend(find_python_files(entry))
+
+    return files
+
+
 USE_MYPYC = False
 # To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH
 if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc":
@@ -27,21 +39,34 @@ if os.getenv("BLACK_USE_MYPYC", None) == "1":
     USE_MYPYC = True
 
 if USE_MYPYC:
+    from mypyc.build import mypycify
+
+    src = CURRENT_DIR / "src"
+    # TIP: filepaths are normalized to use forward slashes and are relative to ./src/
+    # before being checked against.
+    blocklist = [
+        # Not performance sensitive, so save bytes + compilation time:
+        "blib2to3/__init__.py",
+        "blib2to3/pgen2/__init__.py",
+        "black/output.py",
+        "black/concurrency.py",
+        "black/files.py",
+        "black/report.py",
+        # Breaks the test suite when compiled (and is also useless):
+        "black/debug.py",
+        # Compiled modules can't be run directly and that's a problem here:
+        "black/__main__.py",
+    ]
+    discovered = []
+    # black-primer and blackd have no good reason to be compiled.
+    discovered.extend(find_python_files(src / "black"))
+    discovered.extend(find_python_files(src / "blib2to3"))
     mypyc_targets = [
-        "src/black/__init__.py",
-        "src/blib2to3/pytree.py",
-        "src/blib2to3/pygram.py",
-        "src/blib2to3/pgen2/parse.py",
-        "src/blib2to3/pgen2/grammar.py",
-        "src/blib2to3/pgen2/token.py",
-        "src/blib2to3/pgen2/driver.py",
-        "src/blib2to3/pgen2/pgen.py",
+        str(p) for p in discovered if p.relative_to(src).as_posix() not in blocklist
     ]
 
-    from mypyc.build import mypycify
-
     opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
-    ext_modules = mypycify(mypyc_targets, opt_level=opt_level)
+    ext_modules = mypycify(mypyc_targets, opt_level=opt_level, verbose=True)
 else:
     ext_modules = []
 
index ad4ee1a0d1a6dff08825d6580f9d9cf94793c0ff..a5ddec9122105cae9393d126c92f3c3c5bfb3ec3 100644 (file)
@@ -30,8 +30,9 @@ from typing import (
     Union,
 )
 
-from dataclasses import replace
 import click
+from dataclasses import replace
+from mypy_extensions import mypyc_attr
 
 from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
 from black.const import STDIN_PLACEHOLDER
@@ -66,6 +67,8 @@ from blib2to3.pgen2 import token
 
 from _black_version import version as __version__
 
+COMPILED = Path(__file__).suffix in (".pyd", ".so")
+
 # types
 FileContent = str
 Encoding = str
@@ -177,7 +180,12 @@ def validate_regex(
         raise click.BadParameter("Not a valid regular expression") from None
 
 
-@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.command(
+    context_settings=dict(help_option_names=["-h", "--help"]),
+    # While Click does set this field automatically using the docstring, mypyc
+    # (annoyingly) strips 'em so we need to set it here too.
+    help="The uncompromising code formatter.",
+)
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
     "-l",
@@ -346,7 +354,10 @@ def validate_regex(
         " due to exclusion patterns."
     ),
 )
-@click.version_option(version=__version__)
+@click.version_option(
+    version=__version__,
+    message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
+)
 @click.argument(
     "src",
     nargs=-1,
@@ -387,7 +398,7 @@ def main(
     experimental_string_processing: bool,
     quiet: bool,
     verbose: bool,
-    required_version: str,
+    required_version: Optional[str],
     include: Pattern[str],
     exclude: Optional[Pattern[str]],
     extend_exclude: Optional[Pattern[str]],
@@ -655,6 +666,9 @@ def reformat_one(
         report.failed(src, str(exc))
 
 
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
 def reformat_many(
     sources: Set[Path],
     fast: bool,
@@ -669,6 +683,7 @@ def reformat_many(
     worker_count = workers if workers is not None else DEFAULT_WORKERS
     if sys.platform == "win32":
         # Work around https://bugs.python.org/issue26903
+        assert worker_count is not None
         worker_count = min(worker_count, 60)
     try:
         executor = ProcessPoolExecutor(max_workers=worker_count)
index bb865a0d5b709173ac5147549d41a27563ba1f40..c5ed4bf5b9f242752ad2ca7bf8902f853b64c4cf 100644 (file)
@@ -49,7 +49,7 @@ MATH_PRIORITIES: Final = {
 DOT_PRIORITY: Final = 1
 
 
-class BracketMatchError(KeyError):
+class BracketMatchError(Exception):
     """Raised when an opening bracket is unable to be matched to a closing bracket."""
 
 
index c7513c21ef5693cb5e5800eecbbc94a3fab88e7c..a8152d687a3f7d48dd48f709582737cd1d798c1d 100644 (file)
@@ -1,8 +1,14 @@
+import sys
 from dataclasses import dataclass
 from functools import lru_cache
 import regex as re
 from typing import Iterator, List, Optional, Union
 
+if sys.version_info >= (3, 8):
+    from typing import Final
+else:
+    from typing_extensions import Final
+
 from blib2to3.pytree import Node, Leaf
 from blib2to3.pgen2 import token
 
@@ -12,11 +18,10 @@ from black.nodes import STANDALONE_COMMENT, WHITESPACE
 # types
 LN = Union[Leaf, Node]
 
-
-FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
-FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
-FMT_PASS = {*FMT_OFF, *FMT_SKIP}
-FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
+FMT_OFF: Final = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_SKIP: Final = {"# fmt: skip", "# fmt:skip"}
+FMT_PASS: Final = {*FMT_OFF, *FMT_SKIP}
+FMT_ON: Final = {"# fmt: on", "# fmt:on", "# yapf: enable"}
 
 
 @dataclass
index 4d7b47aaa9fcbb4b7abe96ded72c9811d0155402..560aa05080da9136beca90109247088e0d8dd67c 100644 (file)
@@ -17,6 +17,7 @@ from typing import (
     TYPE_CHECKING,
 )
 
+from mypy_extensions import mypyc_attr
 from pathspec import PathSpec
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 import tomli
@@ -88,13 +89,14 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]:
         return None
 
 
+@mypyc_attr(patchable=True)
 def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
     """Parse a pyproject toml file, pulling out relevant parts for Black
 
     If parsing fails, will raise a tomli.TOMLDecodeError
     """
     with open(path_config, encoding="utf8") as f:
-        pyproject_toml = tomli.load(f)  # type: ignore  # due to deprecated API usage
+        pyproject_toml = tomli.loads(f.read())
     config = pyproject_toml.get("tool", {}).get("black", {})
     return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
 
index f10eaed4f3ef950e6671c59e169ee967e78ca86e..2fe6739209dd626597a43011b8bc576ae203cc9a 100644 (file)
@@ -333,7 +333,7 @@ class CellMagic:
         return f"%%{self.name}"
 
 
-@dataclasses.dataclass
+# ast.NodeVisitor + dataclass = breakage under mypyc.
 class CellMagicFinder(ast.NodeVisitor):
     """Find cell magics.
 
@@ -352,7 +352,8 @@ class CellMagicFinder(ast.NodeVisitor):
     and we look for instances of the latter.
     """
 
-    cell_magic: Optional[CellMagic] = None
+    def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
+        self.cell_magic = cell_magic
 
     def visit_Expr(self, node: ast.Expr) -> None:
         """Find cell magic, extract header and body."""
@@ -372,7 +373,8 @@ class OffsetAndMagic:
     magic: str
 
 
-@dataclasses.dataclass
+# Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
+# as mypyc will generate broken code.
 class MagicFinder(ast.NodeVisitor):
     """Visit cell to look for get_ipython calls.
 
@@ -392,9 +394,8 @@ class MagicFinder(ast.NodeVisitor):
     types of magics).
     """
 
-    magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field(
-        default_factory=lambda: collections.defaultdict(list)
-    )
+    def __init__(self) -> None:
+        self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)
 
     def visit_Assign(self, node: ast.Assign) -> None:
         """Look for system assign magics.
index 8cf32c973bb6a4880e543ec6a51ec06ff54dc492..4cba4164fb357f73b971ff0070fc0fd624323bc2 100644 (file)
@@ -5,8 +5,6 @@ from functools import partial, wraps
 import sys
 from typing import Collection, Iterator, List, Optional, Set, Union
 
-from dataclasses import dataclass, field
-
 from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
 from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
 from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
@@ -40,7 +38,8 @@ class CannotSplit(CannotTransform):
     """A readable split that fits the allotted line length is impossible."""
 
 
-@dataclass
+# This isn't a dataclass because @dataclass + Generic breaks mypyc.
+# See also https://github.com/mypyc/mypyc/issues/827.
 class LineGenerator(Visitor[Line]):
     """Generates reformatted Line objects.  Empty lines are not emitted.
 
@@ -48,9 +47,11 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
 
-    mode: Mode
-    remove_u_prefix: bool = False
-    current_line: Line = field(init=False)
+    def __init__(self, mode: Mode, remove_u_prefix: bool = False) -> None:
+        self.mode = mode
+        self.remove_u_prefix = remove_u_prefix
+        self.current_line: Line
+        self.__post_init__()
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -339,7 +340,9 @@ def transform_line(
         transformers = [left_hand_split]
     else:
 
-        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
+        def _rhs(
+            self: object, line: Line, features: Collection[Feature]
+        ) -> Iterator[Line]:
             """Wraps calls to `right_hand_split`.
 
             The calls increasingly `omit` right-hand trailers (bracket pairs with
@@ -366,6 +369,12 @@ def transform_line(
                 line, line_length=mode.line_length, features=features
             )
 
+        # HACK: nested functions (like _rhs) compiled by mypyc don't retain their
+        # __name__ attribute which is needed in `run_transformer` further down.
+        # Unfortunately a nested class breaks mypyc too. So a class must be created
+        # via type ... https://github.com/mypyc/mypyc/issues/884
+        rhs = type("rhs", (), {"__call__": _rhs})()
+
         if mode.experimental_string_processing:
             if line.inside_brackets:
                 transformers = [
@@ -980,7 +989,7 @@ def run_transformer(
         result.extend(transform_line(transformed_line, mode=mode, features=features))
 
     if (
-        transform.__name__ != "rhs"
+        transform.__class__.__name__ != "rhs"
         or not line.bracket_tracker.invisible
         or any(bracket.value for bracket in line.bracket_tracker.invisible)
         or line.contains_multiline_strings()
index b24c9c60dedc568758abdea95d086ba4c0ad0833..3c167569498172daa48560f863059da02fa2ba9b 100644 (file)
@@ -6,6 +6,7 @@ chosen by the user.
 
 from dataclasses import dataclass, field
 from enum import Enum
+from operator import attrgetter
 from typing import Dict, Set
 
 from black.const import DEFAULT_LINE_LENGTH
@@ -134,7 +135,7 @@ class Mode:
         if self.target_versions:
             version_str = ",".join(
                 str(version.value)
-                for version in sorted(self.target_versions, key=lambda v: v.value)
+                for version in sorted(self.target_versions, key=attrgetter("value"))
             )
         else:
             version_str = "-"
index 8f2e15b2cc3013926bc0636c05995749645f8b20..36dd18905114506ec0cac5c920dea482b02899b0 100644 (file)
@@ -15,10 +15,12 @@ from typing import (
     Union,
 )
 
-if sys.version_info < (3, 8):
-    from typing_extensions import Final
-else:
+if sys.version_info >= (3, 8):
     from typing import Final
+else:
+    from typing_extensions import Final
+
+from mypy_extensions import mypyc_attr
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
@@ -30,7 +32,7 @@ from black.strings import has_triple_quotes
 
 
 pygram.initialize(CACHE_DIR)
-syms = pygram.python_symbols
+syms: Final = pygram.python_symbols
 
 
 # types
@@ -128,16 +130,21 @@ ASSIGNMENTS: Final = {
     "//=",
 }
 
-IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
-BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
-OPENING_BRACKETS = set(BRACKET.keys())
-CLOSING_BRACKETS = set(BRACKET.values())
-BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
-ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
+IMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
+BRACKET: Final = {
+    token.LPAR: token.RPAR,
+    token.LSQB: token.RSQB,
+    token.LBRACE: token.RBRACE,
+}
+OPENING_BRACKETS: Final = set(BRACKET.keys())
+CLOSING_BRACKETS: Final = set(BRACKET.values())
+BRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS
+ALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 RARROW = 55
 
 
+@mypyc_attr(allow_interpreted_subclasses=True)
 class Visitor(Generic[T]):
     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
 
@@ -178,9 +185,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
     `complex_subscript` signals whether the given leaf is part of a subscription
     which has non-trivial arguments, like arithmetic expressions or function calls.
     """
-    NO = ""
-    SPACE = " "
-    DOUBLESPACE = "  "
+    NO: Final = ""
+    SPACE: Final = " "
+    DOUBLESPACE: Final = "  "
     t = leaf.type
     p = leaf.parent
     v = leaf.value
@@ -441,8 +448,8 @@ def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> b
 
 def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
     """Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
-    stop_after = None
-    last = None
+    stop_after: Optional[Leaf] = None
+    last: Optional[Leaf] = None
     for leaf in reversed(leaves):
         if stop_after:
             if leaf is stop_after:
index fd3dbb376279497c2fc23f6eb7f03d6db96c4fc7..c85b253c159c885c2fd2e04f9691fdec279fa736 100644 (file)
@@ -11,6 +11,7 @@ import tempfile
 from click import echo, style
 
 
+@mypyc_attr(patchable=True)
 def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
     if message is not None:
         if "bold" not in styles:
@@ -19,6 +20,7 @@ def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
     echo(message, nl=nl, err=True)
 
 
+@mypyc_attr(patchable=True)
 def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
     if message is not None:
         if "fg" not in styles:
@@ -27,6 +29,7 @@ def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
     echo(message, nl=nl, err=True)
 
 
+@mypyc_attr(patchable=True)
 def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
     _out(message, nl=nl, **styles)
 
index ee6aae1e7ff9db26be2f5f4c439a25d31baf2b7d..504e20be0022d890759e2f3f081987b60131e349 100644 (file)
@@ -4,11 +4,16 @@ Parse Python code and perform AST validation.
 import ast
 import platform
 import sys
-from typing import Iterable, Iterator, List, Set, Union, Tuple
+from typing import Any, Iterable, Iterator, List, Set, Tuple, Type, Union
+
+if sys.version_info < (3, 8):
+    from typing_extensions import Final
+else:
+    from typing import Final
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf
-from blib2to3 import pygram, pytree
+from blib2to3 import pygram
 from blib2to3.pgen2 import driver
 from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pgen2.parse import ParseError
@@ -16,6 +21,9 @@ from blib2to3.pgen2.parse import ParseError
 from black.mode import TargetVersion, Feature, supports_feature
 from black.nodes import syms
 
+ast3: Any
+ast27: Any
+
 _IS_PYPY = platform.python_implementation() == "PyPy"
 
 try:
@@ -86,7 +94,7 @@ def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -
         src_txt += "\n"
 
     for grammar in get_grammars(set(target_versions)):
-        drv = driver.Driver(grammar, pytree.convert)
+        drv = driver.Driver(grammar)
         try:
             result = drv.parse_string(src_txt, True)
             break
@@ -148,6 +156,10 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
     raise SyntaxError(first_error)
 
 
+ast3_AST: Final[Type[ast3.AST]] = ast3.AST
+ast27_AST: Final[Type[ast27.AST]] = ast27.AST
+
+
 def stringify_ast(
     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
 ) -> Iterator[str]:
@@ -189,7 +201,13 @@ def stringify_ast(
                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
                     yield from stringify_ast(item, depth + 2)
 
-        elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
+        # Note that we are referencing the typed-ast ASTs via global variables and not
+        # direct module attribute accesses because that breaks mypyc. It's probably
+        # something to do with the ast3 / ast27 variables being marked as Any leading
+        # mypy to think this branch is always taken, leaving the rest of the code
+        # unanalyzed. Tighting up the types for the typed-ast AST types avoids the
+        # mypyc crash.
+        elif isinstance(value, (ast.AST, ast3_AST, ast27_AST)):
             yield from stringify_ast(value, depth + 2)
 
         else:
index d7b6c240e80215ef1c268634433074d82e4c5aeb..97debe3b5de08569e24f5ca8abcdbbfe4e258d21 100644 (file)
@@ -4,10 +4,20 @@ Simple formatting on strings. Further string formatting code is in trans.py.
 
 import regex as re
 import sys
+from functools import lru_cache
 from typing import List, Pattern
 
+if sys.version_info < (3, 8):
+    from typing_extensions import Final
+else:
+    from typing import Final
 
-STRING_PREFIX_CHARS = "furbFURB"  # All possible string prefix characters.
+
+STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
+STRING_PREFIX_RE: Final = re.compile(
+    r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", re.DOTALL
+)
+FIRST_NON_WHITESPACE_RE: Final = re.compile(r"\s*\t+\s*(\S)")
 
 
 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
@@ -37,7 +47,7 @@ def lines_with_leading_tabs_expanded(s: str) -> List[str]:
     for line in s.splitlines():
         # Find the index of the first non-whitespace character after a string of
         # whitespace that includes at least one tab
-        match = re.match(r"\s*\t+\s*(\S)", line)
+        match = FIRST_NON_WHITESPACE_RE.match(line)
         if match:
             first_non_whitespace_idx = match.start(1)
 
@@ -133,7 +143,7 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str:
 
     If remove_u_prefix is given, also removes any u prefix from the string.
     """
-    match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", s, re.DOTALL)
+    match = STRING_PREFIX_RE.match(s)
     assert match is not None, f"failed to match string {s!r}"
     orig_prefix = match.group(1)
     new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
@@ -142,6 +152,14 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str:
     return f"{new_prefix}{match.group(2)}"
 
 
+# Re(gex) does actually cache patterns internally but this still improves
+# performance on a long list literal of strings by 5-9% since lru_cache's
+# caching overhead is much lower.
+@lru_cache(maxsize=64)
+def _cached_compile(pattern: str) -> re.Pattern:
+    return re.compile(pattern)
+
+
 def normalize_string_quotes(s: str) -> str:
     """Prefer double quotes but only if it doesn't cause more escaping.
 
@@ -166,9 +184,9 @@ def normalize_string_quotes(s: str) -> str:
         return s  # There's an internal error
 
     prefix = s[:first_quote_pos]
-    unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
-    escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
-    escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
+    unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
+    escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
+    escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
     body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)]
     if "r" in prefix.casefold():
         if unescaped_new_quote.search(body):
index 023dcd3618a5c0edb6b6f5e7e91e09bb6d3add30..d918ef111a21993b517f10b3ba9c30fad05e1a8a 100644 (file)
@@ -8,6 +8,7 @@ import regex as re
 from typing import (
     Any,
     Callable,
+    ClassVar,
     Collection,
     Dict,
     Iterable,
@@ -20,6 +21,14 @@ from typing import (
     TypeVar,
     Union,
 )
+import sys
+
+if sys.version_info < (3, 8):
+    from typing_extensions import Final
+else:
+    from typing import Final
+
+from mypy_extensions import trait
 
 from black.rusty import Result, Ok, Err
 
@@ -62,7 +71,6 @@ def TErr(err_msg: str) -> Err[CannotTransform]:
     return Err(cant_transform)
 
 
-@dataclass  # type: ignore
 class StringTransformer(ABC):
     """
     An implementation of the Transformer protocol that relies on its
@@ -90,9 +98,13 @@ class StringTransformer(ABC):
         as much as possible.
     """
 
-    line_length: int
-    normalize_strings: bool
-    __name__ = "StringTransformer"
+    __name__: Final = "StringTransformer"
+
+    # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with
+    # `abc.ABC`.
+    def __init__(self, line_length: int, normalize_strings: bool) -> None:
+        self.line_length = line_length
+        self.normalize_strings = normalize_strings
 
     @abstractmethod
     def do_match(self, line: Line) -> TMatchResult:
@@ -184,6 +196,7 @@ class CustomSplit:
     break_idx: int
 
 
+@trait
 class CustomSplitMapMixin:
     """
     This mixin class is used to map merged strings to a sequence of
@@ -191,8 +204,10 @@ class CustomSplitMapMixin:
     the resultant substrings go over the configured max line length.
     """
 
-    _Key = Tuple[StringID, str]
-    _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
+    _Key: ClassVar = Tuple[StringID, str]
+    _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(
+        tuple
+    )
 
     @staticmethod
     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
@@ -243,7 +258,7 @@ class CustomSplitMapMixin:
         return key in self._CUSTOM_SPLIT_MAP
 
 
-class StringMerger(CustomSplitMapMixin, StringTransformer):
+class StringMerger(StringTransformer, CustomSplitMapMixin):
     """StringTransformer that merges strings together.
 
     Requirements:
@@ -739,7 +754,7 @@ class BaseStringSplitter(StringTransformer):
         * The target string is not a multiline (i.e. triple-quote) string.
     """
 
-    STRING_OPERATORS = [
+    STRING_OPERATORS: Final = [
         token.EQEQUAL,
         token.GREATER,
         token.GREATEREQUAL,
@@ -927,7 +942,7 @@ class BaseStringSplitter(StringTransformer):
         return max_string_length
 
 
-class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
+class StringSplitter(BaseStringSplitter, CustomSplitMapMixin):
     """
     StringTransformer that splits "atom" strings (i.e. strings which exist on
     lines by themselves).
@@ -965,9 +980,9 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         CustomSplit objects and add them to the custom split map.
     """
 
-    MIN_SUBSTR_SIZE = 6
+    MIN_SUBSTR_SIZE: Final = 6
     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
-    RE_FEXPR = r"""
+    RE_FEXPR: Final = r"""
     (?<!\{) (?:\{\{)* \{ (?!\{)
         (?:
             [^\{\}]
@@ -1426,7 +1441,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         return string_op_leaves
 
 
-class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
+class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
     """
     StringTransformer that splits non-"atom" strings (i.e. strings that do not
     exist on lines by themselves).
@@ -1811,20 +1826,20 @@ class StringParser:
         ```
     """
 
-    DEFAULT_TOKEN = -1
+    DEFAULT_TOKEN: Final = 20210605
 
     # String Parser States
-    START = 1
-    DOT = 2
-    NAME = 3
-    PERCENT = 4
-    SINGLE_FMT_ARG = 5
-    LPAR = 6
-    RPAR = 7
-    DONE = 8
+    START: Final = 1
+    DOT: Final = 2
+    NAME: Final = 3
+    PERCENT: Final = 4
+    SINGLE_FMT_ARG: Final = 5
+    LPAR: Final = 6
+    RPAR: Final = 7
+    DONE: Final = 8
 
     # Lookup Table for Next State
-    _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
+    _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = {
         # A string trailer may start with '.' OR '%'.
         (START, token.DOT): DOT,
         (START, token.PERCENT): PERCENT,
index 2395d35886a7cc8910c71b9171c3d1b7de131bb5..8524b59a6322026b33fc77064ba9db91e1e19d42 100644 (file)
@@ -104,13 +104,12 @@ async def async_main(
             no_diff,
         )
         return int(ret_val)
+
     finally:
         if not keep and work_path.exists():
             LOG.debug(f"Removing {work_path}")
             rmtree(work_path, onerror=lib.handle_PermissionError)
 
-    return -2
-
 
 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
 @click.option(
index ccad28337b697b3344bf323d2e7409784ad9e5b0..0d3c607c9c7c1723bee89337652ba51bc8f800f0 100644 (file)
@@ -19,3 +19,5 @@ Change Log:
     https://github.com/python/cpython/commit/cae60187cf7a7b26281d012e1952fafe4e2e97e9
   - "bpo-42316: Allow unparenthesized walrus operator in indexes (GH-23317)"
     https://github.com/python/cpython/commit/b0aba1fcdc3da952698d99aec2334faa79a8b68c
+- Tweaks to help mypyc compile faster code (including inlining type information,
+  "Final-ing", etc.)
index 5edd75b1333991d692a8f57ac84101b1a609e916..8fe820651da728ac423fb8823753f829b2a9abd7 100644 (file)
@@ -23,6 +23,7 @@ import pkgutil
 import sys
 from typing import (
     Any,
+    cast,
     IO,
     Iterable,
     List,
@@ -34,14 +35,15 @@ from typing import (
     Generic,
     Union,
 )
+from contextlib import contextmanager
 from dataclasses import dataclass, field
 
 # Pgen imports
 from . import grammar, parse, token, tokenize, pgen
 from logging import Logger
-from blib2to3.pytree import _Convert, NL
+from blib2to3.pytree import NL
 from blib2to3.pgen2.grammar import Grammar
-from contextlib import contextmanager
+from blib2to3.pgen2.tokenize import GoodTokenInfo
 
 Path = Union[str, "os.PathLike[str]"]
 
@@ -115,29 +117,23 @@ class TokenProxy:
 
 
 class Driver(object):
-    def __init__(
-        self,
-        grammar: Grammar,
-        convert: Optional[_Convert] = None,
-        logger: Optional[Logger] = None,
-    ) -> None:
+    def __init__(self, grammar: Grammar, logger: Optional[Logger] = None) -> None:
         self.grammar = grammar
         if logger is None:
             logger = logging.getLogger(__name__)
         self.logger = logger
-        self.convert = convert
 
-    def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
+    def parse_tokens(self, tokens: Iterable[GoodTokenInfo], debug: bool = False) -> NL:
         """Parse a series of tokens and return the syntax tree."""
         # XXX Move the prefix computation into a wrapper around tokenize.
         proxy = TokenProxy(tokens)
 
-        p = parse.Parser(self.grammar, self.convert)
+        p = parse.Parser(self.grammar)
         p.setup(proxy=proxy)
 
         lineno = 1
         column = 0
-        indent_columns = []
+        indent_columns: List[int] = []
         type = value = start = end = line_text = None
         prefix = ""
 
@@ -163,6 +159,7 @@ class Driver(object):
             if type == token.OP:
                 type = grammar.opmap[value]
             if debug:
+                assert type is not None
                 self.logger.debug(
                     "%s %r (prefix=%r)", token.tok_name[type], value, prefix
                 )
@@ -174,7 +171,7 @@ class Driver(object):
             elif type == token.DEDENT:
                 _indent_col = indent_columns.pop()
                 prefix, _prefix = self._partially_consume_prefix(prefix, _indent_col)
-            if p.addtoken(type, value, (prefix, start)):
+            if p.addtoken(cast(int, type), value, (prefix, start)):
                 if debug:
                     self.logger.debug("Stop.")
                 break
index dc405264bad454a92469cb613b6dd962c394e837..792e8e66698246f578fef7ce88d7bf69258af3e9 100644 (file)
@@ -29,7 +29,7 @@ from typing import (
     TYPE_CHECKING,
 )
 from blib2to3.pgen2.grammar import Grammar
-from blib2to3.pytree import NL, Context, RawNode, Leaf, Node
+from blib2to3.pytree import convert, NL, Context, RawNode, Leaf, Node
 
 if TYPE_CHECKING:
     from blib2to3.driver import TokenProxy
@@ -70,9 +70,7 @@ class Recorder:
         finally:
             self.parser.stack = self._start_point
 
-    def add_token(
-        self, tok_type: int, tok_val: Optional[Text], raw: bool = False
-    ) -> None:
+    def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None:
         func: Callable[..., Any]
         if raw:
             func = self.parser._addtoken
@@ -86,9 +84,7 @@ class Recorder:
                     args.insert(0, ilabel)
                 func(*args)
 
-    def determine_route(
-        self, value: Optional[Text] = None, force: bool = False
-    ) -> Optional[int]:
+    def determine_route(self, value: Text = None, force: bool = False) -> Optional[int]:
         alive_ilabels = self.ilabels
         if len(alive_ilabels) == 0:
             *_, most_successful_ilabel = self._dead_ilabels
@@ -164,6 +160,11 @@ class Parser(object):
         to be converted.  The syntax tree is converted from the bottom
         up.
 
+        **post-note: the convert argument is ignored since for Black's
+        usage, convert will always be blib2to3.pytree.convert. Allowing
+        this to be dynamic hurts mypyc's ability to use early binding.
+        These docs are left for historical and informational value.
+
         A concrete syntax tree node is a (type, value, context, nodes)
         tuple, where type is the node type (a token or symbol number),
         value is None for symbols and a string for tokens, context is
@@ -176,6 +177,7 @@ class Parser(object):
 
         """
         self.grammar = grammar
+        # See note in docstring above. TL;DR this is ignored.
         self.convert = convert or lam_sub
 
     def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
@@ -203,7 +205,7 @@ class Parser(object):
         self.used_names: Set[str] = set()
         self.proxy = proxy
 
-    def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
+    def addtoken(self, type: int, value: Text, context: Context) -> bool:
         """Add a token; return True iff this is the end of the program."""
         # Map from token to label
         ilabels = self.classify(type, value, context)
@@ -237,7 +239,7 @@ class Parser(object):
 
                 next_token_type, next_token_value, *_ = proxy.eat(counter)
                 if next_token_type == tokenize.OP:
-                    next_token_type = grammar.opmap[cast(str, next_token_value)]
+                    next_token_type = grammar.opmap[next_token_value]
 
                 recorder.add_token(next_token_type, next_token_value)
                 counter += 1
@@ -247,9 +249,7 @@ class Parser(object):
 
         return self._addtoken(ilabel, type, value, context)
 
-    def _addtoken(
-        self, ilabel: int, type: int, value: Optional[Text], context: Context
-    ) -> bool:
+    def _addtoken(self, ilabel: int, type: int, value: Text, context: Context) -> bool:
         # Loop until the token is shifted; may raise exceptions
         while True:
             dfa, state, node = self.stack[-1]
@@ -257,10 +257,18 @@ class Parser(object):
             arcs = states[state]
             # Look for a state with this label
             for i, newstate in arcs:
-                t, v = self.grammar.labels[i]
-                if ilabel == i:
+                t = self.grammar.labels[i][0]
+                if t >= 256:
+                    # See if it's a symbol and if we're in its first set
+                    itsdfa = self.grammar.dfas[t]
+                    itsstates, itsfirst = itsdfa
+                    if ilabel in itsfirst:
+                        # Push a symbol
+                        self.push(t, itsdfa, newstate, context)
+                        break  # To continue the outer while loop
+
+                elif ilabel == i:
                     # Look it up in the list of labels
-                    assert t < 256
                     # Shift a token; we're done with it
                     self.shift(type, value, newstate, context)
                     # Pop while we are in an accept-only state
@@ -274,14 +282,7 @@ class Parser(object):
                         states, first = dfa
                     # Done with this token
                     return False
-                elif t >= 256:
-                    # See if it's a symbol and if we're in its first set
-                    itsdfa = self.grammar.dfas[t]
-                    itsstates, itsfirst = itsdfa
-                    if ilabel in itsfirst:
-                        # Push a symbol
-                        self.push(t, self.grammar.dfas[t], newstate, context)
-                        break  # To continue the outer while loop
+
             else:
                 if (0, state) in arcs:
                     # An accepting state, pop it and try something else
@@ -293,14 +294,13 @@ class Parser(object):
                     # No success finding a transition
                     raise ParseError("bad input", type, value, context)
 
-    def classify(self, type: int, value: Optional[Text], context: Context) -> List[int]:
+    def classify(self, type: int, value: Text, context: Context) -> List[int]:
         """Turn a token into a label.  (Internal)
 
         Depending on whether the value is a soft-keyword or not,
         this function may return multiple labels to choose from."""
         if type == token.NAME:
             # Keep a listing of all used names
-            assert value is not None
             self.used_names.add(value)
             # Check for reserved words
             if value in self.grammar.keywords:
@@ -317,18 +317,13 @@ class Parser(object):
             raise ParseError("bad token", type, value, context)
         return [ilabel]
 
-    def shift(
-        self, type: int, value: Optional[Text], newstate: int, context: Context
-    ) -> None:
+    def shift(self, type: int, value: Text, newstate: int, context: Context) -> None:
         """Shift a token.  (Internal)"""
         dfa, state, node = self.stack[-1]
-        assert value is not None
-        assert context is not None
         rawnode: RawNode = (type, value, context, None)
-        newnode = self.convert(self.grammar, rawnode)
-        if newnode is not None:
-            assert node[-1] is not None
-            node[-1].append(newnode)
+        newnode = convert(self.grammar, rawnode)
+        assert node[-1] is not None
+        node[-1].append(newnode)
         self.stack[-1] = (dfa, newstate, node)
 
     def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None:
@@ -341,12 +336,11 @@ class Parser(object):
     def pop(self) -> None:
         """Pop a nonterminal.  (Internal)"""
         popdfa, popstate, popnode = self.stack.pop()
-        newnode = self.convert(self.grammar, popnode)
-        if newnode is not None:
-            if self.stack:
-                dfa, state, node = self.stack[-1]
-                assert node[-1] is not None
-                node[-1].append(newnode)
-            else:
-                self.rootnode = newnode
-                self.rootnode.used_names = self.used_names
+        newnode = convert(self.grammar, popnode)
+        if self.stack:
+            dfa, state, node = self.stack[-1]
+            assert node[-1] is not None
+            node[-1].append(newnode)
+        else:
+            self.rootnode = newnode
+            self.rootnode.used_names = self.used_names
index bad79b2dc2c73d4ecac36b4a573c0d8bc42f9270..283fac2d5375a8f9de49879447fb80bf622c8111 100644 (file)
@@ -27,6 +27,7 @@ are the same, except instead of generating tokens, tokeneater is a callback
 function to which the 5 fields described above are passed as 5 arguments,
 each time a new token is found."""
 
+import sys
 from typing import (
     Callable,
     Iterable,
@@ -39,6 +40,12 @@ from typing import (
     Union,
     cast,
 )
+
+if sys.version_info >= (3, 8):
+    from typing import Final
+else:
+    from typing_extensions import Final
+
 from blib2to3.pgen2.token import *
 from blib2to3.pgen2.grammar import Grammar
 
@@ -139,7 +146,7 @@ ContStr = group(
 PseudoExtras = group(r"\\\r?\n", Comment, Triple)
 PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
 
-pseudoprog = re.compile(PseudoToken, re.UNICODE)
+pseudoprog: Final = re.compile(PseudoToken, re.UNICODE)
 single3prog = re.compile(Single3)
 double3prog = re.compile(Double3)
 
@@ -149,7 +156,7 @@ _strprefixes = (
     | {"u", "U", "ur", "uR", "Ur", "UR"}
 )
 
-endprogs = {
+endprogs: Final = {
     "'": re.compile(Single),
     '"': re.compile(Double),
     "'''": single3prog,
@@ -159,12 +166,12 @@ endprogs = {
     **{prefix: None for prefix in _strprefixes},
 }
 
-triple_quoted = (
+triple_quoted: Final = (
     {"'''", '"""'}
     | {f"{prefix}'''" for prefix in _strprefixes}
     | {f'{prefix}"""' for prefix in _strprefixes}
 )
-single_quoted = (
+single_quoted: Final = (
     {"'", '"'}
     | {f"{prefix}'" for prefix in _strprefixes}
     | {f'{prefix}"' for prefix in _strprefixes}
@@ -418,7 +425,7 @@ def generate_tokens(
     logical line; continuation lines are included.
     """
     lnum = parenlev = continued = 0
-    numchars = "0123456789"
+    numchars: Final = "0123456789"
     contstr, needcont = "", 0
     contline: Optional[str] = None
     indents = [0]
@@ -427,7 +434,7 @@ def generate_tokens(
     # `await` as keywords.
     async_keywords = False if grammar is None else grammar.async_keywords
     # 'stashed' and 'async_*' are used for async/await parsing
-    stashed = None
+    stashed: Optional[GoodTokenInfo] = None
     async_def = False
     async_def_indent = 0
     async_def_nl = False
@@ -440,7 +447,7 @@ def generate_tokens(
             line = readline()
         except StopIteration:
             line = ""
-        lnum = lnum + 1
+        lnum += 1
         pos, max = 0, len(line)
 
         if contstr:  # continued string
@@ -481,14 +488,14 @@ def generate_tokens(
             column = 0
             while pos < max:  # measure leading whitespace
                 if line[pos] == " ":
-                    column = column + 1
+                    column += 1
                 elif line[pos] == "\t":
                     column = (column // tabsize + 1) * tabsize
                 elif line[pos] == "\f":
                     column = 0
                 else:
                     break
-                pos = pos + 1
+                pos += 1
             if pos == max:
                 break
 
@@ -507,7 +514,7 @@ def generate_tokens(
                     COMMENT,
                     comment_token,
                     (lnum, pos),
-                    (lnum, pos + len(comment_token)),
+                    (lnum, nl_pos),
                     line,
                 )
                 yield (NL, line[nl_pos:], (lnum, nl_pos), (lnum, len(line)), line)
@@ -652,16 +659,16 @@ def generate_tokens(
                     continued = 1
                 else:
                     if initial in "([{":
-                        parenlev = parenlev + 1
+                        parenlev += 1
                     elif initial in ")]}":
-                        parenlev = parenlev - 1
+                        parenlev -= 1
                     if stashed:
                         yield stashed
                         stashed = None
                     yield (OP, token, spos, epos, line)
             else:
                 yield (ERRORTOKEN, line[pos], (lnum, pos), (lnum, pos + 1), line)
-                pos = pos + 1
+                pos += 1
 
     if stashed:
         yield stashed
index 001652df09fb85d623ede558f75a2ea102549e4e..bd86270b8e232d1230fde477b38ff4ca2836d710 100644 (file)
@@ -14,7 +14,6 @@ There's also a pattern matching implementation here.
 
 from typing import (
     Any,
-    Callable,
     Dict,
     Iterator,
     List,
@@ -92,8 +91,6 @@ class Base(object):
             return NotImplemented
         return self._eq(other)
 
-    __hash__ = None  # type: Any  # For Py3 compatibility.
-
     @property
     def prefix(self) -> Text:
         raise NotImplementedError
@@ -437,7 +434,7 @@ class Leaf(Base):
 
         This reproduces the input source exactly.
         """
-        return self.prefix + str(self.value)
+        return self._prefix + str(self.value)
 
     def _eq(self, other) -> bool:
         """Compare two nodes for equality."""
@@ -672,8 +669,11 @@ class NodePattern(BasePattern):
             newcontent = list(content)
             for i, item in enumerate(newcontent):
                 assert isinstance(item, BasePattern), (i, item)
-                if isinstance(item, WildcardPattern):
-                    self.wildcards = True
+                # I don't even think this code is used anywhere, but it does cause
+                # unreachable errors from mypy. This function's signature does look
+                # odd though *shrug*.
+                if isinstance(item, WildcardPattern):  # type: ignore[unreachable]
+                    self.wildcards = True  # type: ignore[unreachable]
         self.type = type
         self.content = newcontent
         self.name = name
@@ -978,6 +978,3 @@ def generate_matches(
                     r.update(r0)
                     r.update(r1)
                     yield c0 + c1, r
-
-
-_Convert = Callable[[Grammar, RawNode], Any]
index 301a3a5b3637a3b2a0c43031b43269e238bfa72e..3d5d39828173657f708f372219757871baa03432 100644 (file)
@@ -122,7 +122,7 @@ def invokeBlack(
     runner = BlackRunner()
     if ignore_config:
         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
-    result = runner.invoke(black.main, args)
+    result = runner.invoke(black.main, args, catch_exceptions=False)
     assert result.stdout_bytes is not None
     assert result.stderr_bytes is not None
     msg = (
@@ -841,6 +841,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
 
+    @pytest.mark.incompatible_with_mypyc
     def test_debug_visitor(self) -> None:
         source, _ = read_data("debug_visitor.py")
         expected, _ = read_data("debug_visitor.out")
@@ -891,6 +892,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(len(n.children), 1)
         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
 
+    @pytest.mark.incompatible_with_mypyc
     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
     def test_assertFormatEqual(self) -> None:
         out_lines = []
@@ -1055,6 +1057,7 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1072,6 +1075,7 @@ class BlackTestCase(BlackBaseTestCase):
             fsts.assert_called_once()
             report.done.assert_called_with(path, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1094,6 +1098,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1118,6 +1123,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1142,6 +1148,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1296,6 +1303,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
+    @pytest.mark.incompatible_with_mypyc
     def test_find_project_root(self) -> None:
         with TemporaryDirectory() as workspace:
             root = Path(workspace)
@@ -1483,6 +1491,7 @@ class BlackTestCase(BlackBaseTestCase):
         assert output == result_diff, "The output did not match the expected value."
         assert result.exit_code == 0, "The exit code is incorrect."
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_safe(self) -> None:
         """Test that the code option throws an error when the sanity checks fail."""
         # Patch black.assert_equivalent to ensure the sanity checks fail
@@ -1507,6 +1516,7 @@ class BlackTestCase(BlackBaseTestCase):
 
             self.compare_results(result, formatted, 0)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_config(self) -> None:
         """
         Test that the code option finds the pyproject.toml in the current directory.
@@ -1527,6 +1537,7 @@ class BlackTestCase(BlackBaseTestCase):
                 call_args[0].lower() == str(pyproject_path).lower()
             ), "Incorrect config loaded."
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_parent_config(self) -> None:
         """
         Test that the code option finds the pyproject.toml in the parent directory.
@@ -1894,6 +1905,7 @@ class TestFileCollection:
             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
         )
 
+    @pytest.mark.incompatible_with_mypyc
     def test_symlink_out_of_root_directory(self) -> None:
         path = MagicMock()
         root = THIS_DIR.resolve()
@@ -2047,8 +2059,12 @@ def test_python_2_deprecation_autodetection_extended() -> None:
         }, non_python2_case
 
 
-with open(black.__file__, "r", encoding="utf-8") as _bf:
-    black_source_lines = _bf.readlines()
+try:
+    with open(black.__file__, "r", encoding="utf-8") as _bf:
+        black_source_lines = _bf.readlines()
+except UnicodeDecodeError:
+    if not black.COMPILED:
+        raise
 
 
 def tracefunc(