From 117891878e5be4d6b771ae5de299e51b679cea27 Mon Sep 17 00:00:00 2001 From: Richard Si <63936253+ichard26@users.noreply.github.com> Date: Mon, 15 Nov 2021 23:24:16 -0500 Subject: [PATCH] Implementing mypyc support pt. 2 (#2431) --- mypy.ini | 10 +++- pyproject.toml | 3 ++ setup.py | 47 ++++++++++++++----- src/black/__init__.py | 23 +++++++-- src/black/brackets.py | 2 +- src/black/comments.py | 15 ++++-- src/black/files.py | 4 +- src/black/handle_ipynb_magics.py | 13 +++--- src/black/linegen.py | 25 ++++++---- src/black/mode.py | 3 +- src/black/nodes.py | 37 +++++++++------ src/black/output.py | 3 ++ src/black/parsing.py | 26 +++++++++-- src/black/strings.py | 30 +++++++++--- src/black/trans.py | 59 ++++++++++++++--------- src/black_primer/cli.py | 3 +- src/blib2to3/README | 2 + src/blib2to3/pgen2/driver.py | 23 ++++----- src/blib2to3/pgen2/parse.py | 80 +++++++++++++++----------------- src/blib2to3/pgen2/tokenize.py | 33 +++++++------ src/blib2to3/pytree.py | 15 +++--- tests/test_black.py | 22 +++++++-- 22 files changed, 310 insertions(+), 168 deletions(-) diff --git a/mypy.ini b/mypy.ini index 62c1c7f..cfceaa3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,7 +3,6 @@ # free to run mypy on Windows, Linux, or macOS and get consistent # results. python_version=3.6 -platform=linux mypy_path=src @@ -24,6 +23,10 @@ warn_redundant_casts=True warn_unused_ignores=True disallow_any_generics=True +# Unreachable blocks have been an issue when compiling mypyc, let's try +# to avoid 'em in the first place. +warn_unreachable=True + # The following are off by default. Flip them on if you feel # adventurous. disallow_untyped_defs=True @@ -32,6 +35,11 @@ check_untyped_defs=True # No incremental mode cache_dir=/dev/null +[mypy-black] +# The following is because of `patch_click()`. Remove when +# we drop Python 3.6 support. +warn_unused_ignores=False + [mypy-black_primer.*] # Until we're not supporting 3.6 primer needs this disallow_any_generics=False diff --git a/pyproject.toml b/pyproject.toml index 73e1960..aebbc0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,3 +33,6 @@ optional-tests = [ "no_blackd: run when `d` extra NOT installed", "no_jupyter: run when `jupyter` extra NOT installed", ] +markers = [ + "incompatible_with_mypyc: run when testing mypyc compiled black" +] diff --git a/setup.py b/setup.py index 1914ba7..7022b24 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ import os assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+" from pathlib import Path # noqa E402 +from typing import List # noqa: E402 CURRENT_DIR = Path(__file__).parent sys.path.insert(0, str(CURRENT_DIR)) # for setuptools.build_meta @@ -18,6 +19,17 @@ def get_long_description() -> str: ) +def find_python_files(base: Path) -> List[Path]: + files = [] + for entry in base.iterdir(): + if entry.is_file() and entry.suffix == ".py": + files.append(entry) + elif entry.is_dir(): + files.extend(find_python_files(entry)) + + return files + + USE_MYPYC = False # To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc": @@ -27,21 +39,34 @@ if os.getenv("BLACK_USE_MYPYC", None) == "1": USE_MYPYC = True if USE_MYPYC: + from mypyc.build import mypycify + + src = CURRENT_DIR / "src" + # TIP: filepaths are normalized to use forward slashes and are relative to ./src/ + # before being checked against. + blocklist = [ + # Not performance sensitive, so save bytes + compilation time: + "blib2to3/__init__.py", + "blib2to3/pgen2/__init__.py", + "black/output.py", + "black/concurrency.py", + "black/files.py", + "black/report.py", + # Breaks the test suite when compiled (and is also useless): + "black/debug.py", + # Compiled modules can't be run directly and that's a problem here: + "black/__main__.py", + ] + discovered = [] + # black-primer and blackd have no good reason to be compiled. + discovered.extend(find_python_files(src / "black")) + discovered.extend(find_python_files(src / "blib2to3")) mypyc_targets = [ - "src/black/__init__.py", - "src/blib2to3/pytree.py", - "src/blib2to3/pygram.py", - "src/blib2to3/pgen2/parse.py", - "src/blib2to3/pgen2/grammar.py", - "src/blib2to3/pgen2/token.py", - "src/blib2to3/pgen2/driver.py", - "src/blib2to3/pgen2/pgen.py", + str(p) for p in discovered if p.relative_to(src).as_posix() not in blocklist ] - from mypyc.build import mypycify - opt_level = os.getenv("MYPYC_OPT_LEVEL", "3") - ext_modules = mypycify(mypyc_targets, opt_level=opt_level) + ext_modules = mypycify(mypyc_targets, opt_level=opt_level, verbose=True) else: ext_modules = [] diff --git a/src/black/__init__.py b/src/black/__init__.py index ad4ee1a..a5ddec9 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -30,8 +30,9 @@ from typing import ( Union, ) -from dataclasses import replace import click +from dataclasses import replace +from mypy_extensions import mypyc_attr from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES from black.const import STDIN_PLACEHOLDER @@ -66,6 +67,8 @@ from blib2to3.pgen2 import token from _black_version import version as __version__ +COMPILED = Path(__file__).suffix in (".pyd", ".so") + # types FileContent = str Encoding = str @@ -177,7 +180,12 @@ def validate_regex( raise click.BadParameter("Not a valid regular expression") from None -@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.command( + context_settings=dict(help_option_names=["-h", "--help"]), + # While Click does set this field automatically using the docstring, mypyc + # (annoyingly) strips 'em so we need to set it here too. + help="The uncompromising code formatter.", +) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( "-l", @@ -346,7 +354,10 @@ def validate_regex( " due to exclusion patterns." ), ) -@click.version_option(version=__version__) +@click.version_option( + version=__version__, + message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})", +) @click.argument( "src", nargs=-1, @@ -387,7 +398,7 @@ def main( experimental_string_processing: bool, quiet: bool, verbose: bool, - required_version: str, + required_version: Optional[str], include: Pattern[str], exclude: Optional[Pattern[str]], extend_exclude: Optional[Pattern[str]], @@ -655,6 +666,9 @@ def reformat_one( report.failed(src, str(exc)) +# diff-shades depends on being to monkeypatch this function to operate. I know it's +# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 +@mypyc_attr(patchable=True) def reformat_many( sources: Set[Path], fast: bool, @@ -669,6 +683,7 @@ def reformat_many( worker_count = workers if workers is not None else DEFAULT_WORKERS if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 + assert worker_count is not None worker_count = min(worker_count, 60) try: executor = ProcessPoolExecutor(max_workers=worker_count) diff --git a/src/black/brackets.py b/src/black/brackets.py index bb865a0..c5ed4bf 100644 --- a/src/black/brackets.py +++ b/src/black/brackets.py @@ -49,7 +49,7 @@ MATH_PRIORITIES: Final = { DOT_PRIORITY: Final = 1 -class BracketMatchError(KeyError): +class BracketMatchError(Exception): """Raised when an opening bracket is unable to be matched to a closing bracket.""" diff --git a/src/black/comments.py b/src/black/comments.py index c7513c2..a8152d6 100644 --- a/src/black/comments.py +++ b/src/black/comments.py @@ -1,8 +1,14 @@ +import sys from dataclasses import dataclass from functools import lru_cache import regex as re from typing import Iterator, List, Optional, Union +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + from blib2to3.pytree import Node, Leaf from blib2to3.pgen2 import token @@ -12,11 +18,10 @@ from black.nodes import STANDALONE_COMMENT, WHITESPACE # types LN = Union[Leaf, Node] - -FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"} -FMT_SKIP = {"# fmt: skip", "# fmt:skip"} -FMT_PASS = {*FMT_OFF, *FMT_SKIP} -FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"} +FMT_OFF: Final = {"# fmt: off", "# fmt:off", "# yapf: disable"} +FMT_SKIP: Final = {"# fmt: skip", "# fmt:skip"} +FMT_PASS: Final = {*FMT_OFF, *FMT_SKIP} +FMT_ON: Final = {"# fmt: on", "# fmt:on", "# yapf: enable"} @dataclass diff --git a/src/black/files.py b/src/black/files.py index 4d7b47a..560aa05 100644 --- a/src/black/files.py +++ b/src/black/files.py @@ -17,6 +17,7 @@ from typing import ( TYPE_CHECKING, ) +from mypy_extensions import mypyc_attr from pathspec import PathSpec from pathspec.patterns.gitwildmatch import GitWildMatchPatternError import tomli @@ -88,13 +89,14 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]: return None +@mypyc_attr(patchable=True) def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: """Parse a pyproject toml file, pulling out relevant parts for Black If parsing fails, will raise a tomli.TOMLDecodeError """ with open(path_config, encoding="utf8") as f: - pyproject_toml = tomli.load(f) # type: ignore # due to deprecated API usage + pyproject_toml = tomli.loads(f.read()) config = pyproject_toml.get("tool", {}).get("black", {}) return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} diff --git a/src/black/handle_ipynb_magics.py b/src/black/handle_ipynb_magics.py index f10eaed..2fe6739 100644 --- a/src/black/handle_ipynb_magics.py +++ b/src/black/handle_ipynb_magics.py @@ -333,7 +333,7 @@ class CellMagic: return f"%%{self.name}" -@dataclasses.dataclass +# ast.NodeVisitor + dataclass = breakage under mypyc. class CellMagicFinder(ast.NodeVisitor): """Find cell magics. @@ -352,7 +352,8 @@ class CellMagicFinder(ast.NodeVisitor): and we look for instances of the latter. """ - cell_magic: Optional[CellMagic] = None + def __init__(self, cell_magic: Optional[CellMagic] = None) -> None: + self.cell_magic = cell_magic def visit_Expr(self, node: ast.Expr) -> None: """Find cell magic, extract header and body.""" @@ -372,7 +373,8 @@ class OffsetAndMagic: magic: str -@dataclasses.dataclass +# Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here +# as mypyc will generate broken code. class MagicFinder(ast.NodeVisitor): """Visit cell to look for get_ipython calls. @@ -392,9 +394,8 @@ class MagicFinder(ast.NodeVisitor): types of magics). """ - magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field( - default_factory=lambda: collections.defaultdict(list) - ) + def __init__(self) -> None: + self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list) def visit_Assign(self, node: ast.Assign) -> None: """Look for system assign magics. diff --git a/src/black/linegen.py b/src/black/linegen.py index 8cf32c9..4cba416 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -5,8 +5,6 @@ from functools import partial, wraps import sys from typing import Collection, Iterator, List, Optional, Set, Union -from dataclasses import dataclass, field - from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible @@ -40,7 +38,8 @@ class CannotSplit(CannotTransform): """A readable split that fits the allotted line length is impossible.""" -@dataclass +# This isn't a dataclass because @dataclass + Generic breaks mypyc. +# See also https://github.com/mypyc/mypyc/issues/827. class LineGenerator(Visitor[Line]): """Generates reformatted Line objects. Empty lines are not emitted. @@ -48,9 +47,11 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ - mode: Mode - remove_u_prefix: bool = False - current_line: Line = field(init=False) + def __init__(self, mode: Mode, remove_u_prefix: bool = False) -> None: + self.mode = mode + self.remove_u_prefix = remove_u_prefix + self.current_line: Line + self.__post_init__() def line(self, indent: int = 0) -> Iterator[Line]: """Generate a line. @@ -339,7 +340,9 @@ def transform_line( transformers = [left_hand_split] else: - def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]: + def _rhs( + self: object, line: Line, features: Collection[Feature] + ) -> Iterator[Line]: """Wraps calls to `right_hand_split`. The calls increasingly `omit` right-hand trailers (bracket pairs with @@ -366,6 +369,12 @@ def transform_line( line, line_length=mode.line_length, features=features ) + # HACK: nested functions (like _rhs) compiled by mypyc don't retain their + # __name__ attribute which is needed in `run_transformer` further down. + # Unfortunately a nested class breaks mypyc too. So a class must be created + # via type ... https://github.com/mypyc/mypyc/issues/884 + rhs = type("rhs", (), {"__call__": _rhs})() + if mode.experimental_string_processing: if line.inside_brackets: transformers = [ @@ -980,7 +989,7 @@ def run_transformer( result.extend(transform_line(transformed_line, mode=mode, features=features)) if ( - transform.__name__ != "rhs" + transform.__class__.__name__ != "rhs" or not line.bracket_tracker.invisible or any(bracket.value for bracket in line.bracket_tracker.invisible) or line.contains_multiline_strings() diff --git a/src/black/mode.py b/src/black/mode.py index b24c9c6..3c16756 100644 --- a/src/black/mode.py +++ b/src/black/mode.py @@ -6,6 +6,7 @@ chosen by the user. from dataclasses import dataclass, field from enum import Enum +from operator import attrgetter from typing import Dict, Set from black.const import DEFAULT_LINE_LENGTH @@ -134,7 +135,7 @@ class Mode: if self.target_versions: version_str = ",".join( str(version.value) - for version in sorted(self.target_versions, key=lambda v: v.value) + for version in sorted(self.target_versions, key=attrgetter("value")) ) else: version_str = "-" diff --git a/src/black/nodes.py b/src/black/nodes.py index 8f2e15b..36dd189 100644 --- a/src/black/nodes.py +++ b/src/black/nodes.py @@ -15,10 +15,12 @@ from typing import ( Union, ) -if sys.version_info < (3, 8): - from typing_extensions import Final -else: +if sys.version_info >= (3, 8): from typing import Final +else: + from typing_extensions import Final + +from mypy_extensions import mypyc_attr # lib2to3 fork from blib2to3.pytree import Node, Leaf, type_repr @@ -30,7 +32,7 @@ from black.strings import has_triple_quotes pygram.initialize(CACHE_DIR) -syms = pygram.python_symbols +syms: Final = pygram.python_symbols # types @@ -128,16 +130,21 @@ ASSIGNMENTS: Final = { "//=", } -IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist} -BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE} -OPENING_BRACKETS = set(BRACKET.keys()) -CLOSING_BRACKETS = set(BRACKET.values()) -BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS -ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} +IMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist} +BRACKET: Final = { + token.LPAR: token.RPAR, + token.LSQB: token.RSQB, + token.LBRACE: token.RBRACE, +} +OPENING_BRACKETS: Final = set(BRACKET.keys()) +CLOSING_BRACKETS: Final = set(BRACKET.values()) +BRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS +ALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} RARROW = 55 +@mypyc_attr(allow_interpreted_subclasses=True) class Visitor(Generic[T]): """Basic lib2to3 visitor that yields things of type `T` on `visit()`.""" @@ -178,9 +185,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901 `complex_subscript` signals whether the given leaf is part of a subscription which has non-trivial arguments, like arithmetic expressions or function calls. """ - NO = "" - SPACE = " " - DOUBLESPACE = " " + NO: Final = "" + SPACE: Final = " " + DOUBLESPACE: Final = " " t = leaf.type p = leaf.parent v = leaf.value @@ -441,8 +448,8 @@ def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> b def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]: """Return (penultimate, last) leaves skipping brackets in `omit` and contents.""" - stop_after = None - last = None + stop_after: Optional[Leaf] = None + last: Optional[Leaf] = None for leaf in reversed(leaves): if stop_after: if leaf is stop_after: diff --git a/src/black/output.py b/src/black/output.py index fd3dbb3..c85b253 100644 --- a/src/black/output.py +++ b/src/black/output.py @@ -11,6 +11,7 @@ import tempfile from click import echo, style +@mypyc_attr(patchable=True) def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: if message is not None: if "bold" not in styles: @@ -19,6 +20,7 @@ def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: echo(message, nl=nl, err=True) +@mypyc_attr(patchable=True) def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: if message is not None: if "fg" not in styles: @@ -27,6 +29,7 @@ def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: echo(message, nl=nl, err=True) +@mypyc_attr(patchable=True) def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None: _out(message, nl=nl, **styles) diff --git a/src/black/parsing.py b/src/black/parsing.py index ee6aae1..504e20b 100644 --- a/src/black/parsing.py +++ b/src/black/parsing.py @@ -4,11 +4,16 @@ Parse Python code and perform AST validation. import ast import platform import sys -from typing import Iterable, Iterator, List, Set, Union, Tuple +from typing import Any, Iterable, Iterator, List, Set, Tuple, Type, Union + +if sys.version_info < (3, 8): + from typing_extensions import Final +else: + from typing import Final # lib2to3 fork from blib2to3.pytree import Node, Leaf -from blib2to3 import pygram, pytree +from blib2to3 import pygram from blib2to3.pgen2 import driver from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.parse import ParseError @@ -16,6 +21,9 @@ from blib2to3.pgen2.parse import ParseError from black.mode import TargetVersion, Feature, supports_feature from black.nodes import syms +ast3: Any +ast27: Any + _IS_PYPY = platform.python_implementation() == "PyPy" try: @@ -86,7 +94,7 @@ def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) - src_txt += "\n" for grammar in get_grammars(set(target_versions)): - drv = driver.Driver(grammar, pytree.convert) + drv = driver.Driver(grammar) try: result = drv.parse_string(src_txt, True) break @@ -148,6 +156,10 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: raise SyntaxError(first_error) +ast3_AST: Final[Type[ast3.AST]] = ast3.AST +ast27_AST: Final[Type[ast27.AST]] = ast27.AST + + def stringify_ast( node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0 ) -> Iterator[str]: @@ -189,7 +201,13 @@ def stringify_ast( elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): yield from stringify_ast(item, depth + 2) - elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): + # Note that we are referencing the typed-ast ASTs via global variables and not + # direct module attribute accesses because that breaks mypyc. It's probably + # something to do with the ast3 / ast27 variables being marked as Any leading + # mypy to think this branch is always taken, leaving the rest of the code + # unanalyzed. Tighting up the types for the typed-ast AST types avoids the + # mypyc crash. + elif isinstance(value, (ast.AST, ast3_AST, ast27_AST)): yield from stringify_ast(value, depth + 2) else: diff --git a/src/black/strings.py b/src/black/strings.py index d7b6c24..97debe3 100644 --- a/src/black/strings.py +++ b/src/black/strings.py @@ -4,10 +4,20 @@ Simple formatting on strings. Further string formatting code is in trans.py. import regex as re import sys +from functools import lru_cache from typing import List, Pattern +if sys.version_info < (3, 8): + from typing_extensions import Final +else: + from typing import Final -STRING_PREFIX_CHARS = "furbFURB" # All possible string prefix characters. + +STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters. +STRING_PREFIX_RE: Final = re.compile( + r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", re.DOTALL +) +FIRST_NON_WHITESPACE_RE: Final = re.compile(r"\s*\t+\s*(\S)") def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: @@ -37,7 +47,7 @@ def lines_with_leading_tabs_expanded(s: str) -> List[str]: for line in s.splitlines(): # Find the index of the first non-whitespace character after a string of # whitespace that includes at least one tab - match = re.match(r"\s*\t+\s*(\S)", line) + match = FIRST_NON_WHITESPACE_RE.match(line) if match: first_non_whitespace_idx = match.start(1) @@ -133,7 +143,7 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str: If remove_u_prefix is given, also removes any u prefix from the string. """ - match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", s, re.DOTALL) + match = STRING_PREFIX_RE.match(s) assert match is not None, f"failed to match string {s!r}" orig_prefix = match.group(1) new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") @@ -142,6 +152,14 @@ def normalize_string_prefix(s: str, remove_u_prefix: bool = False) -> str: return f"{new_prefix}{match.group(2)}" +# Re(gex) does actually cache patterns internally but this still improves +# performance on a long list literal of strings by 5-9% since lru_cache's +# caching overhead is much lower. +@lru_cache(maxsize=64) +def _cached_compile(pattern: str) -> re.Pattern: + return re.compile(pattern) + + def normalize_string_quotes(s: str) -> str: """Prefer double quotes but only if it doesn't cause more escaping. @@ -166,9 +184,9 @@ def normalize_string_quotes(s: str) -> str: return s # There's an internal error prefix = s[:first_quote_pos] - unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") - escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}") - escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}") + unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}") + escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}") + escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}") body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)] if "r" in prefix.casefold(): if unescaped_new_quote.search(body): diff --git a/src/black/trans.py b/src/black/trans.py index 023dcd3..d918ef1 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -8,6 +8,7 @@ import regex as re from typing import ( Any, Callable, + ClassVar, Collection, Dict, Iterable, @@ -20,6 +21,14 @@ from typing import ( TypeVar, Union, ) +import sys + +if sys.version_info < (3, 8): + from typing_extensions import Final +else: + from typing import Final + +from mypy_extensions import trait from black.rusty import Result, Ok, Err @@ -62,7 +71,6 @@ def TErr(err_msg: str) -> Err[CannotTransform]: return Err(cant_transform) -@dataclass # type: ignore class StringTransformer(ABC): """ An implementation of the Transformer protocol that relies on its @@ -90,9 +98,13 @@ class StringTransformer(ABC): as much as possible. """ - line_length: int - normalize_strings: bool - __name__ = "StringTransformer" + __name__: Final = "StringTransformer" + + # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with + # `abc.ABC`. + def __init__(self, line_length: int, normalize_strings: bool) -> None: + self.line_length = line_length + self.normalize_strings = normalize_strings @abstractmethod def do_match(self, line: Line) -> TMatchResult: @@ -184,6 +196,7 @@ class CustomSplit: break_idx: int +@trait class CustomSplitMapMixin: """ This mixin class is used to map merged strings to a sequence of @@ -191,8 +204,10 @@ class CustomSplitMapMixin: the resultant substrings go over the configured max line length. """ - _Key = Tuple[StringID, str] - _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple) + _Key: ClassVar = Tuple[StringID, str] + _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict( + tuple + ) @staticmethod def _get_key(string: str) -> "CustomSplitMapMixin._Key": @@ -243,7 +258,7 @@ class CustomSplitMapMixin: return key in self._CUSTOM_SPLIT_MAP -class StringMerger(CustomSplitMapMixin, StringTransformer): +class StringMerger(StringTransformer, CustomSplitMapMixin): """StringTransformer that merges strings together. Requirements: @@ -739,7 +754,7 @@ class BaseStringSplitter(StringTransformer): * The target string is not a multiline (i.e. triple-quote) string. """ - STRING_OPERATORS = [ + STRING_OPERATORS: Final = [ token.EQEQUAL, token.GREATER, token.GREATEREQUAL, @@ -927,7 +942,7 @@ class BaseStringSplitter(StringTransformer): return max_string_length -class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): +class StringSplitter(BaseStringSplitter, CustomSplitMapMixin): """ StringTransformer that splits "atom" strings (i.e. strings which exist on lines by themselves). @@ -965,9 +980,9 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): CustomSplit objects and add them to the custom split map. """ - MIN_SUBSTR_SIZE = 6 + MIN_SUBSTR_SIZE: Final = 6 # Matches an "f-expression" (e.g. {var}) that might be found in an f-string. - RE_FEXPR = r""" + RE_FEXPR: Final = r""" (? None: + def __init__(self, grammar: Grammar, logger: Optional[Logger] = None) -> None: self.grammar = grammar if logger is None: logger = logging.getLogger(__name__) self.logger = logger - self.convert = convert - def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL: + def parse_tokens(self, tokens: Iterable[GoodTokenInfo], debug: bool = False) -> NL: """Parse a series of tokens and return the syntax tree.""" # XXX Move the prefix computation into a wrapper around tokenize. proxy = TokenProxy(tokens) - p = parse.Parser(self.grammar, self.convert) + p = parse.Parser(self.grammar) p.setup(proxy=proxy) lineno = 1 column = 0 - indent_columns = [] + indent_columns: List[int] = [] type = value = start = end = line_text = None prefix = "" @@ -163,6 +159,7 @@ class Driver(object): if type == token.OP: type = grammar.opmap[value] if debug: + assert type is not None self.logger.debug( "%s %r (prefix=%r)", token.tok_name[type], value, prefix ) @@ -174,7 +171,7 @@ class Driver(object): elif type == token.DEDENT: _indent_col = indent_columns.pop() prefix, _prefix = self._partially_consume_prefix(prefix, _indent_col) - if p.addtoken(type, value, (prefix, start)): + if p.addtoken(cast(int, type), value, (prefix, start)): if debug: self.logger.debug("Stop.") break diff --git a/src/blib2to3/pgen2/parse.py b/src/blib2to3/pgen2/parse.py index dc40526..792e8e6 100644 --- a/src/blib2to3/pgen2/parse.py +++ b/src/blib2to3/pgen2/parse.py @@ -29,7 +29,7 @@ from typing import ( TYPE_CHECKING, ) from blib2to3.pgen2.grammar import Grammar -from blib2to3.pytree import NL, Context, RawNode, Leaf, Node +from blib2to3.pytree import convert, NL, Context, RawNode, Leaf, Node if TYPE_CHECKING: from blib2to3.driver import TokenProxy @@ -70,9 +70,7 @@ class Recorder: finally: self.parser.stack = self._start_point - def add_token( - self, tok_type: int, tok_val: Optional[Text], raw: bool = False - ) -> None: + def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None: func: Callable[..., Any] if raw: func = self.parser._addtoken @@ -86,9 +84,7 @@ class Recorder: args.insert(0, ilabel) func(*args) - def determine_route( - self, value: Optional[Text] = None, force: bool = False - ) -> Optional[int]: + def determine_route(self, value: Text = None, force: bool = False) -> Optional[int]: alive_ilabels = self.ilabels if len(alive_ilabels) == 0: *_, most_successful_ilabel = self._dead_ilabels @@ -164,6 +160,11 @@ class Parser(object): to be converted. The syntax tree is converted from the bottom up. + **post-note: the convert argument is ignored since for Black's + usage, convert will always be blib2to3.pytree.convert. Allowing + this to be dynamic hurts mypyc's ability to use early binding. + These docs are left for historical and informational value. + A concrete syntax tree node is a (type, value, context, nodes) tuple, where type is the node type (a token or symbol number), value is None for symbols and a string for tokens, context is @@ -176,6 +177,7 @@ class Parser(object): """ self.grammar = grammar + # See note in docstring above. TL;DR this is ignored. self.convert = convert or lam_sub def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None: @@ -203,7 +205,7 @@ class Parser(object): self.used_names: Set[str] = set() self.proxy = proxy - def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool: + def addtoken(self, type: int, value: Text, context: Context) -> bool: """Add a token; return True iff this is the end of the program.""" # Map from token to label ilabels = self.classify(type, value, context) @@ -237,7 +239,7 @@ class Parser(object): next_token_type, next_token_value, *_ = proxy.eat(counter) if next_token_type == tokenize.OP: - next_token_type = grammar.opmap[cast(str, next_token_value)] + next_token_type = grammar.opmap[next_token_value] recorder.add_token(next_token_type, next_token_value) counter += 1 @@ -247,9 +249,7 @@ class Parser(object): return self._addtoken(ilabel, type, value, context) - def _addtoken( - self, ilabel: int, type: int, value: Optional[Text], context: Context - ) -> bool: + def _addtoken(self, ilabel: int, type: int, value: Text, context: Context) -> bool: # Loop until the token is shifted; may raise exceptions while True: dfa, state, node = self.stack[-1] @@ -257,10 +257,18 @@ class Parser(object): arcs = states[state] # Look for a state with this label for i, newstate in arcs: - t, v = self.grammar.labels[i] - if ilabel == i: + t = self.grammar.labels[i][0] + if t >= 256: + # See if it's a symbol and if we're in its first set + itsdfa = self.grammar.dfas[t] + itsstates, itsfirst = itsdfa + if ilabel in itsfirst: + # Push a symbol + self.push(t, itsdfa, newstate, context) + break # To continue the outer while loop + + elif ilabel == i: # Look it up in the list of labels - assert t < 256 # Shift a token; we're done with it self.shift(type, value, newstate, context) # Pop while we are in an accept-only state @@ -274,14 +282,7 @@ class Parser(object): states, first = dfa # Done with this token return False - elif t >= 256: - # See if it's a symbol and if we're in its first set - itsdfa = self.grammar.dfas[t] - itsstates, itsfirst = itsdfa - if ilabel in itsfirst: - # Push a symbol - self.push(t, self.grammar.dfas[t], newstate, context) - break # To continue the outer while loop + else: if (0, state) in arcs: # An accepting state, pop it and try something else @@ -293,14 +294,13 @@ class Parser(object): # No success finding a transition raise ParseError("bad input", type, value, context) - def classify(self, type: int, value: Optional[Text], context: Context) -> List[int]: + def classify(self, type: int, value: Text, context: Context) -> List[int]: """Turn a token into a label. (Internal) Depending on whether the value is a soft-keyword or not, this function may return multiple labels to choose from.""" if type == token.NAME: # Keep a listing of all used names - assert value is not None self.used_names.add(value) # Check for reserved words if value in self.grammar.keywords: @@ -317,18 +317,13 @@ class Parser(object): raise ParseError("bad token", type, value, context) return [ilabel] - def shift( - self, type: int, value: Optional[Text], newstate: int, context: Context - ) -> None: + def shift(self, type: int, value: Text, newstate: int, context: Context) -> None: """Shift a token. (Internal)""" dfa, state, node = self.stack[-1] - assert value is not None - assert context is not None rawnode: RawNode = (type, value, context, None) - newnode = self.convert(self.grammar, rawnode) - if newnode is not None: - assert node[-1] is not None - node[-1].append(newnode) + newnode = convert(self.grammar, rawnode) + assert node[-1] is not None + node[-1].append(newnode) self.stack[-1] = (dfa, newstate, node) def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None: @@ -341,12 +336,11 @@ class Parser(object): def pop(self) -> None: """Pop a nonterminal. (Internal)""" popdfa, popstate, popnode = self.stack.pop() - newnode = self.convert(self.grammar, popnode) - if newnode is not None: - if self.stack: - dfa, state, node = self.stack[-1] - assert node[-1] is not None - node[-1].append(newnode) - else: - self.rootnode = newnode - self.rootnode.used_names = self.used_names + newnode = convert(self.grammar, popnode) + if self.stack: + dfa, state, node = self.stack[-1] + assert node[-1] is not None + node[-1].append(newnode) + else: + self.rootnode = newnode + self.rootnode.used_names = self.used_names diff --git a/src/blib2to3/pgen2/tokenize.py b/src/blib2to3/pgen2/tokenize.py index bad79b2..283fac2 100644 --- a/src/blib2to3/pgen2/tokenize.py +++ b/src/blib2to3/pgen2/tokenize.py @@ -27,6 +27,7 @@ are the same, except instead of generating tokens, tokeneater is a callback function to which the 5 fields described above are passed as 5 arguments, each time a new token is found.""" +import sys from typing import ( Callable, Iterable, @@ -39,6 +40,12 @@ from typing import ( Union, cast, ) + +if sys.version_info >= (3, 8): + from typing import Final +else: + from typing_extensions import Final + from blib2to3.pgen2.token import * from blib2to3.pgen2.grammar import Grammar @@ -139,7 +146,7 @@ ContStr = group( PseudoExtras = group(r"\\\r?\n", Comment, Triple) PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name) -pseudoprog = re.compile(PseudoToken, re.UNICODE) +pseudoprog: Final = re.compile(PseudoToken, re.UNICODE) single3prog = re.compile(Single3) double3prog = re.compile(Double3) @@ -149,7 +156,7 @@ _strprefixes = ( | {"u", "U", "ur", "uR", "Ur", "UR"} ) -endprogs = { +endprogs: Final = { "'": re.compile(Single), '"': re.compile(Double), "'''": single3prog, @@ -159,12 +166,12 @@ endprogs = { **{prefix: None for prefix in _strprefixes}, } -triple_quoted = ( +triple_quoted: Final = ( {"'''", '"""'} | {f"{prefix}'''" for prefix in _strprefixes} | {f'{prefix}"""' for prefix in _strprefixes} ) -single_quoted = ( +single_quoted: Final = ( {"'", '"'} | {f"{prefix}'" for prefix in _strprefixes} | {f'{prefix}"' for prefix in _strprefixes} @@ -418,7 +425,7 @@ def generate_tokens( logical line; continuation lines are included. """ lnum = parenlev = continued = 0 - numchars = "0123456789" + numchars: Final = "0123456789" contstr, needcont = "", 0 contline: Optional[str] = None indents = [0] @@ -427,7 +434,7 @@ def generate_tokens( # `await` as keywords. async_keywords = False if grammar is None else grammar.async_keywords # 'stashed' and 'async_*' are used for async/await parsing - stashed = None + stashed: Optional[GoodTokenInfo] = None async_def = False async_def_indent = 0 async_def_nl = False @@ -440,7 +447,7 @@ def generate_tokens( line = readline() except StopIteration: line = "" - lnum = lnum + 1 + lnum += 1 pos, max = 0, len(line) if contstr: # continued string @@ -481,14 +488,14 @@ def generate_tokens( column = 0 while pos < max: # measure leading whitespace if line[pos] == " ": - column = column + 1 + column += 1 elif line[pos] == "\t": column = (column // tabsize + 1) * tabsize elif line[pos] == "\f": column = 0 else: break - pos = pos + 1 + pos += 1 if pos == max: break @@ -507,7 +514,7 @@ def generate_tokens( COMMENT, comment_token, (lnum, pos), - (lnum, pos + len(comment_token)), + (lnum, nl_pos), line, ) yield (NL, line[nl_pos:], (lnum, nl_pos), (lnum, len(line)), line) @@ -652,16 +659,16 @@ def generate_tokens( continued = 1 else: if initial in "([{": - parenlev = parenlev + 1 + parenlev += 1 elif initial in ")]}": - parenlev = parenlev - 1 + parenlev -= 1 if stashed: yield stashed stashed = None yield (OP, token, spos, epos, line) else: yield (ERRORTOKEN, line[pos], (lnum, pos), (lnum, pos + 1), line) - pos = pos + 1 + pos += 1 if stashed: yield stashed diff --git a/src/blib2to3/pytree.py b/src/blib2to3/pytree.py index 001652d..bd86270 100644 --- a/src/blib2to3/pytree.py +++ b/src/blib2to3/pytree.py @@ -14,7 +14,6 @@ There's also a pattern matching implementation here. from typing import ( Any, - Callable, Dict, Iterator, List, @@ -92,8 +91,6 @@ class Base(object): return NotImplemented return self._eq(other) - __hash__ = None # type: Any # For Py3 compatibility. - @property def prefix(self) -> Text: raise NotImplementedError @@ -437,7 +434,7 @@ class Leaf(Base): This reproduces the input source exactly. """ - return self.prefix + str(self.value) + return self._prefix + str(self.value) def _eq(self, other) -> bool: """Compare two nodes for equality.""" @@ -672,8 +669,11 @@ class NodePattern(BasePattern): newcontent = list(content) for i, item in enumerate(newcontent): assert isinstance(item, BasePattern), (i, item) - if isinstance(item, WildcardPattern): - self.wildcards = True + # I don't even think this code is used anywhere, but it does cause + # unreachable errors from mypy. This function's signature does look + # odd though *shrug*. + if isinstance(item, WildcardPattern): # type: ignore[unreachable] + self.wildcards = True # type: ignore[unreachable] self.type = type self.content = newcontent self.name = name @@ -978,6 +978,3 @@ def generate_matches( r.update(r0) r.update(r1) yield c0 + c1, r - - -_Convert = Callable[[Grammar, RawNode], Any] diff --git a/tests/test_black.py b/tests/test_black.py index 301a3a5..3d5d398 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -122,7 +122,7 @@ def invokeBlack( runner = BlackRunner() if ignore_config: args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args] - result = runner.invoke(black.main, args) + result = runner.invoke(black.main, args, catch_exceptions=False) assert result.stdout_bytes is not None assert result.stderr_bytes is not None msg = ( @@ -841,6 +841,7 @@ class BlackTestCase(BlackBaseTestCase): ) self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node)) + @pytest.mark.incompatible_with_mypyc def test_debug_visitor(self) -> None: source, _ = read_data("debug_visitor.py") expected, _ = read_data("debug_visitor.out") @@ -891,6 +892,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(len(n.children), 1) self.assertEqual(n.children[0].type, black.token.ENDMARKER) + @pytest.mark.incompatible_with_mypyc @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT") def test_assertFormatEqual(self) -> None: out_lines = [] @@ -1055,6 +1057,7 @@ class BlackTestCase(BlackBaseTestCase): actual = result.output self.assertFormatEqual(actual, expected) + @pytest.mark.incompatible_with_mypyc def test_reformat_one_with_stdin(self) -> None: with patch( "black.format_stdin_to_stdout", @@ -1072,6 +1075,7 @@ class BlackTestCase(BlackBaseTestCase): fsts.assert_called_once() report.done.assert_called_with(path, black.Changed.YES) + @pytest.mark.incompatible_with_mypyc def test_reformat_one_with_stdin_filename(self) -> None: with patch( "black.format_stdin_to_stdout", @@ -1094,6 +1098,7 @@ class BlackTestCase(BlackBaseTestCase): # __BLACK_STDIN_FILENAME__ should have been stripped report.done.assert_called_with(expected, black.Changed.YES) + @pytest.mark.incompatible_with_mypyc def test_reformat_one_with_stdin_filename_pyi(self) -> None: with patch( "black.format_stdin_to_stdout", @@ -1118,6 +1123,7 @@ class BlackTestCase(BlackBaseTestCase): # __BLACK_STDIN_FILENAME__ should have been stripped report.done.assert_called_with(expected, black.Changed.YES) + @pytest.mark.incompatible_with_mypyc def test_reformat_one_with_stdin_filename_ipynb(self) -> None: with patch( "black.format_stdin_to_stdout", @@ -1142,6 +1148,7 @@ class BlackTestCase(BlackBaseTestCase): # __BLACK_STDIN_FILENAME__ should have been stripped report.done.assert_called_with(expected, black.Changed.YES) + @pytest.mark.incompatible_with_mypyc def test_reformat_one_with_stdin_and_existing_path(self) -> None: with patch( "black.format_stdin_to_stdout", @@ -1296,6 +1303,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(config["exclude"], r"\.pyi?$") self.assertEqual(config["include"], r"\.py?$") + @pytest.mark.incompatible_with_mypyc def test_find_project_root(self) -> None: with TemporaryDirectory() as workspace: root = Path(workspace) @@ -1483,6 +1491,7 @@ class BlackTestCase(BlackBaseTestCase): assert output == result_diff, "The output did not match the expected value." assert result.exit_code == 0, "The exit code is incorrect." + @pytest.mark.incompatible_with_mypyc def test_code_option_safe(self) -> None: """Test that the code option throws an error when the sanity checks fail.""" # Patch black.assert_equivalent to ensure the sanity checks fail @@ -1507,6 +1516,7 @@ class BlackTestCase(BlackBaseTestCase): self.compare_results(result, formatted, 0) + @pytest.mark.incompatible_with_mypyc def test_code_option_config(self) -> None: """ Test that the code option finds the pyproject.toml in the current directory. @@ -1527,6 +1537,7 @@ class BlackTestCase(BlackBaseTestCase): call_args[0].lower() == str(pyproject_path).lower() ), "Incorrect config loaded." + @pytest.mark.incompatible_with_mypyc def test_code_option_parent_config(self) -> None: """ Test that the code option finds the pyproject.toml in the parent directory. @@ -1894,6 +1905,7 @@ class TestFileCollection: src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude" ) + @pytest.mark.incompatible_with_mypyc def test_symlink_out_of_root_directory(self) -> None: path = MagicMock() root = THIS_DIR.resolve() @@ -2047,8 +2059,12 @@ def test_python_2_deprecation_autodetection_extended() -> None: }, non_python2_case -with open(black.__file__, "r", encoding="utf-8") as _bf: - black_source_lines = _bf.readlines() +try: + with open(black.__file__, "r", encoding="utf-8") as _bf: + black_source_lines = _bf.readlines() +except UnicodeDecodeError: + if not black.COMPILED: + raise def tracefunc( -- 2.39.5