X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/b396f137618e4eb7c73b49033530383c45b160f3..729f2d8cafd1b8e44d7c0a6bd841453ffac01c8e:/black.py diff --git a/black.py b/black.py index 680b1f4..31859d1 100644 --- a/black.py +++ b/black.py @@ -1,6 +1,7 @@ +import ast import asyncio -from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from contextlib import contextmanager from datetime import datetime from enum import Enum from functools import lru_cache, partial, wraps @@ -11,11 +12,12 @@ from multiprocessing import Manager, freeze_support import os from pathlib import Path import pickle -import re +import regex as re import signal import sys import tempfile import tokenize +import traceback from typing import ( Any, Callable, @@ -35,11 +37,15 @@ from typing import ( Union, cast, ) +from typing_extensions import Final +from mypy_extensions import mypyc_attr from appdirs import user_cache_dir -from attr import dataclass, evolve, Factory +from dataclasses import dataclass, field, replace import click import toml +from typed_ast import ast3, ast27 +from pathspec import PathSpec # lib2to3 fork from blib2to3.pytree import Node, Leaf, type_repr @@ -48,12 +54,10 @@ from blib2to3.pgen2 import driver, token from blib2to3.pgen2.grammar import Grammar from blib2to3.pgen2.parse import ParseError +from _black_version import version as __version__ -__version__ = "19.3b0" DEFAULT_LINE_LENGTH = 88 -DEFAULT_EXCLUDES = ( - r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/" -) +DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) @@ -68,7 +72,7 @@ LeafID = int Priority = int Index = int LN = Union[Leaf, Node] -SplitFunc = Callable[["Line", bool], Iterator["Line"]] +SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]] Timestamp = float FileSize = int CacheInfo = Tuple[Timestamp, FileSize] @@ -133,38 +137,57 @@ class Feature(Enum): UNICODE_LITERALS = 1 F_STRINGS = 2 NUMERIC_UNDERSCORES = 3 - TRAILING_COMMA = 4 + TRAILING_COMMA_IN_CALL = 4 + TRAILING_COMMA_IN_DEF = 5 + # The following two feature-flags are mutually exclusive, and exactly one should be + # set for every version of python. + ASYNC_IDENTIFIERS = 6 + ASYNC_KEYWORDS = 7 + ASSIGNMENT_EXPRESSIONS = 8 + POS_ONLY_ARGUMENTS = 9 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { - TargetVersion.PY27: set(), - TargetVersion.PY33: {Feature.UNICODE_LITERALS}, - TargetVersion.PY34: {Feature.UNICODE_LITERALS}, - TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA}, + TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS}, + TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS}, + TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS}, + TargetVersion.PY35: { + Feature.UNICODE_LITERALS, + Feature.TRAILING_COMMA_IN_CALL, + Feature.ASYNC_IDENTIFIERS, + }, TargetVersion.PY36: { Feature.UNICODE_LITERALS, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES, - Feature.TRAILING_COMMA, + Feature.TRAILING_COMMA_IN_CALL, + Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_IDENTIFIERS, }, TargetVersion.PY37: { Feature.UNICODE_LITERALS, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES, - Feature.TRAILING_COMMA, + Feature.TRAILING_COMMA_IN_CALL, + Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_KEYWORDS, }, TargetVersion.PY38: { Feature.UNICODE_LITERALS, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES, - Feature.TRAILING_COMMA, + Feature.TRAILING_COMMA_IN_CALL, + Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_KEYWORDS, + Feature.ASSIGNMENT_EXPRESSIONS, + Feature.POS_ONLY_ARGUMENTS, }, } @dataclass -class FileMode: - target_versions: Set[TargetVersion] = Factory(set) +class Mode: + target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True is_pyi: bool = False @@ -186,10 +209,31 @@ class FileMode: return ".".join(parts) +# Legacy name, left for integrations. +FileMode = Mode + + def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool: return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] ) -> Optional[str]: @@ -200,16 +244,12 @@ def read_pyproject_toml( """ assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -220,13 +260,23 @@ def read_pyproject_toml( if ctx.default_map is None: ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + ctx.default_map.update(config) # type: ignore # bad types in .pyi return value +def target_version_option_callback( + c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] +) -> List[TargetVersion]: + """Compute the target versions from a --target-version flag. + + This is its own function because mypy couldn't infer the type correctly + when it was a lambda, causing mypyc trouble. + """ + return [TargetVersion[val.upper()] for val in v] + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( "-l", "--line-length", @@ -239,7 +289,7 @@ def read_pyproject_toml( "-t", "--target-version", type=click.Choice([v.name.lower() for v in TargetVersion]), - callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v], + callback=target_version_option_callback, multiple=True, help=( "Python versions that should be supported by Black's output. [default: " @@ -319,7 +369,7 @@ def read_pyproject_toml( "--quiet", is_flag=True, help=( - "Don't emit non-error messages to stderr. Errors are still emitted, " + "Don't emit non-error messages to stderr. Errors are still emitted; " "silence those with 2>/dev/null." ), ) @@ -353,6 +403,7 @@ def read_pyproject_toml( @click.pass_context def main( ctx: click.Context, + code: Optional[str], line_length: int, target_version: List[TargetVersion], check: bool, @@ -365,7 +416,7 @@ def main( verbose: bool, include: str, exclude: str, - src: Tuple[str], + src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" @@ -385,7 +436,7 @@ def main( else: # We'll autodetect later. versions = set() - mode = FileMode( + mode = Mode( target_versions=versions, line_length=line_length, is_pyi=pyi, @@ -393,6 +444,9 @@ def main( ) if config and verbose: out(f"Using configuration from {config}.", bold=False, fg="blue") + if code is not None: + print(format_str(code, mode=mode)) + ctx.exit(0) try: include_regex = re_compile_maybe_verbose(include) except re.error: @@ -403,14 +457,17 @@ def main( except re.error: err(f"Invalid regular expression for exclude given: {exclude!r}") ctx.exit(2) - report = Report(check=check, quiet=quiet, verbose=verbose) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) root = find_project_root(src) sources: Set[Path] = set() + path_empty(src, quiet, verbose, ctx) for s in src: p = Path(s) if p.is_dir(): sources.update( - gen_python_files_in_dir(p, root, include_regex, exclude_regex, report) + gen_python_files_in_dir( + p, root, include_regex, exclude_regex, report, get_gitignore(root) + ) ) elif p.is_file() or s == "-": # if a file was explicitly given, we don't care about its extension @@ -419,7 +476,7 @@ def main( err(f"invalid path: {s}") if len(sources) == 0: if verbose or not quiet: - out("No paths given. Nothing to do 😴") + out("No Python files are present to be formatted. Nothing to do 😴") ctx.exit(0) if len(sources) == 1: @@ -431,36 +488,34 @@ def main( report=report, ) else: - loop = asyncio.get_event_loop() - executor = ProcessPoolExecutor(max_workers=os.cpu_count()) - try: - loop.run_until_complete( - schedule_formatting( - sources=sources, - fast=fast, - write_back=write_back, - mode=mode, - report=report, - loop=loop, - executor=executor, - ) - ) - finally: - shutdown(loop) + reformat_many( + sources=sources, fast=fast, write_back=write_back, mode=mode, report=report + ) + if verbose or not quiet: - bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨" - out(f"All done! {bang}") + out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") click.secho(str(report), err=True) ctx.exit(report.return_code) +def path_empty( + src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context +) -> None: + """ + Exit if there is no `src` provided for formatting + """ + if not src: + if verbose or not quiet: + out("No Path provided. Nothing to do 😴") + ctx.exit(0) + + def reformat_one( - src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report" + src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat a single file under `src` without spawning child processes. - If `quiet` is True, non-error messages are not output. `line_length`, - `write_back`, `fast` and `pyi` options are passed to + `fast`, `write_back`, and `mode` options are passed to :func:`format_file_in_place` or :func:`format_stdin_to_stdout`. """ try: @@ -488,20 +543,47 @@ def reformat_one( report.failed(src, str(exc)) +def reformat_many( + sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" +) -> None: + """Reformat multiple files using a ProcessPoolExecutor.""" + loop = asyncio.get_event_loop() + worker_count = os.cpu_count() + if sys.platform == "win32": + # Work around https://bugs.python.org/issue26903 + worker_count = min(worker_count, 61) + executor = ProcessPoolExecutor(max_workers=worker_count) + try: + loop.run_until_complete( + schedule_formatting( + sources=sources, + fast=fast, + write_back=write_back, + mode=mode, + report=report, + loop=loop, + executor=executor, + ) + ) + finally: + shutdown(loop) + executor.shutdown() + + async def schedule_formatting( sources: Set[Path], fast: bool, write_back: WriteBack, - mode: FileMode, + mode: Mode, report: "Report", - loop: BaseEventLoop, + loop: asyncio.AbstractEventLoop, executor: Executor, ) -> None: """Run formatting of `sources` in parallel using the provided `executor`. (Use ProcessPoolExecutors for actual parallelism.) - `line_length`, `write_back`, `fast`, and `pyi` options are passed to + `write_back`, `fast`, and `mode` options are passed to :func:`format_file_in_place`. """ cache: Cache = {} @@ -522,12 +604,14 @@ async def schedule_formatting( manager = Manager() lock = manager.Lock() tasks = { - loop.run_in_executor( - executor, format_file_in_place, src, fast, mode, write_back, lock + asyncio.ensure_future( + loop.run_in_executor( + executor, format_file_in_place, src, fast, mode, write_back, lock + ) ): src for src in sorted(sources) } - pending: Iterable[asyncio.Task] = tasks.keys() + pending: Iterable["asyncio.Future[bool]"] = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -560,7 +644,7 @@ async def schedule_formatting( def format_file_in_place( src: Path, fast: bool, - mode: FileMode, + mode: Mode, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: @@ -568,10 +652,10 @@ def format_file_in_place( If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted code to the file. - `line_length` and `fast` options are passed to :func:`format_file_contents`. + `mode` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": - mode = evolve(mode, is_pyi=True) + mode = replace(mode, is_pyi=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -581,17 +665,16 @@ def format_file_in_place( except NothingChanged: return False - if write_back == write_back.YES: + if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) - elif write_back == write_back.DIFF: + elif write_back == WriteBack.DIFF: now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) - if lock: - lock.acquire() - try: + + with lock or nullcontext(): f = io.TextIOWrapper( sys.stdout.buffer, encoding=encoding, @@ -600,14 +683,12 @@ def format_file_in_place( ) f.write(diff_contents) f.detach() - finally: - if lock: - lock.release() + return True def format_stdin_to_stdout( - fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode + fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode ) -> bool: """Format file on stdin. Return True if changed. @@ -639,14 +720,12 @@ def format_stdin_to_stdout( f.detach() -def format_file_contents( - src_contents: str, *, fast: bool, mode: FileMode -) -> FileContent: +def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. - `line_length` is passed to :func:`format_str`. + `mode` is passed to :func:`format_str`. """ if src_contents.strip() == "": raise NothingChanged @@ -661,13 +740,37 @@ def format_file_contents( return dst_contents -def format_str(src_contents: str, *, mode: FileMode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. - `line_length` determines how many characters per line are allowed. + `mode` determines formatting options, such as how many characters per line are + allowed. Example: + + >>> import black + >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + def f(arg: str = "") -> None: + ... + + A more complex example: + >>> print( + ... black.format_str( + ... "def f(arg:str='')->None: hey", + ... mode=black.Mode( + ... target_versions={black.TargetVersion.PY36}, + ... line_length=10, + ... string_normalization=False, + ... is_pyi=False, + ... ), + ... ), + ... ) + def f( + arg: str = '', + ) -> None: + hey + """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_contents = "" + dst_contents = [] future_imports = get_future_imports(src_node) if mode.target_versions: versions = mode.target_versions @@ -683,19 +786,20 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: elt = EmptyLineTracker(is_pyi=mode.is_pyi) empty_line = Line() after = 0 + split_line_features = { + feature + for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF} + if supports_feature(versions, feature) + } for current_line in lines.visit(src_node): - for _ in range(after): - dst_contents += str(empty_line) + dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) - for _ in range(before): - dst_contents += str(empty_line) + dst_contents.append(str(empty_line) * before) for line in split_line( - current_line, - line_length=mode.line_length, - supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA), + current_line, line_length=mode.line_length, features=split_line_features ): - dst_contents += str(line) - return dst_contents + dst_contents.append(str(line)) + return "".join(dst_contents) def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: @@ -715,24 +819,43 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: return tiow.read(), encoding, newline -GRAMMARS = [ - pygram.python_grammar_no_print_statement_no_exec_statement, - pygram.python_grammar_no_print_statement, - pygram.python_grammar, -] - - def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: if not target_versions: - return GRAMMARS - elif all(not version.is_python2() for version in target_versions): - # Python 3-compatible code, so don't try Python 2 grammar + # No target_version specified, so try all grammars. return [ + # Python 3.7+ + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords, + # Python 3.0-3.6 pygram.python_grammar_no_print_statement_no_exec_statement, + # Python 2.7 with future print_function import pygram.python_grammar_no_print_statement, + # Python 2.7 + pygram.python_grammar, + ] + + if all(version.is_python2() for version in target_versions): + # Python 2-only code, so try Python 2 grammars. + return [ + # Python 2.7 with future print_function import + pygram.python_grammar_no_print_statement, + # Python 2.7 + pygram.python_grammar, ] - else: - return [pygram.python_grammar_no_print_statement, pygram.python_grammar] + + # Python 3-compatible code, so only try Python 3 grammar. + grammars = [] + # If we have to parse both, try to parse async as a keyword first + if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS): + # Python 3.7+ + grammars.append( + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords + ) + if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS): + # Python 3.0-3.6 + grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement) + # At least one of the above branches must have been taken, because every Python + # version has exactly one of the two 'ASYNC_*' flags + return grammars def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: @@ -787,8 +910,16 @@ class Visitor(Generic[T]): if node.type < 256: name = token.tok_name[node.type] else: - name = type_repr(node.type) - yield from getattr(self, f"visit_{name}", self.visit_default)(node) + name = str(type_repr(node.type)) + # We explicitly branch on whether a visitor exists (instead of + # using self.visit_default as the default arg to getattr) in order + # to save needing to create a bound method object and so mypyc can + # generate a native call to visit_default. + visitf = getattr(self, f"visit_{name}", None) + if visitf: + yield from visitf(node) + else: + yield from self.visit_default(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -833,8 +964,8 @@ class DebugVisitor(Visitor[T]): list(v.visit(code)) -WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -STATEMENT = { +WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE} +STATEMENT: Final = { syms.if_stmt, syms.while_stmt, syms.for_stmt, @@ -844,10 +975,10 @@ STATEMENT = { syms.funcdef, syms.classdef, } -STANDALONE_COMMENT = 153 +STANDALONE_COMMENT: Final = 153 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT" -LOGIC_OPERATORS = {"and", "or"} -COMPARATORS = { +LOGIC_OPERATORS: Final = {"and", "or"} +COMPARATORS: Final = { token.LESS, token.GREATER, token.EQEQUAL, @@ -855,7 +986,7 @@ COMPARATORS = { token.LESSEQUAL, token.GREATEREQUAL, } -MATH_OPERATORS = { +MATH_OPERATORS: Final = { token.VBAR, token.CIRCUMFLEX, token.AMPER, @@ -871,22 +1002,23 @@ MATH_OPERATORS = { token.TILDE, token.DOUBLESTAR, } -STARS = {token.STAR, token.DOUBLESTAR} -VARARGS_PARENTS = { +STARS: Final = {token.STAR, token.DOUBLESTAR} +VARARGS_SPECIALS: Final = STARS | {token.SLASH} +VARARGS_PARENTS: Final = { syms.arglist, syms.argument, # double star in arglist syms.trailer, # single argument to call syms.typedargslist, syms.varargslist, # lambdas } -UNPACKING_PARENTS = { +UNPACKING_PARENTS: Final = { syms.atom, # single element of a list or set literal syms.dictsetmaker, syms.listmaker, syms.testlist_gexp, syms.testlist_star_expr, } -TEST_DESCENDANTS = { +TEST_DESCENDANTS: Final = { syms.test, syms.lambdef, syms.or_test, @@ -903,7 +1035,7 @@ TEST_DESCENDANTS = { syms.term, syms.power, } -ASSIGNMENTS = { +ASSIGNMENTS: Final = { "=", "+=", "-=", @@ -919,13 +1051,13 @@ ASSIGNMENTS = { "**=", "//=", } -COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 18 -TERNARY_PRIORITY = 16 -LOGIC_PRIORITY = 14 -STRING_PRIORITY = 12 -COMPARATOR_PRIORITY = 10 -MATH_PRIORITIES = { +COMPREHENSION_PRIORITY: Final = 20 +COMMA_PRIORITY: Final = 18 +TERNARY_PRIORITY: Final = 16 +LOGIC_PRIORITY: Final = 14 +STRING_PRIORITY: Final = 12 +COMPARATOR_PRIORITY: Final = 10 +MATH_PRIORITIES: Final = { token.VBAR: 9, token.CIRCUMFLEX: 8, token.AMPER: 7, @@ -941,7 +1073,7 @@ MATH_PRIORITIES = { token.TILDE: 3, token.DOUBLESTAR: 2, } -DOT_PRIORITY = 1 +DOT_PRIORITY: Final = 1 @dataclass @@ -949,11 +1081,11 @@ class BracketTracker: """Keeps track of brackets on a line.""" depth: int = 0 - bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) - delimiters: Dict[LeafID, Priority] = Factory(dict) + bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) + delimiters: Dict[LeafID, Priority] = field(default_factory=dict) previous: Optional[Leaf] = None - _for_loop_depths: List[int] = Factory(list) - _lambda_argument_depths: List[int] = Factory(list) + _for_loop_depths: List[int] = field(default_factory=list) + _lambda_argument_depths: List[int] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -999,7 +1131,7 @@ class BracketTracker: """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) - def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: + def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority: """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_split_*_delimiter()` return. @@ -1007,7 +1139,7 @@ class BracketTracker: """ return max(v for k, v in self.delimiters.items() if k not in exclude) - def delimiter_count_with_priority(self, priority: int = 0) -> int: + def delimiter_count_with_priority(self, priority: Priority = 0) -> int: """Return the number of delimiters with the given `priority`. If no `priority` is passed, defaults to max priority on the line. @@ -1081,9 +1213,10 @@ class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" depth: int = 0 - leaves: List[Leaf] = Factory(list) - comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves` - bracket_tracker: BracketTracker = Factory(BracketTracker) + leaves: List[Leaf] = field(default_factory=list) + # keys ordered like `leaves` + comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) + bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False should_explode: bool = False @@ -1163,6 +1296,69 @@ class Line: Leaf(token.DOT, ".") for _ in range(3) ] + @property + def is_collection_with_optional_trailing_comma(self) -> bool: + """Is this line a collection literal with a trailing comma that's optional? + + Note that the trailing comma in a 1-tuple is not optional. + """ + if not self.leaves or len(self.leaves) < 4: + return False + + # Look for and address a trailing colon. + if self.leaves[-1].type == token.COLON: + closer = self.leaves[-2] + close_index = -2 + else: + closer = self.leaves[-1] + close_index = -1 + if closer.type not in CLOSING_BRACKETS or self.inside_brackets: + return False + + if closer.type == token.RPAR: + # Tuples require an extra check, because if there's only + # one element in the tuple removing the comma unmakes the + # tuple. + # + # We also check for parens before looking for the trailing + # comma because in some cases (eg assigning a dict + # literal) the literal gets wrapped in temporary parens + # during parsing. This case is covered by the + # collections.py test data. + opener = closer.opening_bracket + for _open_index, leaf in enumerate(self.leaves): + if leaf is opener: + break + + else: + # Couldn't find the matching opening paren, play it safe. + return False + + commas = 0 + comma_depth = self.leaves[close_index - 1].bracket_depth + for leaf in self.leaves[_open_index + 1 : close_index]: + if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA: + commas += 1 + if commas > 1: + # We haven't looked yet for the trailing comma because + # we might also have caught noop parens. + return self.leaves[close_index - 1].type == token.COMMA + + elif commas == 1: + return False # it's either a one-tuple or didn't have a trailing comma + + if self.leaves[close_index - 1].type in CLOSING_BRACKETS: + close_index -= 1 + closer = self.leaves[close_index] + if closer.type == token.RPAR: + # TODO: this is a gut feeling. Will we ever see this? + return False + + if self.leaves[close_index - 1].type != token.COMMA: + return False + + return True + @property def is_def(self) -> bool: """Is this a function definition? (Also returns True for async defs.)""" @@ -1210,95 +1406,117 @@ class Line: def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: - if leaf.type == STANDALONE_COMMENT: - if leaf.bracket_depth <= depth_limit: - return True + if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit: + return True + return False - def contains_inner_type_comments(self) -> bool: + def contains_uncollapsable_type_comments(self) -> bool: ignored_ids = set() try: last_leaf = self.leaves[-1] ignored_ids.add(id(last_leaf)) - if last_leaf.type == token.COMMA: - # When trailing commas are inserted by Black for consistency, comments - # after the previous last element are not moved (they don't have to, - # rendering will still be correct). So we ignore trailing commas. + if last_leaf.type == token.COMMA or ( + last_leaf.type == token.RPAR and not last_leaf.value + ): + # When trailing commas or optional parens are inserted by Black for + # consistency, comments after the previous last element are not moved + # (they don't have to, rendering will still be correct). So we ignore + # trailing commas and invisible. last_leaf = self.leaves[-2] ignored_ids.add(id(last_leaf)) except IndexError: return False + # A type comment is uncollapsable if it is attached to a leaf + # that isn't at the end of the line (since that could cause it + # to get associated to a different argument) or if there are + # comments before it (since that could cause it to get hidden + # behind a comment. + comment_seen = False for leaf_id, comments in self.comments.items(): - if leaf_id in ignored_ids: - continue - for comment in comments: if is_type_comment(comment): - return True + if comment_seen or ( + not is_type_comment(comment, " ignore") + and leaf_id not in ignored_ids + ): + return True + + comment_seen = True return False - def contains_multiline_strings(self) -> bool: - for leaf in self.leaves: - if is_multiline_string(leaf): - return True + def contains_unsplittable_type_ignore(self) -> bool: + if not self.leaves: + return False + + # If a 'type: ignore' is attached to the end of a line, we + # can't split the line, because we can't know which of the + # subexpressions the ignore was meant to apply to. + # + # We only want this to apply to actual physical lines from the + # original source, though: we don't want the presence of a + # 'type: ignore' at the end of a multiline expression to + # justify pushing it all onto one line. Thus we + # (unfortunately) need to check the actual source lines and + # only report an unsplittable 'type: ignore' if this line was + # one line in the original code. + + # Grab the first and last line numbers, skipping generated leaves + first_line = next((l.lineno for l in self.leaves if l.lineno != 0), 0) + last_line = next((l.lineno for l in reversed(self.leaves) if l.lineno != 0), 0) + + if first_line == last_line: + # We look at the last two leaves since a comma or an + # invisible paren could have been added at the end of the + # line. + for node in self.leaves[-2:]: + for comment in self.comments.get(id(node), []): + if is_type_comment(comment, " ignore"): + return True return False + def contains_multiline_strings(self) -> bool: + return any(is_multiline_string(leaf) for leaf in self.leaves) + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" + if not (self.leaves and self.leaves[-1].type == token.COMMA): + return False + + # We remove trailing commas only in the case of importing a + # single name from a module. if not ( self.leaves + and self.is_import + and len(self.leaves) > 4 and self.leaves[-1].type == token.COMMA and closing.type in CLOSING_BRACKETS + and self.leaves[-4].type == token.NAME + and ( + # regular `from foo import bar,` + self.leaves[-4].value == "import" + # `from foo import (bar as baz,) + or ( + len(self.leaves) > 6 + and self.leaves[-6].value == "import" + and self.leaves[-3].value == "as" + ) + # `from foo import bar as baz,` + or ( + len(self.leaves) > 5 + and self.leaves[-5].value == "import" + and self.leaves[-3].value == "as" + ) + ) + and closing.type == token.RPAR ): return False - if closing.type == token.RBRACE: - self.remove_trailing_comma() - return True - - if closing.type == token.RSQB: - comma = self.leaves[-1] - if comma.parent and comma.parent.type == syms.listmaker: - self.remove_trailing_comma() - return True - - # For parens let's check if it's safe to remove the comma. - # Imports are always safe. - if self.is_import: - self.remove_trailing_comma() - return True - - # Otherwise, if the trailing one is the only one, we might mistakenly - # change a tuple into a different type by removing the comma. - depth = closing.bracket_depth + 1 - commas = 0 - opening = closing.opening_bracket - for _opening_index, leaf in enumerate(self.leaves): - if leaf is opening: - break - - else: - return False - - for leaf in self.leaves[_opening_index + 1 :]: - if leaf is closing: - break - - bracket_depth = leaf.bracket_depth - if bracket_depth == depth and leaf.type == token.COMMA: - commas += 1 - if leaf.parent and leaf.parent.type == syms.arglist: - commas += 1 - break - - if commas > 1: - self.remove_trailing_comma() - return True - - return False + self.remove_trailing_comma() + return True def append_comment(self, comment: Leaf) -> bool: """Add an inline or standalone comment to the line.""" @@ -1317,7 +1535,24 @@ class Line: comment.prefix = "" return False - self.comments.setdefault(id(self.leaves[-1]), []).append(comment) + last_leaf = self.leaves[-1] + if ( + last_leaf.type == token.RPAR + and not last_leaf.value + and last_leaf.parent + and len(list(last_leaf.parent.leaves())) <= 3 + and not is_type_comment(comment) + ): + # Comments on an optional parens wrapping a single leaf should belong to + # the wrapped node except if it's a type comment. Pinning the comment like + # this avoids unstable formatting caused by comment migration. + if len(self.leaves) < 2: + comment.type = STANDALONE_COMMENT + comment.prefix = "" + return False + + last_leaf = self.leaves[-2] + self.comments.setdefault(id(last_leaf), []).append(comment) return True def comments_after(self, leaf: Leaf) -> List[Leaf]: @@ -1383,7 +1618,7 @@ class EmptyLineTracker: is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 - previous_defs: List[int] = Factory(list) + previous_defs: List[int] = field(default_factory=list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: """Return the number of extra empty lines before and after the `current_line`. @@ -1392,7 +1627,13 @@ class EmptyLineTracker: lines (two on module-level). """ before, after = self._maybe_empty_lines(current_line) - before -= self.previous_after + before = ( + # Black should not insert empty lines at the beginning + # of the file + 0 + if self.previous_line is None + else before - self.previous_after + ) self.previous_after = after self.previous_line = current_line return before, after @@ -1491,7 +1732,7 @@ class LineGenerator(Visitor[Line]): is_pyi: bool = False normalize_strings: bool = True - current_line: Line = Factory(Line) + current_line: Line = field(default_factory=Line) remove_u_prefix: bool = False def line(self, indent: int = 0) -> Iterator[Line]: @@ -1540,13 +1781,13 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) - def visit_INDENT(self, node: Node) -> Iterator[Line]: + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) - def visit_DEDENT(self, node: Node) -> Iterator[Line]: + def visit_DEDENT(self, node: Leaf) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" # The current line might still wait for trailing comments. At DEDENT time # there won't be any (they would be prefixes on the preceding NEWLINE). @@ -1639,7 +1880,24 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit_default(leaf) - def __attrs_post_init__(self) -> None: + def visit_factor(self, node: Node) -> Iterator[Line]: + """Force parentheses between a unary op and a binary power: + + -2 ** 8 -> -(2 ** 8) + """ + _operator, operand = node.children + if ( + operand.type == syms.power + and len(operand.children) == 3 + and operand.children[1].type == token.DOUBLESTAR + ): + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + index = operand.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) + yield from self.visit_default(node) + + def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt Ø: Set[str] = set() @@ -1729,7 +1987,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901 # that, too. return prevp.prefix - elif prevp.type in STARS: + elif prevp.type in VARARGS_SPECIALS: if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS): return NO @@ -1819,7 +2077,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901 if not prevp or prevp.type == token.LPAR: return NO - elif prev.type in {token.EQUAL} | STARS: + elif prev.type in {token.EQUAL} | VARARGS_SPECIALS: return NO elif p.type == syms.decorator: @@ -1953,7 +2211,7 @@ def container_of(leaf: Leaf) -> LN: return container -def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int: +def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break after it. The delimiter priorities returned here are from those delimiters that would @@ -1967,7 +2225,7 @@ def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int return 0 -def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int: +def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break before it. The delimiter priorities returned here are from those delimiters that would @@ -2117,15 +2375,21 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]: consumed = 0 nlines = 0 + ignored_lines = 0 for index, line in enumerate(prefix.split("\n")): consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() if not line: nlines += 1 if not line.startswith("#"): + # Escaped newlines outside of a comment are not really newlines at + # all. We treat a single-line comment following an escaped newline + # as a simple trailing comment. + if line.endswith("\\"): + ignored_lines += 1 continue - if index == 0 and not is_endmarker: + if index == ignored_lines and not is_endmarker: comment_type = token.COMMENT # simple trailing comment else: comment_type = STANDALONE_COMMENT @@ -2162,7 +2426,7 @@ def split_line( line: Line, line_length: int, inner: bool = False, - supports_trailing_commas: bool = False, + features: Collection[Feature] = (), ) -> Iterator[Line]: """Split a `line` into potentially many lines. @@ -2171,7 +2435,7 @@ def split_line( current `line`, possibly transitively. This means we can fallback to splitting by delimiters if the LHS/RHS don't yield any results. - If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature. + `features` are syntactical features that may be used in the output. """ if line.is_comment: yield line @@ -2180,9 +2444,13 @@ def split_line( line_str = str(line).strip("\n") if ( - not line.contains_inner_type_comments() + not line.contains_uncollapsable_type_comments() and not line.should_explode - and is_line_short_enough(line, line_length=line_length, line_str=line_str) + and not line.is_collection_with_optional_trailing_comma + and ( + is_line_short_enough(line, line_length=line_length, line_str=line_str) + or line.contains_unsplittable_type_ignore() + ) ): yield line return @@ -2192,13 +2460,9 @@ def split_line( split_funcs = [left_hand_split] else: - def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]: + def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]: for omit in generate_trailers_to_omit(line, line_length): - lines = list( - right_hand_split( - line, line_length, supports_trailing_commas, omit=omit - ) - ) + lines = list(right_hand_split(line, line_length, features, omit=omit)) if is_line_short_enough(lines[0], line_length=line_length): yield from lines return @@ -2206,7 +2470,9 @@ def split_line( # All splits failed, best effort split with no omits. # This mostly happens to multiline strings that are by definition # reported as not fitting a single line. - yield from right_hand_split(line, line_length, supports_trailing_commas) + # line_length=1 here was historically a bug that somehow became a feature. + # See #762 and #781 for the full story. + yield from right_hand_split(line, line_length=1, features=features) if line.inside_brackets: split_funcs = [delimiter_split, standalone_comment_split, rhs] @@ -2218,16 +2484,13 @@ def split_line( # split altogether. result: List[Line] = [] try: - for l in split_func(line, supports_trailing_commas): + for l in split_func(line, features): if str(l).strip("\n") == line_str: raise CannotSplit("Split function returned an unchanged result") result.extend( split_line( - l, - line_length=line_length, - inner=True, - supports_trailing_commas=supports_trailing_commas, + l, line_length=line_length, inner=True, features=features ) ) except CannotSplit: @@ -2241,9 +2504,7 @@ def split_line( yield line -def left_hand_split( - line: Line, supports_trailing_commas: bool = False -) -> Iterator[Line]: +def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -2254,7 +2515,7 @@ def left_hand_split( body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = head_leaves - matching_bracket = None + matching_bracket: Optional[Leaf] = None for leaf in line.leaves: if ( current_leaves is body_leaves @@ -2282,7 +2543,7 @@ def left_hand_split( def right_hand_split( line: Line, line_length: int, - supports_trailing_commas: bool = False, + features: Collection[Feature] = (), omit: Collection[LeafID] = (), ) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair. @@ -2297,8 +2558,8 @@ def right_hand_split( body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = tail_leaves - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: @@ -2341,12 +2602,7 @@ def right_hand_split( ): omit = {id(closing_bracket), *omit} try: - yield from right_hand_split( - line, - line_length, - supports_trailing_commas=supports_trailing_commas, - omit=omit, - ) + yield from right_hand_split(line, line_length, features=features, omit=omit) return except CannotSplit: @@ -2414,10 +2670,23 @@ def bracket_split_build_line( if leaves: # Since body is a new indent level, remove spurious leading whitespace. normalize_prefix(leaves[0], inside_brackets=True) - # Ensure a trailing comma when expected. - if original.is_import: - if leaves[-1].type != token.COMMA: - leaves.append(Leaf(token.COMMA, ",")) + # Ensure a trailing comma for imports and standalone function arguments, but + # be careful not to add one after any comments or within type annotations. + no_commas = ( + original.is_def + and opening_bracket.value == "(" + and not any(l.type == token.COMMA for l in leaves) + ) + + if original.is_import or no_commas: + for i in range(len(leaves) - 1, -1, -1): + if leaves[i].type == STANDALONE_COMMENT: + continue + + if leaves[i].type != token.COMMA: + leaves.insert(i + 1, Leaf(token.COMMA, ",")) + break + # Populate the line for leaf in leaves: result.append(leaf, preformatted=True) @@ -2435,10 +2704,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc: """ @wraps(split_func) - def split_wrapper( - line: Line, supports_trailing_commas: bool = False - ) -> Iterator[Line]: - for l in split_func(line, supports_trailing_commas): + def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: + for l in split_func(line, features): normalize_prefix(l.leaves[0], inside_brackets=True) yield l @@ -2446,13 +2713,11 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc: @dont_increase_indentation -def delimiter_split( - line: Line, supports_trailing_commas: bool = False -) -> Iterator[Line]: +def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: """Split according to delimiters of the highest priority. - If `supports_trailing_commas` is True, the split will add trailing commas - also in function signatures that contain `*` and `**`. + If the appropriate Features are given, the split will add trailing commas + also in function signatures and calls that contain `*` and `**`. """ try: last_leaf = line.leaves[-1] @@ -2491,10 +2756,16 @@ def delimiter_split( yield from append_to_line(comment_after) lowest_depth = min(lowest_depth, leaf.bracket_depth) - if leaf.bracket_depth == lowest_depth and is_vararg( - leaf, within=VARARGS_PARENTS - ): - trailing_comma_safe = trailing_comma_safe and supports_trailing_commas + if leaf.bracket_depth == lowest_depth: + if is_vararg(leaf, within={syms.typedargslist}): + trailing_comma_safe = ( + trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features + ) + elif is_vararg(leaf, within={syms.arglist, syms.argument}): + trailing_comma_safe = ( + trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features + ) + leaf_priority = bt.delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: yield current_line @@ -2513,7 +2784,7 @@ def delimiter_split( @dont_increase_indentation def standalone_comment_split( - line: Line, supports_trailing_commas: bool = False + line: Line, features: Collection[Feature] = () ) -> Iterator[Line]: """Split standalone comments from the rest of the line.""" if not line.contains_standalone_comments(0): @@ -2556,12 +2827,12 @@ def is_import(leaf: Leaf) -> bool: ) -def is_type_comment(leaf: Leaf) -> bool: +def is_type_comment(leaf: Leaf, suffix: str = "") -> bool: """Return True if the given leaf is a special comment. Only returns true for type comments for now.""" t = leaf.type v = leaf.value - return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:") + return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:" + suffix) def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: @@ -2592,7 +2863,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) assert match is not None, f"failed to match string {leaf.value!r}" orig_prefix = match.group(1) - new_prefix = orig_prefix.lower() + new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") if remove_u_prefix: new_prefix = new_prefix.replace("u", "") leaf.value = f"{new_prefix}{match.group(2)}" @@ -2646,11 +2917,20 @@ def normalize_string_quotes(leaf: Leaf) -> None: new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) if "f" in prefix.casefold(): - matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body) + matches = re.findall( + r""" + (?:[^{]|^)\{ # start of the string or a non-{ followed by a single { + ([^{].*?) # contents of the brackets except if begins with {{ + \}(?:[^}]|$) # A } followed by end of the string or a non-} + """, + new_body, + re.VERBOSE, + ) for m in matches: if "\\" in str(m): # Do not introduce backslashes in interpolated expressions return + if new_quote == '"""' and new_body[-1:] == '"': # edge case: new_body = new_body[:-1] + '\\"' @@ -2713,7 +2993,7 @@ def format_float_or_int_string(text: str) -> str: def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: """Make existing optional parentheses invisible or create new ones. - `parens_after` is a set of string leaf values immeditely after which parens + `parens_after` is a set of string leaf values immediately after which parens should be put. Standardizes on visible parentheses for single-element tuples, and keeps @@ -2723,22 +3003,25 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: if pc.value in FMT_OFF: # This `node` has a prefix with `# fmt: off`, don't mess with parens. return - check_lpar = False for index, child in enumerate(list(node.children)): + # Add parentheses around long tuple unpacking in assignments. + if ( + index == 0 + and isinstance(child, Node) + and child.type == syms.testlist_star_expr + ): + check_lpar = True + if check_lpar: + if is_walrus_assignment(child): + continue + if child.type == syms.atom: - if maybe_make_parens_invisible_in_atom(child): - lpar = Leaf(token.LPAR, "") - rpar = Leaf(token.RPAR, "") - index = child.remove() or 0 - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + if maybe_make_parens_invisible_in_atom(child, parent=node): + wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): - # wrap child in visible parentheses - lpar = Leaf(token.LPAR, "(") - rpar = Leaf(token.RPAR, ")") - child.remove() - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + wrap_in_parentheses(node, child, visible=True) elif node.type == syms.import_from: # "import from" nodes store parentheses directly as part of # the statement @@ -2753,11 +3036,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: break elif not (isinstance(child, Leaf) and is_multiline_string(child)): - # wrap child in invisible parentheses - lpar = Leaf(token.LPAR, "") - rpar = Leaf(token.RPAR, "") - index = child.remove() or 0 - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + wrap_in_parentheses(node, child, visible=False) check_lpar = isinstance(child, Leaf) and child.value in parens_after @@ -2801,7 +3080,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool: # That happens when one of the `ignored_nodes` ended with a NEWLINE # leaf (possibly followed by a DEDENT). hidden_value = hidden_value[:-1] - first_idx = None + first_idx: Optional[int] = None for ignored in ignored_nodes: index = ignored.remove() if first_idx is None: @@ -2830,17 +3109,24 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: """ container: Optional[LN] = container_of(leaf) while container is not None and container.type != token.ENDMARKER: + is_fmt_on = False for comment in list_comments(container.prefix, is_endmarker=False): if comment.value in FMT_ON: - return + is_fmt_on = True + elif comment.value in FMT_OFF: + is_fmt_on = False + if is_fmt_on: + return yield container container = container.next_sibling -def maybe_make_parens_invisible_in_atom(node: LN) -> bool: +def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: """If it's safe, make the parens in the atom `node` invisible, recursively. + Additionally, remove repeated, adjacent invisible parens from the atom `node` + as they are redundant. Returns whether the node should itself be wrapped in invisible parentheses. @@ -2849,7 +3135,7 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool: node.type != syms.atom or is_empty_tuple(node) or is_one_tuple(node) - or is_yield(node) + or (is_yield(node) and parent.type != syms.expr_stmt) or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY ): return False @@ -2857,16 +3143,40 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool: first = node.children[0] last = node.children[-1] if first.type == token.LPAR and last.type == token.RPAR: + middle = node.children[1] # make parentheses invisible first.value = "" # type: ignore last.value = "" # type: ignore - if len(node.children) > 1: - maybe_make_parens_invisible_in_atom(node.children[1]) + maybe_make_parens_invisible_in_atom(middle, parent=parent) + + if is_atom_with_invisible_parens(middle): + # Strip the invisible parens from `middle` by replacing + # it with the child in-between the invisible parens + middle.replace(middle.children[1]) + return False return True +def is_atom_with_invisible_parens(node: LN) -> bool: + """Given a `LN`, determines whether it's an atom `node` with invisible + parens. Useful in dedupe-ing and normalizing parens. + """ + if isinstance(node, Leaf) or node.type != syms.atom: + return False + + first, last = node.children[0], node.children[-1] + return ( + isinstance(first, Leaf) + and first.type == token.LPAR + and first.value == "" + and isinstance(last, Leaf) + and last.type == token.RPAR + and last.value == "" + ) + + def is_empty_tuple(node: LN) -> bool: """Return True if `node` holds an empty tuple.""" return ( @@ -2877,18 +3187,43 @@ def is_empty_tuple(node: LN) -> bool: ) +def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: + """Returns `wrapped` if `node` is of the shape ( wrapped ). + + Parenthesis can be optional. Returns None otherwise""" + if len(node.children) != 3: + return None + + lpar, wrapped, rpar = node.children + if not (lpar.type == token.LPAR and rpar.type == token.RPAR): + return None + + return wrapped + + +def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: + """Wrap `child` in parentheses. + + This replaces `child` with an atom holding the parentheses and the old + child. That requires moving the prefix. + + If `visible` is False, the leaves will be valueless (and thus invisible). + """ + lpar = Leaf(token.LPAR, "(" if visible else "") + rpar = Leaf(token.RPAR, ")" if visible else "") + prefix = child.prefix + child.prefix = "" + index = child.remove() or 0 + new_child = Node(syms.atom, [lpar, child, rpar]) + new_child.prefix = prefix + parent.insert_child(index, new_child) + + def is_one_tuple(node: LN) -> bool: """Return True if `node` holds a tuple with one element, with or without parens.""" if node.type == syms.atom: - if len(node.children) != 3: - return False - - lpar, gexp, rpar = node.children - if not ( - lpar.type == token.LPAR - and gexp.type == syms.testlist_gexp - and rpar.type == token.RPAR - ): + gexp = unwrap_singleton_parenthesis(node) + if gexp is None or gexp.type != syms.testlist_gexp: return False return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA @@ -2900,6 +3235,12 @@ def is_one_tuple(node: LN) -> bool: ) +def is_walrus_assignment(node: LN) -> bool: + """Return True iff `node` is of the shape ( test := test )""" + inner = unwrap_singleton_parenthesis(node) + return inner is not None and inner.type == syms.namedexpr_test + + def is_yield(node: LN) -> bool: """Return True if `node` holds a `yield` or `yield from` expression.""" if node.type == syms.yield_expr: @@ -2929,7 +3270,7 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: extended iterable unpacking (PEP 3132) and additional unpacking generalizations (PEP 448). """ - if leaf.type not in STARS or not leaf.parent: + if leaf.type not in VARARGS_SPECIALS or not leaf.parent: return False p = leaf.parent @@ -2979,7 +3320,7 @@ def is_stub_body(node: LN) -> bool: ) -def max_delimiter_priority_in_atom(node: LN) -> int: +def max_delimiter_priority_in_atom(node: LN) -> Priority: """Return maximum delimiter priority inside `node`. This is specific to atoms with contents contained in a pair of parentheses. @@ -3011,7 +3352,7 @@ def ensure_visible(leaf: Leaf) -> None: """Make sure parentheses are visible. They could be invisible as part of some statements (see - :func:`normalize_invible_parens` and :func:`visit_import_from`). + :func:`normalize_invisible_parens` and :func:`visit_import_from`). """ if leaf.type == token.LPAR: leaf.value = "(" @@ -3044,8 +3385,9 @@ def get_features_used(node: Node) -> Set[Feature]: Currently looking for: - f-strings; - - underscores in numeric literals; and - - trailing commas after * or ** in function signatures and calls. + - underscores in numeric literals; + - trailing commas after * or ** in function signatures and calls; + - positional only arguments in function signatures and lambdas; """ features: Set[Feature] = set() for n in node.pre_order(): @@ -3058,19 +3400,31 @@ def get_features_used(node: Node) -> Set[Feature]: if "_" in n.value: # type: ignore features.add(Feature.NUMERIC_UNDERSCORES) + elif n.type == token.SLASH: + if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}: + features.add(Feature.POS_ONLY_ARGUMENTS) + + elif n.type == token.COLONEQUAL: + features.add(Feature.ASSIGNMENT_EXPRESSIONS) + elif ( n.type in {syms.typedargslist, syms.arglist} and n.children and n.children[-1].type == token.COMMA ): + if n.type == syms.typedargslist: + feature = Feature.TRAILING_COMMA_IN_DEF + else: + feature = Feature.TRAILING_COMMA_IN_CALL + for ch in n.children: if ch.type in STARS: - features.add(Feature.TRAILING_COMMA) + features.add(feature) if ch.type == syms.argument: for argch in ch.children: if argch.type in STARS: - features.add(Feature.TRAILING_COMMA) + features.add(feature) return features @@ -3097,8 +3451,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit length = 4 * line.depth - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None inner_brackets: Set[LeafID] = set() for index, leaf, leaf_length in enumerate_with_length(line, reversed=True): length += leaf_length @@ -3142,19 +3496,23 @@ def get_future_imports(node: Node) -> Set[str]: if isinstance(child, Leaf): if child.type == token.NAME: yield child.value + elif child.type == syms.import_as_name: orig_name = child.children[0] assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" assert orig_name.type == token.NAME, "Invalid syntax parsing imports" yield orig_name.value + elif child.type == syms.import_as_names: yield from get_imports_from_children(child.children) + else: - assert False, "Invalid syntax parsing imports" + raise AssertionError("Invalid syntax parsing imports") for child in node.children: if child.type != syms.simple_stmt: break + first_child = child.children[0] if isinstance(first_child, Leaf): # Continue looking if we see a docstring; otherwise stop. @@ -3164,24 +3522,39 @@ def get_future_imports(node: Node) -> Set[str]: and child.children[1].type == token.NEWLINE ): continue - else: - break + + break + elif first_child.type == syms.import_from: module_name = first_child.children[1] if not isinstance(module_name, Leaf) or module_name.value != "__future__": break + imports |= set(get_imports_from_children(first_child.children[3:])) else: break + return imports +@lru_cache() +def get_gitignore(root: Path) -> PathSpec: + """ Return a PathSpec matching gitignore content if present.""" + gitignore = root / ".gitignore" + lines: List[str] = [] + if gitignore.is_file(): + with gitignore.open() as gf: + lines = gf.readlines() + return PathSpec.from_lines("gitwildmatch", lines) + + def gen_python_files_in_dir( path: Path, root: Path, include: Pattern[str], exclude: Pattern[str], report: "Report", + gitignore: PathSpec, ) -> Iterator[Path]: """Generate all files under `path` whose paths are not excluded by the `exclude` regex, but are included by the `include` regex. @@ -3192,8 +3565,18 @@ def gen_python_files_in_dir( """ assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}" for child in path.iterdir(): + # First ignore files matching .gitignore + if gitignore.match_file(child.as_posix()): + report.path_ignored(child, f"matches the .gitignore file content") + continue + + # Then ignore with `exclude` option. try: normalized_path = "/" + child.resolve().relative_to(root).as_posix() + except OSError as e: + report.path_ignored(child, f"cannot be read because {e}") + continue + except ValueError: if child.is_symlink(): report.path_ignored( @@ -3205,13 +3588,16 @@ def gen_python_files_in_dir( if child.is_dir(): normalized_path += "/" + exclude_match = exclude.search(normalized_path) if exclude_match and exclude_match.group(0): report.path_ignored(child, f"matches the --exclude regular expression") continue if child.is_dir(): - yield from gen_python_files_in_dir(child, root, include, exclude, report) + yield from gen_python_files_in_dir( + child, root, include, exclude, report, gitignore + ) elif child.is_file(): include_match = include.search(normalized_path) @@ -3237,7 +3623,7 @@ def find_project_root(srcs: Iterable[str]) -> Path: # Append a fake file so `parents` below returns `common_base_dir`, too. common_base /= "fake-file" for directory in common_base.parents: - if (directory / ".git").is_dir(): + if (directory / ".git").exists(): return directory if (directory / ".hg").is_dir(): @@ -3254,6 +3640,7 @@ class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + diff: bool = False quiet: bool = False verbose: bool = False change_count: int = 0 @@ -3263,7 +3650,7 @@ class Report: def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed is Changed.YES: - reformatted = "would reformat" if self.check else "reformatted" + reformatted = "would reformat" if self.check or self.diff else "reformatted" if self.verbose or not self.quiet: out(f"{reformatted} {src}") self.change_count += 1 @@ -3309,7 +3696,7 @@ class Report: Use `click.unstyle` to remove colors. """ - if self.check: + if self.check or self.diff: reformatted = "would be reformatted" unchanged = "would be left unchanged" failed = "would fail to reformat" @@ -3334,17 +3721,59 @@ class Report: return ", ".join(report) + "." +def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: + filename = "" + if sys.version_info >= (3, 8): + # TODO: support Python 4+ ;) + for minor_version in range(sys.version_info[1], 4, -1): + try: + return ast.parse(src, filename, feature_version=(3, minor_version)) + except SyntaxError: + continue + else: + for feature_version in (7, 6): + try: + return ast3.parse(src, filename, feature_version=feature_version) + except SyntaxError: + continue + + return ast27.parse(src) + + +def _fixup_ast_constants( + node: Union[ast.AST, ast3.AST, ast27.AST] +) -> Union[ast.AST, ast3.AST, ast27.AST]: + """Map ast nodes deprecated in 3.8 to Constant.""" + if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)): + return ast.Constant(value=node.s) + + if isinstance(node, (ast.Num, ast3.Num, ast27.Num)): + return ast.Constant(value=node.n) + + if isinstance(node, (ast.NameConstant, ast3.NameConstant)): + return ast.Constant(value=node.value) + + return node + + def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" - import ast - import traceback - - def _v(node: ast.AST, depth: int = 0) -> Iterator[str]: + def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: """Simple visitor generating strings to compare ASTs by content.""" + + node = _fixup_ast_constants(node) + yield f"{' ' * depth}{node.__class__.__name__}(" - for field in sorted(node._fields): + for field in sorted(node._fields): # noqa: F402 + # TypeIgnore has only one field 'lineno' which breaks this comparison + type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) + if sys.version_info >= (3, 8): + type_ignore_classes += (ast.TypeIgnore,) + if isinstance(node, type_ignore_classes): + break + try: value = getattr(node, field) except AttributeError: @@ -3358,15 +3787,16 @@ def assert_equivalent(src: str, dst: str) -> None: # parentheses and they change the AST. if ( field == "targets" - and isinstance(node, ast.Delete) - and isinstance(item, ast.Tuple) + and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) + and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) ): for item in item.elts: yield from _v(item, depth + 2) - elif isinstance(item, ast.AST): + + elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): yield from _v(item, depth + 2) - elif isinstance(value, ast.AST): + elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): yield from _v(value, depth + 2) else: @@ -3375,22 +3805,20 @@ def assert_equivalent(src: str, dst: str) -> None: yield f"{' ' * depth}) # /{node.__class__.__name__}" try: - src_ast = ast.parse(src) + src_ast = parse_ast(src) except Exception as exc: - major, minor = sys.version_info[:2] raise AssertionError( - f"cannot use --safe with this file; failed to parse source file " - f"with Python {major}.{minor}'s builtin AST. Re-run with --fast " - f"or stop using deprecated Python 2 syntax. AST error message: {exc}" + f"cannot use --safe with this file; failed to parse source file. " + f"AST error message: {exc}" ) try: - dst_ast = ast.parse(dst) + dst_ast = parse_ast(dst) except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( f"INTERNAL ERROR: Black produced invalid code: {exc}. " - f"Please report a bug on https://github.com/ambv/black/issues. " + f"Please report a bug on https://github.com/psf/black/issues. " f"This invalid output might be helpful: {log}" ) from None @@ -3401,12 +3829,12 @@ def assert_equivalent(src: str, dst: str) -> None: raise AssertionError( f"INTERNAL ERROR: Black produced code that is not equivalent to " f"the source. " - f"Please report a bug on https://github.com/ambv/black/issues. " + f"Please report a bug on https://github.com/psf/black/issues. " f"This diff might be helpful: {log}" ) from None -def assert_stable(src: str, dst: str, mode: FileMode) -> None: +def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, mode=mode) if dst != newdst: @@ -3417,15 +3845,14 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: raise AssertionError( f"INTERNAL ERROR: Black produced different code on the second pass " f"of the formatter. " - f"Please report a bug on https://github.com/ambv/black/issues. " + f"Please report a bug on https://github.com/psf/black/issues. " f"This diff might be helpful: {log}" ) from None +@mypyc_attr(patchable=True) def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" - import tempfile - with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: @@ -3436,6 +3863,15 @@ def dump_to_file(*output: str) -> str: return f.name +@contextmanager +def nullcontext() -> Iterator[None]: + """Return an empty context manager. + + To be used like `nullcontext` in Python 3.7. + """ + yield + + def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib @@ -3447,14 +3883,14 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) -def cancel(tasks: Iterable[asyncio.Task]) -> None: +def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: task.cancel() -def shutdown(loop: BaseEventLoop) -> None: +def shutdown(loop: asyncio.AbstractEventLoop) -> None: """Cancel all pending tasks on `loop`, wait for them, and close the loop.""" try: if sys.version_info[:2] >= (3, 7): @@ -3496,7 +3932,8 @@ def re_compile_maybe_verbose(regex: str) -> Pattern[str]: """ if "\n" in regex: regex = "(?x)" + regex - return re.compile(regex) + compiled: Pattern[str] = re.compile(regex) + return compiled def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: @@ -3523,7 +3960,6 @@ def enumerate_with_length( if "\n" in leaf.value: return # Multiline strings, we can't continue. - comment: Optional[Leaf] for comment in line.comments_after(leaf): length += len(comment.value) @@ -3669,11 +4105,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: return False -def get_cache_file(mode: FileMode) -> Path: +def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" -def read_cache(mode: FileMode) -> Cache: +def read_cache(mode: Mode) -> Cache: """Read the cache if it exists and is well formed. If it is not well formed, the call to write_cache later should resolve the issue. @@ -3685,7 +4121,7 @@ def read_cache(mode: FileMode) -> Cache: with cache_file.open("rb") as fobj: try: cache: Cache = pickle.load(fobj) - except pickle.UnpicklingError: + except (pickle.UnpicklingError, ValueError): return {} return cache @@ -3713,14 +4149,14 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set return todo, done -def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None: +def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: """Update the cache file.""" cache_file = get_cache_file(mode) try: CACHE_DIR.mkdir(parents=True, exist_ok=True) new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f: - pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(new_cache, f, protocol=4) os.replace(f.name, cache_file) except OSError: pass