X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/e74117f172e29e8a980e2c9de929ad50d3769150..591bedc2be0cec92c5f253fd473864c876233114:/black.py?ds=sidebyside diff --git a/black.py b/black.py index 24c57ca..54044a5 100644 --- a/black.py +++ b/black.py @@ -7,11 +7,12 @@ import keyword import os from pathlib import Path import tokenize +import sys from typing import ( Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union ) -from attr import attrib, dataclass, Factory +from attr import dataclass, Factory import click # lib2to3 fork @@ -20,7 +21,7 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.3a0" +__version__ = "18.3a2" DEFAULT_LINE_LENGTH = 88 # types syms = pygram.python_symbols @@ -55,6 +56,15 @@ class CannotSplit(Exception): help='How many character per line to allow.', show_default=True, ) +@click.option( + '--check', + is_flag=True, + help=( + "Don't write back the files, just return the status. Return code 0 " + "means nothing changed. Return code 1 means some files were " + "reformatted. Return code 123 means there was an internal error." + ), +) @click.option( '--fast/--safe', is_flag=True, @@ -64,10 +74,14 @@ class CannotSplit(Exception): @click.argument( 'src', nargs=-1, - type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True), + type=click.Path( + exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True + ), ) @click.pass_context -def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None: +def main( + ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] +) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] for s in src: @@ -77,6 +91,8 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No elif p.is_file(): # if a file was explicitly given, we don't care about its extension sources.append(p) + elif s == '-': + sources.append(Path('-')) else: err(f'invalid path: {s}') if len(sources) == 0: @@ -85,7 +101,14 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No p = sources[0] report = Report() try: - changed = format_file_in_place(p, line_length=line_length, fast=fast) + if not p.is_file() and str(p) == '-': + changed = format_stdin_to_stdout( + line_length=line_length, fast=fast, write_back=not check + ) + else: + changed = format_file_in_place( + p, line_length=line_length, fast=fast, write_back=not check + ) report.done(p, changed) except Exception as exc: report.failed(p, str(exc)) @@ -96,7 +119,9 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No return_code = 1 try: return_code = loop.run_until_complete( - schedule_formatting(sources, line_length, fast, loop, executor) + schedule_formatting( + sources, line_length, not check, fast, loop, executor + ) ) finally: loop.close() @@ -106,13 +131,14 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No async def schedule_formatting( sources: List[Path], line_length: int, + write_back: bool, fast: bool, loop: BaseEventLoop, executor: Executor, ) -> int: tasks = { src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast + executor, format_file_in_place, src, line_length, fast, write_back ) for src in sources } @@ -135,35 +161,57 @@ async def schedule_formatting( return report.return_code -def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool: +def format_file_in_place( + src: Path, line_length: int, fast: bool, write_back: bool = False +) -> bool: """Format the file and rewrite if changed. Return True if changed.""" + with tokenize.open(src) as src_buffer: + src_contents = src_buffer.read() try: - contents, encoding = format_file(src, line_length=line_length, fast=fast) + contents = format_file_contents( + src_contents, line_length=line_length, fast=fast + ) except NothingChanged: return False - with open(src, "w", encoding=encoding) as f: - f.write(contents) + if write_back: + with open(src, "w", encoding=src_buffer.encoding) as f: + f.write(contents) return True -def format_file( - src: Path, line_length: int, fast: bool -) -> Tuple[FileContent, Encoding]: +def format_stdin_to_stdout( + line_length: int, fast: bool, write_back: bool = False +) -> bool: + """Format file on stdin and pipe output to stdout. Return True if changed.""" + contents = sys.stdin.read() + try: + contents = format_file_contents(contents, line_length=line_length, fast=fast) + return True + + except NothingChanged: + return False + + finally: + if write_back: + sys.stdout.write(contents) + + +def format_file_contents( + src_contents: str, line_length: int, fast: bool +) -> FileContent: """Reformats a file and returns its contents and encoding.""" - with tokenize.open(src) as src_buffer: - src_contents = src_buffer.read() if src_contents.strip() == '': - raise NothingChanged(src) + raise NothingChanged dst_contents = format_str(src_contents, line_length=line_length) if src_contents == dst_contents: - raise NothingChanged(src) + raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) assert_stable(src_contents, dst_contents, line_length=line_length) - return dst_contents, src_buffer.encoding + return dst_contents def format_str(src_contents: str, line_length: int) -> FileContent: @@ -173,6 +221,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: comments: List[Line] = [] lines = LineGenerator() elt = EmptyLineTracker() + py36 = is_python36(src_node) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -185,12 +234,17 @@ def format_str(src_contents: str, line_length: int) -> FileContent: for comment in comments: dst_contents += str(comment) comments = [] - for line in split_line(current_line, line_length=line_length): + for line in split_line(current_line, line_length=line_length, py36=py36): dst_contents += str(line) else: comments.append(current_line) - for comment in comments: - dst_contents += str(comment) + if comments: + if elt.previous_defs: + # Separate postscriptum comments from the last module-level def. + dst_contents += str(empty_line) + dst_contents += str(empty_line) + for comment in comments: + dst_contents += str(comment) return dst_contents @@ -244,7 +298,7 @@ class Visitor(Generic[T]): @dataclass class DebugVisitor(Visitor[T]): - tree_depth: int = attrib(default=0) + tree_depth: int = 0 def visit_default(self, node: LN) -> Iterator[T]: indent = ' ' * (2 * self.tree_depth) @@ -314,10 +368,10 @@ MATH_PRIORITY = 1 @dataclass class BracketTracker: - depth: int = attrib(default=0) - bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = attrib(default=Factory(dict)) - delimiters: Dict[LeafID, Priority] = attrib(default=Factory(dict)) - previous: Optional[Leaf] = attrib(default=None) + depth: int = 0 + bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) + delimiters: Dict[LeafID, Priority] = Factory(dict) + previous: Optional[Leaf] = None def mark(self, leaf: Leaf) -> None: if leaf.type == token.COMMENT: @@ -326,8 +380,8 @@ class BracketTracker: if leaf.type in CLOSING_BRACKETS: self.depth -= 1 opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) - leaf.opening_bracket = opening_bracket # type: ignore - leaf.bracket_depth = self.depth # type: ignore + leaf.opening_bracket = opening_bracket + leaf.bracket_depth = self.depth if self.depth == 0: delim = is_delimiter(leaf) if delim: @@ -336,19 +390,25 @@ class BracketTracker: if leaf.type == token.STRING and self.previous.type == token.STRING: self.delimiters[id(self.previous)] = STRING_PRIORITY elif ( - leaf.type == token.NAME and - leaf.value == 'for' and - leaf.parent and - leaf.parent.type in {syms.comp_for, syms.old_comp_for} + leaf.type == token.NAME + and leaf.value == 'for' + and leaf.parent + and leaf.parent.type in {syms.comp_for, syms.old_comp_for} ): self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY elif ( - leaf.type == token.NAME and - leaf.value == 'if' and - leaf.parent and - leaf.parent.type in {syms.comp_if, syms.old_comp_if} + leaf.type == token.NAME + and leaf.value == 'if' + and leaf.parent + and leaf.parent.type in {syms.comp_if, syms.old_comp_if} ): self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY + elif ( + leaf.type == token.NAME + and leaf.value in LOGIC_OPERATORS + and leaf.parent + ): + self.delimiters[id(self.previous)] = LOGIC_PRIORITY if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 @@ -358,7 +418,7 @@ class BracketTracker: """Returns True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) - def max_priority(self, exclude: Iterable[LeafID] = ()) -> int: + def max_priority(self, exclude: Iterable[LeafID] =()) -> int: """Returns the highest priority of a delimiter found on the line. Values are consistent with what `is_delimiter()` returns. @@ -368,11 +428,13 @@ class BracketTracker: @dataclass class Line: - depth: int = attrib(default=0) - leaves: List[Leaf] = attrib(default=Factory(list)) - comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict)) - bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker)) - inside_brackets: bool = attrib(default=False) + depth: int = 0 + leaves: List[Leaf] = Factory(list) + comments: Dict[LeafID, Leaf] = Factory(dict) + bracket_tracker: BracketTracker = Factory(BracketTracker) + inside_brackets: bool = False + has_for: bool = False + _for_loop_variable: bool = False def append(self, leaf: Leaf, preformatted: bool = False) -> None: has_value = leaf.value.strip() @@ -384,8 +446,10 @@ class Line: # imports, for which we only preserve newlines. leaf.prefix += whitespace(leaf) if self.inside_brackets or not preformatted: + self.maybe_decrement_after_for_loop_variable(leaf) self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) + self.maybe_increment_for_loop_variable(leaf) if self.maybe_adapt_standalone_comment(leaf): return @@ -407,9 +471,9 @@ class Line: @property def is_class(self) -> bool: return ( - bool(self) and - self.leaves[0].type == token.NAME and - self.leaves[0].value == 'class' + bool(self) + and self.leaves[0].type == token.NAME + and self.leaves[0].value == 'class' ) @property @@ -425,37 +489,36 @@ class Line: except IndexError: second_leaf = None return ( - (first_leaf.type == token.NAME and first_leaf.value == 'def') or - ( - first_leaf.type == token.NAME and - first_leaf.value == 'async' and - second_leaf is not None and - second_leaf.type == token.NAME and - second_leaf.value == 'def' + (first_leaf.type == token.NAME and first_leaf.value == 'def') + or ( + first_leaf.type == token.ASYNC + and second_leaf is not None + and second_leaf.type == token.NAME + and second_leaf.value == 'def' ) ) @property def is_flow_control(self) -> bool: return ( - bool(self) and - self.leaves[0].type == token.NAME and - self.leaves[0].value in FLOW_CONTROL + bool(self) + and self.leaves[0].type == token.NAME + and self.leaves[0].value in FLOW_CONTROL ) @property def is_yield(self) -> bool: return ( - bool(self) and - self.leaves[0].type == token.NAME and - self.leaves[0].value == 'yield' + bool(self) + and self.leaves[0].type == token.NAME + and self.leaves[0].value == 'yield' ) def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: if not ( - self.leaves and - self.leaves[-1].type == token.COMMA and - closing.type in CLOSING_BRACKETS + self.leaves + and self.leaves[-1].type == token.COMMA + and closing.type in CLOSING_BRACKETS ): return False @@ -466,9 +529,9 @@ class Line: # For parens let's check if it's safe to remove the comma. If the # trailing one is the only one, we might mistakenly change a tuple # into a different type by removing the comma. - depth = closing.bracket_depth + 1 # type: ignore + depth = closing.bracket_depth + 1 commas = 0 - opening = closing.opening_bracket # type: ignore + opening = closing.opening_bracket for _opening_index, leaf in enumerate(self.leaves): if leaf is opening: break @@ -480,15 +543,42 @@ class Line: if leaf is closing: break - bracket_depth = leaf.bracket_depth # type: ignore + bracket_depth = leaf.bracket_depth if bracket_depth == depth and leaf.type == token.COMMA: commas += 1 + if leaf.parent and leaf.parent.type == syms.arglist: + commas += 1 + break + if commas > 1: self.leaves.pop() return True return False + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """In a for loop, or comprehension, the variables are often unpacks. + + To avoid splitting on the comma in this situation, we will increase + the depth of tokens between `for` and `in`. + """ + if leaf.type == token.NAME and leaf.value == 'for': + self.has_for = True + self.bracket_tracker.depth += 1 + self._for_loop_variable = True + return True + + return False + + def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: + # See `maybe_increment_for_loop_variable` above for explanation. + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in': + self.bracket_tracker.depth -= 1 + self._for_loop_variable = False + return True + + return False + def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: """Hack a standalone comment to act as a trailing comment for line splitting. @@ -499,8 +589,8 @@ class Line: appended will appear "too long" when splitting. """ if not ( - comment.type == STANDALONE_COMMENT and - self.bracket_tracker.any_open_brackets() + comment.type == STANDALONE_COMMENT + and self.bracket_tracker.any_open_brackets() ): return False @@ -557,11 +647,13 @@ class EmptyLineTracker: """Provides a stateful method that returns the number of potential extra empty lines needed before and after the currently processed line. - Note: this tracker works on lines that haven't been split yet. + Note: this tracker works on lines that haven't been split yet. It assumes + the prefix of the first leaf consists of optional newlines. Those newlines + are consumed by `maybe_empty_lines()` and included in the computation. """ - previous_line: Optional[Line] = attrib(default=None) - previous_after: int = attrib(default=0) - previous_defs: List[int] = attrib(default=Factory(list)) + previous_line: Optional[Line] = None + previous_after: int = 0 + previous_defs: List[int] = Factory(list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: """Returns the number of extra empty lines before and after the `current_line`. @@ -570,17 +662,28 @@ class EmptyLineTracker: (two on module-level), as well as providing an extra empty line after flow control keywords to make them more prominent. """ + if current_line.is_comment: + # Don't count standalone comments towards previous empty lines. + return 0, 0 + before, after = self._maybe_empty_lines(current_line) + before -= self.previous_after self.previous_after = after self.previous_line = current_line return before, after def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: - before = 0 + if current_line.leaves: + # Consume the first leaf's extra newlines. + first_leaf = current_line.leaves[0] + before = int('\n' in first_leaf.prefix) + first_leaf.prefix = '' + else: + before = 0 depth = current_line.depth while self.previous_defs and self.previous_defs[-1] >= depth: self.previous_defs.pop() - before = (1 if depth else 2) - self.previous_after + before = 1 if depth else 2 is_decorator = current_line.is_decorator if is_decorator or current_line.is_def or current_line.is_class: if not is_decorator: @@ -596,24 +699,23 @@ class EmptyLineTracker: newlines = 2 if current_line.depth: newlines -= 1 - newlines -= self.previous_after return newlines, 0 if current_line.is_flow_control: return before, 1 if ( - self.previous_line and - self.previous_line.is_import and - not current_line.is_import and - depth == self.previous_line.depth + self.previous_line + and self.previous_line.is_import + and not current_line.is_import + and depth == self.previous_line.depth ): return (before or 1), 0 if ( - self.previous_line and - self.previous_line.is_yield and - (not current_line.is_yield or depth != self.previous_line.depth) + self.previous_line + and self.previous_line.is_yield + and (not current_line.is_yield or depth != self.previous_line.depth) ): return (before or 1), 0 @@ -627,8 +729,8 @@ class LineGenerator(Visitor[Line]): Note: destroys the tree it's visiting by mutating prefixes of its leaves in ways that will no longer stringify to valid Python code on the tree. """ - current_line: Line = attrib(default=Factory(Line)) - standalone_comments: List[Leaf] = attrib(default=Factory(list)) + current_line: Line = Factory(Line) + standalone_comments: List[Leaf] = Factory(list) def line(self, indent: int = 0) -> Iterator[Line]: """Generate a line. @@ -648,8 +750,9 @@ class LineGenerator(Visitor[Line]): def visit_default(self, node: LN) -> Iterator[Line]: if isinstance(node, Leaf): + any_open_brackets = self.current_line.bracket_tracker.any_open_brackets() for comment in generate_comments(node): - if self.current_line.bracket_tracker.any_open_brackets(): + if any_open_brackets: # any comment within brackets is subject to splitting self.current_line.append(comment) elif comment.type == token.COMMENT: @@ -661,7 +764,7 @@ class LineGenerator(Visitor[Line]): # regular standalone comment, to be processed later (see # docstring in `generate_comments()` self.standalone_comments.append(comment) - normalize_prefix(node) + normalize_prefix(node, inside_brackets=any_open_brackets) if node.type not in WHITESPACE: for comment in self.standalone_comments: yield from self.line() @@ -718,7 +821,7 @@ class LineGenerator(Visitor[Line]): for child in children: yield from self.visit(child) - if child.type == token.NAME and child.value == 'async': # type: ignore + if child.type == token.ASYNC: break internal_stmt = next(children) @@ -756,37 +859,71 @@ BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.R OPENING_BRACKETS = set(BRACKET.keys()) CLOSING_BRACKETS = set(BRACKET.values()) BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS +ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} -def whitespace(leaf: Leaf) -> str: +def whitespace(leaf: Leaf) -> str: # noqa C901 """Return whitespace prefix if needed for the given `leaf`.""" NO = '' SPACE = ' ' DOUBLESPACE = ' ' t = leaf.type p = leaf.parent - if t == token.COLON: + v = leaf.value + if t in ALWAYS_NO_SPACE: return NO - if t == token.COMMA: - return NO + if t == token.COMMENT: + return DOUBLESPACE - if t == token.RPAR: + assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" + if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: return NO - if t == token.COMMENT: - return DOUBLESPACE + prev = leaf.prev_sibling + if not prev: + prevp = preceding_leaf(p) + if not prevp or prevp.type in OPENING_BRACKETS: + return NO - if t == STANDALONE_COMMENT: + if t == token.COLON: + return SPACE if prevp.type == token.COMMA else NO + + if prevp.type == token.EQUAL: + if prevp.parent and prevp.parent.type in { + syms.typedargslist, + syms.varargslist, + syms.parameters, + syms.arglist, + syms.argument, + }: + return NO + + elif prevp.type == token.DOUBLESTAR: + if prevp.parent and prevp.parent.type in { + syms.typedargslist, + syms.varargslist, + syms.parameters, + syms.arglist, + syms.dictsetmaker, + }: + return NO + + elif prevp.type == token.COLON: + if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: + return NO + + elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}: + return NO + + elif prev.type in OPENING_BRACKETS: return NO - assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" if p.type in {syms.parameters, syms.arglist}: # untyped function signatures or calls if t == token.RPAR: return NO - prev = leaf.prev_sibling if not prev or prev.type != token.COMMA: return NO @@ -795,13 +932,11 @@ def whitespace(leaf: Leaf) -> str: if t == token.RPAR: return NO - prev = leaf.prev_sibling if prev and prev.type != token.COMMA: return NO elif p.type == syms.typedargslist: # typed function signatures - prev = leaf.prev_sibling if not prev: return NO @@ -819,7 +954,6 @@ def whitespace(leaf: Leaf) -> str: elif p.type == syms.tname: # type names - prev = leaf.prev_sibling if not prev: prevp = preceding_leaf(p) if not prevp or prevp.type != token.COMMA: @@ -830,7 +964,6 @@ def whitespace(leaf: Leaf) -> str: if t == token.LPAR or t == token.RPAR: return NO - prev = leaf.prev_sibling if not prev: if t == token.DOT: prevp = preceding_leaf(p) @@ -848,7 +981,6 @@ def whitespace(leaf: Leaf) -> str: if t == token.EQUAL: return NO - prev = leaf.prev_sibling if not prev: prevp = preceding_leaf(p) if not prevp or prevp.type == token.LPAR: @@ -862,113 +994,56 @@ def whitespace(leaf: Leaf) -> str: return NO elif p.type == syms.dotted_name: - prev = leaf.prev_sibling if prev: return NO prevp = preceding_leaf(p) - if not prevp or prevp.type == token.AT: + if not prevp or prevp.type == token.AT or prevp.type == token.DOT: return NO elif p.type == syms.classdef: if t == token.LPAR: return NO - prev = leaf.prev_sibling if prev and prev.type == token.LPAR: return NO elif p.type == syms.subscript: # indexing - if t == token.COLON: - return NO - - prev = leaf.prev_sibling - if not prev or prev.type == token.COLON: - return NO - - elif p.type in { - syms.test, - syms.not_test, - syms.xor_expr, - syms.or_test, - syms.and_test, - syms.arith_expr, - syms.shift_expr, - syms.yield_expr, - syms.term, - syms.power, - syms.comparison, - }: - # various arithmetic and logic expressions - prev = leaf.prev_sibling if not prev: - prevp = preceding_leaf(p) - if not prevp or prevp.type in OPENING_BRACKETS: - return NO + assert p.parent is not None, "subscripts are always parented" + if p.parent.type == syms.subscriptlist: + return SPACE - if prevp.type == token.EQUAL: - if prevp.parent and prevp.parent.type in { - syms.varargslist, syms.parameters, syms.arglist, syms.argument - }: - return NO - - return SPACE - - elif p.type == syms.atom: - if t in CLOSING_BRACKETS: return NO - prev = leaf.prev_sibling - if not prev: - prevp = preceding_leaf(p) - if not prevp: - return NO - - if prevp.type in OPENING_BRACKETS: - return NO - - if prevp.type == token.EQUAL: - if prevp.parent and prevp.parent.type in { - syms.varargslist, syms.parameters, syms.arglist, syms.argument - }: - return NO - - if prevp.type == token.DOUBLESTAR: - if prevp.parent and prevp.parent.type in { - syms.varargslist, syms.parameters, syms.arglist, syms.dictsetmaker - }: - return NO - - elif prev.type in OPENING_BRACKETS: + else: return NO - elif t == token.DOT: + elif p.type == syms.atom: + if prev and t == token.DOT: # dots, but not the first one. return NO elif ( - p.type == syms.listmaker or - p.type == syms.testlist_gexp or - p.type == syms.subscriptlist + p.type == syms.listmaker + or p.type == syms.testlist_gexp + or p.type == syms.subscriptlist ): # list interior, including unpacking - prev = leaf.prev_sibling if not prev: return NO elif p.type == syms.dictsetmaker: # dict and set interior, including unpacking - prev = leaf.prev_sibling if not prev: return NO if prev.type == token.DOUBLESTAR: return NO - elif p.type == syms.factor or p.type == syms.star_expr: + elif p.type in {syms.factor, syms.star_expr}: # unary ops - prev = leaf.prev_sibling if not prev: prevp = preceding_leaf(p) if not prevp or prevp.type in OPENING_BRACKETS: @@ -987,10 +1062,17 @@ def whitespace(leaf: Leaf) -> str: elif t == token.NAME or t == token.NUMBER: return NO - elif p.type == syms.import_from and t == token.NAME: - prev = leaf.prev_sibling - if prev and prev.type == token.DOT: - return NO + elif p.type == syms.import_from: + if t == token.DOT: + if prev and prev.type == token.DOT: + return NO + + elif t == token.NAME: + if v == 'import': + return SPACE + + if prev and prev.type == token.DOT: + return NO elif p.type == syms.sliceop: return NO @@ -1024,16 +1106,13 @@ def is_delimiter(leaf: Leaf) -> int: if leaf.type == token.COMMA: return COMMA_PRIORITY - if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS: - return LOGIC_PRIORITY - if leaf.type in COMPARATORS: return COMPARATOR_PRIORITY if ( - leaf.type in MATH_OPERATORS and - leaf.parent and - leaf.parent.type not in {syms.factor, syms.star_expr} + leaf.type in MATH_OPERATORS + and leaf.parent + and leaf.parent.type not in {syms.factor, syms.star_expr} ): return MATH_PRIORITY @@ -1070,7 +1149,7 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: if content and (content[0] not in {' ', '!', '#'}): content = ' ' + content is_standalone_comment = ( - '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT + '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER ) if not is_standalone_comment: # simple trailing comment @@ -1085,13 +1164,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: yield Leaf(STANDALONE_COMMENT, line) -def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]: +def split_line( + line: Line, line_length: int, inner: bool = False, py36: bool = False +) -> Iterator[Line]: """Splits a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the current `line`, possibly transitively. This means we can fallback to splitting by delimiters if the LHS/RHS don't yield any results. + + If `py36` is True, splitting may generate syntax that is only compatible + with Python 3.6 and later. """ line_str = str(line).strip('\n') if len(line_str) <= line_length and '\n' not in line_str: @@ -1114,11 +1198,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li # split altogether. result: List[Line] = [] try: - for l in split_func(line): + for l in split_func(line, py36=py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") - result.extend(split_line(l, line_length=line_length, inner=True)) + result.extend( + split_line(l, line_length=line_length, inner=True, py36=py36) + ) except CannotSplit as cs: continue @@ -1130,7 +1216,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li yield line -def left_hand_split(line: Line) -> Iterator[Line]: +def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -1146,11 +1232,11 @@ def left_hand_split(line: Line) -> Iterator[Line]: matching_bracket = None for leaf in line.leaves: if ( - current_leaves is body_leaves and - leaf.type in CLOSING_BRACKETS and - leaf.opening_bracket is matching_bracket # type: ignore + current_leaves is body_leaves + and leaf.type in CLOSING_BRACKETS + and leaf.opening_bracket is matching_bracket ): - current_leaves = tail_leaves + current_leaves = tail_leaves if body_leaves else head_leaves current_leaves.append(leaf) if current_leaves is head_leaves: if leaf.type in OPENING_BRACKETS: @@ -1158,7 +1244,7 @@ def left_hand_split(line: Line) -> Iterator[Line]: current_leaves = body_leaves # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: - normalize_prefix(body_leaves[0]) + normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. for result, leaves in ( (head, head_leaves), (body, body_leaves), (tail, tail_leaves) @@ -1168,24 +1254,13 @@ def left_hand_split(line: Line) -> Iterator[Line]: comment_after = line.comments.get(id(leaf)) if comment_after: result.append(comment_after, preformatted=True) - # Check if the split succeeded. - tail_len = len(str(tail)) - if not body: - if tail_len == 0: - raise CannotSplit("Splitting brackets produced the same line") - - elif tail_len < 3: - raise CannotSplit( - f"Splitting brackets on an empty body to save " - f"{tail_len} characters is not worth it" - ) - + split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): if result: yield result -def right_hand_split(line: Line) -> Iterator[Line]: +def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1198,18 +1273,18 @@ def right_hand_split(line: Line) -> Iterator[Line]: for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: - current_leaves = head_leaves + current_leaves = head_leaves if body_leaves else tail_leaves current_leaves.append(leaf) if current_leaves is tail_leaves: if leaf.type in CLOSING_BRACKETS: - opening_bracket = leaf.opening_bracket # type: ignore + opening_bracket = leaf.opening_bracket current_leaves = body_leaves tail_leaves.reverse() body_leaves.reverse() head_leaves.reverse() # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: - normalize_prefix(body_leaves[0]) + normalize_prefix(body_leaves[0], inside_brackets=True) # Build the new lines. for result, leaves in ( (head, head_leaves), (body, body_leaves), (tail, tail_leaves) @@ -1219,8 +1294,14 @@ def right_hand_split(line: Line) -> Iterator[Line]: comment_after = line.comments.get(id(leaf)) if comment_after: result.append(comment_after, preformatted=True) - # Check if the split succeeded. - tail_len = len(str(tail).strip('\n')) + split_succeeded_or_raise(head, body, tail) + for result in (head, body, tail): + if result: + yield result + + +def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None: + tail_len = len(str(tail).strip()) if not body: if tail_len == 0: raise CannotSplit("Splitting brackets produced the same line") @@ -1231,15 +1312,13 @@ def right_hand_split(line: Line) -> Iterator[Line]: f"{tail_len} characters is not worth it" ) - for result in (head, body, tail): - if result: - yield result - -def delimiter_split(line: Line) -> Iterator[Line]: +def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split according to delimiters of the highest priority. This kind of split doesn't increase indentation. + If `py36` is True, the split will add trailing commas also in function + signatures that contain * and **. """ try: last_leaf = line.leaves[-1] @@ -1253,24 +1332,34 @@ def delimiter_split(line: Line) -> Iterator[Line]: raise CannotSplit("No delimiters found") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + lowest_depth = sys.maxsize + trailing_comma_safe = True for leaf in line.leaves: current_line.append(leaf, preformatted=True) comment_after = line.comments.get(id(leaf)) if comment_after: current_line.append(comment_after, preformatted=True) + lowest_depth = min(lowest_depth, leaf.bracket_depth) + if ( + leaf.bracket_depth == lowest_depth + and leaf.type == token.STAR + or leaf.type == token.DOUBLESTAR + ): + trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: - normalize_prefix(current_line.leaves[0]) + normalize_prefix(current_line.leaves[0], inside_brackets=True) yield current_line current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) if current_line: if ( - delimiter_priority == COMMA_PRIORITY and - current_line.leaves[-1].type != token.COMMA + delimiter_priority == COMMA_PRIORITY + and current_line.leaves[-1].type != token.COMMA + and trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) - normalize_prefix(current_line.leaves[0]) + normalize_prefix(current_line.leaves[0], inside_brackets=True) yield current_line @@ -1280,28 +1369,55 @@ def is_import(leaf: Leaf) -> bool: t = leaf.type v = leaf.value return bool( - t == token.NAME and - ( - (v == 'import' and p and p.type == syms.import_name) or - (v == 'from' and p and p.type == syms.import_from) + t == token.NAME + and ( + (v == 'import' and p and p.type == syms.import_name) + or (v == 'from' and p and p.type == syms.import_from) ) ) -def normalize_prefix(leaf: Leaf) -> None: - """Leave existing extra newlines for imports. Remove everything else.""" - if is_import(leaf): +def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: + """Leave existing extra newlines if not `inside_brackets`. + + Remove everything else. Note: don't use backslashes for formatting or + you'll lose your voting rights. + """ + if not inside_brackets: spl = leaf.prefix.split('#', 1) - nl_count = spl[0].count('\n') - if len(spl) > 1: - # Skip one newline since it was for a standalone comment. - nl_count -= 1 - leaf.prefix = '\n' * nl_count - return + if '\\' not in spl[0]: + nl_count = spl[0].count('\n') + leaf.prefix = '\n' * nl_count + return leaf.prefix = '' +def is_python36(node: Node) -> bool: + """Returns True if the current file is using Python 3.6+ features. + + Currently looking for: + - f-strings; and + - trailing commas after * or ** in function signatures. + """ + for n in node.pre_order(): + if n.type == token.STRING: + value_head = n.value[:2] # type: ignore + if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + return True + + elif ( + n.type == syms.typedargslist + and n.children + and n.children[-1].type == token.COMMA + ): + for ch in n.children: + if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + return True + + return False + + PYTHON_EXTENSIONS = {'.py'} BLACKLISTED_DIRECTORIES = { 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' @@ -1323,9 +1439,9 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: @dataclass class Report: """Provides a reformatting counter.""" - change_count: int = attrib(default=0) - same_count: int = attrib(default=0) - failure_count: int = attrib(default=0) + change_count: int = 0 + same_count: int = 0 + failure_count: int = 0 def done(self, src: Path, changed: bool) -> None: """Increment the counter for successful reformatting. Write out a message.""" @@ -1344,7 +1460,15 @@ class Report: @property def return_code(self) -> int: """Which return code should the app use considering the current state.""" - return 1 if self.failure_count else 0 + # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with + # 126 we have special returncodes reserved by the shell. + if self.failure_count: + return 123 + + elif self.change_count: + return 1 + + return 0 def __str__(self) -> str: """A color report of the current state. @@ -1416,7 +1540,7 @@ def assert_equivalent(src: str, dst: str) -> None: raise AssertionError( f"INTERNAL ERROR: Black produced invalid code: {exc}. " f"Please report a bug on https://github.com/ambv/black/issues. " - f"This invalid output might be helpful: {log}", + f"This invalid output might be helpful: {log}" ) from None src_ast_str = '\n'.join(_v(src_ast)) @@ -1427,7 +1551,7 @@ def assert_equivalent(src: str, dst: str) -> None: f"INTERNAL ERROR: Black produced code that is not equivalent to " f"the source. " f"Please report a bug on https://github.com/ambv/black/issues. " - f"This diff might be helpful: {log}", + f"This diff might be helpful: {log}" ) from None @@ -1446,7 +1570,7 @@ def assert_stable(src: str, dst: str, line_length: int) -> None: f"INTERNAL ERROR: Black produced different code on the second pass " f"of the formatter. " f"Please report a bug on https://github.com/ambv/black/issues. " - f"This diff might be helpful: {log}", + f"This diff might be helpful: {log}" ) from None