X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/d01460d9393d611c0673723320f7a1e50c424e21..a7ff469869d3ccc8e88461543d37264f79bbad2d:/black.py diff --git a/black.py b/black.py index a61a40f..2c61880 100644 --- a/black.py +++ b/black.py @@ -3,14 +3,30 @@ import asyncio from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor -from functools import partial +from enum import Enum +from functools import partial, wraps import keyword +import logging +from multiprocessing import Manager import os from pathlib import Path import tokenize +import signal import sys from typing import ( - Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, ) from attr import dataclass, Factory @@ -32,19 +48,22 @@ Depth = int NodeType = int LeafID = int Priority = int +Index = int LN = Union[Leaf, Node] +SplitFunc = Callable[["Line", bool], Iterator["Line"]] out = partial(click.secho, bold=True, err=True) -err = partial(click.secho, fg='red', err=True) +err = partial(click.secho, fg="red", err=True) class NothingChanged(UserWarning): - """Raised by `format_file` when the reformatted code is the same as source.""" + """Raised by :func:`format_file` when reformatted code is the same as source.""" class CannotSplit(Exception): """A readable split that fits the allotted line length is impossible. - Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`. + Raised by :func:`left_hand_split`, :func:`right_hand_split`, and + :func:`delimiter_split`. """ @@ -76,32 +95,43 @@ class FormatOff(FormatError): """Found a comment like `# fmt: off` in the file.""" +class WriteBack(Enum): + NO = 0 + YES = 1 + DIFF = 2 + + @click.command() @click.option( - '-l', - '--line-length', + "-l", + "--line-length", type=int, default=DEFAULT_LINE_LENGTH, - help='How many character per line to allow.', + help="How many character per line to allow.", show_default=True, ) @click.option( - '--check', + "--check", is_flag=True, help=( - "Don't write back the files, just return the status. Return code 0 " + "Don't write the files back, just return the status. Return code 0 " "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) @click.option( - '--fast/--safe', + "--diff", + is_flag=True, + help="Don't write the files back, just output a diff for each file on stdout.", +) +@click.option( + "--fast/--safe", is_flag=True, - help='If --fast given, skip temporary sanity checks. [default: --safe]', + help="If --fast given, skip temporary sanity checks. [default: --safe]", ) @click.version_option(version=__version__) @click.argument( - 'src', + "src", nargs=-1, type=click.Path( exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True @@ -109,7 +139,12 @@ class FormatOff(FormatError): ) @click.pass_context def main( - ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] + ctx: click.Context, + line_length: int, + check: bool, + diff: bool, + fast: bool, + src: List[str], ) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] @@ -120,23 +155,34 @@ def main( elif p.is_file(): # if a file was explicitly given, we don't care about its extension sources.append(p) - elif s == '-': - sources.append(Path('-')) + elif s == "-": + sources.append(Path("-")) else: - err(f'invalid path: {s}') + err(f"invalid path: {s}") + if check and diff: + exc = click.ClickException("Options --check and --diff are mutually exclusive") + exc.exit_code = 2 + raise exc + + if check: + write_back = WriteBack.NO + elif diff: + write_back = WriteBack.DIFF + else: + write_back = WriteBack.YES if len(sources) == 0: ctx.exit(0) elif len(sources) == 1: p = sources[0] report = Report(check=check) try: - if not p.is_file() and str(p) == '-': + if not p.is_file() and str(p) == "-": changed = format_stdin_to_stdout( - line_length=line_length, fast=fast, write_back=not check + line_length=line_length, fast=fast, write_back=write_back ) else: changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check + p, line_length=line_length, fast=fast, write_back=write_back ) report.done(p, changed) except Exception as exc: @@ -149,88 +195,140 @@ def main( try: return_code = loop.run_until_complete( schedule_formatting( - sources, line_length, not check, fast, loop, executor + sources, line_length, write_back, fast, loop, executor ) ) finally: - loop.close() + shutdown(loop) ctx.exit(return_code) async def schedule_formatting( sources: List[Path], line_length: int, - write_back: bool, + write_back: WriteBack, fast: bool, loop: BaseEventLoop, executor: Executor, ) -> int: + """Run formatting of `sources` in parallel using the provided `executor`. + + (Use ProcessPoolExecutors for actual parallelism.) + + `line_length`, `write_back`, and `fast` options are passed to + :func:`format_file_in_place`. + """ + lock = None + if write_back == WriteBack.DIFF: + # For diff output, we need locks to ensure we don't interleave output + # from different processes. + manager = Manager() + lock = manager.Lock() tasks = { src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast, write_back + executor, format_file_in_place, src, line_length, fast, write_back, lock ) for src in sources } + _task_values = list(tasks.values()) + loop.add_signal_handler(signal.SIGINT, cancel, _task_values) + loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) await asyncio.wait(tasks.values()) cancelled = [] - report = Report() + report = Report(check=not write_back) for src, task in tasks.items(): if not task.done(): - report.failed(src, 'timed out, cancelling') + report.failed(src, "timed out, cancelling") task.cancel() cancelled.append(task) + elif task.cancelled(): + cancelled.append(task) elif task.exception(): report.failed(src, str(task.exception())) else: report.done(src, task.result()) if cancelled: - await asyncio.wait(cancelled, timeout=2) - out('All done! ✨ 🍰 ✨') + await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) + else: + out("All done! ✨ 🍰 ✨") click.echo(str(report)) return report.return_code def format_file_in_place( - src: Path, line_length: int, fast: bool, write_back: bool = False + src: Path, + line_length: int, + fast: bool, + write_back: WriteBack = WriteBack.NO, + lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: - """Format the file and rewrite if changed. Return True if changed.""" + """Format file under `src` path. Return True if changed. + + If `write_back` is True, write reformatted code back to stdout. + `line_length` and `fast` options are passed to :func:`format_file_contents`. + """ with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: - contents = format_file_contents( + dst_contents = format_file_contents( src_contents, line_length=line_length, fast=fast ) except NothingChanged: return False - if write_back: + if write_back == write_back.YES: with open(src, "w", encoding=src_buffer.encoding) as f: - f.write(contents) + f.write(dst_contents) + elif write_back == write_back.DIFF: + src_name = f"{src.name} (original)" + dst_name = f"{src.name} (formatted)" + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if lock: + lock.acquire() + try: + sys.stdout.write(diff_contents) + finally: + if lock: + lock.release() return True def format_stdin_to_stdout( - line_length: int, fast: bool, write_back: bool = False + line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO ) -> bool: - """Format file on stdin and pipe output to stdout. Return True if changed.""" - contents = sys.stdin.read() + """Format file on stdin. Return True if changed. + + If `write_back` is True, write reformatted code back to stdout. + `line_length` and `fast` arguments are passed to :func:`format_file_contents`. + """ + src = sys.stdin.read() try: - contents = format_file_contents(contents, line_length=line_length, fast=fast) + dst = format_file_contents(src, line_length=line_length, fast=fast) return True except NothingChanged: + dst = src return False finally: - if write_back: - sys.stdout.write(contents) + if write_back == WriteBack.YES: + sys.stdout.write(dst) + elif write_back == WriteBack.DIFF: + src_name = " (original)" + dst_name = " (formatted)" + sys.stdout.write(diff(src, dst, src_name, dst_name)) def format_file_contents( src_contents: str, line_length: int, fast: bool ) -> FileContent: - """Reformats a file and returns its contents and encoding.""" - if src_contents.strip() == '': + """Reformat contents a file and return new contents. + + If `fast` is False, additionally confirm that the reformatted code is + valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. + `line_length` is passed to :func:`format_str`. + """ + if src_contents.strip() == "": raise NothingChanged dst_contents = format_str(src_contents, line_length=line_length) @@ -244,7 +342,10 @@ def format_file_contents( def format_str(src_contents: str, line_length: int) -> FileContent: - """Reformats a string and returns new contents.""" + """Reformat a string and return new contents. + + `line_length` determines how many characters per line are allowed. + """ src_node = lib2to3_parse(src_contents) dst_contents = "" lines = LineGenerator() @@ -274,8 +375,8 @@ GRAMMARS = [ def lib2to3_parse(src_txt: str) -> Node: """Given a string with source, return the lib2to3 Node.""" grammar = pygram.python_grammar_no_print_statement - if src_txt[-1] != '\n': - nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n' + if src_txt[-1] != "\n": + nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n" src_txt += nl for grammar in GRAMMARS: drv = driver.Driver(grammar, pytree.convert) @@ -305,14 +406,14 @@ def lib2to3_unparse(node: Node) -> str: return code -T = TypeVar('T') +T = TypeVar("T") class Visitor(Generic[T]): """Basic lib2to3 visitor that yields things of type `T` on `visit()`.""" def visit(self, node: LN) -> Iterator[T]: - """Main method to start the visit process. Yields objects of type `T`. + """Main method to visit `node` and its children. It tries to find a `visit_*()` method for the given `node.type`, like `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects. @@ -325,7 +426,7 @@ class Visitor(Generic[T]): name = token.tok_name[node.type] else: name = type_repr(node.type) - yield from getattr(self, f'visit_{name}', self.visit_default)(node) + yield from getattr(self, f"visit_{name}", self.visit_default)(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -339,28 +440,28 @@ class DebugVisitor(Visitor[T]): tree_depth: int = 0 def visit_default(self, node: LN) -> Iterator[T]: - indent = ' ' * (2 * self.tree_depth) + indent = " " * (2 * self.tree_depth) if isinstance(node, Node): _type = type_repr(node.type) - out(f'{indent}{_type}', fg='yellow') + out(f"{indent}{_type}", fg="yellow") self.tree_depth += 1 for child in node.children: yield from self.visit(child) self.tree_depth -= 1 - out(f'{indent}/{_type}', fg='yellow', bold=False) + out(f"{indent}/{_type}", fg="yellow", bold=False) else: _type = token.tok_name.get(node.type, str(node.type)) - out(f'{indent}{_type}', fg='blue', nl=False) + out(f"{indent}{_type}", fg="blue", nl=False) if node.prefix: # We don't have to handle prefixes for `Node` objects since # that delegates to the first child anyway. - out(f' {node.prefix!r}', fg='green', bold=False, nl=False) - out(f' {node.value!r}', fg='blue', bold=False) + out(f" {node.prefix!r}", fg="green", bold=False, nl=False) + out(f" {node.value!r}", fg="blue", bold=False) @classmethod def show(cls, code: str) -> None: - """Pretty-prints a given string of `code`. + """Pretty-print the lib2to3 AST of a given string of `code`. Convenience method for debugging. """ @@ -370,7 +471,7 @@ class DebugVisitor(Visitor[T]): KEYWORDS = set(keyword.kwlist) WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -FLOW_CONTROL = {'return', 'raise', 'break', 'continue'} +FLOW_CONTROL = {"return", "raise", "break", "continue"} STATEMENT = { syms.if_stmt, syms.while_stmt, @@ -382,7 +483,7 @@ STATEMENT = { syms.classdef, } STANDALONE_COMMENT = 153 -LOGIC_OPERATORS = {'and', 'or'} +LOGIC_OPERATORS = {"and", "or"} COMPARATORS = { token.LESS, token.GREATER, @@ -406,6 +507,7 @@ MATH_OPERATORS = { token.DOUBLESTAR, token.DOUBLESLASH, } +VARARGS = {token.STAR, token.DOUBLESTAR} COMPREHENSION_PRIORITY = 20 COMMA_PRIORITY = 10 LOGIC_PRIORITY = 5 @@ -418,17 +520,13 @@ MATH_PRIORITY = 1 class BracketTracker: """Keeps track of brackets on a line.""" - #: Current bracket depth. depth: int = 0 - #: All currently unclosed brackets. bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) - #: All current delimiters with their assigned priority. delimiters: Dict[LeafID, Priority] = Factory(dict) - #: Last processed leaf, if any. previous: Optional[Leaf] = None def mark(self, leaf: Leaf) -> None: - """Marks `leaf` with bracket-related metadata. Keeps track of delimiters. + """Mark `leaf` with bracket-related metadata. Keep track of delimiters. All leaves receive an int `bracket_depth` field that stores how deep within brackets a given leaf is. 0 means there are no enclosing brackets @@ -451,43 +549,23 @@ class BracketTracker: leaf.opening_bracket = opening_bracket leaf.bracket_depth = self.depth if self.depth == 0: - delim = is_delimiter(leaf) - if delim: - self.delimiters[id(leaf)] = delim - elif self.previous is not None: - if leaf.type == token.STRING and self.previous.type == token.STRING: - self.delimiters[id(self.previous)] = STRING_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == 'for' - and leaf.parent - and leaf.parent.type in {syms.comp_for, syms.old_comp_for} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value == 'if' - and leaf.parent - and leaf.parent.type in {syms.comp_if, syms.old_comp_if} - ): - self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY - elif ( - leaf.type == token.NAME - and leaf.value in LOGIC_OPERATORS - and leaf.parent - ): - self.delimiters[id(self.previous)] = LOGIC_PRIORITY + after_delim = is_split_after_delimiter(leaf, self.previous) + before_delim = is_split_before_delimiter(leaf, self.previous) + if after_delim > before_delim: + self.delimiters[id(leaf)] = after_delim + elif before_delim > after_delim and self.previous is not None: + self.delimiters[id(self.previous)] = before_delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 self.previous = leaf def any_open_brackets(self) -> bool: - """Returns True if there is an yet unmatched open bracket on the line.""" + """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: - """Returns the highest priority of a delimiter found on the line. + """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_delimiter()` returns. """ @@ -498,12 +576,9 @@ class BracketTracker: class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" - #: indentation level depth: int = 0 - #: list of leaves leaves: List[Leaf] = Factory(list) - #: inline comments that belong on this line - comments: Dict[LeafID, Leaf] = Factory(dict) + comments: List[Tuple[Index, Leaf]] = Factory(list) bracket_tracker: BracketTracker = Factory(BracketTracker) inside_brackets: bool = False has_for: bool = False @@ -532,16 +607,31 @@ class Line: self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) self.maybe_increment_for_loop_variable(leaf) - if self.maybe_adapt_standalone_comment(leaf): - return if not self.append_comment(leaf): self.leaves.append(leaf) + def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None: + """Like :func:`append()` but disallow invalid standalone comment structure. + + Raises ValueError when any `leaf` is appended after a standalone comment + or when a standalone comment is not the first leaf on the line. + """ + if self.bracket_tracker.depth == 0: + if self.is_comment: + raise ValueError("cannot append to standalone comments") + + if self.leaves and leaf.type == STANDALONE_COMMENT: + raise ValueError( + "cannot append standalone comments to a populated line" + ) + + self.append(leaf, preformatted=preformatted) + @property def is_comment(self) -> bool: """Is this line a standalone comment?""" - return bool(self) and self.leaves[0].type == STANDALONE_COMMENT + return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT @property def is_decorator(self) -> bool: @@ -555,11 +645,11 @@ class Line: @property def is_class(self) -> bool: - """Is this a class definition?""" + """Is this line a class definition?""" return ( bool(self) and self.leaves[0].type == token.NAME - and self.leaves[0].value == 'class' + and self.leaves[0].value == "class" ) @property @@ -575,18 +665,18 @@ class Line: except IndexError: second_leaf = None return ( - (first_leaf.type == token.NAME and first_leaf.value == 'def') + (first_leaf.type == token.NAME and first_leaf.value == "def") or ( first_leaf.type == token.ASYNC and second_leaf is not None and second_leaf.type == token.NAME - and second_leaf.value == 'def' + and second_leaf.value == "def" ) ) @property def is_flow_control(self) -> bool: - """Is this a flow control statement? + """Is this line a flow control statement? Those are `return`, `raise`, `break`, and `continue`. """ @@ -598,13 +688,22 @@ class Line: @property def is_yield(self) -> bool: - """Is this a yield statement?""" + """Is this line a yield statement?""" return ( bool(self) and self.leaves[0].type == token.NAME - and self.leaves[0].value == 'yield' + and self.leaves[0].value == "yield" ) + @property + def contains_standalone_comments(self) -> bool: + """If so, needs to be split before emitting.""" + for leaf in self.leaves: + if leaf.type == STANDALONE_COMMENT: + return True + + return False + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" if not ( @@ -615,13 +714,13 @@ class Line: return False if closing.type == token.RBRACE: - self.leaves.pop() + self.remove_trailing_comma() return True if closing.type == token.RSQB: comma = self.leaves[-1] if comma.parent and comma.parent.type == syms.listmaker: - self.leaves.pop() + self.remove_trailing_comma() return True # For parens let's check if it's safe to remove the comma. If the @@ -649,7 +748,7 @@ class Line: break if commas > 1: - self.leaves.pop() + self.remove_trailing_comma() return True return False @@ -657,10 +756,10 @@ class Line: def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: """In a for loop, or comprehension, the variables are often unpacks. - To avoid splitting on the comma in this situation, we will increase - the depth of tokens between `for` and `in`. + To avoid splitting on the comma in this situation, increase the depth of + tokens between `for` and `in`. """ - if leaf.type == token.NAME and leaf.value == 'for': + if leaf.type == token.NAME and leaf.value == "for": self.has_for = True self.bracket_tracker.depth += 1 self._for_loop_variable = True @@ -670,77 +769,74 @@ class Line: def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: """See `maybe_increment_for_loop_variable` above for explanation.""" - if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in': + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": self.bracket_tracker.depth -= 1 self._for_loop_variable = False return True return False - def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: - """Hack a standalone comment to act as a trailing comment for line splitting. - - If this line has brackets and a standalone `comment`, we need to adapt - it to be able to still reformat the line. - - This is not perfect, the line to which the standalone comment gets - appended will appear "too long" when splitting. - """ - if not ( + def append_comment(self, comment: Leaf) -> bool: + """Add an inline or standalone comment to the line.""" + if ( comment.type == STANDALONE_COMMENT and self.bracket_tracker.any_open_brackets() ): + comment.prefix = "" return False - comment.type = token.COMMENT - comment.prefix = '\n' + ' ' * (self.depth + 1) - return self.append_comment(comment) - - def append_comment(self, comment: Leaf) -> bool: - """Add an inline comment to the line.""" if comment.type != token.COMMENT: return False - try: - after = id(self.last_non_delimiter()) - except LookupError: + after = len(self.leaves) - 1 + if after == -1: comment.type = STANDALONE_COMMENT - comment.prefix = '' + comment.prefix = "" return False else: - if after in self.comments: - self.comments[after].value += str(comment) - else: - self.comments[after] = comment + self.comments.append((after, comment)) return True - def last_non_delimiter(self) -> Leaf: - """Returns the last non-delimiter on the line. Raises LookupError otherwise.""" - for i in range(len(self.leaves)): - last = self.leaves[-i - 1] - if not is_delimiter(last): - return last + def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`.""" + for _leaf_index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return + + for index, comment_after in self.comments: + if _leaf_index == index: + yield comment_after - raise LookupError("No non-delimiters found") + def remove_trailing_comma(self) -> None: + """Remove the trailing comma and moves the comments attached to it.""" + comma_index = len(self.leaves) - 1 + for i in range(len(self.comments)): + comment_index, comment = self.comments[i] + if comment_index == comma_index: + self.comments[i] = (comma_index - 1, comment) + self.leaves.pop() def __str__(self) -> str: """Render the line.""" if not self: - return '\n' + return "\n" - indent = ' ' * self.depth + indent = " " * self.depth leaves = iter(self.leaves) first = next(leaves) - res = f'{first.prefix}{indent}{first.value}' + res = f"{first.prefix}{indent}{first.value}" for leaf in leaves: res += str(leaf) - for comment in self.comments.values(): + for _, comment in self.comments: res += str(comment) - return res + '\n' + return res + "\n" def __bool__(self) -> bool: - """Returns True if the line has leaves or comments.""" + """Return True if the line has leaves or comments.""" return bool(self.leaves or self.comments) @@ -767,8 +863,21 @@ class UnformattedLines(Line): elif leaf.type == token.DEDENT: self.depth -= 1 + def __str__(self) -> str: + """Render unformatted lines from leaves which were added with `append()`. + + `depth` is not used for indentation in this case. + """ + if not self: + return "\n" + + res = "" + for leaf in self.leaves: + res += str(leaf) + return res + def append_comment(self, comment: Leaf) -> bool: - """Not implemented in this class.""" + """Not implemented in this class. Raises `NotImplementedError`.""" raise NotImplementedError("Unformatted lines don't store comments separately.") def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: @@ -779,23 +888,6 @@ class UnformattedLines(Line): """Does nothing and returns False.""" return False - def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: - """Does nothing and returns False.""" - return False - - def __str__(self) -> str: - """Renders unformatted lines from leaves which were added with `append()`. - - `depth` is not used for indentation in this case. - """ - if not self: - return '\n' - - res = '' - for leaf in self.leaves: - res += str(leaf) - return res - @dataclass class EmptyLineTracker: @@ -811,11 +903,11 @@ class EmptyLineTracker: previous_defs: List[int] = Factory(list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: - """Returns the number of extra empty lines before and after the `current_line`. + """Return the number of extra empty lines before and after the `current_line`. - This is for separating `def`, `async def` and `class` with extra empty lines - (two on module-level), as well as providing an extra empty line after flow - control keywords to make them more prominent. + This is for separating `def`, `async def` and `class` with extra empty + lines (two on module-level), as well as providing an extra empty line + after flow control keywords to make them more prominent. """ if isinstance(current_line, UnformattedLines): return 0, 0 @@ -833,9 +925,9 @@ class EmptyLineTracker: if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] - before = first_leaf.prefix.count('\n') + before = first_leaf.prefix.count("\n") before = min(before, max_allowed) - first_leaf.prefix = '' + first_leaf.prefix = "" else: before = 0 depth = current_line.depth @@ -909,7 +1001,10 @@ class LineGenerator(Visitor[Line]): yield complete_line def visit(self, node: LN) -> Iterator[Line]: - """Main method to start the visit process. Yields :class:`Line` objects.""" + """Main method to visit `node` and its children. + + Yields :class:`Line` objects. + """ if isinstance(self.current_line, UnformattedLines): # File contained `# fmt: off` yield from self.visit_unformatted(node) @@ -951,23 +1046,25 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) + if node.type == token.STRING: + normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) yield from super().visit_default(node) def visit_INDENT(self, node: Node) -> Iterator[Line]: - """Increases indentation level, maybe yields a line.""" + """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) def visit_DEDENT(self, node: Node) -> Iterator[Line]: - """Decreases indentation level, maybe yields a line.""" + """Decrease indentation level, maybe yield a line.""" # DEDENT has no value. Additionally, in blib2to3 it never holds comments. yield from self.line(-1) def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]: - """Visits a statement. + """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, `def`, `with`, and `class`. @@ -982,7 +1079,7 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) def visit_simple_stmt(self, node: Node) -> Iterator[Line]: - """Visits a statement without nested statements.""" + """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: yield from self.line(+1) @@ -994,7 +1091,7 @@ class LineGenerator(Visitor[Line]): yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: - """Visits `async def`, `async for`, `async with`.""" + """Visit `async def`, `async for`, `async with`.""" yield from self.line() children = iter(node.children) @@ -1009,23 +1106,17 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) def visit_decorators(self, node: Node) -> Iterator[Line]: - """Visits decorators.""" + """Visit decorators.""" for child in node.children: yield from self.line() yield from self.visit(child) def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]: - """Semicolons are always removed. - - Statements between them are put on separate lines. - """ + """Remove a semicolon and put the other statement on a separate line.""" yield from self.line() def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]: - """End of file. - - Process outstanding comments and end with a newline. - """ + """End of file. Process outstanding comments and end with a newline.""" yield from self.visit_default(leaf) yield from self.line() @@ -1043,17 +1134,21 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit(node) + if node.type == token.ENDMARKER: + # somebody decided not to put a final `# fmt: on` + yield from self.line() + def __attrs_post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt - self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'}) - self.visit_while_stmt = partial(v, keywords={'while', 'else'}) - self.visit_for_stmt = partial(v, keywords={'for', 'else'}) - self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'}) - self.visit_except_clause = partial(v, keywords={'except'}) - self.visit_funcdef = partial(v, keywords={'def'}) - self.visit_with_stmt = partial(v, keywords={'with'}) - self.visit_classdef = partial(v, keywords={'class'}) + self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}) + self.visit_while_stmt = partial(v, keywords={"while", "else"}) + self.visit_for_stmt = partial(v, keywords={"for", "else"}) + self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"}) + self.visit_except_clause = partial(v, keywords={"except"}) + self.visit_funcdef = partial(v, keywords={"def"}) + self.visit_with_stmt = partial(v, keywords={"with"}) + self.visit_classdef = partial(v, keywords={"class"}) self.visit_async_funcdef = self.visit_async_stmt self.visit_decorated = self.visit_decorators @@ -1067,9 +1162,9 @@ ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} def whitespace(leaf: Leaf) -> str: # noqa C901 """Return whitespace prefix if needed for the given `leaf`.""" - NO = '' - SPACE = ' ' - DOUBLESPACE = ' ' + NO = "" + SPACE = " " + DOUBLESPACE = " " t = leaf.type p = leaf.parent v = leaf.value @@ -1133,7 +1228,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 and prevp.parent.type == syms.shift_expr and prevp.prev_sibling and prevp.prev_sibling.type == token.NAME - and prevp.prev_sibling.value == 'print' # type: ignore + and prevp.prev_sibling.value == "print" # type: ignore ): # Python 2 print chevron return NO @@ -1290,7 +1385,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO elif t == token.NAME: - if v == 'import': + if v == "import": return SPACE if prev and prev.type == token.DOT: @@ -1303,7 +1398,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: - """Returns the first leaf that precedes `node`, if any.""" + """Return the first leaf that precedes `node`, if any.""" while node: res = node.prev_sibling if res: @@ -1320,17 +1415,35 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None -def is_delimiter(leaf: Leaf) -> int: - """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter. +def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line break after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break after themselves. Higher numbers are higher priority. """ if leaf.type == token.COMMA: return COMMA_PRIORITY - if leaf.type in COMPARATORS: - return COMPARATOR_PRIORITY + if ( + leaf.type in VARARGS + and leaf.parent + and leaf.parent.type in {syms.argument, syms.typedargslist} + ): + return MATH_PRIORITY + return 0 + + +def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter, given a line before after it. + + The delimiter priorities returned here are from those delimiters that would + cause a line break before themselves. + + Higher numbers are higher priority. + """ if ( leaf.type in MATH_OPERATORS and leaf.parent @@ -1338,11 +1451,51 @@ def is_delimiter(leaf: Leaf) -> int: ): return MATH_PRIORITY + if leaf.type in COMPARATORS: + return COMPARATOR_PRIORITY + + if ( + leaf.type == token.STRING + and previous is not None + and previous.type == token.STRING + ): + return STRING_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "for" + and leaf.parent + and leaf.parent.type in {syms.comp_for, syms.old_comp_for} + ): + return COMPREHENSION_PRIORITY + + if ( + leaf.type == token.NAME + and leaf.value == "if" + and leaf.parent + and leaf.parent.type in {syms.comp_if, syms.old_comp_if} + ): + return COMPREHENSION_PRIORITY + + if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent: + return LOGIC_PRIORITY + return 0 +def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int: + """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. + + Higher numbers are higher priority. + """ + return max( + is_split_before_delimiter(leaf, previous), + is_split_after_delimiter(leaf, previous), + ) + + def generate_comments(leaf: Leaf) -> Iterator[Leaf]: - """Cleans the prefix of the `leaf` and generates comments from it, if any. + """Clean the prefix of the `leaf` and generate comments from it, if any. Comments in lib2to3 are shoved into the whitespace prefix. This happens in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation @@ -1364,17 +1517,17 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: if not p: return - if '#' not in p: + if "#" not in p: return consumed = 0 nlines = 0 - for index, line in enumerate(p.split('\n')): + for index, line in enumerate(p.split("\n")): consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() if not line: nlines += 1 - if not line.startswith('#'): + if not line.startswith("#"): continue if index == 0 and leaf.type != token.ENDMARKER: @@ -1382,19 +1535,24 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: else: comment_type = STANDALONE_COMMENT comment = make_comment(line) - yield Leaf(comment_type, comment, prefix='\n' * nlines) + yield Leaf(comment_type, comment, prefix="\n" * nlines) - if comment in {'# fmt: on', '# yapf: enable'}: + if comment in {"# fmt: on", "# yapf: enable"}: raise FormatOn(consumed) - if comment in {'# fmt: off', '# yapf: disable'}: - raise FormatOff(consumed) + if comment in {"# fmt: off", "# yapf: disable"}: + if comment_type == STANDALONE_COMMENT: + raise FormatOff(consumed) + + prev = preceding_leaf(leaf) + if not prev or prev.type in WHITESPACE: # standalone comment in disguise + raise FormatOff(consumed) nlines = 0 def make_comment(content: str) -> str: - """Returns a consistently formatted comment from the given `content` string. + """Return a consistently formatted comment from the given `content` string. All comments (except for "##", "#!", "#:") should have a single space between the hash sign and the content. @@ -1403,19 +1561,19 @@ def make_comment(content: str) -> str: """ content = content.rstrip() if not content: - return '#' + return "#" - if content[0] == '#': + if content[0] == "#": content = content[1:] - if content and content[0] not in ' !:#': - content = ' ' + content - return '#' + content + if content and content[0] not in " !:#": + content = " " + content + return "#" + content def split_line( line: Line, line_length: int, inner: bool = False, py36: bool = False ) -> Iterator[Line]: - """Splits a `line` into potentially many lines. + """Split a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the @@ -1425,23 +1583,24 @@ def split_line( If `py36` is True, splitting may generate syntax that is only compatible with Python 3.6 and later. """ - if isinstance(line, UnformattedLines): + if isinstance(line, UnformattedLines) or line.is_comment: yield line return - line_str = str(line).strip('\n') - if len(line_str) <= line_length and '\n' not in line_str: + line_str = str(line).strip("\n") + if ( + len(line_str) <= line_length + and "\n" not in line_str # multiline strings + and not line.contains_standalone_comments + ): yield line return + split_funcs: List[SplitFunc] if line.is_def: split_funcs = [left_hand_split] elif line.inside_brackets: - split_funcs = [delimiter_split] - if '\n' not in line_str: - # Only attempt RHS if we don't have multiline strings or comments - # on this line. - split_funcs.append(right_hand_split) + split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: split_funcs = [right_hand_split] for split_func in split_funcs: @@ -1450,8 +1609,8 @@ def split_line( # split altogether. result: List[Line] = [] try: - for l in split_func(line, py36=py36): - if str(l).strip('\n') == line_str: + for l in split_func(line, py36): + if str(l).strip("\n") == line_str: raise CannotSplit("Split function returned an unchanged result") result.extend( @@ -1469,7 +1628,7 @@ def split_line( def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Splits line into many lines, starting with the first matching bracket pair. + """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. Prefer RHS otherwise. @@ -1503,8 +1662,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: ): for leaf in leaves: result.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: + for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): @@ -1513,7 +1671,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Splits line into many lines, starting with the last matching bracket pair.""" + """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) tail = Line(depth=line.depth) @@ -1543,8 +1701,7 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: ): for leaf in leaves: result.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: + for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): @@ -1578,10 +1735,25 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None ) +def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc: + """Normalize prefix of the first leaf in every line returned by `split_func`. + + This is a decorator over relevant split functions. + """ + + @wraps(split_func) + def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]: + for l in split_func(line, py36): + normalize_prefix(l.leaves[0], inside_brackets=True) + yield l + + return split_wrapper + + +@dont_increase_indentation def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Splits according to delimiters of the highest priority. + """Split according to delimiters of the highest priority. - This kind of split doesn't increase indentation. If `py36` is True, the split will add trailing commas also in function signatures that contain `*` and `**`. """ @@ -1601,11 +1773,24 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) lowest_depth = sys.maxsize trailing_comma_safe = True + + def append_to_line(leaf: Leaf) -> Iterator[Line]: + """Append `leaf` to current line or to new line if appending impossible.""" + nonlocal current_line + try: + current_line.append_safe(leaf, preformatted=True) + except ValueError as ve: + yield current_line + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line.append(leaf) + for leaf in line.leaves: - current_line.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: - current_line.append(comment_after, preformatted=True) + yield from append_to_line(leaf) + + for comment_after in line.comments_after(leaf): + yield from append_to_line(comment_after) + lowest_depth = min(lowest_depth, leaf.bracket_depth) if ( leaf.bracket_depth == lowest_depth @@ -1615,55 +1800,132 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: - normalize_prefix(current_line.leaves[0], inside_brackets=True) yield current_line current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) if current_line: if ( - delimiter_priority == COMMA_PRIORITY + trailing_comma_safe + and delimiter_priority == COMMA_PRIORITY and current_line.leaves[-1].type != token.COMMA - and trailing_comma_safe + and current_line.leaves[-1].type != STANDALONE_COMMENT ): - current_line.append(Leaf(token.COMMA, ',')) - normalize_prefix(current_line.leaves[0], inside_brackets=True) + current_line.append(Leaf(token.COMMA, ",")) + yield current_line + + +@dont_increase_indentation +def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: + """Split standalone comments from the rest of the line.""" + for leaf in line.leaves: + if leaf.type == STANDALONE_COMMENT: + if leaf.bracket_depth == 0: + break + + else: + raise CannotSplit("Line does not have any standalone comments") + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + + def append_to_line(leaf: Leaf) -> Iterator[Line]: + """Append `leaf` to current line or to new line if appending impossible.""" + nonlocal current_line + try: + current_line.append_safe(leaf, preformatted=True) + except ValueError as ve: + yield current_line + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line.append(leaf) + + for leaf in line.leaves: + yield from append_to_line(leaf) + + for comment_after in line.comments_after(leaf): + yield from append_to_line(comment_after) + + if current_line: yield current_line def is_import(leaf: Leaf) -> bool: - """Returns True if the given leaf starts an import statement.""" + """Return True if the given leaf starts an import statement.""" p = leaf.parent t = leaf.type v = leaf.value return bool( t == token.NAME and ( - (v == 'import' and p and p.type == syms.import_name) - or (v == 'from' and p and p.type == syms.import_from) + (v == "import" and p and p.type == syms.import_name) + or (v == "from" and p and p.type == syms.import_from) ) ) def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: - """Leaves existing extra newlines if not `inside_brackets`. + """Leave existing extra newlines if not `inside_brackets`. Remove everything + else. - Removes everything else. Note: don't use backslashes for formatting or - you'll lose your voting rights. + Note: don't use backslashes for formatting or you'll lose your voting rights. """ if not inside_brackets: - spl = leaf.prefix.split('#') - if '\\' not in spl[0]: - nl_count = spl[-1].count('\n') + spl = leaf.prefix.split("#") + if "\\" not in spl[0]: + nl_count = spl[-1].count("\n") if len(spl) > 1: nl_count -= 1 - leaf.prefix = '\n' * nl_count + leaf.prefix = "\n" * nl_count return - leaf.prefix = '' + leaf.prefix = "" + + +def normalize_string_quotes(leaf: Leaf) -> None: + """Prefer double quotes but only if it doesn't cause more escaping. + + Adds or removes backslashes as appropriate. Doesn't parse and fix + strings nested in f-strings (yet). + + Note: Mutates its argument. + """ + value = leaf.value.lstrip("furbFURB") + if value[:3] == '"""': + return + + elif value[:3] == "'''": + orig_quote = "'''" + new_quote = '"""' + elif value[0] == '"': + orig_quote = '"' + new_quote = "'" + else: + orig_quote = "'" + new_quote = '"' + first_quote_pos = leaf.value.find(orig_quote) + if first_quote_pos == -1: + return # There's an internal error + + body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] + new_body = body.replace(f"\\{orig_quote}", orig_quote).replace( + new_quote, f"\\{new_quote}" + ) + if new_quote == '"""' and new_body[-1] == '"': + # edge case: + new_body = new_body[:-1] + '\\"' + orig_escape_count = body.count("\\") + new_escape_count = new_body.count("\\") + if new_escape_count > orig_escape_count: + return # Do not introduce more escaping + + if new_escape_count == orig_escape_count and orig_quote == '"': + return # Prefer double quotes + + prefix = leaf.value[:first_quote_pos] + leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}" def is_python36(node: Node) -> bool: - """Returns True if the current file is using Python 3.6+ features. + """Return True if the current file is using Python 3.6+ features. Currently looking for: - f-strings; and @@ -1672,7 +1934,7 @@ def is_python36(node: Node) -> bool: for n in node.pre_order(): if n.type == token.STRING: value_head = n.value[:2] # type: ignore - if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}: return True elif ( @@ -1687,14 +1949,14 @@ def is_python36(node: Node) -> bool: return False -PYTHON_EXTENSIONS = {'.py'} +PYTHON_EXTENSIONS = {".py"} BLACKLISTED_DIRECTORIES = { - 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' + "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" } def gen_python_files_in_dir(path: Path) -> Iterator[Path]: - """Generates all files under `path` which aren't under BLACKLISTED_DIRECTORIES + """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES and have one of the PYTHON_EXTENSIONS. """ for child in path.iterdir(): @@ -1719,21 +1981,27 @@ class Report: def done(self, src: Path, changed: bool) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed: - reformatted = 'would reformat' if self.check else 'reformatted' - out(f'{reformatted} {src}') + reformatted = "would reformat" if self.check else "reformatted" + out(f"{reformatted} {src}") self.change_count += 1 else: - out(f'{src} already well formatted, good job.', bold=False) + out(f"{src} already well formatted, good job.", bold=False) self.same_count += 1 def failed(self, src: Path, message: str) -> None: """Increment the counter for failed reformatting. Write out a message.""" - err(f'error: cannot format {src}: {message}') + err(f"error: cannot format {src}: {message}") self.failure_count += 1 @property def return_code(self) -> int: - """Which return code should the app use considering the current state.""" + """Return the exit code that the app should use. + + This considers the current state of changed files and failures: + - if there were any failures, return 123; + - if any files were changed and --check is being used, return 1; + - otherwise return 0. + """ # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with # 126 we have special returncodes reserved by the shell. if self.failure_count: @@ -1745,7 +2013,7 @@ class Report: return 0 def __str__(self) -> str: - """A color report of the current state. + """Render a color report of the current state. Use `click.unstyle` to remove colors. """ @@ -1759,26 +2027,23 @@ class Report: failed = "failed to reformat" report = [] if self.change_count: - s = 's' if self.change_count > 1 else '' + s = "s" if self.change_count > 1 else "" report.append( - click.style(f'{self.change_count} file{s} {reformatted}', bold=True) + click.style(f"{self.change_count} file{s} {reformatted}", bold=True) ) if self.same_count: - s = 's' if self.same_count > 1 else '' - report.append(f'{self.same_count} file{s} {unchanged}') + s = "s" if self.same_count > 1 else "" + report.append(f"{self.same_count} file{s} {unchanged}") if self.failure_count: - s = 's' if self.failure_count > 1 else '' + s = "s" if self.failure_count > 1 else "" report.append( - click.style(f'{self.failure_count} file{s} {failed}', fg='red') + click.style(f"{self.failure_count} file{s} {failed}", fg="red") ) - return ', '.join(report) + '.' + return ", ".join(report) + "." def assert_equivalent(src: str, dst: str) -> None: - """Raises AssertionError if `src` and `dst` aren't equivalent. - - This is a temporary sanity check until Black becomes stable. - """ + """Raise AssertionError if `src` and `dst` aren't equivalent.""" import ast import traceback @@ -1821,17 +2086,17 @@ def assert_equivalent(src: str, dst: str) -> None: try: dst_ast = ast.parse(dst) except Exception as exc: - log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst) + log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( f"INTERNAL ERROR: Black produced invalid code: {exc}. " f"Please report a bug on https://github.com/ambv/black/issues. " f"This invalid output might be helpful: {log}" ) from None - src_ast_str = '\n'.join(_v(src_ast)) - dst_ast_str = '\n'.join(_v(dst_ast)) + src_ast_str = "\n".join(_v(src_ast)) + dst_ast_str = "\n".join(_v(dst_ast)) if src_ast_str != dst_ast_str: - log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst')) + log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( f"INTERNAL ERROR: Black produced code that is not equivalent to " f"the source. " @@ -1841,15 +2106,12 @@ def assert_equivalent(src: str, dst: str) -> None: def assert_stable(src: str, dst: str, line_length: int) -> None: - """Raises AssertionError if `dst` reformats differently the second time. - - This is a temporary sanity check until Black becomes stable. - """ + """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, line_length=line_length) if dst != newdst: log = dump_to_file( - diff(src, dst, 'source', 'first pass'), - diff(dst, newdst, 'first pass', 'second pass'), + diff(src, dst, "source", "first pass"), + diff(dst, newdst, "first pass", "second pass"), ) raise AssertionError( f"INTERNAL ERROR: Black produced different code on the second pass " @@ -1860,28 +2122,58 @@ def assert_stable(src: str, dst: str, line_length: int) -> None: def dump_to_file(*output: str) -> str: - """Dumps `output` to a temporary file. Returns path to the file.""" + """Dump `output` to a temporary file. Return path to the file.""" import tempfile with tempfile.NamedTemporaryFile( - mode='w', prefix='blk_', suffix='.log', delete=False + mode="w", prefix="blk_", suffix=".log", delete=False ) as f: for lines in output: f.write(lines) - f.write('\n') + if lines and lines[-1] != "\n": + f.write("\n") return f.name def diff(a: str, b: str, a_name: str, b_name: str) -> str: - """Returns a udiff string between strings `a` and `b`.""" + """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + '\n' for line in a.split('\n')] - b_lines = [line + '\n' for line in b.split('\n')] - return ''.join( + a_lines = [line + "\n" for line in a.split("\n")] + b_lines = [line + "\n" for line in b.split("\n")] + return "".join( difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) ) -if __name__ == '__main__': +def cancel(tasks: List[asyncio.Task]) -> None: + """asyncio signal handler that cancels all `tasks` and reports to stderr.""" + err("Aborted!") + for task in tasks: + task.cancel() + + +def shutdown(loop: BaseEventLoop) -> None: + """Cancel all pending tasks on `loop`, wait for them, and close the loop.""" + try: + # This part is borrowed from asyncio/runners.py in Python 3.7b2. + to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()] + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + loop.run_until_complete( + asyncio.gather(*to_cancel, loop=loop, return_exceptions=True) + ) + finally: + # `concurrent.futures.Future` objects cannot be cancelled once they + # are already running. There might be some when the `shutdown()` happened. + # Silence their logger's spew about the event loop being closed. + cf_logger = logging.getLogger("concurrent.futures") + cf_logger.setLevel(logging.CRITICAL) + loop.close() + + +if __name__ == "__main__": main()