X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/7f7b31058af65b245bfc1c35fd37f2ff6e78e43d..4dfec562ed3332212cb938d6d4da9671b503ac93:/black.py?ds=sidebyside diff --git a/black.py b/black.py index de86156..82fe5d1 100644 --- a/black.py +++ b/black.py @@ -3,14 +3,27 @@ import asyncio from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor -from functools import partial +from functools import partial, wraps import keyword +import logging import os from pathlib import Path import tokenize +import signal import sys from typing import ( - Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, ) from attr import dataclass, Factory @@ -32,19 +45,22 @@ Depth = int NodeType = int LeafID = int Priority = int +Index = int LN = Union[Leaf, Node] +SplitFunc = Callable[['Line', bool], Iterator['Line']] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg='red', err=True) class NothingChanged(UserWarning): - """Raised by `format_file()` when the reformatted code is the same as source.""" + """Raised by :func:`format_file` when reformatted code is the same as source.""" class CannotSplit(Exception): """A readable split that fits the allotted line length is impossible. - Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`. + Raised by :func:`left_hand_split`, :func:`right_hand_split`, and + :func:`delimiter_split`. """ @@ -153,7 +169,7 @@ def main( ) ) finally: - loop.close() + shutdown(loop) ctx.exit(return_code) @@ -178,21 +194,27 @@ async def schedule_formatting( ) for src in sources } + _task_values = list(tasks.values()) + loop.add_signal_handler(signal.SIGINT, cancel, _task_values) + loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) await asyncio.wait(tasks.values()) cancelled = [] - report = Report() + report = Report(check=not write_back) for src, task in tasks.items(): if not task.done(): report.failed(src, 'timed out, cancelling') task.cancel() cancelled.append(task) + elif task.cancelled(): + cancelled.append(task) elif task.exception(): report.failed(src, str(task.exception())) else: report.done(src, task.result()) if cancelled: - await asyncio.wait(cancelled, timeout=2) - out('All done! ✨ 🍰 ✨') + await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) + else: + out('All done! ✨ 🍰 ✨') click.echo(str(report)) return report.return_code @@ -244,7 +266,7 @@ def format_stdin_to_stdout( def format_file_contents( src_contents: str, line_length: int, fast: bool ) -> FileContent: - """Reformats a file and returns its contents and encoding. + """Reformat contents a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. @@ -264,7 +286,7 @@ def format_file_contents( def format_str(src_contents: str, line_length: int) -> FileContent: - """Reformats a string and returns new contents. + """Reformat a string and return new contents. `line_length` determines how many characters per line are allowed. """ @@ -383,7 +405,7 @@ class DebugVisitor(Visitor[T]): @classmethod def show(cls, code: str) -> None: - """Pretty-prints a given string of `code`. + """Pretty-print the lib2to3 AST of a given string of `code`. Convenience method for debugging. """ @@ -447,7 +469,7 @@ class BracketTracker: previous: Optional[Leaf] = None def mark(self, leaf: Leaf) -> None: - """Marks `leaf` with bracket-related metadata. Keeps track of delimiters. + """Mark `leaf` with bracket-related metadata. Keep track of delimiters. All leaves receive an int `bracket_depth` field that stores how deep within brackets a given leaf is. 0 means there are no enclosing brackets @@ -502,11 +524,11 @@ class BracketTracker: self.previous = leaf def any_open_brackets(self) -> bool: - """Returns True if there is an yet unmatched open bracket on the line.""" + """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: - """Returns the highest priority of a delimiter found on the line. + """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_delimiter()` returns. """ @@ -519,7 +541,7 @@ class Line: depth: int = 0 leaves: List[Leaf] = Factory(list) - comments: Dict[LeafID, Leaf] = Factory(dict) + comments: List[Tuple[Index, Leaf]] = Factory(list) bracket_tracker: BracketTracker = Factory(BracketTracker) inside_brackets: bool = False has_for: bool = False @@ -548,16 +570,31 @@ class Line: self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) self.maybe_increment_for_loop_variable(leaf) - if self.maybe_adapt_standalone_comment(leaf): - return if not self.append_comment(leaf): self.leaves.append(leaf) + def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None: + """Like :func:`append()` but disallow invalid standalone comment structure. + + Raises ValueError when any `leaf` is appended after a standalone comment + or when a standalone comment is not the first leaf on the line. + """ + if self.bracket_tracker.depth == 0: + if self.is_comment: + raise ValueError("cannot append to standalone comments") + + if self.leaves and leaf.type == STANDALONE_COMMENT: + raise ValueError( + "cannot append standalone comments to a populated line" + ) + + self.append(leaf, preformatted=preformatted) + @property def is_comment(self) -> bool: """Is this line a standalone comment?""" - return bool(self) and self.leaves[0].type == STANDALONE_COMMENT + return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT @property def is_decorator(self) -> bool: @@ -571,7 +608,7 @@ class Line: @property def is_class(self) -> bool: - """Is this a class definition?""" + """Is this line a class definition?""" return ( bool(self) and self.leaves[0].type == token.NAME @@ -602,7 +639,7 @@ class Line: @property def is_flow_control(self) -> bool: - """Is this a flow control statement? + """Is this line a flow control statement? Those are `return`, `raise`, `break`, and `continue`. """ @@ -614,13 +651,22 @@ class Line: @property def is_yield(self) -> bool: - """Is this a yield statement?""" + """Is this line a yield statement?""" return ( bool(self) and self.leaves[0].type == token.NAME and self.leaves[0].value == 'yield' ) + @property + def contains_standalone_comments(self) -> bool: + """If so, needs to be split before emitting.""" + for leaf in self.leaves: + if leaf.type == STANDALONE_COMMENT: + return True + + return False + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" if not ( @@ -631,13 +677,13 @@ class Line: return False if closing.type == token.RBRACE: - self.leaves.pop() + self.remove_trailing_comma() return True if closing.type == token.RSQB: comma = self.leaves[-1] if comma.parent and comma.parent.type == syms.listmaker: - self.leaves.pop() + self.remove_trailing_comma() return True # For parens let's check if it's safe to remove the comma. If the @@ -665,7 +711,7 @@ class Line: break if commas > 1: - self.leaves.pop() + self.remove_trailing_comma() return True return False @@ -673,8 +719,8 @@ class Line: def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: """In a for loop, or comprehension, the variables are often unpacks. - To avoid splitting on the comma in this situation, we will increase - the depth of tokens between `for` and `in`. + To avoid splitting on the comma in this situation, increase the depth of + tokens between `for` and `in`. """ if leaf.type == token.NAME and leaf.value == 'for': self.has_for = True @@ -693,52 +739,49 @@ class Line: return False - def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: - """Hack a standalone comment to act as a trailing comment for line splitting. - - If this line has brackets and a standalone `comment`, we need to adapt - it to be able to still reformat the line. - - This is not perfect, the line to which the standalone comment gets - appended will appear "too long" when splitting. - """ - if not ( + def append_comment(self, comment: Leaf) -> bool: + """Add an inline or standalone comment to the line.""" + if ( comment.type == STANDALONE_COMMENT and self.bracket_tracker.any_open_brackets() ): + comment.prefix = '' return False - comment.type = token.COMMENT - comment.prefix = '\n' + ' ' * (self.depth + 1) - return self.append_comment(comment) - - def append_comment(self, comment: Leaf) -> bool: - """Add an inline comment to the line.""" if comment.type != token.COMMENT: return False - try: - after = id(self.last_non_delimiter()) - except LookupError: + after = len(self.leaves) - 1 + if after == -1: comment.type = STANDALONE_COMMENT comment.prefix = '' return False else: - if after in self.comments: - self.comments[after].value += str(comment) - else: - self.comments[after] = comment + self.comments.append((after, comment)) return True - def last_non_delimiter(self) -> Leaf: - """Returns the last non-delimiter on the line. Raises LookupError otherwise.""" - for i in range(len(self.leaves)): - last = self.leaves[-i - 1] - if not is_delimiter(last): - return last + def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`.""" + for _leaf_index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return - raise LookupError("No non-delimiters found") + for index, comment_after in self.comments: + if _leaf_index == index: + yield comment_after + + def remove_trailing_comma(self) -> None: + """Remove the trailing comma and moves the comments attached to it.""" + comma_index = len(self.leaves) - 1 + for i in range(len(self.comments)): + comment_index, comment = self.comments[i] + if comment_index == comma_index: + self.comments[i] = (comma_index - 1, comment) + self.leaves.pop() def __str__(self) -> str: """Render the line.""" @@ -751,12 +794,12 @@ class Line: res = f'{first.prefix}{indent}{first.value}' for leaf in leaves: res += str(leaf) - for comment in self.comments.values(): + for _, comment in self.comments: res += str(comment) return res + '\n' def __bool__(self) -> bool: - """Returns True if the line has leaves or comments.""" + """Return True if the line has leaves or comments.""" return bool(self.leaves or self.comments) @@ -783,24 +826,8 @@ class UnformattedLines(Line): elif leaf.type == token.DEDENT: self.depth -= 1 - def append_comment(self, comment: Leaf) -> bool: - """Not implemented in this class.""" - raise NotImplementedError("Unformatted lines don't store comments separately.") - - def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: - """Does nothing and returns False.""" - return False - - def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: - """Does nothing and returns False.""" - return False - - def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: - """Does nothing and returns False.""" - return False - def __str__(self) -> str: - """Renders unformatted lines from leaves which were added with `append()`. + """Render unformatted lines from leaves which were added with `append()`. `depth` is not used for indentation in this case. """ @@ -812,6 +839,18 @@ class UnformattedLines(Line): res += str(leaf) return res + def append_comment(self, comment: Leaf) -> bool: + """Not implemented in this class. Raises `NotImplementedError`.""" + raise NotImplementedError("Unformatted lines don't store comments separately.") + + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: + """Does nothing and returns False.""" + return False + + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """Does nothing and returns False.""" + return False + @dataclass class EmptyLineTracker: @@ -827,11 +866,11 @@ class EmptyLineTracker: previous_defs: List[int] = Factory(list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: - """Returns the number of extra empty lines before and after the `current_line`. + """Return the number of extra empty lines before and after the `current_line`. - This is for separating `def`, `async def` and `class` with extra empty lines - (two on module-level), as well as providing an extra empty line after flow - control keywords to make them more prominent. + This is for separating `def`, `async def` and `class` with extra empty + lines (two on module-level), as well as providing an extra empty line + after flow control keywords to make them more prominent. """ if isinstance(current_line, UnformattedLines): return 0, 0 @@ -925,7 +964,10 @@ class LineGenerator(Visitor[Line]): yield complete_line def visit(self, node: LN) -> Iterator[Line]: - """Main method to start the visit process. Yields :class:`Line` objects.""" + """Main method to visit `node` and its children. + + Yields :class:`Line` objects. + """ if isinstance(self.current_line, UnformattedLines): # File contained `# fmt: off` yield from self.visit_unformatted(node) @@ -972,18 +1014,18 @@ class LineGenerator(Visitor[Line]): yield from super().visit_default(node) def visit_INDENT(self, node: Node) -> Iterator[Line]: - """Increases indentation level, maybe yields a line.""" + """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) def visit_DEDENT(self, node: Node) -> Iterator[Line]: - """Decreases indentation level, maybe yields a line.""" + """Decrease indentation level, maybe yield a line.""" # DEDENT has no value. Additionally, in blib2to3 it never holds comments. yield from self.line(-1) def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]: - """Visits a statement. + """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, `def`, `with`, and `class`. @@ -998,7 +1040,7 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) def visit_simple_stmt(self, node: Node) -> Iterator[Line]: - """Visits a statement without nested statements.""" + """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: yield from self.line(+1) @@ -1010,7 +1052,7 @@ class LineGenerator(Visitor[Line]): yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: - """Visits `async def`, `async for`, `async with`.""" + """Visit `async def`, `async for`, `async with`.""" yield from self.line() children = iter(node.children) @@ -1025,23 +1067,17 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) def visit_decorators(self, node: Node) -> Iterator[Line]: - """Visits decorators.""" + """Visit decorators.""" for child in node.children: yield from self.line() yield from self.visit(child) def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]: - """Semicolons are always removed. - - Statements between them are put on separate lines. - """ + """Remove a semicolon and put the other statement on a separate line.""" yield from self.line() def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]: - """End of file. - - Process outstanding comments and end with a newline. - """ + """End of file. Process outstanding comments and end with a newline.""" yield from self.visit_default(leaf) yield from self.line() @@ -1319,7 +1355,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: - """Returns the first leaf that precedes `node`, if any.""" + """Return the first leaf that precedes `node`, if any.""" while node: res = node.prev_sibling if res: @@ -1337,7 +1373,7 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: def is_delimiter(leaf: Leaf) -> int: - """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter. + """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. Higher numbers are higher priority. """ @@ -1358,7 +1394,7 @@ def is_delimiter(leaf: Leaf) -> int: def generate_comments(leaf: Leaf) -> Iterator[Leaf]: - """Cleans the prefix of the `leaf` and generates comments from it, if any. + """Clean the prefix of the `leaf` and generate comments from it, if any. Comments in lib2to3 are shoved into the whitespace prefix. This happens in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation @@ -1410,7 +1446,7 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: def make_comment(content: str) -> str: - """Returns a consistently formatted comment from the given `content` string. + """Return a consistently formatted comment from the given `content` string. All comments (except for "##", "#!", "#:") should have a single space between the hash sign and the content. @@ -1431,7 +1467,7 @@ def make_comment(content: str) -> str: def split_line( line: Line, line_length: int, inner: bool = False, py36: bool = False ) -> Iterator[Line]: - """Splits a `line` into potentially many lines. + """Split a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the @@ -1441,23 +1477,24 @@ def split_line( If `py36` is True, splitting may generate syntax that is only compatible with Python 3.6 and later. """ - if isinstance(line, UnformattedLines): + if isinstance(line, UnformattedLines) or line.is_comment: yield line return line_str = str(line).strip('\n') - if len(line_str) <= line_length and '\n' not in line_str: + if ( + len(line_str) <= line_length + and '\n' not in line_str # multiline strings + and not line.contains_standalone_comments + ): yield line return + split_funcs: List[SplitFunc] if line.is_def: split_funcs = [left_hand_split] elif line.inside_brackets: - split_funcs = [delimiter_split] - if '\n' not in line_str: - # Only attempt RHS if we don't have multiline strings or comments - # on this line. - split_funcs.append(right_hand_split) + split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: split_funcs = [right_hand_split] for split_func in split_funcs: @@ -1466,7 +1503,7 @@ def split_line( # split altogether. result: List[Line] = [] try: - for l in split_func(line, py36=py36): + for l in split_func(line, py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") @@ -1485,7 +1522,7 @@ def split_line( def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Splits line into many lines, starting with the first matching bracket pair. + """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. Prefer RHS otherwise. @@ -1519,8 +1556,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: ): for leaf in leaves: result.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: + for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): @@ -1529,7 +1565,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Splits line into many lines, starting with the last matching bracket pair.""" + """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) tail = Line(depth=line.depth) @@ -1559,8 +1595,7 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: ): for leaf in leaves: result.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: + for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): @@ -1594,10 +1629,25 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None ) +def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc: + """Normalize prefix of the first leaf in every line returned by `split_func`. + + This is a decorator over relevant split functions. + """ + + @wraps(split_func) + def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]: + for l in split_func(line, py36): + normalize_prefix(l.leaves[0], inside_brackets=True) + yield l + + return split_wrapper + + +@dont_increase_indentation def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: - """Splits according to delimiters of the highest priority. + """Split according to delimiters of the highest priority. - This kind of split doesn't increase indentation. If `py36` is True, the split will add trailing commas also in function signatures that contain `*` and `**`. """ @@ -1617,11 +1667,24 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) lowest_depth = sys.maxsize trailing_comma_safe = True + + def append_to_line(leaf: Leaf) -> Iterator[Line]: + """Append `leaf` to current line or to new line if appending impossible.""" + nonlocal current_line + try: + current_line.append_safe(leaf, preformatted=True) + except ValueError as ve: + yield current_line + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line.append(leaf) + for leaf in line.leaves: - current_line.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: - current_line.append(comment_after, preformatted=True) + yield from append_to_line(leaf) + + for comment_after in line.comments_after(leaf): + yield from append_to_line(comment_after) + lowest_depth = min(lowest_depth, leaf.bracket_depth) if ( leaf.bracket_depth == lowest_depth @@ -1631,7 +1694,6 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: - normalize_prefix(current_line.leaves[0], inside_brackets=True) yield current_line current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) @@ -1642,12 +1704,45 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: and trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) - normalize_prefix(current_line.leaves[0], inside_brackets=True) + yield current_line + + +@dont_increase_indentation +def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: + """Split standalone comments from the rest of the line.""" + for leaf in line.leaves: + if leaf.type == STANDALONE_COMMENT: + if leaf.bracket_depth == 0: + break + + else: + raise CannotSplit("Line does not have any standalone comments") + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + + def append_to_line(leaf: Leaf) -> Iterator[Line]: + """Append `leaf` to current line or to new line if appending impossible.""" + nonlocal current_line + try: + current_line.append_safe(leaf, preformatted=True) + except ValueError as ve: + yield current_line + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line.append(leaf) + + for leaf in line.leaves: + yield from append_to_line(leaf) + + for comment_after in line.comments_after(leaf): + yield from append_to_line(comment_after) + + if current_line: yield current_line def is_import(leaf: Leaf) -> bool: - """Returns True if the given leaf starts an import statement.""" + """Return True if the given leaf starts an import statement.""" p = leaf.parent t = leaf.type v = leaf.value @@ -1661,10 +1756,10 @@ def is_import(leaf: Leaf) -> bool: def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: - """Leaves existing extra newlines if not `inside_brackets`. + """Leave existing extra newlines if not `inside_brackets`. Remove everything + else. - Removes everything else. Note: don't use backslashes for formatting or - you'll lose your voting rights. + Note: don't use backslashes for formatting or you'll lose your voting rights. """ if not inside_brackets: spl = leaf.prefix.split('#') @@ -1679,7 +1774,7 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: def is_python36(node: Node) -> bool: - """Returns True if the current file is using Python 3.6+ features. + """Return True if the current file is using Python 3.6+ features. Currently looking for: - f-strings; and @@ -1710,7 +1805,7 @@ BLACKLISTED_DIRECTORIES = { def gen_python_files_in_dir(path: Path) -> Iterator[Path]: - """Generates all files under `path` which aren't under BLACKLISTED_DIRECTORIES + """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES and have one of the PYTHON_EXTENSIONS. """ for child in path.iterdir(): @@ -1749,7 +1844,13 @@ class Report: @property def return_code(self) -> int: - """Which return code should the app use considering the current state.""" + """Return the exit code that the app should use. + + This considers the current state of changed files and failures: + - if there were any failures, return 123; + - if any files were changed and --check is being used, return 1; + - otherwise return 0. + """ # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with # 126 we have special returncodes reserved by the shell. if self.failure_count: @@ -1761,7 +1862,7 @@ class Report: return 0 def __str__(self) -> str: - """A color report of the current state. + """Render a color report of the current state. Use `click.unstyle` to remove colors. """ @@ -1791,10 +1892,7 @@ class Report: def assert_equivalent(src: str, dst: str) -> None: - """Raises AssertionError if `src` and `dst` aren't equivalent. - - This is a temporary sanity check until Black becomes stable. - """ + """Raise AssertionError if `src` and `dst` aren't equivalent.""" import ast import traceback @@ -1857,10 +1955,7 @@ def assert_equivalent(src: str, dst: str) -> None: def assert_stable(src: str, dst: str, line_length: int) -> None: - """Raises AssertionError if `dst` reformats differently the second time. - - This is a temporary sanity check until Black becomes stable. - """ + """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, line_length=line_length) if dst != newdst: log = dump_to_file( @@ -1876,7 +1971,7 @@ def assert_stable(src: str, dst: str, line_length: int) -> None: def dump_to_file(*output: str) -> str: - """Dumps `output` to a temporary file. Returns path to the file.""" + """Dump `output` to a temporary file. Return path to the file.""" import tempfile with tempfile.NamedTemporaryFile( @@ -1889,7 +1984,7 @@ def dump_to_file(*output: str) -> str: def diff(a: str, b: str, a_name: str, b_name: str) -> str: - """Returns a udiff string between strings `a` and `b`.""" + """Return a unified diff string between strings `a` and `b`.""" import difflib a_lines = [line + '\n' for line in a.split('\n')] @@ -1899,5 +1994,34 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) +def cancel(tasks: List[asyncio.Task]) -> None: + """asyncio signal handler that cancels all `tasks` and reports to stderr.""" + err("Aborted!") + for task in tasks: + task.cancel() + + +def shutdown(loop: BaseEventLoop) -> None: + """Cancel all pending tasks on `loop`, wait for them, and close the loop.""" + try: + # This part is borrowed from asyncio/runners.py in Python 3.7b2. + to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()] + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + loop.run_until_complete( + asyncio.gather(*to_cancel, loop=loop, return_exceptions=True) + ) + finally: + # `concurrent.futures.Future` objects cannot be cancelled once they + # are already running. There might be some when the `shutdown()` happened. + # Silence their logger's spew about the event loop being closed. + cf_logger = logging.getLogger("concurrent.futures") + cf_logger.setLevel(logging.CRITICAL) + loop.close() + + if __name__ == '__main__': main()