X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/ba61bfe3865d8a8fc29abe7f8a94740b618e80ba..639b62dcd32cde3645e9f9a633eee33c04d23901:/black.py diff --git a/black.py b/black.py index 82fe5d1..e795fa3 100644 --- a/black.py +++ b/black.py @@ -1,24 +1,31 @@ #!/usr/bin/env python3 import asyncio +import pickle from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor +from enum import Enum from functools import partial, wraps import keyword import logging +from multiprocessing import Manager import os from pathlib import Path +import re import tokenize import signal import sys from typing import ( + Any, Callable, + Collection, Dict, Generic, Iterable, Iterator, List, Optional, + Pattern, Set, Tuple, Type, @@ -26,6 +33,7 @@ from typing import ( Union, ) +from appdirs import user_cache_dir from attr import dataclass, Factory import click @@ -35,7 +43,7 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.3a4" +__version__ = "18.4a2" DEFAULT_LINE_LENGTH = 88 # types syms = pygram.python_symbols @@ -47,9 +55,13 @@ LeafID = int Priority = int Index = int LN = Union[Leaf, Node] -SplitFunc = Callable[['Line', bool], Iterator['Line']] +SplitFunc = Callable[["Line", bool], Iterator["Line"]] +Timestamp = float +FileSize = int +CacheInfo = Tuple[Timestamp, FileSize] +Cache = Dict[Path, CacheInfo] out = partial(click.secho, bold=True, err=True) -err = partial(click.secho, fg='red', err=True) +err = partial(click.secho, fg="red", err=True) class NothingChanged(UserWarning): @@ -92,32 +104,58 @@ class FormatOff(FormatError): """Found a comment like `# fmt: off` in the file.""" +class WriteBack(Enum): + NO = 0 + YES = 1 + DIFF = 2 + + +class Changed(Enum): + NO = 0 + CACHED = 1 + YES = 2 + + @click.command() @click.option( - '-l', - '--line-length', + "-l", + "--line-length", type=int, default=DEFAULT_LINE_LENGTH, - help='How many character per line to allow.', + help="How many character per line to allow.", show_default=True, ) @click.option( - '--check', + "--check", is_flag=True, help=( - "Don't write back the files, just return the status. Return code 0 " + "Don't write the files back, just return the status. Return code 0 " "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) @click.option( - '--fast/--safe', + "--diff", + is_flag=True, + help="Don't write the files back, just output a diff for each file on stdout.", +) +@click.option( + "--fast/--safe", is_flag=True, - help='If --fast given, skip temporary sanity checks. [default: --safe]', + help="If --fast given, skip temporary sanity checks. [default: --safe]", +) +@click.option( + "-q", + "--quiet", + is_flag=True, + help=( + "Don't emit non-error messages to stderr. Errors are still emitted, " + "silence those with 2>/dev/null." + ), ) @click.version_option(version=__version__) @click.argument( - 'src', + "src", nargs=-1, type=click.Path( exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True @@ -125,7 +163,13 @@ class FormatOff(FormatError): ) @click.pass_context def main( - ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] + ctx: click.Context, + line_length: int, + check: bool, + diff: bool, + fast: bool, + quiet: bool, + src: List[str], ) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] @@ -136,48 +180,95 @@ def main( elif p.is_file(): # if a file was explicitly given, we don't care about its extension sources.append(p) - elif s == '-': - sources.append(Path('-')) + elif s == "-": + sources.append(Path("-")) else: - err(f'invalid path: {s}') + err(f"invalid path: {s}") + if check and diff: + exc = click.ClickException("Options --check and --diff are mutually exclusive") + exc.exit_code = 2 + raise exc + + if check: + write_back = WriteBack.NO + elif diff: + write_back = WriteBack.DIFF + else: + write_back = WriteBack.YES if len(sources) == 0: ctx.exit(0) + return + elif len(sources) == 1: - p = sources[0] - report = Report(check=check) - try: - if not p.is_file() and str(p) == '-': - changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=not check - ) - else: - changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check - ) - report.done(p, changed) - except Exception as exc: - report.failed(p, str(exc)) - ctx.exit(report.return_code) + return_code = run_single_file_mode( + line_length, check, fast, quiet, write_back, sources[0] + ) else: - loop = asyncio.get_event_loop() - executor = ProcessPoolExecutor(max_workers=os.cpu_count()) - return_code = 1 - try: - return_code = loop.run_until_complete( - schedule_formatting( - sources, line_length, not check, fast, loop, executor + return_code = run_multi_file_mode(line_length, fast, quiet, write_back, sources) + ctx.exit(return_code) + + +def run_single_file_mode( + line_length: int, + check: bool, + fast: bool, + quiet: bool, + write_back: WriteBack, + src: Path, +) -> int: + report = Report(check=check, quiet=quiet) + try: + if not src.is_file() and str(src) == "-": + changed = format_stdin_to_stdout( + line_length=line_length, fast=fast, write_back=write_back + ) + else: + changed = Changed.NO + cache: Cache = {} + if write_back != WriteBack.DIFF: + cache = read_cache() + src = src.resolve() + if src in cache and cache[src] == get_cache_info(src): + changed = Changed.CACHED + if changed is not Changed.CACHED: + changed = format_file_in_place( + src, line_length=line_length, fast=fast, write_back=write_back ) + if write_back != WriteBack.DIFF and changed is not Changed.NO: + write_cache(cache, [src]) + report.done(src, changed) + except Exception as exc: + report.failed(src, str(exc)) + return report.return_code + + +def run_multi_file_mode( + line_length: int, + fast: bool, + quiet: bool, + write_back: WriteBack, + sources: List[Path], +) -> int: + loop = asyncio.get_event_loop() + executor = ProcessPoolExecutor(max_workers=os.cpu_count()) + return_code = 1 + try: + return_code = loop.run_until_complete( + schedule_formatting( + sources, line_length, write_back, fast, quiet, loop, executor ) - finally: - shutdown(loop) - ctx.exit(return_code) + ) + finally: + shutdown(loop) + return return_code async def schedule_formatting( sources: List[Path], line_length: int, - write_back: bool, + write_back: WriteBack, fast: bool, + quiet: bool, loop: BaseEventLoop, executor: Executor, ) -> int: @@ -188,79 +279,121 @@ async def schedule_formatting( `line_length`, `write_back`, and `fast` options are passed to :func:`format_file_in_place`. """ - tasks = { - src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back - ) - for src in sources - } - _task_values = list(tasks.values()) - loop.add_signal_handler(signal.SIGINT, cancel, _task_values) - loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) - await asyncio.wait(tasks.values()) + report = Report(check=write_back is WriteBack.NO, quiet=quiet) + cache: Cache = {} + if write_back != WriteBack.DIFF: + cache = read_cache() + sources, cached = filter_cached(cache, sources) + for src in cached: + report.done(src, Changed.CACHED) cancelled = [] - report = Report(check=not write_back) - for src, task in tasks.items(): - if not task.done(): - report.failed(src, 'timed out, cancelling') - task.cancel() - cancelled.append(task) - elif task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - report.done(src, task.result()) + formatted = [] + if sources: + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() + tasks = { + src: loop.run_in_executor( + executor, format_file_in_place, src, line_length, fast, write_back, lock + ) + for src in sources + } + _task_values = list(tasks.values()) + loop.add_signal_handler(signal.SIGINT, cancel, _task_values) + loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) + await asyncio.wait(_task_values) + for src, task in tasks.items(): + if not task.done(): + report.failed(src, "timed out, cancelling") + task.cancel() + cancelled.append(task) + elif task.cancelled(): + cancelled.append(task) + elif task.exception(): + report.failed(src, str(task.exception())) + else: + formatted.append(src) + report.done(src, task.result()) + if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - else: - out('All done! ✨ 🍰 ✨') - click.echo(str(report)) + elif not quiet: + out("All done! ✨ 🍰 ✨") + if not quiet: + click.echo(str(report)) + + if write_back != WriteBack.DIFF and formatted: + write_cache(cache, formatted) + return report.return_code def format_file_in_place( - src: Path, line_length: int, fast: bool, write_back: bool = False -) -> bool: + src: Path, + line_length: int, + fast: bool, + write_back: WriteBack = WriteBack.NO, + lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy +) -> Changed: """Format file under `src` path. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ + with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: - contents = format_file_contents( + dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast ) except NothingChanged: - return False + return Changed.NO - if write_back: + if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: - f.write(contents) - return True + f.write(dst_contents) + elif write_back == write_back.DIFF: + src_name = f"{src.name} (original)" + dst_name = f"{src.name} (formatted)" + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if lock: + lock.acquire() + try: + sys.stdout.write(diff_contents) + finally: + if lock: + lock.release() + return Changed.YES def format_stdin_to_stdout( - line_length: int, fast: bool, write_back: bool = False -) -> bool: + line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO +) -> Changed: """Format file on stdin. Return True if changed. If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` arguments are passed to :func:`format_file_contents`. """ - contents = sys.stdin.read() + src = sys.stdin.read() + dst = src try: - contents = format_file_contents(contents, line_length=line_length, fast=fast) - return True + dst = format_file_contents(src, line_length=line_length, fast=fast) + return Changed.YES except NothingChanged: - return False + return Changed.NO finally: - if write_back: - sys.stdout.write(contents) + if write_back == WriteBack.YES: + sys.stdout.write(dst) + elif write_back == WriteBack.DIFF: + src_name = " (original)" + dst_name = " (formatted)" + sys.stdout.write(diff(src, dst, src_name, dst_name)) def format_file_contents( @@ -272,7 +405,7 @@ def format_file_contents( valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. `line_length` is passed to :func:`format_str`. """ - if src_contents.strip() == '': + if src_contents.strip() == "": raise NothingChanged dst_contents = format_str(src_contents, line_length=line_length) @@ -319,8 +452,8 @@ GRAMMARS = [ def lib2to3_parse(src_txt: str) -> Node: """Given a string with source, return the lib2to3 Node.""" grammar = pygram.python_grammar_no_print_statement - if src_txt[-1] != '\n': - nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n' + if src_txt[-1] != "\n": + nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n" src_txt += nl for grammar in GRAMMARS: drv = driver.Driver(grammar, pytree.convert) @@ -350,7 +483,7 @@ def lib2to3_unparse(node: Node) -> str: return code -T = TypeVar('T') +T = TypeVar("T") class Visitor(Generic[T]): @@ -370,7 +503,7 @@ class Visitor(Generic[T]): name = token.tok_name[node.type] else: name = type_repr(node.type) - yield from getattr(self, f'visit_{name}', self.visit_default)(node) + yield from getattr(self, f"visit_{name}", self.visit_default)(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -384,24 +517,24 @@ class DebugVisitor(Visitor[T]): tree_depth: int = 0 def visit_default(self, node: LN) -> Iterator[T]: - indent = ' ' * (2 * self.tree_depth) + indent = " " * (2 * self.tree_depth) if isinstance(node, Node): _type = type_repr(node.type) - out(f'{indent}{_type}', fg='yellow') + out(f"{indent}{_type}", fg="yellow") self.tree_depth += 1 for child in node.children: yield from self.visit(child) self.tree_depth -= 1 - out(f'{indent}/{_type}', fg='yellow', bold=False) + out(f"{indent}/{_type}", fg="yellow", bold=False) else: _type = token.tok_name.get(node.type, str(node.type)) - out(f'{indent}{_type}', fg='blue', nl=False) + out(f"{indent}{_type}", fg="blue", nl=False) if node.prefix: # We don't have to handle prefixes for `Node` objects since # that delegates to the first child anyway. - out(f' {node.prefix!r}', fg='green', bold=False, nl=False) - out(f' {node.value!r}', fg='blue', bold=False) + out(f" {node.prefix!r}", fg="green", bold=False, nl=False) + out(f" {node.value!r}", fg="blue", bold=False) @classmethod def show(cls, code: str) -> None: @@ -415,7 +548,7 @@ class DebugVisitor(Visitor[T]): KEYWORDS = set(keyword.kwlist) WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -FLOW_CONTROL = {'return', 'raise', 'break', 'continue'} +FLOW_CONTROL = {"return", "raise", "break", "continue"} STATEMENT = { syms.if_stmt, syms.while_stmt, @@ -427,7 +560,7 @@ STATEMENT = { syms.classdef, } STANDALONE_COMMENT = 153 -LOGIC_OPERATORS = {'and', 'or'} +LOGIC_OPERATORS = {"and", "or"} COMPARATORS = { token.LESS, token.GREATER, @@ -451,6 +584,20 @@ MATH_OPERATORS = { token.DOUBLESTAR, token.DOUBLESLASH, } +STARS = {token.STAR, token.DOUBLESTAR} +VARARGS_PARENTS = { + syms.arglist, + syms.argument, # double star in arglist + syms.trailer, # single argument to call + syms.typedargslist, + syms.varargslist, # lambdas +} +UNPACKING_PARENTS = { + syms.atom, # single element of a list or set literal + syms.dictsetmaker, + syms.listmaker, + syms.testlist_gexp, +} COMPREHENSION_PRIORITY = 20 COMMA_PRIORITY = 10 LOGIC_PRIORITY = 5 @@ -492,32 +639,13 @@ class BracketTracker: leaf.opening_bracket = opening_bracket leaf.bracket_depth = self.depth if self.depth == 0: - delim = is_delimiter(leaf) - if delim: - self.delimiters[id(leaf)] = delim - elif self.previous is not None: - if leaf.type == token.STRING and self.previous.type == token.STRING: - self.delimiters[id(self.previous)] = STRING_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == 'for' - and leaf.parent - and leaf.parent.type in {syms.comp_for, syms.old_comp_for} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == 'if' - and leaf.parent - and leaf.parent.type in {syms.comp_if, syms.old_comp_if} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value in LOGIC_OPERATORS - and leaf.parent - ): - self.delimiters[id(self.previous)] = LOGIC_PRIORITY + delim = is_split_before_delimiter(leaf, self.previous) + if delim and self.previous is not None: + self.delimiters[id(self.previous)] = delim + else: + delim = is_split_after_delimiter(leaf, self.previous) + if delim: + self.delimiters[id(leaf)] = delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 @@ -531,6 +659,7 @@ class BracketTracker: """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_delimiter()` returns. + Raises ValueError on no delimiters. """ return max(v for k, v in self.delimiters.items() if k not in exclude) @@ -557,7 +686,7 @@ class Line: Inline comments are put aside. """ - has_value = leaf.value.strip() + has_value = leaf.type in BRACKETS or bool(leaf.value.strip()) if not has_value: return @@ -612,7 +741,7 @@ class Line: return ( bool(self) and self.leaves[0].type == token.NAME - and self.leaves[0].value == 'class' + and self.leaves[0].value == "class" ) @property @@ -628,12 +757,12 @@ class Line: except IndexError: second_leaf = None return ( - (first_leaf.type == token.NAME and first_leaf.value == 'def') + (first_leaf.type == token.NAME and first_leaf.value == "def") or ( first_leaf.type == token.ASYNC and second_leaf is not None and second_leaf.type == token.NAME - and second_leaf.value == 'def' + and second_leaf.value == "def" ) ) @@ -655,15 +784,15 @@ class Line: return ( bool(self) and self.leaves[0].type == token.NAME - and self.leaves[0].value == 'yield' + and self.leaves[0].value == "yield" ) - @property - def contains_standalone_comments(self) -> bool: + def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: if leaf.type == STANDALONE_COMMENT: - return True + if leaf.bracket_depth <= depth_limit: + return True return False @@ -722,7 +851,7 @@ class Line: To avoid splitting on the comma in this situation, increase the depth of tokens between `for` and `in`. """ - if leaf.type == token.NAME and leaf.value == 'for': + if leaf.type == token.NAME and leaf.value == "for": self.has_for = True self.bracket_tracker.depth += 1 self._for_loop_variable = True @@ -732,7 +861,7 @@ class Line: def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: """See `maybe_increment_for_loop_variable` above for explanation.""" - if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in': + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": self.bracket_tracker.depth -= 1 self._for_loop_variable = False return True @@ -745,7 +874,7 @@ class Line: comment.type == STANDALONE_COMMENT and self.bracket_tracker.any_open_brackets() ): - comment.prefix = '' + comment.prefix = "" return False if comment.type != token.COMMENT: @@ -754,7 +883,7 @@ class Line: after = len(self.leaves) - 1 if after == -1: comment.type = STANDALONE_COMMENT - comment.prefix = '' + comment.prefix = "" return False else: @@ -786,17 +915,17 @@ class Line: def __str__(self) -> str: """Render the line.""" if not self: - return '\n' + return "\n" - indent = ' ' * self.depth + indent = " " * self.depth leaves = iter(self.leaves) first = next(leaves) - res = f'{first.prefix}{indent}{first.value}' + res = f"{first.prefix}{indent}{first.value}" for leaf in leaves: res += str(leaf) for _, comment in self.comments: res += str(comment) - return res + '\n' + return res + "\n" def __bool__(self) -> bool: """Return True if the line has leaves or comments.""" @@ -832,9 +961,9 @@ class UnformattedLines(Line): `depth` is not used for indentation in this case. """ if not self: - return '\n' + return "\n" - res = '' + res = "" for leaf in self.leaves: res += str(leaf) return res @@ -888,9 +1017,9 @@ class EmptyLineTracker: if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] - before = first_leaf.prefix.count('\n') + before = first_leaf.prefix.count("\n") before = min(before, max_allowed) - first_leaf.prefix = '' + first_leaf.prefix = "" else: before = 0 depth = current_line.depth @@ -1009,6 +1138,8 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) + if node.type == token.STRING: + normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) yield from super().visit_default(node) @@ -1024,15 +1155,22 @@ class LineGenerator(Visitor[Line]): # DEDENT has no value. Additionally, in blib2to3 it never holds comments. yield from self.line(-1) - def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]: + def visit_stmt( + self, node: Node, keywords: Set[str], parens: Set[str] + ) -> Iterator[Line]: """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, - `def`, `with`, and `class`. + `def`, `with`, `class`, and `assert`. + + The relevant Python language `keywords` for a given statement will be + NAME leaves within it. This methods puts those on a separate line. - The relevant Python language `keywords` for a given statement will be NAME - leaves within it. This methods puts those on a separate line. + `parens` holds pairs of nodes where invisible parentheses should be put. + Keys hold nodes after which opening parentheses should be put, values + hold nodes before which closing parentheses should be put. """ + normalize_invisible_parens(node, parens_after=parens) for child in node.children: if child.type == token.NAME and child.value in keywords: # type: ignore yield from self.line() @@ -1072,6 +1210,32 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(child) + def visit_import_from(self, node: Node) -> Iterator[Line]: + """Visit import_from and maybe put invisible parentheses. + + This is separate from `visit_stmt` because import statements don't + support arbitrary atoms and thus handling of parentheses is custom. + """ + check_lpar = False + for index, child in enumerate(node.children): + if check_lpar: + if child.type == token.LPAR: + # make parentheses invisible + child.value = "" # type: ignore + node.children[-1].value = "" # type: ignore + else: + # insert invisible parentheses + node.insert_child(index, Leaf(token.LPAR, "")) + node.append_child(Leaf(token.RPAR, "")) + break + + check_lpar = ( + child.type == token.NAME and child.value == "import" # type: ignore + ) + + for child in node.children: + yield from self.visit(child) + def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]: """Remove a semicolon and put the other statement on a separate line.""" yield from self.line() @@ -1095,21 +1259,30 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(node) + if node.type == token.ENDMARKER: + # somebody decided not to put a final `# fmt: on` + yield from self.line() + def __attrs_post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt - self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'}) - self.visit_while_stmt = partial(v, keywords={'while', 'else'}) - self.visit_for_stmt = partial(v, keywords={'for', 'else'}) - self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'}) - self.visit_except_clause = partial(v, keywords={'except'}) - self.visit_funcdef = partial(v, keywords={'def'}) - self.visit_with_stmt = partial(v, keywords={'with'}) - self.visit_classdef = partial(v, keywords={'class'}) + Ø: Set[str] = set() + self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) + self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"}) + self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"}) + self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"}) + self.visit_try_stmt = partial( + v, keywords={"try", "except", "else", "finally"}, parens=Ø + ) + self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø) + self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) + self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø) + self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) self.visit_async_funcdef = self.visit_async_stmt self.visit_decorated = self.visit_decorators +IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist} BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE} OPENING_BRACKETS = set(BRACKET.keys()) CLOSING_BRACKETS = set(BRACKET.values()) @@ -1119,9 +1292,9 @@ ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} def whitespace(leaf: Leaf) -> str: # noqa C901 """Return whitespace prefix if needed for the given `leaf`.""" - NO = '' - SPACE = ' ' - DOUBLESPACE = ' ' + NO = "" + SPACE = " " + DOUBLESPACE = " " t = leaf.type p = leaf.parent v = leaf.value @@ -1157,15 +1330,8 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 # that, too. return prevp.prefix - elif prevp.type == token.DOUBLESTAR: - if prevp.parent and prevp.parent.type in { - syms.arglist, - syms.argument, - syms.dictsetmaker, - syms.parameters, - syms.typedargslist, - syms.varargslist, - }: + elif prevp.type in STARS: + if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS): return NO elif prevp.type == token.COLON: @@ -1174,7 +1340,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 elif ( prevp.parent - and prevp.parent.type in {syms.factor, syms.star_expr} + and prevp.parent.type == syms.factor and prevp.type in MATH_OPERATORS ): return NO @@ -1185,7 +1351,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 and prevp.parent.type == syms.shift_expr and prevp.prev_sibling and prevp.prev_sibling.type == token.NAME - and prevp.prev_sibling.value == 'print' # type: ignore + and prevp.prev_sibling.value == "print" # type: ignore ): # Python 2 print chevron return NO @@ -1260,7 +1426,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if not prevp or prevp.type == token.LPAR: return NO - elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR: + elif prev.type in {token.EQUAL} | STARS: return NO elif p.type == syms.decorator: @@ -1325,9 +1491,10 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 prevp_parent = prevp.parent assert prevp_parent is not None - if prevp.type == token.COLON and prevp_parent.type in { - syms.subscript, syms.sliceop - }: + if ( + prevp.type == token.COLON + and prevp_parent.type in {syms.subscript, syms.sliceop} + ): return NO elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument: @@ -1342,7 +1509,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO elif t == token.NAME: - if v == 'import': + if v == "import": return SPACE if prev and prev.type == token.DOT: @@ -1372,16 +1539,32 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None -def is_delimiter(leaf: Leaf) -> int: - """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. +def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line break after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break after themselves. Higher numbers are higher priority. """ if leaf.type == token.COMMA: return COMMA_PRIORITY - if leaf.type in COMPARATORS: - return COMPARATOR_PRIORITY + return 0 + + +def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line before after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break before themselves. + + Higher numbers are higher priority. + """ + if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS): + # * and ** might also be MATH_OPERATORS but in this case they are not. + # Don't treat them as a delimiter. + return 0 if ( leaf.type in MATH_OPERATORS @@ -1390,9 +1573,49 @@ def is_delimiter(leaf: Leaf) -> int: ): return MATH_PRIORITY + if leaf.type in COMPARATORS: + return COMPARATOR_PRIORITY + + if ( + leaf.type == token.STRING + and previous is not None + and previous.type == token.STRING + ): + return STRING_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "for" + and leaf.parent + and leaf.parent.type in {syms.comp_for, syms.old_comp_for} + ): + return COMPREHENSION_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "if" + and leaf.parent + and leaf.parent.type in {syms.comp_if, syms.old_comp_if} + ): + return COMPREHENSION_PRIORITY + + if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: + return LOGIC_PRIORITY + return 0 +def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. + + Higher numbers are higher priority. + """ + return max( + is_split_before_delimiter(leaf, previous), + is_split_after_delimiter(leaf, previous), + ) + + def generate_comments(leaf: Leaf) -> Iterator[Leaf]: """Clean the prefix of the `leaf` and generate comments from it, if any. @@ -1416,17 +1639,17 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: if not p: return - if '#' not in p: + if "#" not in p: return consumed = 0 nlines = 0 - for index, line in enumerate(p.split('\n')): + for index, line in enumerate(p.split("\n")): consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() if not line: nlines += 1 - if not line.startswith('#'): + if not line.startswith("#"): continue if index == 0 and leaf.type != token.ENDMARKER: @@ -1434,13 +1657,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: else: comment_type = STANDALONE_COMMENT comment = make_comment(line) - yield Leaf(comment_type, comment, prefix='\n' * nlines) + yield Leaf(comment_type, comment, prefix="\n" * nlines) - if comment in {'# fmt: on', '# yapf: enable'}: + if comment in {"# fmt: on", "# yapf: enable"}: raise FormatOn(consumed) - if comment in {'# fmt: off', '# yapf: disable'}: - raise FormatOff(consumed) + if comment in {"# fmt: off", "# yapf: disable"}: + if comment_type == STANDALONE_COMMENT: + raise FormatOff(consumed) + + prev = preceding_leaf(leaf) + if not prev or prev.type in WHITESPACE: # standalone comment in disguise + raise FormatOff(consumed) nlines = 0 @@ -1455,13 +1683,13 @@ def make_comment(content: str) -> str: """ content = content.rstrip() if not content: - return '#' + return "#" - if content[0] == '#': + if content[0] == "#": content = content[1:] - if content and content[0] not in ' !:#': - content = ' ' + content - return '#' + content + if content and content[0] not in " !:#": + content = " " + content + return "#" + content def split_line( @@ -1481,11 +1709,11 @@ def split_line( yield line return - line_str = str(line).strip('\n') + line_str = str(line).strip("\n") if ( len(line_str) <= line_length - and '\n' not in line_str # multiline strings - and not line.contains_standalone_comments + and "\n" not in line_str # multiline strings + and not line.contains_standalone_comments() ): yield line return @@ -1504,7 +1732,7 @@ def split_line( result: List[Line] = [] try: for l in split_func(line, py36): - if str(l).strip('\n') == line_str: + if str(l).strip("\n") == line_str: raise CannotSplit("Split function returned an unchanged result") result.extend( @@ -1551,9 +1779,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. - for result, leaves in ( - (head, head_leaves), (body, body_leaves), (tail, tail_leaves) - ): + for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves): for leaf in leaves: result.append(leaf, preformatted=True) for comment_after in line.comments_after(leaf): @@ -1564,7 +1790,9 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: yield result -def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: +def right_hand_split( + line: Line, py36: bool = False, omit: Collection[LeafID] = () +) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1574,14 +1802,16 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: head_leaves: List[Leaf] = [] current_leaves = tail_leaves opening_bracket = None + closing_bracket = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: current_leaves = head_leaves if body_leaves else tail_leaves current_leaves.append(leaf) if current_leaves is tail_leaves: - if leaf.type in CLOSING_BRACKETS: + if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit: opening_bracket = leaf.opening_bracket + closing_bracket = leaf current_leaves = body_leaves tail_leaves.reverse() body_leaves.reverse() @@ -1589,15 +1819,36 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) + elif not head_leaves: + # No `head` and no `body` means the split failed. `tail` has all content. + raise CannotSplit("No brackets found") + # Build the new lines. - for result, leaves in ( - (head, head_leaves), (body, body_leaves), (tail, tail_leaves) - ): + for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves): for leaf in leaves: result.append(leaf, preformatted=True) for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) + assert opening_bracket and closing_bracket + if ( + opening_bracket.type == token.LPAR + and not opening_bracket.value + and closing_bracket.type == token.RPAR + and not closing_bracket.value + ): + # These parens were optional. If there aren't any delimiters or standalone + # comments in the body, they were unnecessary and another split without + # them should be attempted. + if not ( + body.bracket_tracker.delimiters or line.contains_standalone_comments(0) + ): + omit = {id(closing_bracket), *omit} + yield from right_hand_split(line, py36=py36, omit=omit) + return + + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) for result in (head, body, tail): if result: yield result @@ -1688,8 +1939,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: lowest_depth = min(lowest_depth, leaf.bracket_depth) if ( leaf.bracket_depth == lowest_depth - and leaf.type == token.STAR - or leaf.type == token.DOUBLESTAR + and is_vararg(leaf, within=VARARGS_PARENTS) ): trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) @@ -1699,23 +1949,19 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) if current_line: if ( - delimiter_priority == COMMA_PRIORITY + trailing_comma_safe + and delimiter_priority == COMMA_PRIORITY and current_line.leaves[-1].type != token.COMMA - and trailing_comma_safe + and current_line.leaves[-1].type != STANDALONE_COMMENT ): - current_line.append(Leaf(token.COMMA, ',')) + current_line.append(Leaf(token.COMMA, ",")) yield current_line @dont_increase_indentation def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split standalone comments from the rest of the line.""" - for leaf in line.leaves: - if leaf.type == STANDALONE_COMMENT: - if leaf.bracket_depth == 0: - break - - else: + if not line.contains_standalone_comments(0): raise CannotSplit("Line does not have any standalone comments") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) @@ -1749,8 +1995,8 @@ def is_import(leaf: Leaf) -> bool: return bool( t == token.NAME and ( - (v == 'import' and p and p.type == syms.import_name) - or (v == 'from' and p and p.type == syms.import_from) + (v == "import" and p and p.type == syms.import_name) + or (v == "from" and p and p.type == syms.import_from) ) ) @@ -1762,15 +2008,204 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: Note: don't use backslashes for formatting or you'll lose your voting rights. """ if not inside_brackets: - spl = leaf.prefix.split('#') - if '\\' not in spl[0]: - nl_count = spl[-1].count('\n') + spl = leaf.prefix.split("#") + if "\\" not in spl[0]: + nl_count = spl[-1].count("\n") if len(spl) > 1: nl_count -= 1 - leaf.prefix = '\n' * nl_count + leaf.prefix = "\n" * nl_count + return + + leaf.prefix = "" + + +def normalize_string_quotes(leaf: Leaf) -> None: + """Prefer double quotes but only if it doesn't cause more escaping. + + Adds or removes backslashes as appropriate. Doesn't parse and fix + strings nested in f-strings (yet). + + Note: Mutates its argument. + """ + value = leaf.value.lstrip("furbFURB") + if value[:3] == '"""': + return + + elif value[:3] == "'''": + orig_quote = "'''" + new_quote = '"""' + elif value[0] == '"': + orig_quote = '"' + new_quote = "'" + else: + orig_quote = "'" + new_quote = '"' + first_quote_pos = leaf.value.find(orig_quote) + if first_quote_pos == -1: + return # There's an internal error + + prefix = leaf.value[:first_quote_pos] + unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") + escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}") + escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}") + body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] + if "r" in prefix.casefold(): + if unescaped_new_quote.search(body): + # There's at least one unescaped new_quote in this raw string + # so converting is impossible return - leaf.prefix = '' + # Do not introduce or remove backslashes in raw strings + new_body = body + else: + # remove unnecessary quotes + new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body) + if body != new_body: + # Consider the string without unnecessary quotes as the original + body = new_body + leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}" + new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body) + new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body) + if new_quote == '"""' and new_body[-1] == '"': + # edge case: + new_body = new_body[:-1] + '\\"' + orig_escape_count = body.count("\\") + new_escape_count = new_body.count("\\") + if new_escape_count > orig_escape_count: + return # Do not introduce more escaping + + if new_escape_count == orig_escape_count and orig_quote == '"': + return # Prefer double quotes + + leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}" + + +def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: + """Make existing optional parentheses invisible or create new ones. + + Standardizes on visible parentheses for single-element tuples, and keeps + existing visible parentheses for other tuples and generator expressions. + """ + check_lpar = False + for child in list(node.children): + if check_lpar: + if child.type == syms.atom: + if not ( + is_empty_tuple(child) + or is_one_tuple(child) + or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY + ): + first = child.children[0] + last = child.children[-1] + if first.type == token.LPAR and last.type == token.RPAR: + # make parentheses invisible + first.value = "" # type: ignore + last.value = "" # type: ignore + elif is_one_tuple(child): + # wrap child in visible parentheses + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + else: + # wrap child in invisible parentheses + lpar = Leaf(token.LPAR, "") + rpar = Leaf(token.RPAR, "") + index = child.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + + check_lpar = isinstance(child, Leaf) and child.value in parens_after + + +def is_empty_tuple(node: LN) -> bool: + """Return True if `node` holds an empty tuple.""" + return ( + node.type == syms.atom + and len(node.children) == 2 + and node.children[0].type == token.LPAR + and node.children[1].type == token.RPAR + ) + + +def is_one_tuple(node: LN) -> bool: + """Return True if `node` holds a tuple with one element, with or without parens.""" + if node.type == syms.atom: + if len(node.children) != 3: + return False + + lpar, gexp, rpar = node.children + if not ( + lpar.type == token.LPAR + and gexp.type == syms.testlist_gexp + and rpar.type == token.RPAR + ): + return False + + return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA + + return ( + node.type in IMPLICIT_TUPLE + and len(node.children) == 2 + and node.children[1].type == token.COMMA + ) + + +def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: + """Return True if `leaf` is a star or double star in a vararg or kwarg. + + If `within` includes VARARGS_PARENTS, this applies to function signatures. + If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right + hand-side extended iterable unpacking (PEP 3132) and additional unpacking + generalizations (PEP 448). + """ + if leaf.type not in STARS or not leaf.parent: + return False + + p = leaf.parent + if p.type == syms.star_expr: + # Star expressions are also used as assignment targets in extended + # iterable unpacking (PEP 3132). See what its parent is instead. + if not p.parent: + return False + + p = p.parent + + return p.type in within + + +def max_delimiter_priority_in_atom(node: LN) -> int: + if node.type != syms.atom: + return 0 + + first = node.children[0] + last = node.children[-1] + if not (first.type == token.LPAR and last.type == token.RPAR): + return 0 + + bt = BracketTracker() + for c in node.children[1:-1]: + if isinstance(c, Leaf): + bt.mark(c) + else: + for leaf in c.leaves(): + bt.mark(leaf) + try: + return bt.max_delimiter_priority() + + except ValueError: + return 0 + + +def ensure_visible(leaf: Leaf) -> None: + """Make sure parentheses are visible. + + They could be invisible as part of some statements (see + :func:`normalize_invible_parens` and :func:`visit_import_from`). + """ + if leaf.type == token.LPAR: + leaf.value = "(" + elif leaf.type == token.RPAR: + leaf.value = ")" def is_python36(node: Node) -> bool: @@ -1783,7 +2218,7 @@ def is_python36(node: Node) -> bool: for n in node.pre_order(): if n.type == token.STRING: value_head = n.value[:2] # type: ignore - if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: return True elif ( @@ -1792,15 +2227,15 @@ def is_python36(node: Node) -> bool: and n.children[-1].type == token.COMMA ): for ch in n.children: - if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + if ch.type in STARS: return True return False -PYTHON_EXTENSIONS = {'.py'} +PYTHON_EXTENSIONS = {".py"} BLACKLISTED_DIRECTORIES = { - 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' + "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" } @@ -1823,23 +2258,30 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + quiet: bool = False change_count: int = 0 same_count: int = 0 failure_count: int = 0 - def done(self, src: Path, changed: bool) -> None: + def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" - if changed: - reformatted = 'would reformat' if self.check else 'reformatted' - out(f'{reformatted} {src}') + if changed is Changed.YES: + reformatted = "would reformat" if self.check else "reformatted" + if not self.quiet: + out(f"{reformatted} {src}") self.change_count += 1 else: - out(f'{src} already well formatted, good job.', bold=False) + if not self.quiet: + if changed is Changed.NO: + msg = f"{src} already well formatted, good job." + else: + msg = f"{src} wasn't modified on disk since last run." + out(msg, bold=False) self.same_count += 1 def failed(self, src: Path, message: str) -> None: """Increment the counter for failed reformatting. Write out a message.""" - err(f'error: cannot format {src}: {message}') + err(f"error: cannot format {src}: {message}") self.failure_count += 1 @property @@ -1876,19 +2318,19 @@ class Report: failed = "failed to reformat" report = [] if self.change_count: - s = 's' if self.change_count > 1 else '' + s = "s" if self.change_count > 1 else "" report.append( - click.style(f'{self.change_count} file{s} {reformatted}', bold=True) + click.style(f"{self.change_count} file{s} {reformatted}", bold=True) ) if self.same_count: - s = 's' if self.same_count > 1 else '' - report.append(f'{self.same_count} file{s} {unchanged}') + s = "s" if self.same_count > 1 else "" + report.append(f"{self.same_count} file{s} {unchanged}") if self.failure_count: - s = 's' if self.failure_count > 1 else '' + s = "s" if self.failure_count > 1 else "" report.append( - click.style(f'{self.failure_count} file{s} {failed}', fg='red') + click.style(f"{self.failure_count} file{s} {failed}", fg="red") ) - return ', '.join(report) + '.' + return ", ".join(report) + "." def assert_equivalent(src: str, dst: str) -> None: @@ -1935,17 +2377,17 @@ def assert_equivalent(src: str, dst: str) -> None: try: dst_ast = ast.parse(dst) except Exception as exc: - log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst) + log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( f"INTERNAL ERROR: Black produced invalid code: {exc}. " f"Please report a bug on https://github.com/ambv/black/issues. " f"This invalid output might be helpful: {log}" ) from None - src_ast_str = '\n'.join(_v(src_ast)) - dst_ast_str = '\n'.join(_v(dst_ast)) + src_ast_str = "\n".join(_v(src_ast)) + dst_ast_str = "\n".join(_v(dst_ast)) if src_ast_str != dst_ast_str: - log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst')) + log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( f"INTERNAL ERROR: Black produced code that is not equivalent to " f"the source. " @@ -1959,8 +2401,8 @@ def assert_stable(src: str, dst: str, line_length: int) -> None: newdst = format_str(dst, line_length=line_length) if dst != newdst: log = dump_to_file( - diff(src, dst, 'source', 'first pass'), - diff(dst, newdst, 'first pass', 'second pass'), + diff(src, dst, "source", "first pass"), + diff(dst, newdst, "first pass", "second pass"), ) raise AssertionError( f"INTERNAL ERROR: Black produced different code on the second pass " @@ -1975,11 +2417,12 @@ def dump_to_file(*output: str) -> str: import tempfile with tempfile.NamedTemporaryFile( - mode='w', prefix='blk_', suffix='.log', delete=False + mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) - f.write('\n') + if lines and lines[-1] != "\n": + f.write("\n") return f.name @@ -1987,9 +2430,9 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + '\n' for line in a.split('\n')] - b_lines = [line + '\n' for line in b.split('\n')] - return ''.join( + a_lines = [line + "\n" for line in a.split("\n")] + b_lines = [line + "\n" for line in b.split("\n")] + return "".join( difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) ) @@ -2023,5 +2466,71 @@ def shutdown(loop: BaseEventLoop) -> None: loop.close() -if __name__ == '__main__': +def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: + """Replace `regex` with `replacement` twice on `original`. + + This is used by string normalization to perform replaces on + overlapping matches. + """ + return regex.sub(replacement, regex.sub(replacement, original)) + + +CACHE_DIR = Path(user_cache_dir("black", version=__version__)) +CACHE_FILE = CACHE_DIR / "cache.pickle" + + +def read_cache() -> Cache: + """Read the cache if it exists and is well formed. + + If it is not well formed, the call to write_cache later should resolve the issue. + """ + if not CACHE_FILE.exists(): + return {} + + with CACHE_FILE.open("rb") as fobj: + try: + cache: Cache = pickle.load(fobj) + except pickle.UnpicklingError: + return {} + + return cache + + +def get_cache_info(path: Path) -> CacheInfo: + """Return the information used to check if a file is already formatted or not.""" + stat = path.stat() + return stat.st_mtime, stat.st_size + + +def filter_cached( + cache: Cache, sources: Iterable[Path] +) -> Tuple[List[Path], List[Path]]: + """Split a list of paths into two. + + The first list contains paths of files that modified on disk or are not in the + cache. The other list contains paths to non-modified files. + """ + todo, done = [], [] + for src in sources: + src = src.resolve() + if cache.get(src) != get_cache_info(src): + todo.append(src) + else: + done.append(src) + return todo, done + + +def write_cache(cache: Cache, sources: List[Path]) -> None: + """Update the cache file.""" + try: + if not CACHE_DIR.exists(): + CACHE_DIR.mkdir(parents=True) + new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + with CACHE_FILE.open("wb") as fobj: + pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL) + except OSError: + pass + + +if __name__ == "__main__": main()