X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/65c52a655fd67480a2017a79c99094039dcaffa3..8c74d7901fe8de0abd72a182d775b639b4202577:/black.py diff --git a/black.py b/black.py index 6bb29fe..4b5f19e 100644 --- a/black.py +++ b/black.py @@ -1,6 +1,5 @@ -#!/usr/bin/env python3 - import asyncio +import pickle from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor from enum import Enum @@ -10,18 +9,22 @@ import logging from multiprocessing import Manager import os from pathlib import Path +import re import tokenize import signal import sys from typing import ( Any, Callable, + Collection, Dict, Generic, Iterable, Iterator, List, Optional, + Pattern, + Sequence, Set, Tuple, Type, @@ -29,6 +32,7 @@ from typing import ( Union, ) +from appdirs import user_cache_dir from attr import dataclass, Factory import click @@ -38,8 +42,10 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.4a0" + +__version__ = "18.4a6" DEFAULT_LINE_LENGTH = 88 + # types syms = pygram.python_symbols FileContent = str @@ -51,6 +57,10 @@ Priority = int Index = int LN = Union[Leaf, Node] SplitFunc = Callable[["Line", bool], Iterator["Line"]] +Timestamp = float +FileSize = int +CacheInfo = Tuple[Timestamp, FileSize] +Cache = Dict[Path, CacheInfo] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg="red", err=True) @@ -79,11 +89,11 @@ class FormatError(Exception): self.consumed = consumed def trim_prefix(self, leaf: Leaf) -> None: - leaf.prefix = leaf.prefix[self.consumed:] + leaf.prefix = leaf.prefix[self.consumed :] def leaf_from_consumed(self, leaf: Leaf) -> Leaf: """Returns a new Leaf from the consumed part of the prefix.""" - unformatted_prefix = leaf.prefix[:self.consumed] + unformatted_prefix = leaf.prefix[: self.consumed] return Leaf(token.NEWLINE, unformatted_prefix) @@ -101,6 +111,12 @@ class WriteBack(Enum): DIFF = 2 +class Changed(Enum): + NO = 0 + CACHED = 1 + YES = 2 + + @click.command() @click.option( "-l", @@ -129,6 +145,15 @@ class WriteBack(Enum): is_flag=True, help="If --fast given, skip temporary sanity checks. [default: --safe]", ) +@click.option( + "-q", + "--quiet", + is_flag=True, + help=( + "Don't emit non-error messages to stderr. Errors are still emitted, " + "silence those with 2>/dev/null." + ), +) @click.version_option(version=__version__) @click.argument( "src", @@ -144,6 +169,7 @@ def main( check: bool, diff: bool, fast: bool, + quiet: bool, src: List[str], ) -> None: """The uncompromising code formatter.""" @@ -159,58 +185,80 @@ def main( sources.append(Path("-")) else: err(f"invalid path: {s}") - if check and diff: - exc = click.ClickException("Options --check and --diff are mutually exclusive") - exc.exit_code = 2 - raise exc - if check: + if check and not diff: write_back = WriteBack.NO elif diff: write_back = WriteBack.DIFF else: write_back = WriteBack.YES + report = Report(check=check, quiet=quiet) if len(sources) == 0: + out("No paths given. Nothing to do 😴") ctx.exit(0) + return + elif len(sources) == 1: - p = sources[0] - report = Report(check=check) - try: - if not p.is_file() and str(p) == "-": - changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=write_back - ) - else: - changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=write_back - ) - report.done(p, changed) - except Exception as exc: - report.failed(p, str(exc)) - ctx.exit(report.return_code) + reformat_one(sources[0], line_length, fast, write_back, report) else: loop = asyncio.get_event_loop() executor = ProcessPoolExecutor(max_workers=os.cpu_count()) - return_code = 1 try: - return_code = loop.run_until_complete( + loop.run_until_complete( schedule_formatting( - sources, line_length, write_back, fast, loop, executor + sources, line_length, fast, write_back, report, loop, executor ) ) finally: shutdown(loop) - ctx.exit(return_code) + if not quiet: + out("All done! ✨ 🍰 ✨") + click.echo(str(report)) + ctx.exit(report.return_code) + + +def reformat_one( + src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report" +) -> None: + """Reformat a single file under `src` without spawning child processes. + + If `quiet` is True, non-error messages are not output. `line_length`, + `write_back`, and `fast` options are passed to :func:`format_file_in_place`. + """ + try: + changed = Changed.NO + if not src.is_file() and str(src) == "-": + if format_stdin_to_stdout( + line_length=line_length, fast=fast, write_back=write_back + ): + changed = Changed.YES + else: + cache: Cache = {} + if write_back != WriteBack.DIFF: + cache = read_cache(line_length) + src = src.resolve() + if src in cache and cache[src] == get_cache_info(src): + changed = Changed.CACHED + if changed is not Changed.CACHED and format_file_in_place( + src, line_length=line_length, fast=fast, write_back=write_back + ): + changed = Changed.YES + if write_back == WriteBack.YES and changed is not Changed.NO: + write_cache(cache, [src], line_length) + report.done(src, changed) + except Exception as exc: + report.failed(src, str(exc)) async def schedule_formatting( sources: List[Path], line_length: int, - write_back: WriteBack, fast: bool, + write_back: WriteBack, + report: "Report", loop: BaseEventLoop, executor: Executor, -) -> int: +) -> None: """Run formatting of `sources` in parallel using the provided `executor`. (Use ProcessPoolExecutors for actual parallelism.) @@ -218,41 +266,49 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ - lock = None - if write_back == WriteBack.DIFF: - # For diff output, we need locks to ensure we don't interleave output - # from different processes. - manager = Manager() - lock = manager.Lock() - tasks = { - src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back, lock - ) - for src in sources - } - _task_values = list(tasks.values()) - loop.add_signal_handler(signal.SIGINT, cancel, _task_values) - loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) - await asyncio.wait(tasks.values()) + cache: Cache = {} + if write_back != WriteBack.DIFF: + cache = read_cache(line_length) + sources, cached = filter_cached(cache, sources) + for src in cached: + report.done(src, Changed.CACHED) cancelled = [] - report = Report(check=not write_back) - for src, task in tasks.items(): - if not task.done(): - report.failed(src, "timed out, cancelling") - task.cancel() - cancelled.append(task) - elif task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - report.done(src, task.result()) + formatted = [] + if sources: + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() + tasks = { + loop.run_in_executor( + executor, format_file_in_place, src, line_length, fast, write_back, lock + ): src + for src in sorted(sources) + } + pending: Iterable[asyncio.Task] = tasks.keys() + try: + loop.add_signal_handler(signal.SIGINT, cancel, pending) + loop.add_signal_handler(signal.SIGTERM, cancel, pending) + except NotImplementedError: + # There are no good alternatives for these on Windows + pass + while pending: + done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + src = tasks.pop(task) + if task.cancelled(): + cancelled.append(task) + elif task.exception(): + report.failed(src, str(task.exception())) + else: + formatted.append(src) + report.done(src, Changed.YES if task.result() else Changed.NO) if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - else: - out("All done! ✨ 🍰 ✨") - click.echo(str(report)) - return report.return_code + if write_back == WriteBack.YES and formatted: + write_cache(cache, formatted, line_length) def format_file_in_place( @@ -267,11 +323,13 @@ def format_file_in_place( If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ + is_pyi = src.suffix == ".pyi" + with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: dst_contents = format_file_contents( - src_contents, line_length=line_length, fast=fast + src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi ) except NothingChanged: return False @@ -280,8 +338,8 @@ def format_file_in_place( with open(src, "w", encoding=src_buffer.encoding) as f: f.write(dst_contents) elif write_back == write_back.DIFF: - src_name = f"{src.name} (original)" - dst_name = f"{src.name} (formatted)" + src_name = f"{src} (original)" + dst_name = f"{src} (formatted)" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) if lock: lock.acquire() @@ -302,12 +360,12 @@ def format_stdin_to_stdout( `line_length` and `fast` arguments are passed to :func:`format_file_contents`. """ src = sys.stdin.read() + dst = src try: dst = format_file_contents(src, line_length=line_length, fast=fast) return True except NothingChanged: - dst = src return False finally: @@ -320,7 +378,7 @@ def format_stdin_to_stdout( def format_file_contents( - src_contents: str, line_length: int, fast: bool + src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False ) -> FileContent: """Reformat contents a file and return new contents. @@ -331,26 +389,33 @@ def format_file_contents( if src_contents.strip() == "": raise NothingChanged - dst_contents = format_str(src_contents, line_length=line_length) + dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi) if src_contents == dst_contents: raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) - assert_stable(src_contents, dst_contents, line_length=line_length) + assert_stable( + src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi + ) return dst_contents -def format_str(src_contents: str, line_length: int) -> FileContent: +def format_str( + src_contents: str, line_length: int, *, is_pyi: bool = False +) -> FileContent: """Reformat a string and return new contents. `line_length` determines how many characters per line are allowed. """ src_node = lib2to3_parse(src_contents) dst_contents = "" - lines = LineGenerator() - elt = EmptyLineTracker() + future_imports = get_future_imports(src_node) + elt = EmptyLineTracker(is_pyi=is_pyi) py36 = is_python36(src_node) + lines = LineGenerator( + remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi + ) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -367,7 +432,6 @@ def format_str(src_contents: str, line_length: int) -> FileContent: GRAMMARS = [ pygram.python_grammar_no_print_statement_no_exec_statement, pygram.python_grammar_no_print_statement, - pygram.python_grammar_no_exec_statement, pygram.python_grammar, ] @@ -493,27 +557,91 @@ COMPARATORS = { token.GREATEREQUAL, } MATH_OPERATORS = { + token.VBAR, + token.CIRCUMFLEX, + token.AMPER, + token.LEFTSHIFT, + token.RIGHTSHIFT, token.PLUS, token.MINUS, token.STAR, token.SLASH, - token.VBAR, - token.AMPER, + token.DOUBLESLASH, token.PERCENT, - token.CIRCUMFLEX, + token.AT, token.TILDE, - token.LEFTSHIFT, - token.RIGHTSHIFT, token.DOUBLESTAR, - token.DOUBLESLASH, } -VARARGS = {token.STAR, token.DOUBLESTAR} +STARS = {token.STAR, token.DOUBLESTAR} +VARARGS_PARENTS = { + syms.arglist, + syms.argument, # double star in arglist + syms.trailer, # single argument to call + syms.typedargslist, + syms.varargslist, # lambdas +} +UNPACKING_PARENTS = { + syms.atom, # single element of a list or set literal + syms.dictsetmaker, + syms.listmaker, + syms.testlist_gexp, +} +TEST_DESCENDANTS = { + syms.test, + syms.lambdef, + syms.or_test, + syms.and_test, + syms.not_test, + syms.comparison, + syms.star_expr, + syms.expr, + syms.xor_expr, + syms.and_expr, + syms.shift_expr, + syms.arith_expr, + syms.trailer, + syms.term, + syms.power, +} +ASSIGNMENTS = { + "=", + "+=", + "-=", + "*=", + "@=", + "/=", + "%=", + "&=", + "|=", + "^=", + "<<=", + ">>=", + "**=", + "//=", +} COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 10 -LOGIC_PRIORITY = 5 -STRING_PRIORITY = 4 -COMPARATOR_PRIORITY = 3 -MATH_PRIORITY = 1 +COMMA_PRIORITY = 18 +TERNARY_PRIORITY = 16 +LOGIC_PRIORITY = 14 +STRING_PRIORITY = 12 +COMPARATOR_PRIORITY = 10 +MATH_PRIORITIES = { + token.VBAR: 9, + token.CIRCUMFLEX: 8, + token.AMPER: 7, + token.LEFTSHIFT: 6, + token.RIGHTSHIFT: 6, + token.PLUS: 5, + token.MINUS: 5, + token.STAR: 4, + token.SLASH: 4, + token.DOUBLESLASH: 4, + token.PERCENT: 4, + token.AT: 4, + token.TILDE: 3, + token.DOUBLESTAR: 2, +} +DOT_PRIORITY = 1 @dataclass @@ -524,6 +652,8 @@ class BracketTracker: bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) delimiters: Dict[LeafID, Priority] = Factory(dict) previous: Optional[Leaf] = None + _for_loop_variable: int = 0 + _lambda_arguments: int = 0 def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -543,22 +673,27 @@ class BracketTracker: if leaf.type == token.COMMENT: return + self.maybe_decrement_after_for_loop_variable(leaf) + self.maybe_decrement_after_lambda_arguments(leaf) if leaf.type in CLOSING_BRACKETS: self.depth -= 1 opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) leaf.opening_bracket = opening_bracket leaf.bracket_depth = self.depth if self.depth == 0: - after_delim = is_split_after_delimiter(leaf, self.previous) - before_delim = is_split_before_delimiter(leaf, self.previous) - if after_delim > before_delim: - self.delimiters[id(leaf)] = after_delim - elif before_delim > after_delim and self.previous is not None: - self.delimiters[id(self.previous)] = before_delim + delim = is_split_before_delimiter(leaf, self.previous) + if delim and self.previous is not None: + self.delimiters[id(self.previous)] = delim + else: + delim = is_split_after_delimiter(leaf, self.previous) + if delim: + self.delimiters[id(leaf)] = delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 self.previous = leaf + self.maybe_increment_lambda_arguments(leaf) + self.maybe_increment_for_loop_variable(leaf) def any_open_brackets(self) -> bool: """Return True if there is an yet unmatched open bracket on the line.""" @@ -567,10 +702,70 @@ class BracketTracker: def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: """Return the highest priority of a delimiter found on the line. - Values are consistent with what `is_delimiter()` returns. + Values are consistent with what `is_split_*_delimiter()` return. + Raises ValueError on no delimiters. """ return max(v for k, v in self.delimiters.items() if k not in exclude) + def delimiter_count_with_priority(self, priority: int = 0) -> int: + """Return the number of delimiters with the given `priority`. + + If no `priority` is passed, defaults to max priority on the line. + """ + if not self.delimiters: + return 0 + + priority = priority or self.max_delimiter_priority() + return sum(1 for p in self.delimiters.values() if p == priority) + + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """In a for loop, or comprehension, the variables are often unpacks. + + To avoid splitting on the comma in this situation, increase the depth of + tokens between `for` and `in`. + """ + if leaf.type == token.NAME and leaf.value == "for": + self.depth += 1 + self._for_loop_variable += 1 + return True + + return False + + def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: + """See `maybe_increment_for_loop_variable` above for explanation.""" + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": + self.depth -= 1 + self._for_loop_variable -= 1 + return True + + return False + + def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool: + """In a lambda expression, there might be more than one argument. + + To avoid splitting on the comma in this situation, increase the depth of + tokens between `lambda` and `:`. + """ + if leaf.type == token.NAME and leaf.value == "lambda": + self.depth += 1 + self._lambda_arguments += 1 + return True + + return False + + def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool: + """See `maybe_increment_lambda_arguments` above for explanation.""" + if self._lambda_arguments and leaf.type == token.COLON: + self.depth -= 1 + self._lambda_arguments -= 1 + return True + + return False + + def get_open_lsqb(self) -> Optional[Leaf]: + """Return the most recent opening square bracket (if any).""" + return self.bracket_match.get((self.depth - 1, token.RSQB)) + @dataclass class Line: @@ -581,8 +776,6 @@ class Line: comments: List[Tuple[Index, Leaf]] = Factory(list) bracket_tracker: BracketTracker = Factory(BracketTracker) inside_brackets: bool = False - has_for: bool = False - _for_loop_variable: bool = False def append(self, leaf: Leaf, preformatted: bool = False) -> None: """Add a new `leaf` to the end of the line. @@ -594,20 +787,21 @@ class Line: Inline comments are put aside. """ - has_value = leaf.value.strip() + has_value = leaf.type in BRACKETS or bool(leaf.value.strip()) if not has_value: return + if token.COLON == leaf.type and self.is_class_paren_empty: + del self.leaves[-2:] if self.leaves and not preformatted: # Note: at this point leaf.prefix should be empty except for # imports, for which we only preserve newlines. - leaf.prefix += whitespace(leaf) + leaf.prefix += whitespace( + leaf, complex_subscript=self.is_complex_subscript(leaf) + ) if self.inside_brackets or not preformatted: - self.maybe_decrement_after_for_loop_variable(leaf) self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) - self.maybe_increment_for_loop_variable(leaf) - if not self.append_comment(leaf): self.leaves.append(leaf) @@ -652,6 +846,13 @@ class Line: and self.leaves[0].value == "class" ) + @property + def is_stub_class(self) -> bool: + """Is this line a class definition with a body consisting only of "..."?""" + return self.is_class and self.leaves[-3:] == [ + Leaf(token.DOT, ".") for _ in range(3) + ] + @property def is_def(self) -> bool: """Is this a function definition? (Also returns True for async defs.)""" @@ -664,14 +865,11 @@ class Line: second_leaf: Optional[Leaf] = self.leaves[1] except IndexError: second_leaf = None - return ( - (first_leaf.type == token.NAME and first_leaf.value == "def") - or ( - first_leaf.type == token.ASYNC - and second_leaf is not None - and second_leaf.type == token.NAME - and second_leaf.value == "def" - ) + return (first_leaf.type == token.NAME and first_leaf.value == "def") or ( + first_leaf.type == token.ASYNC + and second_leaf is not None + and second_leaf.type == token.NAME + and second_leaf.value == "def" ) @property @@ -696,11 +894,27 @@ class Line: ) @property - def contains_standalone_comments(self) -> bool: + def is_class_paren_empty(self) -> bool: + """Is this a class with no base classes but using parentheses? + + Those are unnecessary and should be removed. + """ + return ( + bool(self) + and len(self.leaves) == 4 + and self.is_class + and self.leaves[2].type == token.LPAR + and self.leaves[2].value == "(" + and self.leaves[3].type == token.RPAR + and self.leaves[3].value == ")" + ) + + def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: if leaf.type == STANDALONE_COMMENT: - return True + if leaf.bracket_depth <= depth_limit: + return True return False @@ -723,9 +937,14 @@ class Line: self.remove_trailing_comma() return True - # For parens let's check if it's safe to remove the comma. If the - # trailing one is the only one, we might mistakenly change a tuple - # into a different type by removing the comma. + # For parens let's check if it's safe to remove the comma. + # Imports are always safe. + if self.is_import: + self.remove_trailing_comma() + return True + + # Otheriwsse, if the trailing one is the only one, we might mistakenly + # change a tuple into a different type by removing the comma. depth = closing.bracket_depth + 1 commas = 0 opening = closing.opening_bracket @@ -736,7 +955,7 @@ class Line: else: return False - for leaf in self.leaves[_opening_index + 1:]: + for leaf in self.leaves[_opening_index + 1 :]: if leaf is closing: break @@ -753,29 +972,6 @@ class Line: return False - def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: - """In a for loop, or comprehension, the variables are often unpacks. - - To avoid splitting on the comma in this situation, increase the depth of - tokens between `for` and `in`. - """ - if leaf.type == token.NAME and leaf.value == "for": - self.has_for = True - self.bracket_tracker.depth += 1 - self._for_loop_variable = True - return True - - return False - - def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: - """See `maybe_increment_for_loop_variable` above for explanation.""" - if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": - self.bracket_tracker.depth -= 1 - self._for_loop_variable = False - return True - - return False - def append_comment(self, comment: Leaf) -> bool: """Add an inline or standalone comment to the line.""" if ( @@ -798,17 +994,21 @@ class Line: self.comments.append((after, comment)) return True - def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: - """Generate comments that should appear directly after `leaf`.""" - for _leaf_index, _leaf in enumerate(self.leaves): - if leaf is _leaf: - break + def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`. - else: - return + Provide a non-negative leaf `_index` to speed up the function. + """ + if _index == -1: + for _index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return for index, comment_after in self.comments: - if _leaf_index == index: + if _index == index: yield comment_after def remove_trailing_comma(self) -> None: @@ -820,6 +1020,24 @@ class Line: self.comments[i] = (comma_index - 1, comment) self.leaves.pop() + def is_complex_subscript(self, leaf: Leaf) -> bool: + """Return True iff `leaf` is part of a slice with non-trivial exprs.""" + open_lsqb = ( + leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb() + ) + if open_lsqb is None: + return False + + subscript_start = open_lsqb.next_sibling + if ( + isinstance(subscript_start, Node) + and subscript_start.type == syms.subscriptlist + ): + subscript_start = child_towards(subscript_start, leaf) + return subscript_start is not None and any( + n.type in TEST_DESCENDANTS for n in subscript_start.pre_order() + ) + def __str__(self) -> str: """Render the line.""" if not self: @@ -898,6 +1116,7 @@ class EmptyLineTracker: the prefix of the first leaf consists of optional newlines. Those newlines are consumed by `maybe_empty_lines()` and included in the computation. """ + is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 previous_defs: List[int] = Factory(list) @@ -921,7 +1140,7 @@ class EmptyLineTracker: def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: max_allowed = 1 if current_line.depth == 0: - max_allowed = 2 + max_allowed = 1 if self.is_pyi else 2 if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] @@ -933,7 +1152,10 @@ class EmptyLineTracker: depth = current_line.depth while self.previous_defs and self.previous_defs[-1] >= depth: self.previous_defs.pop() - before = 1 if depth else 2 + if self.is_pyi: + before = 0 if depth else 1 + else: + before = 1 if depth else 2 is_decorator = current_line.is_decorator if is_decorator or current_line.is_def or current_line.is_class: if not is_decorator: @@ -942,18 +1164,32 @@ class EmptyLineTracker: # Don't insert empty lines before the first line in the file. return 0, 0 - if self.previous_line and self.previous_line.is_decorator: - # Don't insert empty lines between decorators. + if self.previous_line.is_decorator: + return 0, 0 + + if ( + self.previous_line.is_comment + and self.previous_line.depth == current_line.depth + and before == 0 + ): return 0, 0 - newlines = 2 - if current_line.depth: + if self.is_pyi: + if self.previous_line.depth > current_line.depth: + newlines = 1 + elif current_line.is_class or self.previous_line.is_class: + if current_line.is_stub_class and self.previous_line.is_stub_class: + newlines = 0 + else: + newlines = 1 + else: + newlines = 0 + else: + newlines = 2 + if current_line.depth and newlines: newlines -= 1 return newlines, 0 - if current_line.is_flow_control: - return before, 1 - if ( self.previous_line and self.previous_line.is_import @@ -962,13 +1198,6 @@ class EmptyLineTracker: ): return (before or 1), 0 - if ( - self.previous_line - and self.previous_line.is_yield - and (not current_line.is_yield or depth != self.previous_line.depth) - ): - return (before or 1), 0 - return before, 0 @@ -979,7 +1208,9 @@ class LineGenerator(Visitor[Line]): Note: destroys the tree it's visiting by mutating prefixes of its leaves in ways that will no longer stringify to valid Python code on the tree. """ + is_pyi: bool = False current_line: Line = Factory(Line) + remove_u_prefix: bool = False def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]: """Generate a line. @@ -1047,6 +1278,7 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) if node.type == token.STRING: + normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) @@ -1060,34 +1292,60 @@ class LineGenerator(Visitor[Line]): def visit_DEDENT(self, node: Node) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" - # DEDENT has no value. Additionally, in blib2to3 it never holds comments. + # The current line might still wait for trailing comments. At DEDENT time + # there won't be any (they would be prefixes on the preceding NEWLINE). + # Emit the line then. + yield from self.line() + + # While DEDENT has no value, its prefix may contain standalone comments + # that belong to the current indentation level. Get 'em. + yield from self.visit_default(node) + + # Finally, emit the dedent. yield from self.line(-1) - def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]: + def visit_stmt( + self, node: Node, keywords: Set[str], parens: Set[str] + ) -> Iterator[Line]: """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, - `def`, `with`, and `class`. + `def`, `with`, `class`, `assert` and assignments. + + The relevant Python language `keywords` for a given statement will be + NAME leaves within it. This methods puts those on a separate line. - The relevant Python language `keywords` for a given statement will be NAME - leaves within it. This methods puts those on a separate line. + `parens` holds a set of string leaf values immeditely after which + invisible parens should be put. """ + normalize_invisible_parens(node, parens_after=parens) for child in node.children: if child.type == token.NAME and child.value in keywords: # type: ignore yield from self.line() yield from self.visit(child) + def visit_suite(self, node: Node) -> Iterator[Line]: + """Visit a suite.""" + if self.is_pyi and is_stub_suite(node): + yield from self.visit(node.children[2]) + else: + yield from self.visit_default(node) + def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - yield from self.line(+1) - yield from self.visit_default(node) - yield from self.line(-1) + if self.is_pyi and is_stub_body(node): + yield from self.visit_default(node) + else: + yield from self.line(+1) + yield from self.visit_default(node) + yield from self.line(-1) else: - yield from self.line() + if not self.is_pyi or not node.parent or not is_stub_suite(node.parent): + yield from self.line() yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: @@ -1111,6 +1369,32 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(child) + def visit_import_from(self, node: Node) -> Iterator[Line]: + """Visit import_from and maybe put invisible parentheses. + + This is separate from `visit_stmt` because import statements don't + support arbitrary atoms and thus handling of parentheses is custom. + """ + check_lpar = False + for index, child in enumerate(node.children): + if check_lpar: + if child.type == token.LPAR: + # make parentheses invisible + child.value = "" # type: ignore + node.children[-1].value = "" # type: ignore + else: + # insert invisible parentheses + node.insert_child(index, Leaf(token.LPAR, "")) + node.append_child(Leaf(token.RPAR, "")) + break + + check_lpar = ( + child.type == token.NAME and child.value == "import" # type: ignore + ) + + for child in node.children: + yield from self.visit(child) + def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]: """Remove a semicolon and put the other statement on a separate line.""" yield from self.line() @@ -1141,18 +1425,27 @@ class LineGenerator(Visitor[Line]): def __attrs_post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt - self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}) - self.visit_while_stmt = partial(v, keywords={"while", "else"}) - self.visit_for_stmt = partial(v, keywords={"for", "else"}) - self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"}) - self.visit_except_clause = partial(v, keywords={"except"}) - self.visit_funcdef = partial(v, keywords={"def"}) - self.visit_with_stmt = partial(v, keywords={"with"}) - self.visit_classdef = partial(v, keywords={"class"}) + Ø: Set[str] = set() + self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) + self.visit_if_stmt = partial( + v, keywords={"if", "else", "elif"}, parens={"if", "elif"} + ) + self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"}) + self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"}) + self.visit_try_stmt = partial( + v, keywords={"try", "except", "else", "finally"}, parens=Ø + ) + self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø) + self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) + self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø) + self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) + self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) + self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) self.visit_async_funcdef = self.visit_async_stmt self.visit_decorated = self.visit_decorators +IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist} BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE} OPENING_BRACKETS = set(BRACKET.keys()) CLOSING_BRACKETS = set(BRACKET.values()) @@ -1160,8 +1453,12 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} -def whitespace(leaf: Leaf) -> str: # noqa C901 - """Return whitespace prefix if needed for the given `leaf`.""" +def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901 + """Return whitespace prefix if needed for the given `leaf`. + + `complex_subscript` signals whether the given leaf is part of a subscription + which has non-trivial arguments, like arithmetic expressions or function calls. + """ NO = "" SPACE = " " DOUBLESPACE = " " @@ -1175,7 +1472,9 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: + if t == token.COLON and p.type not in { + syms.subscript, syms.subscriptlist, syms.sliceop + }: return NO prev = leaf.prev_sibling @@ -1185,7 +1484,13 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO if t == token.COLON: - return SPACE if prevp.type == token.COMMA else NO + if prevp.type == token.COLON: + return NO + + elif prevp.type != token.COMMA and not complex_subscript: + return NO + + return SPACE if prevp.type == token.EQUAL: if prevp.parent: @@ -1200,24 +1505,17 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 # that, too. return prevp.prefix - elif prevp.type == token.DOUBLESTAR: - if prevp.parent and prevp.parent.type in { - syms.arglist, - syms.argument, - syms.dictsetmaker, - syms.parameters, - syms.typedargslist, - syms.varargslist, - }: + elif prevp.type in STARS: + if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS): return NO elif prevp.type == token.COLON: if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: - return NO + return SPACE if complex_subscript else NO elif ( prevp.parent - and prevp.parent.type in {syms.factor, syms.star_expr} + and prevp.parent.type == syms.factor and prevp.type in MATH_OPERATORS ): return NO @@ -1238,17 +1536,11 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if p.type in {syms.parameters, syms.arglist}: # untyped function signatures or calls - if t == token.RPAR: - return NO - if not prev or prev.type != token.COMMA: return NO elif p.type == syms.varargslist: # lambdas - if t == token.RPAR: - return NO - if prev and prev.type != token.COMMA: return NO @@ -1303,7 +1595,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if not prevp or prevp.type == token.LPAR: return NO - elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR: + elif prev.type in {token.EQUAL} | STARS: return NO elif p.type == syms.decorator: @@ -1325,7 +1617,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if prev and prev.type == token.LPAR: return NO - elif p.type == syms.subscript: + elif p.type in {syms.subscript, syms.sliceop}: # indexing if not prev: assert p.parent is not None, "subscripts are always parented" @@ -1334,7 +1626,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO - else: + elif not complex_subscript: return NO elif p.type == syms.atom: @@ -1342,21 +1634,9 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 # dots, but not the first one. return NO - elif ( - p.type == syms.listmaker - or p.type == syms.testlist_gexp - or p.type == syms.subscriptlist - ): - # list interior, including unpacking - if not prev: - return NO - elif p.type == syms.dictsetmaker: - # dict and set interior, including unpacking - if not prev: - return NO - - if prev.type == token.DOUBLESTAR: + # dict unpacking + if prev and prev.type == token.DOUBLESTAR: return NO elif p.type in {syms.factor, syms.star_expr}: @@ -1415,6 +1695,14 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None +def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]: + """Return the child of `ancestor` that contains `descendant`.""" + node: Optional[LN] = descendant + while node and node.parent != ancestor: + node = node.parent + return node + + def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: """Return the priority of the `leaf` delimiter, given a line break after it. @@ -1426,13 +1714,6 @@ def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: if leaf.type == token.COMMA: return COMMA_PRIORITY - if ( - leaf.type in VARARGS - and leaf.parent - and leaf.parent.type in {syms.argument, syms.typedargslist} - ): - return MATH_PRIORITY - return 0 @@ -1444,12 +1725,25 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: Higher numbers are higher priority. """ + if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS): + # * and ** might also be MATH_OPERATORS but in this case they are not. + # Don't treat them as a delimiter. + return 0 + + if ( + leaf.type == token.DOT + and leaf.parent + and leaf.parent.type not in {syms.import_from, syms.dotted_name} + and (previous is None or previous.type != token.NAME) + ): + return DOT_PRIORITY + if ( leaf.type in MATH_OPERATORS and leaf.parent and leaf.parent.type not in {syms.factor, syms.star_expr} ): - return MATH_PRIORITY + return MATH_PRIORITIES[leaf.type] if leaf.type in COMPARATORS: return COMPARATOR_PRIORITY @@ -1461,37 +1755,57 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: ): return STRING_PRIORITY + if leaf.type != token.NAME: + return 0 + if ( - leaf.type == token.NAME - and leaf.value == "for" + leaf.value == "for" and leaf.parent and leaf.parent.type in {syms.comp_for, syms.old_comp_for} ): return COMPREHENSION_PRIORITY if ( - leaf.type == token.NAME - and leaf.value == "if" + leaf.value == "if" and leaf.parent and leaf.parent.type in {syms.comp_if, syms.old_comp_if} ): return COMPREHENSION_PRIORITY - if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: - return LOGIC_PRIORITY + if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test: + return TERNARY_PRIORITY - return 0 + if leaf.value == "is": + return COMPARATOR_PRIORITY + if ( + leaf.value == "in" + and leaf.parent + and leaf.parent.type in {syms.comp_op, syms.comparison} + and not ( + previous is not None + and previous.type == token.NAME + and previous.value == "not" + ) + ): + return COMPARATOR_PRIORITY -def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: - """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. + if ( + leaf.value == "not" + and leaf.parent + and leaf.parent.type == syms.comp_op + and not ( + previous is not None + and previous.type == token.NAME + and previous.value == "is" + ) + ): + return COMPARATOR_PRIORITY - Higher numbers are higher priority. - """ - return max( - is_split_before_delimiter(leaf, previous), - is_split_after_delimiter(leaf, previous), - ) + if leaf.value in LOGIC_OPERATORS and leaf.parent: + return LOGIC_PRIORITY + + return 0 def generate_comments(leaf: Leaf) -> Iterator[Leaf]: @@ -1588,21 +1902,31 @@ def split_line( return line_str = str(line).strip("\n") - if ( - len(line_str) <= line_length - and "\n" not in line_str # multiline strings - and not line.contains_standalone_comments - ): + if is_line_short_enough(line, line_length=line_length, line_str=line_str): yield line return split_funcs: List[SplitFunc] if line.is_def: split_funcs = [left_hand_split] - elif line.inside_brackets: - split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] + elif line.is_import: + split_funcs = [explode_split] else: - split_funcs = [right_hand_split] + + def rhs(line: Line, py36: bool = False) -> Iterator[Line]: + for omit in generate_trailers_to_omit(line, line_length): + lines = list(right_hand_split(line, py36, omit=omit)) + if is_line_short_enough(lines[0], line_length=line_length): + yield from lines + return + + # All splits failed, best effort split with no omits. + yield from right_hand_split(line, py36) + + if line.inside_brackets: + split_funcs = [delimiter_split, standalone_comment_split, rhs] + else: + split_funcs = [rhs] for split_func in split_funcs: # We are accumulating lines in `result` because we might want to abort # mission and return the original line in the end, or attempt a different @@ -1631,7 +1955,8 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. - Prefer RHS otherwise. + Prefer RHS otherwise. This is why this function is not symmetrical with + :func:`right_hand_split` which also handles optional parentheses. """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1657,9 +1982,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. - for result, leaves in ( - (head, head_leaves), (body, body_leaves), (tail, tail_leaves) - ): + for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves): for leaf in leaves: result.append(leaf, preformatted=True) for comment_after in line.comments_after(leaf): @@ -1670,8 +1993,17 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: yield result -def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Split line into many lines, starting with the last matching bracket pair.""" +def right_hand_split( + line: Line, py36: bool = False, omit: Collection[LeafID] = () +) -> Iterator[Line]: + """Split line into many lines, starting with the last matching bracket pair. + + If the split was by optional parentheses, attempt splitting without them, too. + `omit` is a collection of closing bracket IDs that shouldn't be considered for + this split. + + Note: running this function modifies `bracket_depth` on the leaves of `line`. + """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) tail = Line(depth=line.depth) @@ -1680,14 +2012,16 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: head_leaves: List[Leaf] = [] current_leaves = tail_leaves opening_bracket = None + closing_bracket = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: current_leaves = head_leaves if body_leaves else tail_leaves current_leaves.append(leaf) if current_leaves is tail_leaves: - if leaf.type in CLOSING_BRACKETS: + if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit: opening_bracket = leaf.opening_bracket + closing_bracket = leaf current_leaves = body_leaves tail_leaves.reverse() body_leaves.reverse() @@ -1695,15 +2029,50 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) + if not head_leaves: + # No `head` means the split failed. Either `tail` has all content or + # the matching `opening_bracket` wasn't available on `line` anymore. + raise CannotSplit("No brackets found") + # Build the new lines. - for result, leaves in ( - (head, head_leaves), (body, body_leaves), (tail, tail_leaves) - ): + for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves): for leaf in leaves: result.append(leaf, preformatted=True) for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) + assert opening_bracket and closing_bracket + if ( + # the opening bracket is an optional paren + opening_bracket.type == token.LPAR + and not opening_bracket.value + # the closing bracket is an optional paren + and closing_bracket.type == token.RPAR + and not closing_bracket.value + # there are no standalone comments in the body + and not line.contains_standalone_comments(0) + # and it's not an import (optional parens are the only thing we can split + # on in this case; attempting a split without them is a waste of time) + and not line.is_import + ): + omit = {id(closing_bracket), *omit} + delimiter_count = body.bracket_tracker.delimiter_count_with_priority() + if ( + delimiter_count == 0 + or delimiter_count == 1 + and ( + body.leaves[0].type in OPENING_BRACKETS + or body.leaves[-1].type in CLOSING_BRACKETS + ) + ): + try: + yield from right_hand_split(line, py36=py36, omit=omit) + return + except CannotSplit: + pass + + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) for result in (head, body, tail): if result: yield result @@ -1762,14 +2131,16 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: except IndexError: raise CannotSplit("Line empty") - delimiters = line.bracket_tracker.delimiters + bt = line.bracket_tracker try: - delimiter_priority = line.bracket_tracker.max_delimiter_priority( - exclude={id(last_leaf)} - ) + delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)}) except ValueError: raise CannotSplit("No delimiters found") + if delimiter_priority == DOT_PRIORITY: + if bt.delimiter_count_with_priority(delimiter_priority) == 1: + raise CannotSplit("Splitting a single attribute from its owner looks wrong") + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) lowest_depth = sys.maxsize trailing_comma_safe = True @@ -1785,20 +2156,18 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) lowest_depth = min(lowest_depth, leaf.bracket_depth) - if ( - leaf.bracket_depth == lowest_depth - and leaf.type == token.STAR - or leaf.type == token.DOUBLESTAR + if leaf.bracket_depth == lowest_depth and is_vararg( + leaf, within=VARARGS_PARENTS ): trailing_comma_safe = trailing_comma_safe and py36 - leaf_priority = delimiters.get(id(leaf)) + leaf_priority = bt.delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: yield current_line @@ -1817,12 +2186,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: @dont_increase_indentation def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split standalone comments from the rest of the line.""" - for leaf in line.leaves: - if leaf.type == STANDALONE_COMMENT: - if leaf.bracket_depth == 0: - break - - else: + if not line.contains_standalone_comments(0): raise CannotSplit("Line does not have any standalone comments") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) @@ -1838,16 +2202,36 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) if current_line: yield current_line +def explode_split( + line: Line, py36: bool = False, omit: Collection[LeafID] = () +) -> Iterator[Line]: + """Split by rightmost bracket and immediately split contents by a delimiter.""" + new_lines = list(right_hand_split(line, py36, omit)) + if len(new_lines) != 3: + yield from new_lines + return + + yield new_lines[0] + + try: + yield from delimiter_split(new_lines[1], py36) + + except CannotSplit: + yield new_lines[1] + + yield new_lines[2] + + def is_import(leaf: Leaf) -> bool: """Return True if the given leaf starts an import statement.""" p = leaf.parent @@ -1880,6 +2264,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: leaf.prefix = "" +def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: + """Make all string prefixes lowercase. + + If remove_u_prefix is given, also removes any u prefix from the string. + + Note: Mutates its argument. + """ + match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) + assert match is not None, f"failed to match string {leaf.value!r}" + orig_prefix = match.group(1) + new_prefix = orig_prefix.lower() + if remove_u_prefix: + new_prefix = new_prefix.replace("u", "") + leaf.value = f"{new_prefix}{match.group(2)}" + + def normalize_string_quotes(leaf: Leaf) -> None: """Prefer double quotes but only if it doesn't cause more escaping. @@ -1905,10 +2305,28 @@ def normalize_string_quotes(leaf: Leaf) -> None: if first_quote_pos == -1: return # There's an internal error - body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] - new_body = body.replace(f"\\{orig_quote}", orig_quote).replace( - new_quote, f"\\{new_quote}" - ) + prefix = leaf.value[:first_quote_pos] + unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") + escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}") + escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}") + body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)] + if "r" in prefix.casefold(): + if unescaped_new_quote.search(body): + # There's at least one unescaped new_quote in this raw string + # so converting is impossible + return + + # Do not introduce or remove backslashes in raw strings + new_body = body + else: + # remove unnecessary quotes + new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body) + if body != new_body: + # Consider the string without unnecessary quotes as the original + body = new_body + leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}" + new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) + new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) if new_quote == '"""' and new_body[-1] == '"': # edge case: new_body = new_body[:-1] + '\\"' @@ -1920,16 +2338,215 @@ def normalize_string_quotes(leaf: Leaf) -> None: if new_escape_count == orig_escape_count and orig_quote == '"': return # Prefer double quotes - prefix = leaf.value[:first_quote_pos] leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}" +def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: + """Make existing optional parentheses invisible or create new ones. + + `parens_after` is a set of string leaf values immeditely after which parens + should be put. + + Standardizes on visible parentheses for single-element tuples, and keeps + existing visible parentheses for other tuples and generator expressions. + """ + check_lpar = False + for child in list(node.children): + if check_lpar: + if child.type == syms.atom: + maybe_make_parens_invisible_in_atom(child) + elif is_one_tuple(child): + # wrap child in visible parentheses + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + else: + # wrap child in invisible parentheses + lpar = Leaf(token.LPAR, "") + rpar = Leaf(token.RPAR, "") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + + check_lpar = isinstance(child, Leaf) and child.value in parens_after + + +def maybe_make_parens_invisible_in_atom(node: LN) -> bool: + """If it's safe, make the parens in the atom `node` invisible, recusively.""" + if ( + node.type != syms.atom + or is_empty_tuple(node) + or is_one_tuple(node) + or is_yield(node) + or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY + ): + return False + + first = node.children[0] + last = node.children[-1] + if first.type == token.LPAR and last.type == token.RPAR: + # make parentheses invisible + first.value = "" # type: ignore + last.value = "" # type: ignore + if len(node.children) > 1: + maybe_make_parens_invisible_in_atom(node.children[1]) + return True + + return False + + +def is_empty_tuple(node: LN) -> bool: + """Return True if `node` holds an empty tuple.""" + return ( + node.type == syms.atom + and len(node.children) == 2 + and node.children[0].type == token.LPAR + and node.children[1].type == token.RPAR + ) + + +def is_one_tuple(node: LN) -> bool: + """Return True if `node` holds a tuple with one element, with or without parens.""" + if node.type == syms.atom: + if len(node.children) != 3: + return False + + lpar, gexp, rpar = node.children + if not ( + lpar.type == token.LPAR + and gexp.type == syms.testlist_gexp + and rpar.type == token.RPAR + ): + return False + + return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA + + return ( + node.type in IMPLICIT_TUPLE + and len(node.children) == 2 + and node.children[1].type == token.COMMA + ) + + +def is_yield(node: LN) -> bool: + """Return True if `node` holds a `yield` or `yield from` expression.""" + if node.type == syms.yield_expr: + return True + + if node.type == token.NAME and node.value == "yield": # type: ignore + return True + + if node.type != syms.atom: + return False + + if len(node.children) != 3: + return False + + lpar, expr, rpar = node.children + if lpar.type == token.LPAR and rpar.type == token.RPAR: + return is_yield(expr) + + return False + + +def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: + """Return True if `leaf` is a star or double star in a vararg or kwarg. + + If `within` includes VARARGS_PARENTS, this applies to function signatures. + If `within` includes UNPACKING_PARENTS, it applies to right hand-side + extended iterable unpacking (PEP 3132) and additional unpacking + generalizations (PEP 448). + """ + if leaf.type not in STARS or not leaf.parent: + return False + + p = leaf.parent + if p.type == syms.star_expr: + # Star expressions are also used as assignment targets in extended + # iterable unpacking (PEP 3132). See what its parent is instead. + if not p.parent: + return False + + p = p.parent + + return p.type in within + + +def is_stub_suite(node: Node) -> bool: + """Return True if `node` is a suite with a stub body.""" + if ( + len(node.children) != 4 + or node.children[0].type != token.NEWLINE + or node.children[1].type != token.INDENT + or node.children[3].type != token.DEDENT + ): + return False + + return is_stub_body(node.children[2]) + + +def is_stub_body(node: LN) -> bool: + """Return True if `node` is a simple statement containing an ellipsis.""" + if not isinstance(node, Node) or node.type != syms.simple_stmt: + return False + + if len(node.children) != 2: + return False + + child = node.children[0] + return ( + child.type == syms.atom + and len(child.children) == 3 + and all(leaf == Leaf(token.DOT, ".") for leaf in child.children) + ) + + +def max_delimiter_priority_in_atom(node: LN) -> int: + """Return maximum delimiter priority inside `node`. + + This is specific to atoms with contents contained in a pair of parentheses. + If `node` isn't an atom or there are no enclosing parentheses, returns 0. + """ + if node.type != syms.atom: + return 0 + + first = node.children[0] + last = node.children[-1] + if not (first.type == token.LPAR and last.type == token.RPAR): + return 0 + + bt = BracketTracker() + for c in node.children[1:-1]: + if isinstance(c, Leaf): + bt.mark(c) + else: + for leaf in c.leaves(): + bt.mark(leaf) + try: + return bt.max_delimiter_priority() + + except ValueError: + return 0 + + +def ensure_visible(leaf: Leaf) -> None: + """Make sure parentheses are visible. + + They could be invisible as part of some statements (see + :func:`normalize_invible_parens` and :func:`visit_import_from`). + """ + if leaf.type == token.LPAR: + leaf.value = "(" + elif leaf.type == token.RPAR: + leaf.value = ")" + + def is_python36(node: Node) -> bool: """Return True if the current file is using Python 3.6+ features. Currently looking for: - f-strings; and - - trailing commas after * or ** in function signatures. + - trailing commas after * or ** in function signatures and calls. """ for n in node.pre_order(): if n.type == token.STRING: @@ -1938,18 +2555,119 @@ def is_python36(node: Node) -> bool: return True elif ( - n.type == syms.typedargslist + n.type in {syms.typedargslist, syms.arglist} and n.children and n.children[-1].type == token.COMMA ): for ch in n.children: - if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + if ch.type in STARS: return True + if ch.type == syms.argument: + for argch in ch.children: + if argch.type in STARS: + return True + return False -PYTHON_EXTENSIONS = {".py"} +def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]: + """Generate sets of closing bracket IDs that should be omitted in a RHS. + + Brackets can be omitted if the entire trailer up to and including + a preceding closing bracket fits in one line. + + Yielded sets are cumulative (contain results of previous yields, too). First + set is empty. + """ + + omit: Set[LeafID] = set() + yield omit + + length = 4 * line.depth + opening_bracket = None + closing_bracket = None + optional_brackets: Set[LeafID] = set() + inner_brackets: Set[LeafID] = set() + for index, leaf in enumerate_reversed(line.leaves): + length += len(leaf.prefix) + len(leaf.value) + if length > line_length: + break + + comment: Optional[Leaf] + for comment in line.comments_after(leaf, index): + if "\n" in comment.prefix: + break # Oops, standalone comment! + + length += len(comment.value) + else: + comment = None + if comment is not None: + break # There was a standalone comment, we can't continue. + + optional_brackets.discard(id(leaf)) + if opening_bracket: + if leaf is opening_bracket: + opening_bracket = None + elif leaf.type in CLOSING_BRACKETS: + inner_brackets.add(id(leaf)) + elif leaf.type in CLOSING_BRACKETS: + if not leaf.value: + optional_brackets.add(id(opening_bracket)) + continue + + if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS: + # Empty brackets would fail a split so treat them as "inner" + # brackets (e.g. only add them to the `omit` set if another + # pair of brackets was good enough. + inner_brackets.add(id(leaf)) + continue + + opening_bracket = leaf.opening_bracket + if closing_bracket: + omit.add(id(closing_bracket)) + omit.update(inner_brackets) + inner_brackets.clear() + yield omit + closing_bracket = leaf + + +def get_future_imports(node: Node) -> Set[str]: + """Return a set of __future__ imports in the file.""" + imports = set() + for child in node.children: + if child.type != syms.simple_stmt: + break + first_child = child.children[0] + if isinstance(first_child, Leaf): + # Continue looking if we see a docstring; otherwise stop. + if ( + len(child.children) == 2 + and first_child.type == token.STRING + and child.children[1].type == token.NEWLINE + ): + continue + else: + break + elif first_child.type == syms.import_from: + module_name = first_child.children[1] + if not isinstance(module_name, Leaf) or module_name.value != "__future__": + break + for import_from_child in first_child.children[3:]: + if isinstance(import_from_child, Leaf): + if import_from_child.type == token.NAME: + imports.add(import_from_child.value) + else: + assert import_from_child.type == syms.import_as_names + for leaf in import_from_child.children: + if isinstance(leaf, Leaf) and leaf.type == token.NAME: + imports.add(leaf.value) + else: + break + return imports + + +PYTHON_EXTENSIONS = {".py", ".pyi"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" } @@ -1966,7 +2684,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: yield from gen_python_files_in_dir(child) - elif child.suffix in PYTHON_EXTENSIONS: + elif child.is_file() and child.suffix in PYTHON_EXTENSIONS: yield child @@ -1974,18 +2692,25 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + quiet: bool = False change_count: int = 0 same_count: int = 0 failure_count: int = 0 - def done(self, src: Path, changed: bool) -> None: + def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" - if changed: + if changed is Changed.YES: reformatted = "would reformat" if self.check else "reformatted" - out(f"{reformatted} {src}") + if not self.quiet: + out(f"{reformatted} {src}") self.change_count += 1 else: - out(f"{src} already well formatted, good job.", bold=False) + if not self.quiet: + if changed is Changed.NO: + msg = f"{src} already well formatted, good job." + else: + msg = f"{src} wasn't modified on disk since last run." + out(msg, bold=False) self.same_count += 1 def failed(self, src: Path, message: str) -> None: @@ -2105,9 +2830,9 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, line_length: int) -> None: +def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, line_length=line_length) + newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi) if dst != newdst: log = dump_to_file( diff(src, dst, "source", "first pass"), @@ -2126,7 +2851,7 @@ def dump_to_file(*output: str) -> str: import tempfile with tempfile.NamedTemporaryFile( - mode="w", prefix="blk_", suffix=".log", delete=False + mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) @@ -2146,7 +2871,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) -def cancel(tasks: List[asyncio.Task]) -> None: +def cancel(tasks: Iterable[asyncio.Task]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: @@ -2175,5 +2900,98 @@ def shutdown(loop: BaseEventLoop) -> None: loop.close() +def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: + """Replace `regex` with `replacement` twice on `original`. + + This is used by string normalization to perform replaces on + overlapping matches. + """ + return regex.sub(replacement, regex.sub(replacement, original)) + + +def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: + """Like `reversed(enumerate(sequence))` if that were possible.""" + index = len(sequence) - 1 + for element in reversed(sequence): + yield (index, element) + index -= 1 + + +def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool: + """Return True if `line` is no longer than `line_length`. + + Uses the provided `line_str` rendering, if any, otherwise computes a new one. + """ + if not line_str: + line_str = str(line).strip("\n") + return ( + len(line_str) <= line_length + and "\n" not in line_str # multiline strings + and not line.contains_standalone_comments() + ) + + +CACHE_DIR = Path(user_cache_dir("black", version=__version__)) + + +def get_cache_file(line_length: int) -> Path: + return CACHE_DIR / f"cache.{line_length}.pickle" + + +def read_cache(line_length: int) -> Cache: + """Read the cache if it exists and is well formed. + + If it is not well formed, the call to write_cache later should resolve the issue. + """ + cache_file = get_cache_file(line_length) + if not cache_file.exists(): + return {} + + with cache_file.open("rb") as fobj: + try: + cache: Cache = pickle.load(fobj) + except pickle.UnpicklingError: + return {} + + return cache + + +def get_cache_info(path: Path) -> CacheInfo: + """Return the information used to check if a file is already formatted or not.""" + stat = path.stat() + return stat.st_mtime, stat.st_size + + +def filter_cached( + cache: Cache, sources: Iterable[Path] +) -> Tuple[List[Path], List[Path]]: + """Split a list of paths into two. + + The first list contains paths of files that modified on disk or are not in the + cache. The other list contains paths to non-modified files. + """ + todo, done = [], [] + for src in sources: + src = src.resolve() + if cache.get(src) != get_cache_info(src): + todo.append(src) + else: + done.append(src) + return todo, done + + +def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None: + """Update the cache file.""" + cache_file = get_cache_file(line_length) + try: + if not CACHE_DIR.exists(): + CACHE_DIR.mkdir(parents=True) + new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + with cache_file.open("wb") as fobj: + pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL) + except OSError: + pass + + if __name__ == "__main__": main()