# free to run mypy on Windows, Linux, or macOS and get consistent
# results.
python_version=3.6
-platform=linux
mypy_path=src
warn_unused_ignores=True
disallow_any_generics=True
+# Unreachable blocks have been an issue when compiling mypyc, let's try
+# to avoid 'em in the first place.
+warn_unreachable=True
+
# The following are off by default. Flip them on if you feel
# adventurous.
disallow_untyped_defs=True
# No incremental mode
cache_dir=/dev/null
+[mypy-black]
+# The following is because of `patch_click()`. Remove when
+# we drop Python 3.6 support.
+warn_unused_ignores=False
+
[mypy-black_primer.*]
# Until we're not supporting 3.6 primer needs this
disallow_any_generics=False
"no_blackd: run when `d` extra NOT installed",
"no_jupyter: run when `jupyter` extra NOT installed",
]
+markers = [
+ "incompatible_with_mypyc: run when testing mypyc compiled black"
+]
assert sys.version_info >= (3, 6, 2), "black requires Python 3.6.2+"
from pathlib import Path # noqa E402
+from typing import List # noqa: E402
CURRENT_DIR = Path(__file__).parent
sys.path.insert(0, str(CURRENT_DIR)) # for setuptools.build_meta
)
+def find_python_files(base: Path) -> List[Path]:
+ files = []
+ for entry in base.iterdir():
+ if entry.is_file() and entry.suffix == ".py":
+ files.append(entry)
+ elif entry.is_dir():
+ files.extend(find_python_files(entry))
+
+ return files
+
+
USE_MYPYC = False
# To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH
if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc":
USE_MYPYC = True
if USE_MYPYC:
+ from mypyc.build import mypycify
+
+ src = CURRENT_DIR / "src"
+ # TIP: filepaths are normalized to use forward slashes and are relative to ./src/
+ # before being checked against.
+ blocklist = [
+ # Not performance sensitive, so save bytes + compilation time:
+ "blib2to3/__init__.py",
+ "blib2to3/pgen2/__init__.py",
+ "black/output.py",
+ "black/concurrency.py",
+ "black/files.py",
+ "black/report.py",
+ # Breaks the test suite when compiled (and is also useless):
+ "black/debug.py",
+ # Compiled modules can't be run directly and that's a problem here:
+ "black/__main__.py",
+ ]
+ discovered = []
+ # black-primer and blackd have no good reason to be compiled.
+ discovered.extend(find_python_files(src / "black"))
+ discovered.extend(find_python_files(src / "blib2to3"))
mypyc_targets = [
- "src/black/__init__.py",
- "src/blib2to3/pytree.py",
- "src/blib2to3/pygram.py",
- "src/blib2to3/pgen2/parse.py",
- "src/blib2to3/pgen2/grammar.py",
- "src/blib2to3/pgen2/token.py",
- "src/blib2to3/pgen2/driver.py",
- "src/blib2to3/pgen2/pgen.py",
+ str(p) for p in discovered if p.relative_to(src).as_posix() not in blocklist
]
- from mypyc.build import mypycify
-
opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
- ext_modules = mypycify(mypyc_targets, opt_level=opt_level)
+ ext_modules = mypycify(mypyc_targets, opt_level=opt_level, verbose=True)
else:
ext_modules = []
Union,
)
-from dataclasses import replace
import click
+from dataclasses import replace
+from mypy_extensions import mypyc_attr
from black.const import DEFAULT_LINE_LENGTH, DEFAULT_INCLUDES, DEFAULT_EXCLUDES
from black.const import STDIN_PLACEHOLDER
from _black_version import version as __version__
+COMPILED = Path(__file__).suffix in (".pyd", ".so")
+
# types
FileContent = str
Encoding = str
raise click.BadParameter("Not a valid regular expression") from None
-@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.command(
+ context_settings=dict(help_option_names=["-h", "--help"]),
+ # While Click does set this field automatically using the docstring, mypyc
+ # (annoyingly) strips 'em so we need to set it here too.
+ help="The uncompromising code formatter.",
+)
@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-l",
" due to exclusion patterns."
),
)
-@click.version_option(version=__version__)
+@click.version_option(
+ version=__version__,
+ message=f"%(prog)s, %(version)s (compiled: {'yes' if COMPILED else 'no'})",
+)
@click.argument(
"src",
nargs=-1,
experimental_string_processing: bool,
quiet: bool,
verbose: bool,
- required_version: str,
+ required_version: Optional[str],
include: Pattern[str],
exclude: Optional[Pattern[str]],
extend_exclude: Optional[Pattern[str]],
report.failed(src, str(exc))
+# diff-shades depends on being to monkeypatch this function to operate. I know it's
+# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
+@mypyc_attr(patchable=True)
def reformat_many(
sources: Set[Path],
fast: bool,
worker_count = workers if workers is not None else DEFAULT_WORKERS
if sys.platform == "win32":
# Work around https://bugs.python.org/issue26903
+ assert worker_count is not None
worker_count = min(worker_count, 60)
try:
executor = ProcessPoolExecutor(max_workers=worker_count)
DOT_PRIORITY: Final = 1
-class BracketMatchError(KeyError):
+class BracketMatchError(Exception):
"""Raised when an opening bracket is unable to be matched to a closing bracket."""
+import sys
from dataclasses import dataclass
from functools import lru_cache
import regex as re
from typing import Iterator, List, Optional, Union
+if sys.version_info >= (3, 8):
+ from typing import Final
+else:
+ from typing_extensions import Final
+
from blib2to3.pytree import Node, Leaf
from blib2to3.pgen2 import token
# types
LN = Union[Leaf, Node]
-
-FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
-FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
-FMT_PASS = {*FMT_OFF, *FMT_SKIP}
-FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
+FMT_OFF: Final = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_SKIP: Final = {"# fmt: skip", "# fmt:skip"}
+FMT_PASS: Final = {*FMT_OFF, *FMT_SKIP}
+FMT_ON: Final = {"# fmt: on", "# fmt:on", "# yapf: enable"}
@dataclass
TYPE_CHECKING,
)
+from mypy_extensions import mypyc_attr
from pathspec import PathSpec
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
import tomli
return None
+@mypyc_attr(patchable=True)
def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
"""Parse a pyproject toml file, pulling out relevant parts for Black
If parsing fails, will raise a tomli.TOMLDecodeError
"""
with open(path_config, encoding="utf8") as f:
- pyproject_toml = tomli.load(f) # type: ignore # due to deprecated API usage
+ pyproject_toml = tomli.loads(f.read())
config = pyproject_toml.get("tool", {}).get("black", {})
return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
return f"%%{self.name}"
-@dataclasses.dataclass
+# ast.NodeVisitor + dataclass = breakage under mypyc.
class CellMagicFinder(ast.NodeVisitor):
"""Find cell magics.
and we look for instances of the latter.
"""
- cell_magic: Optional[CellMagic] = None
+ def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
+ self.cell_magic = cell_magic
def visit_Expr(self, node: ast.Expr) -> None:
"""Find cell magic, extract header and body."""
magic: str
-@dataclasses.dataclass
+# Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
+# as mypyc will generate broken code.
class MagicFinder(ast.NodeVisitor):
"""Visit cell to look for get_ipython calls.
types of magics).
"""
- magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field(
- default_factory=lambda: collections.defaultdict(list)
- )
+ def __init__(self) -> None:
+ self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)
def visit_Assign(self, node: ast.Assign) -> None:
"""Look for system assign magics.
import sys
from typing import Collection, Iterator, List, Optional, Set, Union
-from dataclasses import dataclass, field
-
from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
"""A readable split that fits the allotted line length is impossible."""
-@dataclass
+# This isn't a dataclass because @dataclass + Generic breaks mypyc.
+# See also https://github.com/mypyc/mypyc/issues/827.
class LineGenerator(Visitor[Line]):
"""Generates reformatted Line objects. Empty lines are not emitted.
in ways that will no longer stringify to valid Python code on the tree.
"""
- mode: Mode
- remove_u_prefix: bool = False
- current_line: Line = field(init=False)
+ def __init__(self, mode: Mode, remove_u_prefix: bool = False) -> None:
+ self.mode = mode
+ self.remove_u_prefix = remove_u_prefix
+ self.current_line: Line
+ self.__post_init__()
def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
transformers = [left_hand_split]
else:
- def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
+ def _rhs(
+ self: object, line: Line, features: Collection[Feature]
+ ) -> Iterator[Line]:
"""Wraps calls to `right_hand_split`.
The calls increasingly `omit` right-hand trailers (bracket pairs with
line, line_length=mode.line_length, features=features
)
+ # HACK: nested functions (like _rhs) compiled by mypyc don't retain their
+ # __name__ attribute which is needed in `run_transformer` further down.
+ # Unfortunately a nested class breaks mypyc too. So a class must be created
+ # via type ... https://github.com/mypyc/mypyc/issues/884
+ rhs = type("rhs", (), {"__call__": _rhs})()
+
if mode.experimental_string_processing:
if line.inside_brackets:
transformers = [
result.extend(transform_line(transformed_line, mode=mode, features=features))
if (
- transform.__name__ != "rhs"
+ transform.__class__.__name__ != "rhs"
or not line.bracket_tracker.invisible
or any(bracket.value for bracket in line.bracket_tracker.invisible)
or line.contains_multiline_strings()
from dataclasses import dataclass, field
from enum import Enum
+from operator import attrgetter
from typing import Dict, Set
from black.const import DEFAULT_LINE_LENGTH
if self.target_versions:
version_str = ",".join(
str(version.value)
- for version in sorted(self.target_versions, key=lambda v: v.value)
+ for version in sorted(self.target_versions, key=attrgetter("value"))
)
else:
version_str = "-"
Union,
)
-if sys.version_info < (3, 8):
- from typing_extensions import Final
-else:
+if sys.version_info >= (3, 8):
from typing import Final
+else:
+ from typing_extensions import Final
+
+from mypy_extensions import mypyc_attr
# lib2to3 fork
from blib2to3.pytree import Node, Leaf, type_repr
pygram.initialize(CACHE_DIR)
-syms = pygram.python_symbols
+syms: Final = pygram.python_symbols
# types
"//=",
}
-IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
-BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
-OPENING_BRACKETS = set(BRACKET.keys())
-CLOSING_BRACKETS = set(BRACKET.values())
-BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
-ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
+IMPLICIT_TUPLE: Final = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
+BRACKET: Final = {
+ token.LPAR: token.RPAR,
+ token.LSQB: token.RSQB,
+ token.LBRACE: token.RBRACE,
+}
+OPENING_BRACKETS: Final = set(BRACKET.keys())
+CLOSING_BRACKETS: Final = set(BRACKET.values())
+BRACKETS: Final = OPENING_BRACKETS | CLOSING_BRACKETS
+ALWAYS_NO_SPACE: Final = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
RARROW = 55
+@mypyc_attr(allow_interpreted_subclasses=True)
class Visitor(Generic[T]):
"""Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
`complex_subscript` signals whether the given leaf is part of a subscription
which has non-trivial arguments, like arithmetic expressions or function calls.
"""
- NO = ""
- SPACE = " "
- DOUBLESPACE = " "
+ NO: Final = ""
+ SPACE: Final = " "
+ DOUBLESPACE: Final = " "
t = leaf.type
p = leaf.parent
v = leaf.value
def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]:
"""Return (penultimate, last) leaves skipping brackets in `omit` and contents."""
- stop_after = None
- last = None
+ stop_after: Optional[Leaf] = None
+ last: Optional[Leaf] = None
for leaf in reversed(leaves):
if stop_after:
if leaf is stop_after:
from click import echo, style
+@mypyc_attr(patchable=True)
def _out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
if message is not None:
if "bold" not in styles:
echo(message, nl=nl, err=True)
+@mypyc_attr(patchable=True)
def _err(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
if message is not None:
if "fg" not in styles:
echo(message, nl=nl, err=True)
+@mypyc_attr(patchable=True)
def out(message: Optional[str] = None, nl: bool = True, **styles: Any) -> None:
_out(message, nl=nl, **styles)
import ast
import platform
import sys
-from typing import Iterable, Iterator, List, Set, Union, Tuple
+from typing import Any, Iterable, Iterator, List, Set, Tuple, Type, Union
+
+if sys.version_info < (3, 8):
+ from typing_extensions import Final
+else:
+ from typing import Final
# lib2to3 fork
from blib2to3.pytree import Node, Leaf
-from blib2to3 import pygram, pytree
+from blib2to3 import pygram
from blib2to3.pgen2 import driver
from blib2to3.pgen2.grammar import Grammar
from blib2to3.pgen2.parse import ParseError
from black.mode import TargetVersion, Feature, supports_feature
from black.nodes import syms
+ast3: Any
+ast27: Any
+
_IS_PYPY = platform.python_implementation() == "PyPy"
try:
src_txt += "\n"
for grammar in get_grammars(set(target_versions)):
- drv = driver.Driver(grammar, pytree.convert)
+ drv = driver.Driver(grammar)
try:
result = drv.parse_string(src_txt, True)
break
raise SyntaxError(first_error)
+ast3_AST: Final[Type[ast3.AST]] = ast3.AST
+ast27_AST: Final[Type[ast27.AST]] = ast27.AST
+
+
def stringify_ast(
node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
) -> Iterator[str]:
elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
yield from stringify_ast(item, depth + 2)
- elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
+ # Note that we are referencing the typed-ast ASTs via global variables and not
+ # direct module attribute accesses because that breaks mypyc. It's probably
+ # something to do with the ast3 / ast27 variables being marked as Any leading
+ # mypy to think this branch is always taken, leaving the rest of the code
+ # unanalyzed. Tighting up the types for the typed-ast AST types avoids the
+ # mypyc crash.
+ elif isinstance(value, (ast.AST, ast3_AST, ast27_AST)):
yield from stringify_ast(value, depth + 2)
else:
import regex as re
import sys
+from functools import lru_cache
from typing import List, Pattern
+if sys.version_info < (3, 8):
+ from typing_extensions import Final
+else:
+ from typing import Final
-STRING_PREFIX_CHARS = "furbFURB" # All possible string prefix characters.
+
+STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters.
+STRING_PREFIX_RE: Final = re.compile(
+ r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", re.DOTALL
+)
+FIRST_NON_WHITESPACE_RE: Final = re.compile(r"\s*\t+\s*(\S)")
def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
for line in s.splitlines():
# Find the index of the first non-whitespace character after a string of
# whitespace that includes at least one tab
- match = re.match(r"\s*\t+\s*(\S)", line)
+ match = FIRST_NON_WHITESPACE_RE.match(line)
if match:
first_non_whitespace_idx = match.start(1)
If remove_u_prefix is given, also removes any u prefix from the string.
"""
- match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", s, re.DOTALL)
+ match = STRING_PREFIX_RE.match(s)
assert match is not None, f"failed to match string {s!r}"
orig_prefix = match.group(1)
new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
return f"{new_prefix}{match.group(2)}"
+# Re(gex) does actually cache patterns internally but this still improves
+# performance on a long list literal of strings by 5-9% since lru_cache's
+# caching overhead is much lower.
+@lru_cache(maxsize=64)
+def _cached_compile(pattern: str) -> re.Pattern:
+ return re.compile(pattern)
+
+
def normalize_string_quotes(s: str) -> str:
"""Prefer double quotes but only if it doesn't cause more escaping.
return s # There's an internal error
prefix = s[:first_quote_pos]
- unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
- escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
- escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
+ unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
+ escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
+ escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
body = s[first_quote_pos + len(orig_quote) : -len(orig_quote)]
if "r" in prefix.casefold():
if unescaped_new_quote.search(body):
from typing import (
Any,
Callable,
+ ClassVar,
Collection,
Dict,
Iterable,
TypeVar,
Union,
)
+import sys
+
+if sys.version_info < (3, 8):
+ from typing_extensions import Final
+else:
+ from typing import Final
+
+from mypy_extensions import trait
from black.rusty import Result, Ok, Err
return Err(cant_transform)
-@dataclass # type: ignore
class StringTransformer(ABC):
"""
An implementation of the Transformer protocol that relies on its
as much as possible.
"""
- line_length: int
- normalize_strings: bool
- __name__ = "StringTransformer"
+ __name__: Final = "StringTransformer"
+
+ # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with
+ # `abc.ABC`.
+ def __init__(self, line_length: int, normalize_strings: bool) -> None:
+ self.line_length = line_length
+ self.normalize_strings = normalize_strings
@abstractmethod
def do_match(self, line: Line) -> TMatchResult:
break_idx: int
+@trait
class CustomSplitMapMixin:
"""
This mixin class is used to map merged strings to a sequence of
the resultant substrings go over the configured max line length.
"""
- _Key = Tuple[StringID, str]
- _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
+ _Key: ClassVar = Tuple[StringID, str]
+ _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(
+ tuple
+ )
@staticmethod
def _get_key(string: str) -> "CustomSplitMapMixin._Key":
return key in self._CUSTOM_SPLIT_MAP
-class StringMerger(CustomSplitMapMixin, StringTransformer):
+class StringMerger(StringTransformer, CustomSplitMapMixin):
"""StringTransformer that merges strings together.
Requirements:
* The target string is not a multiline (i.e. triple-quote) string.
"""
- STRING_OPERATORS = [
+ STRING_OPERATORS: Final = [
token.EQEQUAL,
token.GREATER,
token.GREATEREQUAL,
return max_string_length
-class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
+class StringSplitter(BaseStringSplitter, CustomSplitMapMixin):
"""
StringTransformer that splits "atom" strings (i.e. strings which exist on
lines by themselves).
CustomSplit objects and add them to the custom split map.
"""
- MIN_SUBSTR_SIZE = 6
+ MIN_SUBSTR_SIZE: Final = 6
# Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
- RE_FEXPR = r"""
+ RE_FEXPR: Final = r"""
(?<!\{) (?:\{\{)* \{ (?!\{)
(?:
[^\{\}]
return string_op_leaves
-class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
+class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
"""
StringTransformer that splits non-"atom" strings (i.e. strings that do not
exist on lines by themselves).
```
"""
- DEFAULT_TOKEN = -1
+ DEFAULT_TOKEN: Final = 20210605
# String Parser States
- START = 1
- DOT = 2
- NAME = 3
- PERCENT = 4
- SINGLE_FMT_ARG = 5
- LPAR = 6
- RPAR = 7
- DONE = 8
+ START: Final = 1
+ DOT: Final = 2
+ NAME: Final = 3
+ PERCENT: Final = 4
+ SINGLE_FMT_ARG: Final = 5
+ LPAR: Final = 6
+ RPAR: Final = 7
+ DONE: Final = 8
# Lookup Table for Next State
- _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
+ _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = {
# A string trailer may start with '.' OR '%'.
(START, token.DOT): DOT,
(START, token.PERCENT): PERCENT,
no_diff,
)
return int(ret_val)
+
finally:
if not keep and work_path.exists():
LOG.debug(f"Removing {work_path}")
rmtree(work_path, onerror=lib.handle_PermissionError)
- return -2
-
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option(
https://github.com/python/cpython/commit/cae60187cf7a7b26281d012e1952fafe4e2e97e9
- "bpo-42316: Allow unparenthesized walrus operator in indexes (GH-23317)"
https://github.com/python/cpython/commit/b0aba1fcdc3da952698d99aec2334faa79a8b68c
+- Tweaks to help mypyc compile faster code (including inlining type information,
+ "Final-ing", etc.)
import sys
from typing import (
Any,
+ cast,
IO,
Iterable,
List,
Generic,
Union,
)
+from contextlib import contextmanager
from dataclasses import dataclass, field
# Pgen imports
from . import grammar, parse, token, tokenize, pgen
from logging import Logger
-from blib2to3.pytree import _Convert, NL
+from blib2to3.pytree import NL
from blib2to3.pgen2.grammar import Grammar
-from contextlib import contextmanager
+from blib2to3.pgen2.tokenize import GoodTokenInfo
Path = Union[str, "os.PathLike[str]"]
class Driver(object):
- def __init__(
- self,
- grammar: Grammar,
- convert: Optional[_Convert] = None,
- logger: Optional[Logger] = None,
- ) -> None:
+ def __init__(self, grammar: Grammar, logger: Optional[Logger] = None) -> None:
self.grammar = grammar
if logger is None:
logger = logging.getLogger(__name__)
self.logger = logger
- self.convert = convert
- def parse_tokens(self, tokens: Iterable[Any], debug: bool = False) -> NL:
+ def parse_tokens(self, tokens: Iterable[GoodTokenInfo], debug: bool = False) -> NL:
"""Parse a series of tokens and return the syntax tree."""
# XXX Move the prefix computation into a wrapper around tokenize.
proxy = TokenProxy(tokens)
- p = parse.Parser(self.grammar, self.convert)
+ p = parse.Parser(self.grammar)
p.setup(proxy=proxy)
lineno = 1
column = 0
- indent_columns = []
+ indent_columns: List[int] = []
type = value = start = end = line_text = None
prefix = ""
if type == token.OP:
type = grammar.opmap[value]
if debug:
+ assert type is not None
self.logger.debug(
"%s %r (prefix=%r)", token.tok_name[type], value, prefix
)
elif type == token.DEDENT:
_indent_col = indent_columns.pop()
prefix, _prefix = self._partially_consume_prefix(prefix, _indent_col)
- if p.addtoken(type, value, (prefix, start)):
+ if p.addtoken(cast(int, type), value, (prefix, start)):
if debug:
self.logger.debug("Stop.")
break
TYPE_CHECKING,
)
from blib2to3.pgen2.grammar import Grammar
-from blib2to3.pytree import NL, Context, RawNode, Leaf, Node
+from blib2to3.pytree import convert, NL, Context, RawNode, Leaf, Node
if TYPE_CHECKING:
from blib2to3.driver import TokenProxy
finally:
self.parser.stack = self._start_point
- def add_token(
- self, tok_type: int, tok_val: Optional[Text], raw: bool = False
- ) -> None:
+ def add_token(self, tok_type: int, tok_val: Text, raw: bool = False) -> None:
func: Callable[..., Any]
if raw:
func = self.parser._addtoken
args.insert(0, ilabel)
func(*args)
- def determine_route(
- self, value: Optional[Text] = None, force: bool = False
- ) -> Optional[int]:
+ def determine_route(self, value: Text = None, force: bool = False) -> Optional[int]:
alive_ilabels = self.ilabels
if len(alive_ilabels) == 0:
*_, most_successful_ilabel = self._dead_ilabels
to be converted. The syntax tree is converted from the bottom
up.
+ **post-note: the convert argument is ignored since for Black's
+ usage, convert will always be blib2to3.pytree.convert. Allowing
+ this to be dynamic hurts mypyc's ability to use early binding.
+ These docs are left for historical and informational value.
+
A concrete syntax tree node is a (type, value, context, nodes)
tuple, where type is the node type (a token or symbol number),
value is None for symbols and a string for tokens, context is
"""
self.grammar = grammar
+ # See note in docstring above. TL;DR this is ignored.
self.convert = convert or lam_sub
def setup(self, proxy: "TokenProxy", start: Optional[int] = None) -> None:
self.used_names: Set[str] = set()
self.proxy = proxy
- def addtoken(self, type: int, value: Optional[Text], context: Context) -> bool:
+ def addtoken(self, type: int, value: Text, context: Context) -> bool:
"""Add a token; return True iff this is the end of the program."""
# Map from token to label
ilabels = self.classify(type, value, context)
next_token_type, next_token_value, *_ = proxy.eat(counter)
if next_token_type == tokenize.OP:
- next_token_type = grammar.opmap[cast(str, next_token_value)]
+ next_token_type = grammar.opmap[next_token_value]
recorder.add_token(next_token_type, next_token_value)
counter += 1
return self._addtoken(ilabel, type, value, context)
- def _addtoken(
- self, ilabel: int, type: int, value: Optional[Text], context: Context
- ) -> bool:
+ def _addtoken(self, ilabel: int, type: int, value: Text, context: Context) -> bool:
# Loop until the token is shifted; may raise exceptions
while True:
dfa, state, node = self.stack[-1]
arcs = states[state]
# Look for a state with this label
for i, newstate in arcs:
- t, v = self.grammar.labels[i]
- if ilabel == i:
+ t = self.grammar.labels[i][0]
+ if t >= 256:
+ # See if it's a symbol and if we're in its first set
+ itsdfa = self.grammar.dfas[t]
+ itsstates, itsfirst = itsdfa
+ if ilabel in itsfirst:
+ # Push a symbol
+ self.push(t, itsdfa, newstate, context)
+ break # To continue the outer while loop
+
+ elif ilabel == i:
# Look it up in the list of labels
- assert t < 256
# Shift a token; we're done with it
self.shift(type, value, newstate, context)
# Pop while we are in an accept-only state
states, first = dfa
# Done with this token
return False
- elif t >= 256:
- # See if it's a symbol and if we're in its first set
- itsdfa = self.grammar.dfas[t]
- itsstates, itsfirst = itsdfa
- if ilabel in itsfirst:
- # Push a symbol
- self.push(t, self.grammar.dfas[t], newstate, context)
- break # To continue the outer while loop
+
else:
if (0, state) in arcs:
# An accepting state, pop it and try something else
# No success finding a transition
raise ParseError("bad input", type, value, context)
- def classify(self, type: int, value: Optional[Text], context: Context) -> List[int]:
+ def classify(self, type: int, value: Text, context: Context) -> List[int]:
"""Turn a token into a label. (Internal)
Depending on whether the value is a soft-keyword or not,
this function may return multiple labels to choose from."""
if type == token.NAME:
# Keep a listing of all used names
- assert value is not None
self.used_names.add(value)
# Check for reserved words
if value in self.grammar.keywords:
raise ParseError("bad token", type, value, context)
return [ilabel]
- def shift(
- self, type: int, value: Optional[Text], newstate: int, context: Context
- ) -> None:
+ def shift(self, type: int, value: Text, newstate: int, context: Context) -> None:
"""Shift a token. (Internal)"""
dfa, state, node = self.stack[-1]
- assert value is not None
- assert context is not None
rawnode: RawNode = (type, value, context, None)
- newnode = self.convert(self.grammar, rawnode)
- if newnode is not None:
- assert node[-1] is not None
- node[-1].append(newnode)
+ newnode = convert(self.grammar, rawnode)
+ assert node[-1] is not None
+ node[-1].append(newnode)
self.stack[-1] = (dfa, newstate, node)
def push(self, type: int, newdfa: DFAS, newstate: int, context: Context) -> None:
def pop(self) -> None:
"""Pop a nonterminal. (Internal)"""
popdfa, popstate, popnode = self.stack.pop()
- newnode = self.convert(self.grammar, popnode)
- if newnode is not None:
- if self.stack:
- dfa, state, node = self.stack[-1]
- assert node[-1] is not None
- node[-1].append(newnode)
- else:
- self.rootnode = newnode
- self.rootnode.used_names = self.used_names
+ newnode = convert(self.grammar, popnode)
+ if self.stack:
+ dfa, state, node = self.stack[-1]
+ assert node[-1] is not None
+ node[-1].append(newnode)
+ else:
+ self.rootnode = newnode
+ self.rootnode.used_names = self.used_names
function to which the 5 fields described above are passed as 5 arguments,
each time a new token is found."""
+import sys
from typing import (
Callable,
Iterable,
Union,
cast,
)
+
+if sys.version_info >= (3, 8):
+ from typing import Final
+else:
+ from typing_extensions import Final
+
from blib2to3.pgen2.token import *
from blib2to3.pgen2.grammar import Grammar
PseudoExtras = group(r"\\\r?\n", Comment, Triple)
PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
-pseudoprog = re.compile(PseudoToken, re.UNICODE)
+pseudoprog: Final = re.compile(PseudoToken, re.UNICODE)
single3prog = re.compile(Single3)
double3prog = re.compile(Double3)
| {"u", "U", "ur", "uR", "Ur", "UR"}
)
-endprogs = {
+endprogs: Final = {
"'": re.compile(Single),
'"': re.compile(Double),
"'''": single3prog,
**{prefix: None for prefix in _strprefixes},
}
-triple_quoted = (
+triple_quoted: Final = (
{"'''", '"""'}
| {f"{prefix}'''" for prefix in _strprefixes}
| {f'{prefix}"""' for prefix in _strprefixes}
)
-single_quoted = (
+single_quoted: Final = (
{"'", '"'}
| {f"{prefix}'" for prefix in _strprefixes}
| {f'{prefix}"' for prefix in _strprefixes}
logical line; continuation lines are included.
"""
lnum = parenlev = continued = 0
- numchars = "0123456789"
+ numchars: Final = "0123456789"
contstr, needcont = "", 0
contline: Optional[str] = None
indents = [0]
# `await` as keywords.
async_keywords = False if grammar is None else grammar.async_keywords
# 'stashed' and 'async_*' are used for async/await parsing
- stashed = None
+ stashed: Optional[GoodTokenInfo] = None
async_def = False
async_def_indent = 0
async_def_nl = False
line = readline()
except StopIteration:
line = ""
- lnum = lnum + 1
+ lnum += 1
pos, max = 0, len(line)
if contstr: # continued string
column = 0
while pos < max: # measure leading whitespace
if line[pos] == " ":
- column = column + 1
+ column += 1
elif line[pos] == "\t":
column = (column // tabsize + 1) * tabsize
elif line[pos] == "\f":
column = 0
else:
break
- pos = pos + 1
+ pos += 1
if pos == max:
break
COMMENT,
comment_token,
(lnum, pos),
- (lnum, pos + len(comment_token)),
+ (lnum, nl_pos),
line,
)
yield (NL, line[nl_pos:], (lnum, nl_pos), (lnum, len(line)), line)
continued = 1
else:
if initial in "([{":
- parenlev = parenlev + 1
+ parenlev += 1
elif initial in ")]}":
- parenlev = parenlev - 1
+ parenlev -= 1
if stashed:
yield stashed
stashed = None
yield (OP, token, spos, epos, line)
else:
yield (ERRORTOKEN, line[pos], (lnum, pos), (lnum, pos + 1), line)
- pos = pos + 1
+ pos += 1
if stashed:
yield stashed
from typing import (
Any,
- Callable,
Dict,
Iterator,
List,
return NotImplemented
return self._eq(other)
- __hash__ = None # type: Any # For Py3 compatibility.
-
@property
def prefix(self) -> Text:
raise NotImplementedError
This reproduces the input source exactly.
"""
- return self.prefix + str(self.value)
+ return self._prefix + str(self.value)
def _eq(self, other) -> bool:
"""Compare two nodes for equality."""
newcontent = list(content)
for i, item in enumerate(newcontent):
assert isinstance(item, BasePattern), (i, item)
- if isinstance(item, WildcardPattern):
- self.wildcards = True
+ # I don't even think this code is used anywhere, but it does cause
+ # unreachable errors from mypy. This function's signature does look
+ # odd though *shrug*.
+ if isinstance(item, WildcardPattern): # type: ignore[unreachable]
+ self.wildcards = True # type: ignore[unreachable]
self.type = type
self.content = newcontent
self.name = name
r.update(r0)
r.update(r1)
yield c0 + c1, r
-
-
-_Convert = Callable[[Grammar, RawNode], Any]
runner = BlackRunner()
if ignore_config:
args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
- result = runner.invoke(black.main, args)
+ result = runner.invoke(black.main, args, catch_exceptions=False)
assert result.stdout_bytes is not None
assert result.stderr_bytes is not None
msg = (
)
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
+ @pytest.mark.incompatible_with_mypyc
def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py")
expected, _ = read_data("debug_visitor.out")
self.assertEqual(len(n.children), 1)
self.assertEqual(n.children[0].type, black.token.ENDMARKER)
+ @pytest.mark.incompatible_with_mypyc
@unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
def test_assertFormatEqual(self) -> None:
out_lines = []
actual = result.output
self.assertFormatEqual(actual, expected)
+ @pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin(self) -> None:
with patch(
"black.format_stdin_to_stdout",
fsts.assert_called_once()
report.done.assert_called_with(path, black.Changed.YES)
+ @pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename(self) -> None:
with patch(
"black.format_stdin_to_stdout",
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
+ @pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename_pyi(self) -> None:
with patch(
"black.format_stdin_to_stdout",
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
+ @pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
with patch(
"black.format_stdin_to_stdout",
# __BLACK_STDIN_FILENAME__ should have been stripped
report.done.assert_called_with(expected, black.Changed.YES)
+ @pytest.mark.incompatible_with_mypyc
def test_reformat_one_with_stdin_and_existing_path(self) -> None:
with patch(
"black.format_stdin_to_stdout",
self.assertEqual(config["exclude"], r"\.pyi?$")
self.assertEqual(config["include"], r"\.py?$")
+ @pytest.mark.incompatible_with_mypyc
def test_find_project_root(self) -> None:
with TemporaryDirectory() as workspace:
root = Path(workspace)
assert output == result_diff, "The output did not match the expected value."
assert result.exit_code == 0, "The exit code is incorrect."
+ @pytest.mark.incompatible_with_mypyc
def test_code_option_safe(self) -> None:
"""Test that the code option throws an error when the sanity checks fail."""
# Patch black.assert_equivalent to ensure the sanity checks fail
self.compare_results(result, formatted, 0)
+ @pytest.mark.incompatible_with_mypyc
def test_code_option_config(self) -> None:
"""
Test that the code option finds the pyproject.toml in the current directory.
call_args[0].lower() == str(pyproject_path).lower()
), "Incorrect config loaded."
+ @pytest.mark.incompatible_with_mypyc
def test_code_option_parent_config(self) -> None:
"""
Test that the code option finds the pyproject.toml in the parent directory.
src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
)
+ @pytest.mark.incompatible_with_mypyc
def test_symlink_out_of_root_directory(self) -> None:
path = MagicMock()
root = THIS_DIR.resolve()
}, non_python2_case
-with open(black.__file__, "r", encoding="utf-8") as _bf:
- black_source_lines = _bf.readlines()
+try:
+ with open(black.__file__, "r", encoding="utf-8") as _bf:
+ black_source_lines = _bf.readlines()
+except UnicodeDecodeError:
+ if not black.COMPILED:
+ raise
def tracefunc(