X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/174fc47b7893b98d4a4393da849c09b6bd415764..907dc6c35e23a0111b140962265d08c23f97082b:/black.py?ds=inline diff --git a/black.py b/black.py index 64df7f7..51b5b21 100644 --- a/black.py +++ b/black.py @@ -7,6 +7,7 @@ import keyword import os from pathlib import Path import tokenize +import sys from typing import ( Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union ) @@ -55,6 +56,15 @@ class CannotSplit(Exception): help='How many character per line to allow.', show_default=True, ) +@click.option( + '--check', + is_flag=True, + help=( + "Don't write back the files, just return the status. Return code 0 " + "means nothing changed. Return code 1 means some files were " + "reformatted. Return code 123 means there was an internal error." + ), +) @click.option( '--fast/--safe', is_flag=True, @@ -67,7 +77,9 @@ class CannotSplit(Exception): type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True), ) @click.pass_context -def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None: +def main( + ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str] +) -> None: """The uncompromising code formatter.""" sources: List[Path] = [] for s in src: @@ -85,7 +97,9 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No p = sources[0] report = Report() try: - changed = format_file_in_place(p, line_length=line_length, fast=fast) + changed = format_file_in_place( + p, line_length=line_length, fast=fast, write_back=not check + ) report.done(p, changed) except Exception as exc: report.failed(p, str(exc)) @@ -96,7 +110,9 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No return_code = 1 try: return_code = loop.run_until_complete( - schedule_formatting(sources, line_length, fast, loop, executor) + schedule_formatting( + sources, line_length, not check, fast, loop, executor + ) ) finally: loop.close() @@ -106,13 +122,14 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No async def schedule_formatting( sources: List[Path], line_length: int, + write_back: bool, fast: bool, loop: BaseEventLoop, executor: Executor, ) -> int: tasks = { src: loop.run_in_executor( - executor, format_file_in_place, src, line_length, fast + executor, format_file_in_place, src, line_length, fast, write_back ) for src in sources } @@ -135,15 +152,18 @@ async def schedule_formatting( return report.return_code -def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool: +def format_file_in_place( + src: Path, line_length: int, fast: bool, write_back: bool = False +) -> bool: """Format the file and rewrite if changed. Return True if changed.""" try: contents, encoding = format_file(src, line_length=line_length, fast=fast) except NothingChanged: return False - with open(src, "w", encoding=encoding) as f: - f.write(contents) + if write_back: + with open(src, "w", encoding=encoding) as f: + f.write(contents) return True @@ -173,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: comments: List[Line] = [] lines = LineGenerator() elt = EmptyLineTracker() + py36 = is_python36(src_node) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -185,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: for comment in comments: dst_contents += str(comment) comments = [] - for line in split_line(current_line, line_length=line_length): + for line in split_line(current_line, line_length=line_length, py36=py36): dst_contents += str(line) else: comments.append(current_line) @@ -326,8 +347,8 @@ class BracketTracker: if leaf.type in CLOSING_BRACKETS: self.depth -= 1 opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) - leaf.opening_bracket = opening_bracket # type: ignore - leaf.bracket_depth = self.depth # type: ignore + leaf.opening_bracket = opening_bracket + leaf.bracket_depth = self.depth if self.depth == 0: delim = is_delimiter(leaf) if delim: @@ -358,7 +379,7 @@ class BracketTracker: """Returns True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) - def max_priority(self, exclude: Iterable[LeafID] = ()) -> int: + def max_priority(self, exclude: Iterable[LeafID] =()) -> int: """Returns the highest priority of a delimiter found on the line. Values are consistent with what `is_delimiter()` returns. @@ -373,6 +394,8 @@ class Line: comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict)) bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker)) inside_brackets: bool = attrib(default=False) + has_for: bool = attrib(default=False) + _for_loop_variable: bool = attrib(default=False, init=False) def append(self, leaf: Leaf, preformatted: bool = False) -> None: has_value = leaf.value.strip() @@ -384,8 +407,10 @@ class Line: # imports, for which we only preserve newlines. leaf.prefix += whitespace(leaf) if self.inside_brackets or not preformatted: + self.maybe_decrement_after_for_loop_variable(leaf) self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) + self.maybe_increment_for_loop_variable(leaf) if self.maybe_adapt_standalone_comment(leaf): return @@ -466,9 +491,9 @@ class Line: # For parens let's check if it's safe to remove the comma. If the # trailing one is the only one, we might mistakenly change a tuple # into a different type by removing the comma. - depth = closing.bracket_depth + 1 # type: ignore + depth = closing.bracket_depth + 1 commas = 0 - opening = closing.opening_bracket # type: ignore + opening = closing.opening_bracket for _opening_index, leaf in enumerate(self.leaves): if leaf is opening: break @@ -480,7 +505,7 @@ class Line: if leaf is closing: break - bracket_depth = leaf.bracket_depth # type: ignore + bracket_depth = leaf.bracket_depth if bracket_depth == depth and leaf.type == token.COMMA: commas += 1 if commas > 1: @@ -489,6 +514,29 @@ class Line: return False + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: + """In a for loop, or comprehension, the variables are often unpacks. + + To avoid splitting on the comma in this situation, we will increase + the depth of tokens between `for` and `in`. + """ + if leaf.type == token.NAME and leaf.value == 'for': + self.has_for = True + self.bracket_tracker.depth += 1 + self._for_loop_variable = True + return True + + return False + + def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: + # See `maybe_increment_for_loop_variable` above for explanation. + if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in': + self.bracket_tracker.depth -= 1 + self._for_loop_variable = False + return True + + return False + def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool: """Hack a standalone comment to act as a trailing comment for line splitting. @@ -781,13 +829,51 @@ def whitespace(leaf: Leaf) -> str: if t == STANDALONE_COMMENT: return NO + if t in CLOSING_BRACKETS: + return NO + assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" + prev = leaf.prev_sibling + if not prev: + prevp = preceding_leaf(p) + if not prevp or prevp.type in OPENING_BRACKETS: + return NO + + if prevp.type == token.EQUAL: + if prevp.parent and prevp.parent.type in { + syms.typedargslist, + syms.varargslist, + syms.parameters, + syms.arglist, + syms.argument, + }: + return NO + + elif prevp.type == token.DOUBLESTAR: + if prevp.parent and prevp.parent.type in { + syms.typedargslist, + syms.varargslist, + syms.parameters, + syms.arglist, + syms.dictsetmaker, + }: + return NO + + elif prevp.type == token.COLON: + if prevp.parent and prevp.parent.type == syms.subscript: + return NO + + elif prevp.parent and prevp.parent.type == syms.factor: + return NO + + elif prev.type in OPENING_BRACKETS: + return NO + if p.type in {syms.parameters, syms.arglist}: # untyped function signatures or calls if t == token.RPAR: return NO - prev = leaf.prev_sibling if not prev or prev.type != token.COMMA: return NO @@ -796,13 +882,11 @@ def whitespace(leaf: Leaf) -> str: if t == token.RPAR: return NO - prev = leaf.prev_sibling if prev and prev.type != token.COMMA: return NO elif p.type == syms.typedargslist: # typed function signatures - prev = leaf.prev_sibling if not prev: return NO @@ -820,7 +904,6 @@ def whitespace(leaf: Leaf) -> str: elif p.type == syms.tname: # type names - prev = leaf.prev_sibling if not prev: prevp = preceding_leaf(p) if not prevp or prevp.type != token.COMMA: @@ -831,7 +914,6 @@ def whitespace(leaf: Leaf) -> str: if t == token.LPAR or t == token.RPAR: return NO - prev = leaf.prev_sibling if not prev: if t == token.DOT: prevp = preceding_leaf(p) @@ -849,7 +931,6 @@ def whitespace(leaf: Leaf) -> str: if t == token.EQUAL: return NO - prev = leaf.prev_sibling if not prev: prevp = preceding_leaf(p) if not prevp or prevp.type == token.LPAR: @@ -863,7 +944,6 @@ def whitespace(leaf: Leaf) -> str: return NO elif p.type == syms.dotted_name: - prev = leaf.prev_sibling if prev: return NO @@ -875,77 +955,16 @@ def whitespace(leaf: Leaf) -> str: if t == token.LPAR: return NO - prev = leaf.prev_sibling if prev and prev.type == token.LPAR: return NO elif p.type == syms.subscript: # indexing - if t == token.COLON: - return NO - - prev = leaf.prev_sibling if not prev or prev.type == token.COLON: return NO - elif p.type in { - syms.test, - syms.not_test, - syms.xor_expr, - syms.or_test, - syms.and_test, - syms.arith_expr, - syms.expr, - syms.shift_expr, - syms.yield_expr, - syms.term, - syms.power, - syms.comparison, - }: - # various arithmetic and logic expressions - prev = leaf.prev_sibling - if not prev: - prevp = preceding_leaf(p) - if not prevp or prevp.type in OPENING_BRACKETS: - return NO - - if prevp.type == token.EQUAL: - if prevp.parent and prevp.parent.type in { - syms.varargslist, syms.parameters, syms.arglist, syms.argument - }: - return NO - - return SPACE - elif p.type == syms.atom: - if t in CLOSING_BRACKETS: - return NO - - prev = leaf.prev_sibling - if not prev: - prevp = preceding_leaf(p) - if not prevp: - return NO - - if prevp.type in OPENING_BRACKETS: - return NO - - if prevp.type == token.EQUAL: - if prevp.parent and prevp.parent.type in { - syms.varargslist, syms.parameters, syms.arglist, syms.argument - }: - return NO - - if prevp.type == token.DOUBLESTAR: - if prevp.parent and prevp.parent.type in { - syms.varargslist, syms.parameters, syms.arglist, syms.dictsetmaker - }: - return NO - - elif prev.type in OPENING_BRACKETS: - return NO - - elif t == token.DOT: + if prev and t == token.DOT: # dots, but not the first one. return NO @@ -955,13 +974,11 @@ def whitespace(leaf: Leaf) -> str: p.type == syms.subscriptlist ): # list interior, including unpacking - prev = leaf.prev_sibling if not prev: return NO elif p.type == syms.dictsetmaker: # dict and set interior, including unpacking - prev = leaf.prev_sibling if not prev: return NO @@ -970,7 +987,6 @@ def whitespace(leaf: Leaf) -> str: elif p.type == syms.factor or p.type == syms.star_expr: # unary ops - prev = leaf.prev_sibling if not prev: prevp = preceding_leaf(p) if not prevp or prevp.type in OPENING_BRACKETS: @@ -991,7 +1007,6 @@ def whitespace(leaf: Leaf) -> str: elif p.type == syms.import_from: if t == token.DOT: - prev = leaf.prev_sibling if prev and prev.type == token.DOT: return NO @@ -999,7 +1014,6 @@ def whitespace(leaf: Leaf) -> str: if v == 'import': return SPACE - prev = leaf.prev_sibling if prev and prev.type == token.DOT: return NO @@ -1096,13 +1110,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: yield Leaf(STANDALONE_COMMENT, line) -def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]: +def split_line( + line: Line, line_length: int, inner: bool = False, py36: bool = False +) -> Iterator[Line]: """Splits a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the current `line`, possibly transitively. This means we can fallback to splitting by delimiters if the LHS/RHS don't yield any results. + + If `py36` is True, splitting may generate syntax that is only compatible + with Python 3.6 and later. """ line_str = str(line).strip('\n') if len(line_str) <= line_length and '\n' not in line_str: @@ -1125,11 +1144,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li # split altogether. result: List[Line] = [] try: - for l in split_func(line): + for l in split_func(line, py36=py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") - result.extend(split_line(l, line_length=line_length, inner=True)) + result.extend( + split_line(l, line_length=line_length, inner=True, py36=py36) + ) except CannotSplit as cs: continue @@ -1141,7 +1162,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li yield line -def left_hand_split(line: Line) -> Iterator[Line]: +def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -1159,7 +1180,7 @@ def left_hand_split(line: Line) -> Iterator[Line]: if ( current_leaves is body_leaves and leaf.type in CLOSING_BRACKETS and - leaf.opening_bracket is matching_bracket # type: ignore + leaf.opening_bracket is matching_bracket ): current_leaves = tail_leaves current_leaves.append(leaf) @@ -1196,7 +1217,7 @@ def left_hand_split(line: Line) -> Iterator[Line]: yield result -def right_hand_split(line: Line) -> Iterator[Line]: +def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1213,7 +1234,7 @@ def right_hand_split(line: Line) -> Iterator[Line]: current_leaves.append(leaf) if current_leaves is tail_leaves: if leaf.type in CLOSING_BRACKETS: - opening_bracket = leaf.opening_bracket # type: ignore + opening_bracket = leaf.opening_bracket current_leaves = body_leaves tail_leaves.reverse() body_leaves.reverse() @@ -1247,10 +1268,12 @@ def right_hand_split(line: Line) -> Iterator[Line]: yield result -def delimiter_split(line: Line) -> Iterator[Line]: +def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split according to delimiters of the highest priority. This kind of split doesn't increase indentation. + If `py36` is True, the split will add trailing commas also in function + signatures that contain * and **. """ try: last_leaf = line.leaves[-1] @@ -1264,11 +1287,20 @@ def delimiter_split(line: Line) -> Iterator[Line]: raise CannotSplit("No delimiters found") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + lowest_depth = sys.maxsize + trailing_comma_safe = True for leaf in line.leaves: current_line.append(leaf, preformatted=True) comment_after = line.comments.get(id(leaf)) if comment_after: current_line.append(comment_after, preformatted=True) + lowest_depth = min(lowest_depth, leaf.bracket_depth) + if ( + leaf.bracket_depth == lowest_depth and + leaf.type == token.STAR or + leaf.type == token.DOUBLESTAR + ): + trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: normalize_prefix(current_line.leaves[0]) @@ -1278,7 +1310,8 @@ def delimiter_split(line: Line) -> Iterator[Line]: if current_line: if ( delimiter_priority == COMMA_PRIORITY and - current_line.leaves[-1].type != token.COMMA + current_line.leaves[-1].type != token.COMMA and + trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) normalize_prefix(current_line.leaves[0]) @@ -1313,6 +1346,31 @@ def normalize_prefix(leaf: Leaf) -> None: leaf.prefix = '' +def is_python36(node: Node) -> bool: + """Returns True if the current file is using Python 3.6+ features. + + Currently looking for: + - f-strings; and + - trailing commas after * or ** in function signatures. + """ + for n in node.pre_order(): + if n.type == token.STRING: + value_head = n.value[:2] # type: ignore + if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + return True + + elif ( + n.type == syms.typedargslist and + n.children and + n.children[-1].type == token.COMMA + ): + for ch in n.children: + if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + return True + + return False + + PYTHON_EXTENSIONS = {'.py'} BLACKLISTED_DIRECTORIES = { 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv' @@ -1355,7 +1413,15 @@ class Report: @property def return_code(self) -> int: """Which return code should the app use considering the current state.""" - return 1 if self.failure_count else 0 + # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with + # 126 we have special returncodes reserved by the shell. + if self.failure_count: + return 123 + + elif self.change_count: + return 1 + + return 0 def __str__(self) -> str: """A color report of the current state.