X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/639b62dcd32cde3645e9f9a633eee33c04d23901..9f096d55365cb63548eef97e254c2793ae2776a0:/black.py diff --git a/black.py b/black.py index e795fa3..5e087d1 100644 --- a/black.py +++ b/black.py @@ -43,8 +43,9 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.4a2" +__version__ = "18.4a4" DEFAULT_LINE_LENGTH = 88 + # types syms = pygram.python_symbols FileContent = str @@ -88,11 +89,11 @@ class FormatError(Exception): self.consumed = consumed def trim_prefix(self, leaf: Leaf) -> None: - leaf.prefix = leaf.prefix[self.consumed:] + leaf.prefix = leaf.prefix[self.consumed :] def leaf_from_consumed(self, leaf: Leaf) -> Leaf: """Returns a new Leaf from the consumed part of the prefix.""" - unformatted_prefix = leaf.prefix[:self.consumed] + unformatted_prefix = leaf.prefix[: self.consumed] return Leaf(token.NEWLINE, unformatted_prefix) @@ -184,94 +185,82 @@ def main( sources.append(Path("-")) else: err(f"invalid path: {s}") - if check and diff: - exc = click.ClickException("Options --check and --diff are mutually exclusive") - exc.exit_code = 2 - raise exc - if check: + if check and not diff: write_back = WriteBack.NO elif diff: write_back = WriteBack.DIFF else: write_back = WriteBack.YES + report = Report(check=check, quiet=quiet) if len(sources) == 0: ctx.exit(0) return elif len(sources) == 1: - return_code = run_single_file_mode( - line_length, check, fast, quiet, write_back, sources[0] - ) + reformat_one(sources[0], line_length, fast, write_back, report) else: - return_code = run_multi_file_mode(line_length, fast, quiet, write_back, sources) - ctx.exit(return_code) + loop = asyncio.get_event_loop() + executor = ProcessPoolExecutor(max_workers=os.cpu_count()) + try: + loop.run_until_complete( + schedule_formatting( + sources, line_length, fast, write_back, report, loop, executor + ) + ) + finally: + shutdown(loop) + if not quiet: + out("All done! ✨ 🍰 ✨") + click.echo(str(report)) + ctx.exit(report.return_code) -def run_single_file_mode( - line_length: int, - check: bool, - fast: bool, - quiet: bool, - write_back: WriteBack, - src: Path, -) -> int: - report = Report(check=check, quiet=quiet) +def reformat_one( + src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report" +) -> None: + """Reformat a single file under `src` without spawning child processes. + + If `quiet` is True, non-error messages are not output. `line_length`, + `write_back`, and `fast` options are passed to :func:`format_file_in_place`. + """ try: + changed = Changed.NO if not src.is_file() and str(src) == "-": - changed = format_stdin_to_stdout( + if format_stdin_to_stdout( line_length=line_length, fast=fast, write_back=write_back - ) + ): + changed = Changed.YES else: - changed = Changed.NO cache: Cache = {} if write_back != WriteBack.DIFF: - cache = read_cache() + cache = read_cache(line_length) src = src.resolve() if src in cache and cache[src] == get_cache_info(src): changed = Changed.CACHED - if changed is not Changed.CACHED: - changed = format_file_in_place( + if ( + changed is not Changed.CACHED + and format_file_in_place( src, line_length=line_length, fast=fast, write_back=write_back ) - if write_back != WriteBack.DIFF and changed is not Changed.NO: - write_cache(cache, [src]) + ): + changed = Changed.YES + if write_back == WriteBack.YES and changed is not Changed.NO: + write_cache(cache, [src], line_length) report.done(src, changed) except Exception as exc: report.failed(src, str(exc)) - return report.return_code - - -def run_multi_file_mode( - line_length: int, - fast: bool, - quiet: bool, - write_back: WriteBack, - sources: List[Path], -) -> int: - loop = asyncio.get_event_loop() - executor = ProcessPoolExecutor(max_workers=os.cpu_count()) - return_code = 1 - try: - return_code = loop.run_until_complete( - schedule_formatting( - sources, line_length, write_back, fast, quiet, loop, executor - ) - ) - finally: - shutdown(loop) - return return_code async def schedule_formatting( sources: List[Path], line_length: int, - write_back: WriteBack, fast: bool, - quiet: bool, + write_back: WriteBack, + report: "Report", loop: BaseEventLoop, executor: Executor, -) -> int: +) -> None: """Run formatting of `sources` in parallel using the provided `executor`. (Use ProcessPoolExecutors for actual parallelism.) @@ -279,10 +268,9 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ - report = Report(check=write_back is WriteBack.NO, quiet=quiet) cache: Cache = {} if write_back != WriteBack.DIFF: - cache = read_cache() + cache = read_cache(line_length) sources, cached = filter_cached(cache, sources) for src in cached: report.done(src, Changed.CACHED) @@ -302,8 +290,12 @@ async def schedule_formatting( for src in sources } _task_values = list(tasks.values()) - loop.add_signal_handler(signal.SIGINT, cancel, _task_values) - loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) + try: + loop.add_signal_handler(signal.SIGINT, cancel, _task_values) + loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) + except NotImplementedError: + # There are no good alternatives for these on Windows + pass await asyncio.wait(_task_values) for src, task in tasks.items(): if not task.done(): @@ -316,19 +308,12 @@ async def schedule_formatting( report.failed(src, str(task.exception())) else: formatted.append(src) - report.done(src, task.result()) + report.done(src, Changed.YES if task.result() else Changed.NO) if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - elif not quiet: - out("All done! ✨ 🍰 ✨") - if not quiet: - click.echo(str(report)) - - if write_back != WriteBack.DIFF and formatted: - write_cache(cache, formatted) - - return report.return_code + if write_back == WriteBack.YES and formatted: + write_cache(cache, formatted, line_length) def format_file_in_place( @@ -337,7 +322,7 @@ def format_file_in_place( fast: bool, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy -) -> Changed: +) -> bool: """Format file under `src` path. Return True if changed. If `write_back` is True, write reformatted code back to stdout. @@ -351,14 +336,14 @@ def format_file_in_place( src_contents, line_length=line_length, fast=fast ) except NothingChanged: - return Changed.NO + return False if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: f.write(dst_contents) elif write_back == write_back.DIFF: - src_name = f"{src.name} (original)" - dst_name = f"{src.name} (formatted)" + src_name = f"{src} (original)" + dst_name = f"{src} (formatted)" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) if lock: lock.acquire() @@ -367,12 +352,12 @@ def format_file_in_place( finally: if lock: lock.release() - return Changed.YES + return True def format_stdin_to_stdout( line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO -) -> Changed: +) -> bool: """Format file on stdin. Return True if changed. If `write_back` is True, write reformatted code back to stdout. @@ -382,10 +367,10 @@ def format_stdin_to_stdout( dst = src try: dst = format_file_contents(src, line_length=line_length, fast=fast) - return Changed.YES + return True except NothingChanged: - return Changed.NO + return False finally: if write_back == WriteBack.YES: @@ -444,7 +429,6 @@ def format_str(src_contents: str, line_length: int) -> FileContent: GRAMMARS = [ pygram.python_grammar_no_print_statement_no_exec_statement, pygram.python_grammar_no_print_statement, - pygram.python_grammar_no_exec_statement, pygram.python_grammar, ] @@ -598,8 +582,26 @@ UNPACKING_PARENTS = { syms.listmaker, syms.testlist_gexp, } +TEST_DESCENDANTS = { + syms.test, + syms.lambdef, + syms.or_test, + syms.and_test, + syms.not_test, + syms.comparison, + syms.star_expr, + syms.expr, + syms.xor_expr, + syms.and_expr, + syms.shift_expr, + syms.arith_expr, + syms.trailer, + syms.term, + syms.power, +} COMPREHENSION_PRIORITY = 20 COMMA_PRIORITY = 10 +TERNARY_PRIORITY = 7 LOGIC_PRIORITY = 5 STRING_PRIORITY = 4 COMPARATOR_PRIORITY = 3 @@ -614,6 +616,8 @@ class BracketTracker: bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) delimiters: Dict[LeafID, Priority] = Factory(dict) previous: Optional[Leaf] = None + _for_loop_variable: bool = False + _lambda_arguments: bool = False def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -633,6 +637,8 @@ class BracketTracker: if leaf.type == token.COMMENT: return + self.maybe_decrement_after_for_loop_variable(leaf) + self.maybe_decrement_after_lambda_arguments(leaf) if leaf.type in CLOSING_BRACKETS: self.depth -= 1 opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) @@ -650,6 +656,8 @@ class BracketTracker: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 self.previous = leaf + self.maybe_increment_lambda_arguments(leaf) + self.maybe_increment_for_loop_variable(leaf) def any_open_brackets(self) -> bool: """Return True if there is an yet unmatched open bracket on the line.""" @@ -658,11 +666,59 @@ class BracketTracker: def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: """Return the highest priority of a delimiter found on the line. - Values are consistent with what `is_delimiter()` returns. + Values are consistent with what `is_split_*_delimiter()` return. Raises ValueError on no delimiters. """ return max(v for k, v in self.delimiters.items() if k not in exclude) + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """In a for loop, or comprehension, the variables are often unpacks. + + To avoid splitting on the comma in this situation, increase the depth of + tokens between `for` and `in`. + """ + if leaf.type == token.NAME and leaf.value == "for": + self.depth += 1 + self._for_loop_variable = True + return True + + return False + + def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: + """See `maybe_increment_for_loop_variable` above for explanation.""" + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": + self.depth -= 1 + self._for_loop_variable = False + return True + + return False + + def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool: + """In a lambda expression, there might be more than one argument. + + To avoid splitting on the comma in this situation, increase the depth of + tokens between `lambda` and `:`. + """ + if leaf.type == token.NAME and leaf.value == "lambda": + self.depth += 1 + self._lambda_arguments = True + return True + + return False + + def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool: + """See `maybe_increment_lambda_arguments` above for explanation.""" + if self._lambda_arguments and leaf.type == token.COLON: + self.depth -= 1 + self._lambda_arguments = False + return True + + return False + + def get_open_lsqb(self) -> Optional[Leaf]: + """Return the most recent opening square bracket (if any).""" + return self.bracket_match.get((self.depth - 1, token.RSQB)) + @dataclass class Line: @@ -673,8 +729,6 @@ class Line: comments: List[Tuple[Index, Leaf]] = Factory(list) bracket_tracker: BracketTracker = Factory(BracketTracker) inside_brackets: bool = False - has_for: bool = False - _for_loop_variable: bool = False def append(self, leaf: Leaf, preformatted: bool = False) -> None: """Add a new `leaf` to the end of the line. @@ -693,12 +747,12 @@ class Line: if self.leaves and not preformatted: # Note: at this point leaf.prefix should be empty except for # imports, for which we only preserve newlines. - leaf.prefix += whitespace(leaf) + leaf.prefix += whitespace( + leaf, complex_subscript=self.is_complex_subscript(leaf) + ) if self.inside_brackets or not preformatted: - self.maybe_decrement_after_for_loop_variable(leaf) self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) - self.maybe_increment_for_loop_variable(leaf) if not self.append_comment(leaf): self.leaves.append(leaf) @@ -828,7 +882,7 @@ class Line: else: return False - for leaf in self.leaves[_opening_index + 1:]: + for leaf in self.leaves[_opening_index + 1 :]: if leaf is closing: break @@ -845,29 +899,6 @@ class Line: return False - def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: - """In a for loop, or comprehension, the variables are often unpacks. - - To avoid splitting on the comma in this situation, increase the depth of - tokens between `for` and `in`. - """ - if leaf.type == token.NAME and leaf.value == "for": - self.has_for = True - self.bracket_tracker.depth += 1 - self._for_loop_variable = True - return True - - return False - - def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: - """See `maybe_increment_for_loop_variable` above for explanation.""" - if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": - self.bracket_tracker.depth -= 1 - self._for_loop_variable = False - return True - - return False - def append_comment(self, comment: Leaf) -> bool: """Add an inline or standalone comment to the line.""" if ( @@ -912,6 +943,24 @@ class Line: self.comments[i] = (comma_index - 1, comment) self.leaves.pop() + def is_complex_subscript(self, leaf: Leaf) -> bool: + """Return True iff `leaf` is part of a slice with non-trivial exprs.""" + open_lsqb = ( + leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb() + ) + if open_lsqb is None: + return False + + subscript_start = open_lsqb.next_sibling + if ( + isinstance(subscript_start, Node) + and subscript_start.type == syms.subscriptlist + ): + subscript_start = child_towards(subscript_start, leaf) + return subscript_start is not None and any( + n.type in TEST_DESCENDANTS for n in subscript_start.pre_order() + ) + def __str__(self) -> str: """Render the line.""" if not self: @@ -1034,8 +1083,14 @@ class EmptyLineTracker: # Don't insert empty lines before the first line in the file. return 0, 0 - if self.previous_line and self.previous_line.is_decorator: - # Don't insert empty lines between decorators. + if self.previous_line.is_decorator: + return 0, 0 + + if ( + self.previous_line.is_comment + and self.previous_line.depth == current_line.depth + and before == 0 + ): return 0, 0 newlines = 2 @@ -1043,9 +1098,6 @@ class EmptyLineTracker: newlines -= 1 return newlines, 0 - if current_line.is_flow_control: - return before, 1 - if ( self.previous_line and self.previous_line.is_import @@ -1054,13 +1106,6 @@ class EmptyLineTracker: ): return (before or 1), 0 - if ( - self.previous_line - and self.previous_line.is_yield - and (not current_line.is_yield or depth != self.previous_line.depth) - ): - return (before or 1), 0 - return before, 0 @@ -1152,7 +1197,16 @@ class LineGenerator(Visitor[Line]): def visit_DEDENT(self, node: Node) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" - # DEDENT has no value. Additionally, in blib2to3 it never holds comments. + # The current line might still wait for trailing comments. At DEDENT time + # there won't be any (they would be prefixes on the preceding NEWLINE). + # Emit the line then. + yield from self.line() + + # While DEDENT has no value, its prefix may contain standalone comments + # that belong to the current indentation level. Get 'em. + yield from self.visit_default(node) + + # Finally, emit the dedent. yield from self.line(-1) def visit_stmt( @@ -1290,8 +1344,12 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} -def whitespace(leaf: Leaf) -> str: # noqa C901 - """Return whitespace prefix if needed for the given `leaf`.""" +def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901 + """Return whitespace prefix if needed for the given `leaf`. + + `complex_subscript` signals whether the given leaf is part of a subscription + which has non-trivial arguments, like arithmetic expressions or function calls. + """ NO = "" SPACE = " " DOUBLESPACE = " " @@ -1305,7 +1363,10 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: + if ( + t == token.COLON + and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop} + ): return NO prev = leaf.prev_sibling @@ -1315,7 +1376,13 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO if t == token.COLON: - return SPACE if prevp.type == token.COMMA else NO + if prevp.type == token.COLON: + return NO + + elif prevp.type != token.COMMA and not complex_subscript: + return NO + + return SPACE if prevp.type == token.EQUAL: if prevp.parent: @@ -1336,7 +1403,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 elif prevp.type == token.COLON: if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: - return NO + return SPACE if complex_subscript else NO elif ( prevp.parent @@ -1361,17 +1428,11 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if p.type in {syms.parameters, syms.arglist}: # untyped function signatures or calls - if t == token.RPAR: - return NO - if not prev or prev.type != token.COMMA: return NO elif p.type == syms.varargslist: # lambdas - if t == token.RPAR: - return NO - if prev and prev.type != token.COMMA: return NO @@ -1448,7 +1509,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if prev and prev.type == token.LPAR: return NO - elif p.type == syms.subscript: + elif p.type in {syms.subscript, syms.sliceop}: # indexing if not prev: assert p.parent is not None, "subscripts are always parented" @@ -1457,7 +1518,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO - else: + elif not complex_subscript: return NO elif p.type == syms.atom: @@ -1465,21 +1526,9 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 # dots, but not the first one. return NO - elif ( - p.type == syms.listmaker - or p.type == syms.testlist_gexp - or p.type == syms.subscriptlist - ): - # list interior, including unpacking - if not prev: - return NO - elif p.type == syms.dictsetmaker: - # dict and set interior, including unpacking - if not prev: - return NO - - if prev.type == token.DOUBLESTAR: + # dict unpacking + if prev and prev.type == token.DOUBLESTAR: return NO elif p.type in {syms.factor, syms.star_expr}: @@ -1539,6 +1588,14 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None +def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]: + """Return the child of `ancestor` that contains `descendant`.""" + node: Optional[LN] = descendant + while node and node.parent != ancestor: + node = node.parent + return node + + def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: """Return the priority of the `leaf` delimiter, given a line break after it. @@ -1599,23 +1656,20 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: ): return COMPREHENSION_PRIORITY + if ( + leaf.type == token.NAME + and leaf.value in {"if", "else"} + and leaf.parent + and leaf.parent.type == syms.test + ): + return TERNARY_PRIORITY + if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: return LOGIC_PRIORITY return 0 -def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: - """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. - - Higher numbers are higher priority. - """ - return max( - is_split_before_delimiter(leaf, previous), - is_split_after_delimiter(leaf, previous), - ) - - def generate_comments(leaf: Leaf) -> Iterator[Leaf]: """Clean the prefix of the `leaf` and generate comments from it, if any. @@ -1721,6 +1775,8 @@ def split_line( split_funcs: List[SplitFunc] if line.is_def: split_funcs = [left_hand_split] + elif line.is_import: + split_funcs = [explode_split] elif line.inside_brackets: split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: @@ -1987,6 +2043,26 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: yield current_line +def explode_split( + line: Line, py36: bool = False, omit: Collection[LeafID] = () +) -> Iterator[Line]: + """Split by rightmost bracket and immediately split contents by a delimiter.""" + new_lines = list(right_hand_split(line, py36, omit)) + if len(new_lines) != 3: + yield from new_lines + return + + yield new_lines[0] + + try: + yield from delimiter_split(new_lines[1], py36) + + except CannotSplit: + yield new_lines[1] + + yield new_lines[2] + + def is_import(leaf: Leaf) -> bool: """Return True if the given leaf starts an import statement.""" p = leaf.parent @@ -2048,7 +2124,7 @@ def normalize_string_quotes(leaf: Leaf) -> None: unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}") escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}") - body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] + body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)] if "r" in prefix.casefold(): if unescaped_new_quote.search(body): # There's at least one unescaped new_quote in this raw string @@ -2174,6 +2250,11 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: def max_delimiter_priority_in_atom(node: LN) -> int: + """Return maximum delimiter priority inside `node`. + + This is specific to atoms with contents contained in a pair of parentheses. + If `node` isn't an atom or there are no enclosing parentheses, returns 0. + """ if node.type != syms.atom: return 0 @@ -2476,18 +2557,22 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: CACHE_DIR = Path(user_cache_dir("black", version=__version__)) -CACHE_FILE = CACHE_DIR / "cache.pickle" -def read_cache() -> Cache: +def get_cache_file(line_length: int) -> Path: + return CACHE_DIR / f"cache.{line_length}.pickle" + + +def read_cache(line_length: int) -> Cache: """Read the cache if it exists and is well formed. If it is not well formed, the call to write_cache later should resolve the issue. """ - if not CACHE_FILE.exists(): + cache_file = get_cache_file(line_length) + if not cache_file.exists(): return {} - with CACHE_FILE.open("rb") as fobj: + with cache_file.open("rb") as fobj: try: cache: Cache = pickle.load(fobj) except pickle.UnpicklingError: @@ -2520,13 +2605,14 @@ def filter_cached( return todo, done -def write_cache(cache: Cache, sources: List[Path]) -> None: +def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None: """Update the cache file.""" + cache_file = get_cache_file(line_length) try: if not CACHE_DIR.exists(): CACHE_DIR.mkdir(parents=True) new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} - with CACHE_FILE.open("wb") as fobj: + with cache_file.open("wb") as fobj: pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL) except OSError: pass