X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/2fa31ff31469e587ea23cb86308495c4673b5ddd..c55d08d0b96c8de8bd867ca315e380d9e9d2d7ec:/black.py diff --git a/black.py b/black.py index b2b2db7..6499b22 100644 --- a/black.py +++ b/black.py @@ -1,15 +1,27 @@ #!/usr/bin/env python3 + import asyncio from asyncio.base_events import BaseEventLoop from concurrent.futures import Executor, ProcessPoolExecutor -from functools import partial +from functools import partial, wraps import keyword import os from pathlib import Path import tokenize import sys from typing import ( - Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, ) from attr import dataclass, Factory @@ -21,7 +33,7 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.3a2" +__version__ = "18.3a4" DEFAULT_LINE_LENGTH = 88 # types syms = pygram.python_symbols @@ -31,21 +43,52 @@ Depth = int NodeType = int LeafID = int Priority = int +Index = int LN = Union[Leaf, Node] +SplitFunc = Callable[['Line', bool], Iterator['Line']] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg='red', err=True) class NothingChanged(UserWarning): - """Raised by `format_file` when the reformatted code is the same as source.""" + """Raised by :func:`format_file` when reformatted code is the same as source.""" class CannotSplit(Exception): """A readable split that fits the allotted line length is impossible. - Raised by `left_hand_split()` and `right_hand_split()`. + Raised by :func:`left_hand_split`, :func:`right_hand_split`, and + :func:`delimiter_split`. + """ + + +class FormatError(Exception): + """Base exception for `# fmt: on` and `# fmt: off` handling. + + It holds the number of bytes of the prefix consumed before the format + control comment appeared. """ + def __init__(self, consumed: int) -> None: + super().__init__(consumed) + self.consumed = consumed + + def trim_prefix(self, leaf: Leaf) -> None: + leaf.prefix = leaf.prefix[self.consumed:] + + def leaf_from_consumed(self, leaf: Leaf) -> Leaf: + """Returns a new Leaf from the consumed part of the prefix.""" + unformatted_prefix = leaf.prefix[:self.consumed] + return Leaf(token.NEWLINE, unformatted_prefix) + + +class FormatOn(FormatError): + """Found a comment like `# fmt: on` in the file.""" + + +class FormatOff(FormatError): + """Found a comment like `# fmt: off` in the file.""" + @click.command() @click.option( @@ -61,7 +104,7 @@ class CannotSplit(Exception): is_flag=True, help=( "Don't write back the files, just return the status. Return code 0 " - "means nothing changed. Return code 1 means some files were " + "means nothing would change. Return code 1 means some files would be " "reformatted. Return code 123 means there was an internal error." ), ) @@ -74,7 +117,9 @@ class CannotSplit(Exception): @click.argument( 'src', nargs=-1, - type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True), + type=click.Path( + exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True + ), ) @click.pass_context def main( @@ -89,17 +134,24 @@ def main( elif p.is_file(): # if a file was explicitly given, we don't care about its extension sources.append(p) + elif s == '-': + sources.append(Path('-')) else: err(f'invalid path: {s}') if len(sources) == 0: ctx.exit(0) elif len(sources) == 1: p = sources[0] - report = Report() + report = Report(check=check) try: - changed = format_file_in_place( - p, line_length=line_length, fast=fast, write_back=not check - ) + if not p.is_file() and str(p) == '-': + changed = format_stdin_to_stdout( + line_length=line_length, fast=fast, write_back=not check + ) + else: + changed = format_file_in_place( + p, line_length=line_length, fast=fast, write_back=not check + ) report.done(p, changed) except Exception as exc: report.failed(p, str(exc)) @@ -127,6 +179,13 @@ async def schedule_formatting( loop: BaseEventLoop, executor: Executor, ) -> int: + """Run formatting of `sources` in parallel using the provided `executor`. + + (Use ProcessPoolExecutors for actual parallelism.) + + `line_length`, `write_back`, and `fast` options are passed to + :func:`format_file_in_place`. + """ tasks = { src: loop.run_in_executor( executor, format_file_in_place, src, line_length, fast, write_back @@ -135,7 +194,7 @@ async def schedule_formatting( } await asyncio.wait(tasks.values()) cancelled = [] - report = Report() + report = Report(check=not write_back) for src, task in tasks.items(): if not task.done(): report.failed(src, 'timed out, cancelling') @@ -155,42 +214,76 @@ async def schedule_formatting( def format_file_in_place( src: Path, line_length: int, fast: bool, write_back: bool = False ) -> bool: - """Format the file and rewrite if changed. Return True if changed.""" + """Format file under `src` path. Return True if changed. + + If `write_back` is True, write reformatted code back to stdout. + `line_length` and `fast` options are passed to :func:`format_file_contents`. + """ + with tokenize.open(src) as src_buffer: + src_contents = src_buffer.read() try: - contents, encoding = format_file(src, line_length=line_length, fast=fast) + contents = format_file_contents( + src_contents, line_length=line_length, fast=fast + ) except NothingChanged: return False if write_back: - with open(src, "w", encoding=encoding) as f: + with open(src, "w", encoding=src_buffer.encoding) as f: f.write(contents) return True -def format_file( - src: Path, line_length: int, fast: bool -) -> Tuple[FileContent, Encoding]: - """Reformats a file and returns its contents and encoding.""" - with tokenize.open(src) as src_buffer: - src_contents = src_buffer.read() +def format_stdin_to_stdout( + line_length: int, fast: bool, write_back: bool = False +) -> bool: + """Format file on stdin. Return True if changed. + + If `write_back` is True, write reformatted code back to stdout. + `line_length` and `fast` arguments are passed to :func:`format_file_contents`. + """ + contents = sys.stdin.read() + try: + contents = format_file_contents(contents, line_length=line_length, fast=fast) + return True + + except NothingChanged: + return False + + finally: + if write_back: + sys.stdout.write(contents) + + +def format_file_contents( + src_contents: str, line_length: int, fast: bool +) -> FileContent: + """Reformat contents a file and return new contents. + + If `fast` is False, additionally confirm that the reformatted code is + valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. + `line_length` is passed to :func:`format_str`. + """ if src_contents.strip() == '': - raise NothingChanged(src) + raise NothingChanged dst_contents = format_str(src_contents, line_length=line_length) if src_contents == dst_contents: - raise NothingChanged(src) + raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) assert_stable(src_contents, dst_contents, line_length=line_length) - return dst_contents, src_buffer.encoding + return dst_contents def format_str(src_contents: str, line_length: int) -> FileContent: - """Reformats a string and returns new contents.""" + """Reformat a string and return new contents. + + `line_length` determines how many characters per line are allowed. + """ src_node = lib2to3_parse(src_contents) dst_contents = "" - comments: List[Line] = [] lines = LineGenerator() elt = EmptyLineTracker() py36 = is_python36(src_node) @@ -202,41 +295,41 @@ def format_str(src_contents: str, line_length: int) -> FileContent: before, after = elt.maybe_empty_lines(current_line) for _ in range(before): dst_contents += str(empty_line) - if not current_line.is_comment: - for comment in comments: - dst_contents += str(comment) - comments = [] - for line in split_line(current_line, line_length=line_length, py36=py36): - dst_contents += str(line) - else: - comments.append(current_line) - if comments: - if elt.previous_defs: - # Separate postscriptum comments from the last module-level def. - dst_contents += str(empty_line) - dst_contents += str(empty_line) - for comment in comments: - dst_contents += str(comment) + for line in split_line(current_line, line_length=line_length, py36=py36): + dst_contents += str(line) return dst_contents +GRAMMARS = [ + pygram.python_grammar_no_print_statement_no_exec_statement, + pygram.python_grammar_no_print_statement, + pygram.python_grammar_no_exec_statement, + pygram.python_grammar, +] + + def lib2to3_parse(src_txt: str) -> Node: """Given a string with source, return the lib2to3 Node.""" grammar = pygram.python_grammar_no_print_statement - drv = driver.Driver(grammar, pytree.convert) if src_txt[-1] != '\n': nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n' src_txt += nl - try: - result = drv.parse_string(src_txt, True) - except ParseError as pe: - lineno, column = pe.context[1] - lines = src_txt.splitlines() + for grammar in GRAMMARS: + drv = driver.Driver(grammar, pytree.convert) try: - faulty_line = lines[lineno - 1] - except IndexError: - faulty_line = "" - raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None + result = drv.parse_string(src_txt, True) + break + + except ParseError as pe: + lineno, column = pe.context[1] + lines = src_txt.splitlines() + try: + faulty_line = lines[lineno - 1] + except IndexError: + faulty_line = "" + exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") + else: + raise exc from None if isinstance(result, Leaf): result = Node(syms.file_input, [result]) @@ -253,9 +346,18 @@ T = TypeVar('T') class Visitor(Generic[T]): - """Basic lib2to3 visitor that yields things on visiting.""" + """Basic lib2to3 visitor that yields things of type `T` on `visit()`.""" def visit(self, node: LN) -> Iterator[T]: + """Main method to visit `node` and its children. + + It tries to find a `visit_*()` method for the given `node.type`, like + `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects. + If no dedicated `visit_*()` method is found, chooses `visit_default()` + instead. + + Then yields objects of type `T` from the selected visitor. + """ if node.type < 256: name = token.tok_name[node.type] else: @@ -263,6 +365,7 @@ class Visitor(Generic[T]): yield from getattr(self, f'visit_{name}', self.visit_default)(node) def visit_default(self, node: LN) -> Iterator[T]: + """Default `visit_*()` implementation. Recurses to children of `node`.""" if isinstance(node, Node): for child in node.children: yield from self.visit(child) @@ -292,6 +395,15 @@ class DebugVisitor(Visitor[T]): out(f' {node.prefix!r}', fg='green', bold=False, nl=False) out(f' {node.value!r}', fg='blue', bold=False) + @classmethod + def show(cls, code: str) -> None: + """Pretty-print the lib2to3 AST of a given string of `code`. + + Convenience method for debugging. + """ + v: DebugVisitor[None] = DebugVisitor() + list(v.visit(lib2to3_parse(code))) + KEYWORDS = set(keyword.kwlist) WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} @@ -325,6 +437,7 @@ MATH_OPERATORS = { token.AMPER, token.PERCENT, token.CIRCUMFLEX, + token.TILDE, token.LEFTSHIFT, token.RIGHTSHIFT, token.DOUBLESTAR, @@ -340,12 +453,28 @@ MATH_PRIORITY = 1 @dataclass class BracketTracker: + """Keeps track of brackets on a line.""" + depth: int = 0 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) delimiters: Dict[LeafID, Priority] = Factory(dict) previous: Optional[Leaf] = None def mark(self, leaf: Leaf) -> None: + """Mark `leaf` with bracket-related metadata. Keep track of delimiters. + + All leaves receive an int `bracket_depth` field that stores how deep + within brackets a given leaf is. 0 means there are no enclosing brackets + that started on this line. + + If a leaf is itself a closing bracket, it receives an `opening_bracket` + field that it forms a pair with. This is a one-directional link to + avoid reference cycles. + + If a leaf is a delimiter (a token on which Black can split the line if + needed) and it's on depth 0, its `id()` is stored in the tracker's + `delimiters` field. + """ if leaf.type == token.COMMENT: return @@ -387,11 +516,11 @@ class BracketTracker: self.previous = leaf def any_open_brackets(self) -> bool: - """Returns True if there is an yet unmatched open bracket on the line.""" + """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) - def max_priority(self, exclude: Iterable[LeafID] =()) -> int: - """Returns the highest priority of a delimiter found on the line. + def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int: + """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_delimiter()` returns. """ @@ -400,15 +529,26 @@ class BracketTracker: @dataclass class Line: + """Holds leaves and comments. Can be printed with `str(line)`.""" + depth: int = 0 leaves: List[Leaf] = Factory(list) - comments: Dict[LeafID, Leaf] = Factory(dict) + comments: List[Tuple[Index, Leaf]] = Factory(list) bracket_tracker: BracketTracker = Factory(BracketTracker) inside_brackets: bool = False has_for: bool = False _for_loop_variable: bool = False def append(self, leaf: Leaf, preformatted: bool = False) -> None: + """Add a new `leaf` to the end of the line. + + Unless `preformatted` is True, the `leaf` will receive a new consistent + whitespace prefix and metadata applied by :class:`BracketTracker`. + Trailing commas are maybe removed, unpacked for loop variables are + demoted from being delimiters. + + Inline comments are put aside. + """ has_value = leaf.value.strip() if not has_value: return @@ -422,26 +562,45 @@ class Line: self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) self.maybe_increment_for_loop_variable(leaf) - if self.maybe_adapt_standalone_comment(leaf): - return if not self.append_comment(leaf): self.leaves.append(leaf) + def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None: + """Like :func:`append()` but disallow invalid standalone comment structure. + + Raises ValueError when any `leaf` is appended after a standalone comment + or when a standalone comment is not the first leaf on the line. + """ + if self.bracket_tracker.depth == 0: + if self.is_comment: + raise ValueError("cannot append to standalone comments") + + if self.leaves and leaf.type == STANDALONE_COMMENT: + raise ValueError( + "cannot append standalone comments to a populated line" + ) + + self.append(leaf, preformatted=preformatted) + @property def is_comment(self) -> bool: - return bool(self) and self.leaves[0].type == STANDALONE_COMMENT + """Is this line a standalone comment?""" + return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT @property def is_decorator(self) -> bool: + """Is this line a decorator?""" return bool(self) and self.leaves[0].type == token.AT @property def is_import(self) -> bool: + """Is this an import line?""" return bool(self) and is_import(self.leaves[0]) @property def is_class(self) -> bool: + """Is this line a class definition?""" return ( bool(self) and self.leaves[0].type == token.NAME @@ -450,7 +609,7 @@ class Line: @property def is_def(self) -> bool: - """Also returns True for async defs.""" + """Is this a function definition? (Also returns True for async defs.)""" try: first_leaf = self.leaves[0] except IndexError: @@ -463,8 +622,7 @@ class Line: return ( (first_leaf.type == token.NAME and first_leaf.value == 'def') or ( - first_leaf.type == token.NAME - and first_leaf.value == 'async' + first_leaf.type == token.ASYNC and second_leaf is not None and second_leaf.type == token.NAME and second_leaf.value == 'def' @@ -473,6 +631,10 @@ class Line: @property def is_flow_control(self) -> bool: + """Is this line a flow control statement? + + Those are `return`, `raise`, `break`, and `continue`. + """ return ( bool(self) and self.leaves[0].type == token.NAME @@ -481,13 +643,24 @@ class Line: @property def is_yield(self) -> bool: + """Is this line a yield statement?""" return ( bool(self) and self.leaves[0].type == token.NAME and self.leaves[0].value == 'yield' ) + @property + def contains_standalone_comments(self) -> bool: + """If so, needs to be split before emitting.""" + for leaf in self.leaves: + if leaf.type == STANDALONE_COMMENT: + return True + + return False + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: + """Remove trailing comma if there is one and it's safe.""" if not ( self.leaves and self.leaves[-1].type == token.COMMA @@ -495,10 +668,16 @@ class Line: ): return False - if closing.type == token.RSQB or closing.type == token.RBRACE: - self.leaves.pop() + if closing.type == token.RBRACE: + self.remove_trailing_comma() return True + if closing.type == token.RSQB: + comma = self.leaves[-1] + if comma.parent and comma.parent.type == syms.listmaker: + self.remove_trailing_comma() + return True + # For parens let's check if it's safe to remove the comma. If the # trailing one is the only one, we might mistakenly change a tuple # into a different type by removing the comma. @@ -524,7 +703,7 @@ class Line: break if commas > 1: - self.leaves.pop() + self.remove_trailing_comma() return True return False @@ -532,8 +711,8 @@ class Line: def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: """In a for loop, or comprehension, the variables are often unpacks. - To avoid splitting on the comma in this situation, we will increase - the depth of tokens between `for` and `in`. + To avoid splitting on the comma in this situation, increase the depth of + tokens between `for` and `in`. """ if leaf.type == token.NAME and leaf.value == 'for': self.has_for = True @@ -544,7 +723,7 @@ class Line: return False def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: - # See `maybe_increment_for_loop_variable` above for explanation. + """See `maybe_increment_for_loop_variable` above for explanation.""" if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in': self.bracket_tracker.depth -= 1 self._for_loop_variable = False @@ -552,52 +731,52 @@ class Line: return False - def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: - """Hack a standalone comment to act as a trailing comment for line splitting. - - If this line has brackets and a standalone `comment`, we need to adapt - it to be able to still reformat the line. - - This is not perfect, the line to which the standalone comment gets - appended will appear "too long" when splitting. - """ - if not ( + def append_comment(self, comment: Leaf) -> bool: + """Add an inline or standalone comment to the line.""" + if ( comment.type == STANDALONE_COMMENT and self.bracket_tracker.any_open_brackets() ): + comment.prefix = '' return False - comment.type = token.COMMENT - comment.prefix = '\n' + ' ' * (self.depth + 1) - return self.append_comment(comment) - - def append_comment(self, comment: Leaf) -> bool: if comment.type != token.COMMENT: return False - try: - after = id(self.last_non_delimiter()) - except LookupError: + after = len(self.leaves) - 1 + if after == -1: comment.type = STANDALONE_COMMENT comment.prefix = '' return False else: - if after in self.comments: - self.comments[after].value += str(comment) - else: - self.comments[after] = comment + self.comments.append((after, comment)) return True - def last_non_delimiter(self) -> Leaf: - for i in range(len(self.leaves)): - last = self.leaves[-i - 1] - if not is_delimiter(last): - return last + def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`.""" + for _leaf_index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return - raise LookupError("No non-delimiters found") + for index, comment_after in self.comments: + if _leaf_index == index: + yield comment_after + + def remove_trailing_comma(self) -> None: + """Remove the trailing comma and moves the comments attached to it.""" + comma_index = len(self.leaves) - 1 + for i in range(len(self.comments)): + comment_index, comment = self.comments[i] + if comment_index == comma_index: + self.comments[i] = (comma_index - 1, comment) + self.leaves.pop() def __str__(self) -> str: + """Render the line.""" if not self: return '\n' @@ -607,14 +786,64 @@ class Line: res = f'{first.prefix}{indent}{first.value}' for leaf in leaves: res += str(leaf) - for comment in self.comments.values(): + for _, comment in self.comments: res += str(comment) return res + '\n' def __bool__(self) -> bool: + """Return True if the line has leaves or comments.""" return bool(self.leaves or self.comments) +class UnformattedLines(Line): + """Just like :class:`Line` but stores lines which aren't reformatted.""" + + def append(self, leaf: Leaf, preformatted: bool = True) -> None: + """Just add a new `leaf` to the end of the lines. + + The `preformatted` argument is ignored. + + Keeps track of indentation `depth`, which is useful when the user + says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`. + """ + try: + list(generate_comments(leaf)) + except FormatOn as f_on: + self.leaves.append(f_on.leaf_from_consumed(leaf)) + raise + + self.leaves.append(leaf) + if leaf.type == token.INDENT: + self.depth += 1 + elif leaf.type == token.DEDENT: + self.depth -= 1 + + def __str__(self) -> str: + """Render unformatted lines from leaves which were added with `append()`. + + `depth` is not used for indentation in this case. + """ + if not self: + return '\n' + + res = '' + for leaf in self.leaves: + res += str(leaf) + return res + + def append_comment(self, comment: Leaf) -> bool: + """Not implemented in this class. Raises `NotImplementedError`.""" + raise NotImplementedError("Unformatted lines don't store comments separately.") + + def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: + """Does nothing and returns False.""" + return False + + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """Does nothing and returns False.""" + return False + + @dataclass class EmptyLineTracker: """Provides a stateful method that returns the number of potential extra @@ -629,14 +858,13 @@ class EmptyLineTracker: previous_defs: List[int] = Factory(list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: - """Returns the number of extra empty lines before and after the `current_line`. + """Return the number of extra empty lines before and after the `current_line`. - This is for separating `def`, `async def` and `class` with extra empty lines - (two on module-level), as well as providing an extra empty line after flow - control keywords to make them more prominent. + This is for separating `def`, `async def` and `class` with extra empty + lines (two on module-level), as well as providing an extra empty line + after flow control keywords to make them more prominent. """ - if current_line.is_comment: - # Don't count standalone comments towards previous empty lines. + if isinstance(current_line, UnformattedLines): return 0, 0 before, after = self._maybe_empty_lines(current_line) @@ -646,10 +874,14 @@ class EmptyLineTracker: return before, after def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: + max_allowed = 1 + if current_line.depth == 0: + max_allowed = 2 if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] - before = int('\n' in first_leaf.prefix) + before = first_leaf.prefix.count('\n') + before = min(before, max_allowed) first_leaf.prefix = '' else: before = 0 @@ -703,9 +935,8 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ current_line: Line = Factory(Line) - standalone_comments: List[Leaf] = Factory(list) - def line(self, indent: int = 0) -> Iterator[Line]: + def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]: """Generate a line. If the line is empty, only emit if it makes sense. @@ -714,59 +945,85 @@ class LineGenerator(Visitor[Line]): If any lines were generated, set up a new current_line. """ if not self.current_line: - self.current_line.depth += indent + if self.current_line.__class__ == type: + self.current_line.depth += indent + else: + self.current_line = type(depth=self.current_line.depth + indent) return # Line is empty, don't emit. Creating a new one unnecessary. complete_line = self.current_line - self.current_line = Line(depth=complete_line.depth + indent) + self.current_line = type(depth=complete_line.depth + indent) yield complete_line + def visit(self, node: LN) -> Iterator[Line]: + """Main method to visit `node` and its children. + + Yields :class:`Line` objects. + """ + if isinstance(self.current_line, UnformattedLines): + # File contained `# fmt: off` + yield from self.visit_unformatted(node) + + else: + yield from super().visit(node) + def visit_default(self, node: LN) -> Iterator[Line]: + """Default `visit_*()` implementation. Recurses to children of `node`.""" if isinstance(node, Leaf): - for comment in generate_comments(node): - if self.current_line.bracket_tracker.any_open_brackets(): - # any comment within brackets is subject to splitting - self.current_line.append(comment) - elif comment.type == token.COMMENT: - # regular trailing comment - self.current_line.append(comment) - yield from self.line() - - else: - # regular standalone comment, to be processed later (see - # docstring in `generate_comments()` - self.standalone_comments.append(comment) - normalize_prefix(node) - if node.type not in WHITESPACE: - for comment in self.standalone_comments: - yield from self.line() - - self.current_line.append(comment) - yield from self.line() - - self.standalone_comments = [] - self.current_line.append(node) + any_open_brackets = self.current_line.bracket_tracker.any_open_brackets() + try: + for comment in generate_comments(node): + if any_open_brackets: + # any comment within brackets is subject to splitting + self.current_line.append(comment) + elif comment.type == token.COMMENT: + # regular trailing comment + self.current_line.append(comment) + yield from self.line() + + else: + # regular standalone comment + yield from self.line() + + self.current_line.append(comment) + yield from self.line() + + except FormatOff as f_off: + f_off.trim_prefix(node) + yield from self.line(type=UnformattedLines) + yield from self.visit(node) + + except FormatOn as f_on: + # This only happens here if somebody says "fmt: on" multiple + # times in a row. + f_on.trim_prefix(node) + yield from self.visit_default(node) + + else: + normalize_prefix(node, inside_brackets=any_open_brackets) + if node.type not in WHITESPACE: + self.current_line.append(node) yield from super().visit_default(node) - def visit_suite(self, node: Node) -> Iterator[Line]: - """Body of a statement after a colon.""" - children = iter(node.children) - # Process newline before indenting. It might contain an inline - # comment that should go right after the colon. - newline = next(children) - yield from self.visit(newline) + def visit_INDENT(self, node: Node) -> Iterator[Line]: + """Increase indentation level, maybe yield a line.""" + # In blib2to3 INDENT never holds comments. yield from self.line(+1) + yield from self.visit_default(node) - for child in children: - yield from self.visit(child) - + def visit_DEDENT(self, node: Node) -> Iterator[Line]: + """Decrease indentation level, maybe yield a line.""" + # DEDENT has no value. Additionally, in blib2to3 it never holds comments. yield from self.line(-1) def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]: """Visit a statement. - The relevant Python language keywords for this statement are NAME leaves - within it. + This implementation is shared for `if`, `while`, `for`, `try`, `except`, + `def`, `with`, and `class`. + + The relevant Python language `keywords` for a given statement will be NAME + leaves within it. This methods puts those on a separate line. """ for child in node.children: if child.type == token.NAME and child.value in keywords: # type: ignore @@ -775,7 +1032,7 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) def visit_simple_stmt(self, node: Node) -> Iterator[Line]: - """A statement without nested statements.""" + """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: yield from self.line(+1) @@ -787,13 +1044,14 @@ class LineGenerator(Visitor[Line]): yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: + """Visit `async def`, `async for`, `async with`.""" yield from self.line() children = iter(node.children) for child in children: yield from self.visit(child) - if child.type == token.NAME and child.value == 'async': # type: ignore + if child.type == token.ASYNC: break internal_stmt = next(children) @@ -801,17 +1059,34 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) def visit_decorators(self, node: Node) -> Iterator[Line]: + """Visit decorators.""" for child in node.children: yield from self.line() yield from self.visit(child) def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]: + """Remove a semicolon and put the other statement on a separate line.""" yield from self.line() def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]: + """End of file. Process outstanding comments and end with a newline.""" yield from self.visit_default(leaf) yield from self.line() + def visit_unformatted(self, node: LN) -> Iterator[Line]: + """Used when file contained a `# fmt: off`.""" + if isinstance(node, Node): + for child in node.children: + yield from self.visit(child) + + else: + try: + self.current_line.append(node) + except FormatOn as f_on: + f_on.trim_prefix(node) + yield from self.line() + yield from self.visit(node) + def __attrs_post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt @@ -849,7 +1124,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if t == token.COLON and p.type != syms.subscript: + if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: return NO prev = leaf.prev_sibling @@ -862,30 +1137,49 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return SPACE if prevp.type == token.COMMA else NO if prevp.type == token.EQUAL: - if prevp.parent and prevp.parent.type in { - syms.typedargslist, - syms.varargslist, - syms.parameters, - syms.arglist, - syms.argument, - }: - return NO + if prevp.parent: + if prevp.parent.type in { + syms.arglist, syms.argument, syms.parameters, syms.varargslist + }: + return NO + + elif prevp.parent.type == syms.typedargslist: + # A bit hacky: if the equal sign has whitespace, it means we + # previously found it's a typed argument. So, we're using + # that, too. + return prevp.prefix elif prevp.type == token.DOUBLESTAR: if prevp.parent and prevp.parent.type in { - syms.typedargslist, - syms.varargslist, - syms.parameters, syms.arglist, + syms.argument, syms.dictsetmaker, + syms.parameters, + syms.typedargslist, + syms.varargslist, }: return NO elif prevp.type == token.COLON: - if prevp.parent and prevp.parent.type == syms.subscript: + if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: return NO - elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}: + elif ( + prevp.parent + and prevp.parent.type in {syms.factor, syms.star_expr} + and prevp.type in MATH_OPERATORS + ): + return NO + + elif ( + prevp.type == token.RIGHTSHIFT + and prevp.parent + and prevp.parent.type == syms.shift_expr + and prevp.prev_sibling + and prevp.prev_sibling.type == token.NAME + and prevp.prev_sibling.value == 'print' # type: ignore + ): + # Python 2 print chevron return NO elif prev.type in OPENING_BRACKETS: @@ -899,7 +1193,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if not prev or prev.type != token.COMMA: return NO - if p.type == syms.varargslist: + elif p.type == syms.varargslist: # lambdas if t == token.RPAR: return NO @@ -1053,7 +1347,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: - """Returns the first leaf that precedes `node`, if any.""" + """Return the first leaf that precedes `node`, if any.""" while node: res = node.prev_sibling if res: @@ -1071,7 +1365,7 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: def is_delimiter(leaf: Leaf) -> int: - """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter. + """Return the priority of the `leaf` delimiter. Return 0 if not delimiter. Higher numbers are higher priority. """ @@ -1092,7 +1386,7 @@ def is_delimiter(leaf: Leaf) -> int: def generate_comments(leaf: Leaf) -> Iterator[Leaf]: - """Cleans the prefix of the `leaf` and generates comments from it, if any. + """Clean the prefix of the `leaf` and generate comments from it, if any. Comments in lib2to3 are shoved into the whitespace prefix. This happens in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation @@ -1110,36 +1404,62 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: Inline comments are emitted as regular token.COMMENT leaves. Standalone are emitted with a fake STANDALONE_COMMENT token identifier. """ - if not leaf.prefix: - return - - if '#' not in leaf.prefix: + p = leaf.prefix + if not p: return - before_comment, content = leaf.prefix.split('#', 1) - content = content.rstrip() - if content and (content[0] not in {' ', '!', '#'}): - content = ' ' + content - is_standalone_comment = ( - '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER - ) - if not is_standalone_comment: - # simple trailing comment - yield Leaf(token.COMMENT, value='#' + content) + if '#' not in p: return - for line in ('#' + content).split('\n'): + consumed = 0 + nlines = 0 + for index, line in enumerate(p.split('\n')): + consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() + if not line: + nlines += 1 if not line.startswith('#'): continue - yield Leaf(STANDALONE_COMMENT, line) + if index == 0 and leaf.type != token.ENDMARKER: + comment_type = token.COMMENT # simple trailing comment + else: + comment_type = STANDALONE_COMMENT + comment = make_comment(line) + yield Leaf(comment_type, comment, prefix='\n' * nlines) + + if comment in {'# fmt: on', '# yapf: enable'}: + raise FormatOn(consumed) + + if comment in {'# fmt: off', '# yapf: disable'}: + raise FormatOff(consumed) + + nlines = 0 + + +def make_comment(content: str) -> str: + """Return a consistently formatted comment from the given `content` string. + + All comments (except for "##", "#!", "#:") should have a single space between + the hash sign and the content. + + If `content` didn't start with a hash sign, one is provided. + """ + content = content.rstrip() + if not content: + return '#' + + if content[0] == '#': + content = content[1:] + if content and content[0] not in ' !:#': + content = ' ' + content + return '#' + content def split_line( line: Line, line_length: int, inner: bool = False, py36: bool = False ) -> Iterator[Line]: - """Splits a `line` into potentially many lines. + """Split a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the @@ -1149,19 +1469,24 @@ def split_line( If `py36` is True, splitting may generate syntax that is only compatible with Python 3.6 and later. """ + if isinstance(line, UnformattedLines) or line.is_comment: + yield line + return + line_str = str(line).strip('\n') - if len(line_str) <= line_length and '\n' not in line_str: + if ( + len(line_str) <= line_length + and '\n' not in line_str # multiline strings + and not line.contains_standalone_comments + ): yield line return + split_funcs: List[SplitFunc] if line.is_def: split_funcs = [left_hand_split] elif line.inside_brackets: - split_funcs = [delimiter_split] - if '\n' not in line_str: - # Only attempt RHS if we don't have multiline strings or comments - # on this line. - split_funcs.append(right_hand_split) + split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: split_funcs = [right_hand_split] for split_func in split_funcs: @@ -1170,7 +1495,7 @@ def split_line( # split altogether. result: List[Line] = [] try: - for l in split_func(line, py36=py36): + for l in split_func(line, py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") @@ -1216,17 +1541,16 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: current_leaves = body_leaves # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: - normalize_prefix(body_leaves[0]) + normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. for result, leaves in ( (head, head_leaves), (body, body_leaves), (tail, tail_leaves) ): for leaf in leaves: result.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: + for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) - split_succeeded_or_raise(head, body, tail) + bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): if result: yield result @@ -1256,23 +1580,35 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: head_leaves.reverse() # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: - normalize_prefix(body_leaves[0]) + normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. for result, leaves in ( (head, head_leaves), (body, body_leaves), (tail, tail_leaves) ): for leaf in leaves: result.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: + for comment_after in line.comments_after(leaf): result.append(comment_after, preformatted=True) - split_succeeded_or_raise(head, body, tail) + bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): if result: yield result -def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None: +def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None: + """Raise :exc:`CannotSplit` if the last left- or right-hand split failed. + + Do nothing otherwise. + + A left- or right-hand split is based on a pair of brackets. Content before + (and including) the opening bracket is left on one line, content inside the + brackets is put on a separate line, and finally content starting with and + following the closing bracket is put on a separate line. + + Those are called `head`, `body`, and `tail`, respectively. If the split + produced the same line (all content in `head`) or ended up with an empty `body` + and the `tail` is just the closing bracket, then it's considered failed. + """ tail_len = len(str(tail).strip()) if not body: if tail_len == 0: @@ -1285,12 +1621,27 @@ def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None: ) +def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc: + """Normalize prefix of the first leaf in every line returned by `split_func`. + + This is a decorator over relevant split functions. + """ + + @wraps(split_func) + def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]: + for l in split_func(line, py36): + normalize_prefix(l.leaves[0], inside_brackets=True) + yield l + + return split_wrapper + + +@dont_increase_indentation def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split according to delimiters of the highest priority. - This kind of split doesn't increase indentation. If `py36` is True, the split will add trailing commas also in function - signatures that contain * and **. + signatures that contain `*` and `**`. """ try: last_leaf = line.leaves[-1] @@ -1299,18 +1650,33 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: delimiters = line.bracket_tracker.delimiters try: - delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)}) + delimiter_priority = line.bracket_tracker.max_delimiter_priority( + exclude={id(last_leaf)} + ) except ValueError: raise CannotSplit("No delimiters found") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) lowest_depth = sys.maxsize trailing_comma_safe = True + + def append_to_line(leaf: Leaf) -> Iterator[Line]: + """Append `leaf` to current line or to new line if appending impossible.""" + nonlocal current_line + try: + current_line.append_safe(leaf, preformatted=True) + except ValueError as ve: + yield current_line + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line.append(leaf) + for leaf in line.leaves: - current_line.append(leaf, preformatted=True) - comment_after = line.comments.get(id(leaf)) - if comment_after: - current_line.append(comment_after, preformatted=True) + yield from append_to_line(leaf) + + for comment_after in line.comments_after(leaf): + yield from append_to_line(comment_after) + lowest_depth = min(lowest_depth, leaf.bracket_depth) if ( leaf.bracket_depth == lowest_depth @@ -1320,7 +1686,6 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: - normalize_prefix(current_line.leaves[0]) yield current_line current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) @@ -1331,12 +1696,45 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: and trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) - normalize_prefix(current_line.leaves[0]) + yield current_line + + +@dont_increase_indentation +def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: + """Split standalone comments from the rest of the line.""" + for leaf in line.leaves: + if leaf.type == STANDALONE_COMMENT: + if leaf.bracket_depth == 0: + break + + else: + raise CannotSplit("Line does not have any standalone comments") + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + + def append_to_line(leaf: Leaf) -> Iterator[Line]: + """Append `leaf` to current line or to new line if appending impossible.""" + nonlocal current_line + try: + current_line.append_safe(leaf, preformatted=True) + except ValueError as ve: + yield current_line + + current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line.append(leaf) + + for leaf in line.leaves: + yield from append_to_line(leaf) + + for comment_after in line.comments_after(leaf): + yield from append_to_line(comment_after) + + if current_line: yield current_line def is_import(leaf: Leaf) -> bool: - """Returns True if the given leaf starts an import statement.""" + """Return True if the given leaf starts an import statement.""" p = leaf.parent t = leaf.type v = leaf.value @@ -1349,19 +1747,26 @@ def is_import(leaf: Leaf) -> bool: ) -def normalize_prefix(leaf: Leaf) -> None: - """Leave existing extra newlines for imports. Remove everything else.""" - if is_import(leaf): - spl = leaf.prefix.split('#', 1) - nl_count = spl[0].count('\n') - leaf.prefix = '\n' * nl_count - return +def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: + """Leave existing extra newlines if not `inside_brackets`. Remove everything + else. + + Note: don't use backslashes for formatting or you'll lose your voting rights. + """ + if not inside_brackets: + spl = leaf.prefix.split('#') + if '\\' not in spl[0]: + nl_count = spl[-1].count('\n') + if len(spl) > 1: + nl_count -= 1 + leaf.prefix = '\n' * nl_count + return leaf.prefix = '' def is_python36(node: Node) -> bool: - """Returns True if the current file is using Python 3.6+ features. + """Return True if the current file is using Python 3.6+ features. Currently looking for: - f-strings; and @@ -1392,6 +1797,9 @@ BLACKLISTED_DIRECTORIES = { def gen_python_files_in_dir(path: Path) -> Iterator[Path]: + """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES + and have one of the PYTHON_EXTENSIONS. + """ for child in path.iterdir(): if child.is_dir(): if child.name in BLACKLISTED_DIRECTORIES: @@ -1405,7 +1813,8 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: @dataclass class Report: - """Provides a reformatting counter.""" + """Provides a reformatting counter. Can be rendered with `str(report)`.""" + check: bool = False change_count: int = 0 same_count: int = 0 failure_count: int = 0 @@ -1413,7 +1822,8 @@ class Report: def done(self, src: Path, changed: bool) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed: - out(f'reformatted {src}') + reformatted = 'would reformat' if self.check else 'reformatted' + out(f'{reformatted} {src}') self.change_count += 1 else: out(f'{src} already well formatted, good job.', bold=False) @@ -1426,46 +1836,55 @@ class Report: @property def return_code(self) -> int: - """Which return code should the app use considering the current state.""" + """Return the exit code that the app should use. + + This considers the current state of changed files and failures: + - if there were any failures, return 123; + - if any files were changed and --check is being used, return 1; + - otherwise return 0. + """ # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with # 126 we have special returncodes reserved by the shell. if self.failure_count: return 123 - elif self.change_count: + elif self.change_count and self.check: return 1 return 0 def __str__(self) -> str: - """A color report of the current state. + """Render a color report of the current state. Use `click.unstyle` to remove colors. """ + if self.check: + reformatted = "would be reformatted" + unchanged = "would be left unchanged" + failed = "would fail to reformat" + else: + reformatted = "reformatted" + unchanged = "left unchanged" + failed = "failed to reformat" report = [] if self.change_count: s = 's' if self.change_count > 1 else '' report.append( - click.style(f'{self.change_count} file{s} reformatted', bold=True) + click.style(f'{self.change_count} file{s} {reformatted}', bold=True) ) if self.same_count: s = 's' if self.same_count > 1 else '' - report.append(f'{self.same_count} file{s} left unchanged') + report.append(f'{self.same_count} file{s} {unchanged}') if self.failure_count: s = 's' if self.failure_count > 1 else '' report.append( - click.style( - f'{self.failure_count} file{s} failed to reformat', fg='red' - ) + click.style(f'{self.failure_count} file{s} {failed}', fg='red') ) return ', '.join(report) + '.' def assert_equivalent(src: str, dst: str) -> None: - """Raises AssertionError if `src` and `dst` aren't equivalent. - - This is a temporary sanity check until Black becomes stable. - """ + """Raise AssertionError if `src` and `dst` aren't equivalent.""" import ast import traceback @@ -1498,7 +1917,12 @@ def assert_equivalent(src: str, dst: str) -> None: try: src_ast = ast.parse(src) except Exception as exc: - raise AssertionError(f"cannot parse source: {exc}") from None + major, minor = sys.version_info[:2] + raise AssertionError( + f"cannot use --safe with this file; failed to parse source file " + f"with Python {major}.{minor}'s builtin AST. Re-run with --fast " + f"or stop using deprecated Python 2 syntax. AST error message: {exc}" + ) try: dst_ast = ast.parse(dst) @@ -1523,10 +1947,7 @@ def assert_equivalent(src: str, dst: str) -> None: def assert_stable(src: str, dst: str, line_length: int) -> None: - """Raises AssertionError if `dst` reformats differently the second time. - - This is a temporary sanity check until Black becomes stable. - """ + """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, line_length=line_length) if dst != newdst: log = dump_to_file( @@ -1542,7 +1963,7 @@ def assert_stable(src: str, dst: str, line_length: int) -> None: def dump_to_file(*output: str) -> str: - """Dumps `output` to a temporary file. Returns path to the file.""" + """Dump `output` to a temporary file. Return path to the file.""" import tempfile with tempfile.NamedTemporaryFile( @@ -1555,7 +1976,7 @@ def dump_to_file(*output: str) -> str: def diff(a: str, b: str, a_name: str, b_name: str) -> str: - """Returns a udiff string between strings `a` and `b`.""" + """Return a unified diff string between strings `a` and `b`.""" import difflib a_lines = [line + '\n' for line in a.split('\n')]