X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/3455389e486e0bb1d8a8318cb5f266b7ec8964dd..0967dfcbeba8aceaacd468b279cc23089d697878:/black.py?ds=sidebyside diff --git a/black.py b/black.py index 3d5ea14..d2d23c8 100644 --- a/black.py +++ b/black.py @@ -1,24 +1,29 @@ -#!/usr/bin/env python3 - import asyncio +import pickle from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from enum import Enum from functools import partial, wraps import keyword import logging +from multiprocessing import Manager import os from pathlib import Path +import re import tokenize import signal import sys from typing import ( + Any, Callable, + Collection, Dict, Generic, Iterable, Iterator, List, Optional, + Pattern, Set, Tuple, Type, @@ -26,6 +31,7 @@ from typing import ( Union, ) +from appdirs import user_cache_dir from attr import dataclass, Factory import click @@ -35,8 +41,9 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.3a4" +__version__ = "18.4a6" DEFAULT_LINE_LENGTH = 88 + # types syms = pygram.python_symbols FileContent = str @@ -48,6 +55,10 @@ Priority = int Index = int LN = Union[Leaf, Node] SplitFunc = Callable[["Line", bool], Iterator["Line"]] +Timestamp = float +FileSize = int +CacheInfo = Tuple[Timestamp, FileSize] +Cache = Dict[Path, CacheInfo] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg="red", err=True) @@ -76,11 +87,11 @@ class FormatError(Exception): self.consumed = consumed def trim_prefix(self, leaf: Leaf) -> None: - leaf.prefix = leaf.prefix[self.consumed:] + leaf.prefix = leaf.prefix[self.consumed :] def leaf_from_consumed(self, leaf: Leaf) -> Leaf: """Returns a new Leaf from the consumed part of the prefix.""" - unformatted_prefix = leaf.prefix[:self.consumed] + unformatted_prefix = leaf.prefix[: self.consumed] return Leaf(token.NEWLINE, unformatted_prefix) @@ -92,6 +103,18 @@ class FormatOff(FormatError): """Found a comment like `# fmt: off` in the file.""" +class WriteBack(Enum): + NO = 0 + YES = 1 + DIFF = 2 + + +class Changed(Enum): + NO = 0 + CACHED = 1 + YES = 2 + + @click.command() @click.option( "-l", @@ -105,16 +128,30 @@ class FormatOff(FormatError): "--check", is_flag=True, help=( - "Don't write back the files, just return the status. Return code 0 " + "Don't write the files back, just return the status. Return code 0 " "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) +@click.option( + "--diff", + is_flag=True, + help="Don't write the files back, just output a diff for each file on stdout.", +) @click.option( "--fast/--safe", is_flag=True, help="If --fast given, skip temporary sanity checks. [default: --safe]", ) +@click.option( + "-q", + "--quiet", + is_flag=True, + help=( + "Don't emit non-error messages to stderr. Errors are still emitted, " + "silence those with 2>/dev/null." + ), +) @click.version_option(version=__version__) @click.argument( "src", @@ -125,7 +162,13 @@ class FormatOff(FormatError): ) @click.pass_context def main( - ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] + ctx: click.Context, + line_length: int, + check: bool, + diff: bool, + fast: bool, + quiet: bool, + src: List[str], ) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] @@ -140,47 +183,83 @@ def main( sources.append(Path("-")) else: err(f"invalid path: {s}") + + if check and not diff: + write_back = WriteBack.NO + elif diff: + write_back = WriteBack.DIFF + else: + write_back = WriteBack.YES + report = Report(check=check, quiet=quiet) if len(sources) == 0: + out("No paths given. Nothing to do 😴") ctx.exit(0) + return + elif len(sources) == 1: - p = sources[0] - report = Report(check=check) - try: - if not p.is_file() and str(p) == "-": - changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=not check - ) - else: - changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check - ) - report.done(p, changed) - except Exception as exc: - report.failed(p, str(exc)) - ctx.exit(report.return_code) + reformat_one(sources[0], line_length, fast, write_back, report) else: loop = asyncio.get_event_loop() executor = ProcessPoolExecutor(max_workers=os.cpu_count()) - return_code = 1 try: - return_code = loop.run_until_complete( + loop.run_until_complete( schedule_formatting( - sources, line_length, not check, fast, loop, executor + sources, line_length, fast, write_back, report, loop, executor ) ) finally: shutdown(loop) - ctx.exit(return_code) + if not quiet: + out("All done! ✨ 🍰 ✨") + click.echo(str(report)) + ctx.exit(report.return_code) + + +def reformat_one( + src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report" +) -> None: + """Reformat a single file under `src` without spawning child processes. + + If `quiet` is True, non-error messages are not output. `line_length`, + `write_back`, and `fast` options are passed to :func:`format_file_in_place`. + """ + try: + changed = Changed.NO + if not src.is_file() and str(src) == "-": + if format_stdin_to_stdout( + line_length=line_length, fast=fast, write_back=write_back + ): + changed = Changed.YES + else: + cache: Cache = {} + if write_back != WriteBack.DIFF: + cache = read_cache(line_length) + src = src.resolve() + if src in cache and cache[src] == get_cache_info(src): + changed = Changed.CACHED + if ( + changed is not Changed.CACHED + and format_file_in_place( + src, line_length=line_length, fast=fast, write_back=write_back + ) + ): + changed = Changed.YES + if write_back == WriteBack.YES and changed is not Changed.NO: + write_cache(cache, [src], line_length) + report.done(src, changed) + except Exception as exc: + report.failed(src, str(exc)) async def schedule_formatting( sources: List[Path], line_length: int, - write_back: bool, fast: bool, + write_back: WriteBack, + report: "Report", loop: BaseEventLoop, executor: Executor, -) -> int: +) -> None: """Run formatting of `sources` in parallel using the provided `executor`. (Use ProcessPoolExecutors for actual parallelism.) @@ -188,79 +267,117 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ - tasks = { - src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back - ) - for src in sources - } - _task_values = list(tasks.values()) - loop.add_signal_handler(signal.SIGINT, cancel, _task_values) - loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) - await asyncio.wait(tasks.values()) + cache: Cache = {} + if write_back != WriteBack.DIFF: + cache = read_cache(line_length) + sources, cached = filter_cached(cache, sources) + for src in cached: + report.done(src, Changed.CACHED) cancelled = [] - report = Report(check=not write_back) - for src, task in tasks.items(): - if not task.done(): - report.failed(src, "timed out, cancelling") - task.cancel() - cancelled.append(task) - elif task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - report.done(src, task.result()) + formatted = [] + if sources: + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() + tasks = { + src: loop.run_in_executor( + executor, format_file_in_place, src, line_length, fast, write_back, lock + ) + for src in sources + } + _task_values = list(tasks.values()) + try: + loop.add_signal_handler(signal.SIGINT, cancel, _task_values) + loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) + except NotImplementedError: + # There are no good alternatives for these on Windows + pass + await asyncio.wait(_task_values) + for src, task in tasks.items(): + if not task.done(): + report.failed(src, "timed out, cancelling") + task.cancel() + cancelled.append(task) + elif task.cancelled(): + cancelled.append(task) + elif task.exception(): + report.failed(src, str(task.exception())) + else: + formatted.append(src) + report.done(src, Changed.YES if task.result() else Changed.NO) + if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - else: - out("All done! ✨ 🍰 ✨") - click.echo(str(report)) - return report.return_code + if write_back == WriteBack.YES and formatted: + write_cache(cache, formatted, line_length) def format_file_in_place( - src: Path, line_length: int, fast: bool, write_back: bool = False + src: Path, + line_length: int, + fast: bool, + write_back: WriteBack = WriteBack.NO, + lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: """Format file under `src` path. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ + with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: - contents = format_file_contents( + dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast ) except NothingChanged: return False - if write_back: + if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: - f.write(contents) + f.write(dst_contents) + elif write_back == write_back.DIFF: + src_name = f"{src} (original)" + dst_name = f"{src} (formatted)" + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if lock: + lock.acquire() + try: + sys.stdout.write(diff_contents) + finally: + if lock: + lock.release() return True def format_stdin_to_stdout( - line_length: int, fast: bool, write_back: bool = False + line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO ) -> bool: """Format file on stdin. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` arguments are passed to :func:`format_file_contents`. """ - contents = sys.stdin.read() + src = sys.stdin.read() + dst = src try: - contents = format_file_contents(contents, line_length=line_length, fast=fast) + dst = format_file_contents(src, line_length=line_length, fast=fast) return True except NothingChanged: return False finally: - if write_back: - sys.stdout.write(contents) + if write_back == WriteBack.YES: + sys.stdout.write(dst) + elif write_back == WriteBack.DIFF: + src_name = " (original)" + dst_name = " (formatted)" + sys.stdout.write(diff(src, dst, src_name, dst_name)) def format_file_contents( @@ -311,7 +428,6 @@ def format_str(src_contents: str, line_length: int) -> FileContent: GRAMMARS = [ pygram.python_grammar_no_print_statement_no_exec_statement, pygram.python_grammar_no_print_statement, - pygram.python_grammar_no_exec_statement, pygram.python_grammar, ] @@ -451,9 +567,40 @@ MATH_OPERATORS = { token.DOUBLESTAR, token.DOUBLESLASH, } -VARARGS = {token.STAR, token.DOUBLESTAR} +STARS = {token.STAR, token.DOUBLESTAR} +VARARGS_PARENTS = { + syms.arglist, + syms.argument, # double star in arglist + syms.trailer, # single argument to call + syms.typedargslist, + syms.varargslist, # lambdas +} +UNPACKING_PARENTS = { + syms.atom, # single element of a list or set literal + syms.dictsetmaker, + syms.listmaker, + syms.testlist_gexp, +} +TEST_DESCENDANTS = { + syms.test, + syms.lambdef, + syms.or_test, + syms.and_test, + syms.not_test, + syms.comparison, + syms.star_expr, + syms.expr, + syms.xor_expr, + syms.and_expr, + syms.shift_expr, + syms.arith_expr, + syms.trailer, + syms.term, + syms.power, +} COMPREHENSION_PRIORITY = 20 COMMA_PRIORITY = 10 +TERNARY_PRIORITY = 7 LOGIC_PRIORITY = 5 STRING_PRIORITY = 4 COMPARATOR_PRIORITY = 3 @@ -468,6 +615,8 @@ class BracketTracker: bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) delimiters: Dict[LeafID, Priority] = Factory(dict) previous: Optional[Leaf] = None + _for_loop_variable: bool = False + _lambda_arguments: bool = False def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -487,22 +636,27 @@ class BracketTracker: if leaf.type == token.COMMENT: return + self.maybe_decrement_after_for_loop_variable(leaf) + self.maybe_decrement_after_lambda_arguments(leaf) if leaf.type in CLOSING_BRACKETS: self.depth -= 1 opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) leaf.opening_bracket = opening_bracket leaf.bracket_depth = self.depth if self.depth == 0: - after_delim = is_split_after_delimiter(leaf, self.previous) - before_delim = is_split_before_delimiter(leaf, self.previous) - if after_delim > before_delim: - self.delimiters[id(leaf)] = after_delim - elif before_delim > after_delim and self.previous is not None: - self.delimiters[id(self.previous)] = before_delim + delim = is_split_before_delimiter(leaf, self.previous) + if delim and self.previous is not None: + self.delimiters[id(self.previous)] = delim + else: + delim = is_split_after_delimiter(leaf, self.previous) + if delim: + self.delimiters[id(leaf)] = delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 self.previous = leaf + self.maybe_increment_lambda_arguments(leaf) + self.maybe_increment_for_loop_variable(leaf) def any_open_brackets(self) -> bool: """Return True if there is an yet unmatched open bracket on the line.""" @@ -511,10 +665,59 @@ class BracketTracker: def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: """Return the highest priority of a delimiter found on the line. - Values are consistent with what `is_delimiter()` returns. + Values are consistent with what `is_split_*_delimiter()` return. + Raises ValueError on no delimiters. """ return max(v for k, v in self.delimiters.items() if k not in exclude) + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """In a for loop, or comprehension, the variables are often unpacks. + + To avoid splitting on the comma in this situation, increase the depth of + tokens between `for` and `in`. + """ + if leaf.type == token.NAME and leaf.value == "for": + self.depth += 1 + self._for_loop_variable = True + return True + + return False + + def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: + """See `maybe_increment_for_loop_variable` above for explanation.""" + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": + self.depth -= 1 + self._for_loop_variable = False + return True + + return False + + def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool: + """In a lambda expression, there might be more than one argument. + + To avoid splitting on the comma in this situation, increase the depth of + tokens between `lambda` and `:`. + """ + if leaf.type == token.NAME and leaf.value == "lambda": + self.depth += 1 + self._lambda_arguments = True + return True + + return False + + def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool: + """See `maybe_increment_lambda_arguments` above for explanation.""" + if self._lambda_arguments and leaf.type == token.COLON: + self.depth -= 1 + self._lambda_arguments = False + return True + + return False + + def get_open_lsqb(self) -> Optional[Leaf]: + """Return the most recent opening square bracket (if any).""" + return self.bracket_match.get((self.depth - 1, token.RSQB)) + @dataclass class Line: @@ -525,8 +728,6 @@ class Line: comments: List[Tuple[Index, Leaf]] = Factory(list) bracket_tracker: BracketTracker = Factory(BracketTracker) inside_brackets: bool = False - has_for: bool = False - _for_loop_variable: bool = False def append(self, leaf: Leaf, preformatted: bool = False) -> None: """Add a new `leaf` to the end of the line. @@ -538,20 +739,21 @@ class Line: Inline comments are put aside. """ - has_value = leaf.value.strip() + has_value = leaf.type in BRACKETS or bool(leaf.value.strip()) if not has_value: return + if token.COLON == leaf.type and self.is_class_paren_empty: + del self.leaves[-2:] if self.leaves and not preformatted: # Note: at this point leaf.prefix should be empty except for # imports, for which we only preserve newlines. - leaf.prefix += whitespace(leaf) + leaf.prefix += whitespace( + leaf, complex_subscript=self.is_complex_subscript(leaf) + ) if self.inside_brackets or not preformatted: - self.maybe_decrement_after_for_loop_variable(leaf) self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) - self.maybe_increment_for_loop_variable(leaf) - if not self.append_comment(leaf): self.leaves.append(leaf) @@ -640,11 +842,27 @@ class Line: ) @property - def contains_standalone_comments(self) -> bool: + def is_class_paren_empty(self) -> bool: + """Is this a class with no base classes but using parentheses? + + Those are unnecessary and should be removed. + """ + return ( + bool(self) + and len(self.leaves) == 4 + and self.is_class + and self.leaves[2].type == token.LPAR + and self.leaves[2].value == "(" + and self.leaves[3].type == token.RPAR + and self.leaves[3].value == ")" + ) + + def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: if leaf.type == STANDALONE_COMMENT: - return True + if leaf.bracket_depth <= depth_limit: + return True return False @@ -667,9 +885,14 @@ class Line: self.remove_trailing_comma() return True - # For parens let's check if it's safe to remove the comma. If the - # trailing one is the only one, we might mistakenly change a tuple - # into a different type by removing the comma. + # For parens let's check if it's safe to remove the comma. + # Imports are always safe. + if self.is_import: + self.remove_trailing_comma() + return True + + # Otheriwsse, if the trailing one is the only one, we might mistakenly + # change a tuple into a different type by removing the comma. depth = closing.bracket_depth + 1 commas = 0 opening = closing.opening_bracket @@ -680,7 +903,7 @@ class Line: else: return False - for leaf in self.leaves[_opening_index + 1:]: + for leaf in self.leaves[_opening_index + 1 :]: if leaf is closing: break @@ -697,29 +920,6 @@ class Line: return False - def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: - """In a for loop, or comprehension, the variables are often unpacks. - - To avoid splitting on the comma in this situation, increase the depth of - tokens between `for` and `in`. - """ - if leaf.type == token.NAME and leaf.value == "for": - self.has_for = True - self.bracket_tracker.depth += 1 - self._for_loop_variable = True - return True - - return False - - def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: - """See `maybe_increment_for_loop_variable` above for explanation.""" - if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": - self.bracket_tracker.depth -= 1 - self._for_loop_variable = False - return True - - return False - def append_comment(self, comment: Leaf) -> bool: """Add an inline or standalone comment to the line.""" if ( @@ -764,6 +964,24 @@ class Line: self.comments[i] = (comma_index - 1, comment) self.leaves.pop() + def is_complex_subscript(self, leaf: Leaf) -> bool: + """Return True iff `leaf` is part of a slice with non-trivial exprs.""" + open_lsqb = ( + leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb() + ) + if open_lsqb is None: + return False + + subscript_start = open_lsqb.next_sibling + if ( + isinstance(subscript_start, Node) + and subscript_start.type == syms.subscriptlist + ): + subscript_start = child_towards(subscript_start, leaf) + return subscript_start is not None and any( + n.type in TEST_DESCENDANTS for n in subscript_start.pre_order() + ) + def __str__(self) -> str: """Render the line.""" if not self: @@ -886,8 +1104,14 @@ class EmptyLineTracker: # Don't insert empty lines before the first line in the file. return 0, 0 - if self.previous_line and self.previous_line.is_decorator: - # Don't insert empty lines between decorators. + if self.previous_line.is_decorator: + return 0, 0 + + if ( + self.previous_line.is_comment + and self.previous_line.depth == current_line.depth + and before == 0 + ): return 0, 0 newlines = 2 @@ -895,9 +1119,6 @@ class EmptyLineTracker: newlines -= 1 return newlines, 0 - if current_line.is_flow_control: - return before, 1 - if ( self.previous_line and self.previous_line.is_import @@ -906,13 +1127,6 @@ class EmptyLineTracker: ): return (before or 1), 0 - if ( - self.previous_line - and self.previous_line.is_yield - and (not current_line.is_yield or depth != self.previous_line.depth) - ): - return (before or 1), 0 - return before, 0 @@ -1004,18 +1218,34 @@ class LineGenerator(Visitor[Line]): def visit_DEDENT(self, node: Node) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" - # DEDENT has no value. Additionally, in blib2to3 it never holds comments. + # The current line might still wait for trailing comments. At DEDENT time + # there won't be any (they would be prefixes on the preceding NEWLINE). + # Emit the line then. + yield from self.line() + + # While DEDENT has no value, its prefix may contain standalone comments + # that belong to the current indentation level. Get 'em. + yield from self.visit_default(node) + + # Finally, emit the dedent. yield from self.line(-1) - def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]: + def visit_stmt( + self, node: Node, keywords: Set[str], parens: Set[str] + ) -> Iterator[Line]: """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, - `def`, `with`, and `class`. + `def`, `with`, `class`, and `assert`. - The relevant Python language `keywords` for a given statement will be NAME - leaves within it. This methods puts those on a separate line. + The relevant Python language `keywords` for a given statement will be + NAME leaves within it. This methods puts those on a separate line. + + `parens` holds pairs of nodes where invisible parentheses should be put. + Keys hold nodes after which opening parentheses should be put, values + hold nodes before which closing parentheses should be put. """ + normalize_invisible_parens(node, parens_after=parens) for child in node.children: if child.type == token.NAME and child.value in keywords: # type: ignore yield from self.line() @@ -1055,6 +1285,32 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(child) + def visit_import_from(self, node: Node) -> Iterator[Line]: + """Visit import_from and maybe put invisible parentheses. + + This is separate from `visit_stmt` because import statements don't + support arbitrary atoms and thus handling of parentheses is custom. + """ + check_lpar = False + for index, child in enumerate(node.children): + if check_lpar: + if child.type == token.LPAR: + # make parentheses invisible + child.value = "" # type: ignore + node.children[-1].value = "" # type: ignore + else: + # insert invisible parentheses + node.insert_child(index, Leaf(token.LPAR, "")) + node.append_child(Leaf(token.RPAR, "")) + break + + check_lpar = ( + child.type == token.NAME and child.value == "import" # type: ignore + ) + + for child in node.children: + yield from self.visit(child) + def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]: """Remove a semicolon and put the other statement on a separate line.""" yield from self.line() @@ -1078,21 +1334,30 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(node) + if node.type == token.ENDMARKER: + # somebody decided not to put a final `# fmt: on` + yield from self.line() + def __attrs_post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt - self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}) - self.visit_while_stmt = partial(v, keywords={"while", "else"}) - self.visit_for_stmt = partial(v, keywords={"for", "else"}) - self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"}) - self.visit_except_clause = partial(v, keywords={"except"}) - self.visit_funcdef = partial(v, keywords={"def"}) - self.visit_with_stmt = partial(v, keywords={"with"}) - self.visit_classdef = partial(v, keywords={"class"}) + Ø: Set[str] = set() + self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) + self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"}) + self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"}) + self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"}) + self.visit_try_stmt = partial( + v, keywords={"try", "except", "else", "finally"}, parens=Ø + ) + self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø) + self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) + self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø) + self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) self.visit_async_funcdef = self.visit_async_stmt self.visit_decorated = self.visit_decorators +IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist} BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE} OPENING_BRACKETS = set(BRACKET.keys()) CLOSING_BRACKETS = set(BRACKET.values()) @@ -1100,8 +1365,12 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} -def whitespace(leaf: Leaf) -> str: # noqa C901 - """Return whitespace prefix if needed for the given `leaf`.""" +def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901 + """Return whitespace prefix if needed for the given `leaf`. + + `complex_subscript` signals whether the given leaf is part of a subscription + which has non-trivial arguments, like arithmetic expressions or function calls. + """ NO = "" SPACE = " " DOUBLESPACE = " " @@ -1115,7 +1384,10 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: + if ( + t == token.COLON + and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop} + ): return NO prev = leaf.prev_sibling @@ -1125,7 +1397,13 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO if t == token.COLON: - return SPACE if prevp.type == token.COMMA else NO + if prevp.type == token.COLON: + return NO + + elif prevp.type != token.COMMA and not complex_subscript: + return NO + + return SPACE if prevp.type == token.EQUAL: if prevp.parent: @@ -1140,24 +1418,17 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 # that, too. return prevp.prefix - elif prevp.type == token.DOUBLESTAR: - if prevp.parent and prevp.parent.type in { - syms.arglist, - syms.argument, - syms.dictsetmaker, - syms.parameters, - syms.typedargslist, - syms.varargslist, - }: + elif prevp.type in STARS: + if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS): return NO elif prevp.type == token.COLON: if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: - return NO + return SPACE if complex_subscript else NO elif ( prevp.parent - and prevp.parent.type in {syms.factor, syms.star_expr} + and prevp.parent.type == syms.factor and prevp.type in MATH_OPERATORS ): return NO @@ -1178,17 +1449,11 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if p.type in {syms.parameters, syms.arglist}: # untyped function signatures or calls - if t == token.RPAR: - return NO - if not prev or prev.type != token.COMMA: return NO elif p.type == syms.varargslist: # lambdas - if t == token.RPAR: - return NO - if prev and prev.type != token.COMMA: return NO @@ -1243,7 +1508,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if not prevp or prevp.type == token.LPAR: return NO - elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR: + elif prev.type in {token.EQUAL} | STARS: return NO elif p.type == syms.decorator: @@ -1265,7 +1530,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if prev and prev.type == token.LPAR: return NO - elif p.type == syms.subscript: + elif p.type in {syms.subscript, syms.sliceop}: # indexing if not prev: assert p.parent is not None, "subscripts are always parented" @@ -1274,7 +1539,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO - else: + elif not complex_subscript: return NO elif p.type == syms.atom: @@ -1282,21 +1547,9 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 # dots, but not the first one. return NO - elif ( - p.type == syms.listmaker - or p.type == syms.testlist_gexp - or p.type == syms.subscriptlist - ): - # list interior, including unpacking - if not prev: - return NO - elif p.type == syms.dictsetmaker: - # dict and set interior, including unpacking - if not prev: - return NO - - if prev.type == token.DOUBLESTAR: + # dict unpacking + if prev and prev.type == token.DOUBLESTAR: return NO elif p.type in {syms.factor, syms.star_expr}: @@ -1308,9 +1561,10 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 prevp_parent = prevp.parent assert prevp_parent is not None - if prevp.type == token.COLON and prevp_parent.type in { - syms.subscript, syms.sliceop - }: + if ( + prevp.type == token.COLON + and prevp_parent.type in {syms.subscript, syms.sliceop} + ): return NO elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument: @@ -1355,6 +1609,14 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None +def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]: + """Return the child of `ancestor` that contains `descendant`.""" + node: Optional[LN] = descendant + while node and node.parent != ancestor: + node = node.parent + return node + + def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: """Return the priority of the `leaf` delimiter, given a line break after it. @@ -1366,13 +1628,6 @@ def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: if leaf.type == token.COMMA: return COMMA_PRIORITY - if ( - leaf.type in VARARGS - and leaf.parent - and leaf.parent.type in {syms.argument, syms.typedargslist} - ): - return MATH_PRIORITY - return 0 @@ -1384,6 +1639,11 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: Higher numbers are higher priority. """ + if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS): + # * and ** might also be MATH_OPERATORS but in this case they are not. + # Don't treat them as a delimiter. + return 0 + if ( leaf.type in MATH_OPERATORS and leaf.parent @@ -1417,23 +1677,20 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: ): return COMPREHENSION_PRIORITY + if ( + leaf.type == token.NAME + and leaf.value in {"if", "else"} + and leaf.parent + and leaf.parent.type == syms.test + ): + return TERNARY_PRIORITY + if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: return LOGIC_PRIORITY return 0 -def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: - """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. - - Higher numbers are higher priority. - """ - return max( - is_split_before_delimiter(leaf, previous), - is_split_after_delimiter(leaf, previous), - ) - - def generate_comments(leaf: Leaf) -> Iterator[Leaf]: """Clean the prefix of the `leaf` and generate comments from it, if any. @@ -1481,7 +1738,12 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: raise FormatOn(consumed) if comment in {"# fmt: off", "# yapf: disable"}: - raise FormatOff(consumed) + if comment_type == STANDALONE_COMMENT: + raise FormatOff(consumed) + + prev = preceding_leaf(leaf) + if not prev or prev.type in WHITESPACE: # standalone comment in disguise + raise FormatOff(consumed) nlines = 0 @@ -1526,7 +1788,7 @@ def split_line( if ( len(line_str) <= line_length and "\n" not in line_str # multiline strings - and not line.contains_standalone_comments + and not line.contains_standalone_comments() ): yield line return @@ -1534,6 +1796,8 @@ def split_line( split_funcs: List[SplitFunc] if line.is_def: split_funcs = [left_hand_split] + elif line.is_import: + split_funcs = [explode_split] elif line.inside_brackets: split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: @@ -1566,7 +1830,8 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. - Prefer RHS otherwise. + Prefer RHS otherwise. This is why this function is not symmetrical with + :func:`right_hand_split` which also handles optional parentheses. """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1592,9 +1857,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. - for result, leaves in ( - (head, head_leaves), (body, body_leaves), (tail, tail_leaves) - ): + for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves): for leaf in leaves: result.append(leaf, preformatted=True) for comment_after in line.comments_after(leaf): @@ -1605,8 +1868,13 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: yield result -def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Split line into many lines, starting with the last matching bracket pair.""" +def right_hand_split( + line: Line, py36: bool = False, omit: Collection[LeafID] = () +) -> Iterator[Line]: + """Split line into many lines, starting with the last matching bracket pair. + + If the split was by optional parentheses, attempt splitting without them, too. + """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) tail = Line(depth=line.depth) @@ -1615,14 +1883,16 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: head_leaves: List[Leaf] = [] current_leaves = tail_leaves opening_bracket = None + closing_bracket = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: current_leaves = head_leaves if body_leaves else tail_leaves current_leaves.append(leaf) if current_leaves is tail_leaves: - if leaf.type in CLOSING_BRACKETS: + if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit: opening_bracket = leaf.opening_bracket + closing_bracket = leaf current_leaves = body_leaves tail_leaves.reverse() body_leaves.reverse() @@ -1630,15 +1900,41 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) + elif not head_leaves: + # No `head` and no `body` means the split failed. `tail` has all content. + raise CannotSplit("No brackets found") + # Build the new lines. - for result, leaves in ( - (head, head_leaves), (body, body_leaves), (tail, tail_leaves) - ): + for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves): for leaf in leaves: result.append(leaf, preformatted=True) for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) + assert opening_bracket and closing_bracket + if ( + # the opening bracket is an optional paren + opening_bracket.type == token.LPAR + and not opening_bracket.value + # the closing bracket is an optional paren + and closing_bracket.type == token.RPAR + and not closing_bracket.value + # there are no delimiters or standalone comments in the body + and not body.bracket_tracker.delimiters + and not line.contains_standalone_comments(0) + # and it's not an import (optional parens are the only thing we can split + # on in this case; attempting a split without them is a waste of time) + and not line.is_import + ): + omit = {id(closing_bracket), *omit} + try: + yield from right_hand_split(line, py36=py36, omit=omit) + return + except CannotSplit: + pass + + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) for result in (head, body, tail): if result: yield result @@ -1729,8 +2025,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: lowest_depth = min(lowest_depth, leaf.bracket_depth) if ( leaf.bracket_depth == lowest_depth - and leaf.type == token.STAR - or leaf.type == token.DOUBLESTAR + and is_vararg(leaf, within=VARARGS_PARENTS) ): trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) @@ -1740,9 +2035,10 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) if current_line: if ( - delimiter_priority == COMMA_PRIORITY + trailing_comma_safe + and delimiter_priority == COMMA_PRIORITY and current_line.leaves[-1].type != token.COMMA - and trailing_comma_safe + and current_line.leaves[-1].type != STANDALONE_COMMENT ): current_line.append(Leaf(token.COMMA, ",")) yield current_line @@ -1751,12 +2047,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: @dont_increase_indentation def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split standalone comments from the rest of the line.""" - for leaf in line.leaves: - if leaf.type == STANDALONE_COMMENT: - if leaf.bracket_depth == 0: - break - - else: + if not line.contains_standalone_comments(0): raise CannotSplit("Line does not have any standalone comments") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) @@ -1782,6 +2073,26 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: yield current_line +def explode_split( + line: Line, py36: bool = False, omit: Collection[LeafID] = () +) -> Iterator[Line]: + """Split by rightmost bracket and immediately split contents by a delimiter.""" + new_lines = list(right_hand_split(line, py36, omit)) + if len(new_lines) != 3: + yield from new_lines + return + + yield new_lines[0] + + try: + yield from delimiter_split(new_lines[1], py36) + + except CannotSplit: + yield new_lines[1] + + yield new_lines[2] + + def is_import(leaf: Leaf) -> bool: """Return True if the given leaf starts an import statement.""" p = leaf.parent @@ -1815,6 +2126,13 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: def normalize_string_quotes(leaf: Leaf) -> None: + """Prefer double quotes but only if it doesn't cause more escaping. + + Adds or removes backslashes as appropriate. Doesn't parse and fix + strings nested in f-strings (yet). + + Note: Mutates its argument. + """ value = leaf.value.lstrip("furbFURB") if value[:3] == '"""': return @@ -1832,10 +2150,28 @@ def normalize_string_quotes(leaf: Leaf) -> None: if first_quote_pos == -1: return # There's an internal error - body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] - new_body = body.replace(f"\\{orig_quote}", orig_quote).replace( - new_quote, f"\\{new_quote}" - ) + prefix = leaf.value[:first_quote_pos] + unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") + escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}") + escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}") + body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)] + if "r" in prefix.casefold(): + if unescaped_new_quote.search(body): + # There's at least one unescaped new_quote in this raw string + # so converting is impossible + return + + # Do not introduce or remove backslashes in raw strings + new_body = body + else: + # remove unnecessary quotes + new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body) + if body != new_body: + # Consider the string without unnecessary quotes as the original + body = new_body + leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}" + new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) + new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) if new_quote == '"""' and new_body[-1] == '"': # edge case: new_body = new_body[:-1] + '\\"' @@ -1847,10 +2183,155 @@ def normalize_string_quotes(leaf: Leaf) -> None: if new_escape_count == orig_escape_count and orig_quote == '"': return # Prefer double quotes - prefix = leaf.value[:first_quote_pos] leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}" +def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: + """Make existing optional parentheses invisible or create new ones. + + Standardizes on visible parentheses for single-element tuples, and keeps + existing visible parentheses for other tuples and generator expressions. + """ + check_lpar = False + for child in list(node.children): + if check_lpar: + if child.type == syms.atom: + maybe_make_parens_invisible_in_atom(child) + elif is_one_tuple(child): + # wrap child in visible parentheses + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + else: + # wrap child in invisible parentheses + lpar = Leaf(token.LPAR, "") + rpar = Leaf(token.RPAR, "") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + + check_lpar = isinstance(child, Leaf) and child.value in parens_after + + +def maybe_make_parens_invisible_in_atom(node: LN) -> bool: + """If it's safe, make the parens in the atom `node` invisible, recusively.""" + if ( + node.type != syms.atom + or is_empty_tuple(node) + or is_one_tuple(node) + or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY + ): + return False + + first = node.children[0] + last = node.children[-1] + if first.type == token.LPAR and last.type == token.RPAR: + # make parentheses invisible + first.value = "" # type: ignore + last.value = "" # type: ignore + if len(node.children) > 1: + maybe_make_parens_invisible_in_atom(node.children[1]) + return True + + return False + + +def is_empty_tuple(node: LN) -> bool: + """Return True if `node` holds an empty tuple.""" + return ( + node.type == syms.atom + and len(node.children) == 2 + and node.children[0].type == token.LPAR + and node.children[1].type == token.RPAR + ) + + +def is_one_tuple(node: LN) -> bool: + """Return True if `node` holds a tuple with one element, with or without parens.""" + if node.type == syms.atom: + if len(node.children) != 3: + return False + + lpar, gexp, rpar = node.children + if not ( + lpar.type == token.LPAR + and gexp.type == syms.testlist_gexp + and rpar.type == token.RPAR + ): + return False + + return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA + + return ( + node.type in IMPLICIT_TUPLE + and len(node.children) == 2 + and node.children[1].type == token.COMMA + ) + + +def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: + """Return True if `leaf` is a star or double star in a vararg or kwarg. + + If `within` includes VARARGS_PARENTS, this applies to function signatures. + If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right + hand-side extended iterable unpacking (PEP 3132) and additional unpacking + generalizations (PEP 448). + """ + if leaf.type not in STARS or not leaf.parent: + return False + + p = leaf.parent + if p.type == syms.star_expr: + # Star expressions are also used as assignment targets in extended + # iterable unpacking (PEP 3132). See what its parent is instead. + if not p.parent: + return False + + p = p.parent + + return p.type in within + + +def max_delimiter_priority_in_atom(node: LN) -> int: + """Return maximum delimiter priority inside `node`. + + This is specific to atoms with contents contained in a pair of parentheses. + If `node` isn't an atom or there are no enclosing parentheses, returns 0. + """ + if node.type != syms.atom: + return 0 + + first = node.children[0] + last = node.children[-1] + if not (first.type == token.LPAR and last.type == token.RPAR): + return 0 + + bt = BracketTracker() + for c in node.children[1:-1]: + if isinstance(c, Leaf): + bt.mark(c) + else: + for leaf in c.leaves(): + bt.mark(leaf) + try: + return bt.max_delimiter_priority() + + except ValueError: + return 0 + + +def ensure_visible(leaf: Leaf) -> None: + """Make sure parentheses are visible. + + They could be invisible as part of some statements (see + :func:`normalize_invible_parens` and :func:`visit_import_from`). + """ + if leaf.type == token.LPAR: + leaf.value = "(" + elif leaf.type == token.RPAR: + leaf.value = ")" + + def is_python36(node: Node) -> bool: """Return True if the current file is using Python 3.6+ features. @@ -1870,7 +2351,7 @@ def is_python36(node: Node) -> bool: and n.children[-1].type == token.COMMA ): for ch in n.children: - if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + if ch.type in STARS: return True return False @@ -1901,18 +2382,25 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + quiet: bool = False change_count: int = 0 same_count: int = 0 failure_count: int = 0 - def done(self, src: Path, changed: bool) -> None: + def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" - if changed: + if changed is Changed.YES: reformatted = "would reformat" if self.check else "reformatted" - out(f"{reformatted} {src}") + if not self.quiet: + out(f"{reformatted} {src}") self.change_count += 1 else: - out(f"{src} already well formatted, good job.", bold=False) + if not self.quiet: + if changed is Changed.NO: + msg = f"{src} already well formatted, good job." + else: + msg = f"{src} wasn't modified on disk since last run." + out(msg, bold=False) self.same_count += 1 def failed(self, src: Path, message: str) -> None: @@ -2053,11 +2541,12 @@ def dump_to_file(*output: str) -> str: import tempfile with tempfile.NamedTemporaryFile( - mode="w", prefix="blk_", suffix=".log", delete=False + mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) - f.write("\n") + if lines and lines[-1] != "\n": + f.write("\n") return f.name @@ -2101,5 +2590,76 @@ def shutdown(loop: BaseEventLoop) -> None: loop.close() +def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: + """Replace `regex` with `replacement` twice on `original`. + + This is used by string normalization to perform replaces on + overlapping matches. + """ + return regex.sub(replacement, regex.sub(replacement, original)) + + +CACHE_DIR = Path(user_cache_dir("black", version=__version__)) + + +def get_cache_file(line_length: int) -> Path: + return CACHE_DIR / f"cache.{line_length}.pickle" + + +def read_cache(line_length: int) -> Cache: + """Read the cache if it exists and is well formed. + + If it is not well formed, the call to write_cache later should resolve the issue. + """ + cache_file = get_cache_file(line_length) + if not cache_file.exists(): + return {} + + with cache_file.open("rb") as fobj: + try: + cache: Cache = pickle.load(fobj) + except pickle.UnpicklingError: + return {} + + return cache + + +def get_cache_info(path: Path) -> CacheInfo: + """Return the information used to check if a file is already formatted or not.""" + stat = path.stat() + return stat.st_mtime, stat.st_size + + +def filter_cached( + cache: Cache, sources: Iterable[Path] +) -> Tuple[List[Path], List[Path]]: + """Split a list of paths into two. + + The first list contains paths of files that modified on disk or are not in the + cache. The other list contains paths to non-modified files. + """ + todo, done = [], [] + for src in sources: + src = src.resolve() + if cache.get(src) != get_cache_info(src): + todo.append(src) + else: + done.append(src) + return todo, done + + +def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None: + """Update the cache file.""" + cache_file = get_cache_file(line_length) + try: + if not CACHE_DIR.exists(): + CACHE_DIR.mkdir(parents=True) + new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + with cache_file.open("wb") as fobj: + pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL) + except OSError: + pass + + if __name__ == "__main__": main()