X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/c26daa4fd50e2db13f38279ce27b51b0ae4479fe..5fb5cc8c2bd5a0bb1359fb69cdb705b55afade52:/black.py diff --git a/black.py b/black.py index 774d91d..0a9d3ea 100644 --- a/black.py +++ b/black.py @@ -7,6 +7,7 @@ import keyword import os from pathlib import Path import tokenize +import sys from typing import ( Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union ) @@ -192,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: comments: List[Line] = [] lines = LineGenerator() elt = EmptyLineTracker() + py36 = is_python36(src_node) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -204,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent: for comment in comments: dst_contents += str(comment) comments = [] - for line in split_line(current_line, line_length=line_length): + for line in split_line(current_line, line_length=line_length, py36=py36): dst_contents += str(line) else: comments.append(current_line) @@ -1108,13 +1110,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]: yield Leaf(STANDALONE_COMMENT, line) -def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]: +def split_line( + line: Line, line_length: int, inner: bool = False, py36: bool = False +) -> Iterator[Line]: """Splits a `line` into potentially many lines. They should fit in the allotted `line_length` but might not be able to. `inner` signifies that there were a pair of brackets somewhere around the current `line`, possibly transitively. This means we can fallback to splitting by delimiters if the LHS/RHS don't yield any results. + + If `py36` is True, splitting may generate syntax that is only compatible + with Python 3.6 and later. """ line_str = str(line).strip('\n') if len(line_str) <= line_length and '\n' not in line_str: @@ -1137,11 +1144,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li # split altogether. result: List[Line] = [] try: - for l in split_func(line): + for l in split_func(line, py36=py36): if str(l).strip('\n') == line_str: raise CannotSplit("Split function returned an unchanged result") - result.extend(split_line(l, line_length=line_length, inner=True)) + result.extend( + split_line(l, line_length=line_length, inner=True, py36=py36) + ) except CannotSplit as cs: continue @@ -1153,7 +1162,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li yield line -def left_hand_split(line: Line) -> Iterator[Line]: +def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -1208,7 +1217,7 @@ def left_hand_split(line: Line) -> Iterator[Line]: yield result -def right_hand_split(line: Line) -> Iterator[Line]: +def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the last matching bracket pair.""" head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1259,10 +1268,12 @@ def right_hand_split(line: Line) -> Iterator[Line]: yield result -def delimiter_split(line: Line) -> Iterator[Line]: +def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split according to delimiters of the highest priority. This kind of split doesn't increase indentation. + If `py36` is True, the split will add trailing commas also in function + signatures that contain * and **. """ try: last_leaf = line.leaves[-1] @@ -1276,11 +1287,20 @@ def delimiter_split(line: Line) -> Iterator[Line]: raise CannotSplit("No delimiters found") current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + lowest_depth = sys.maxsize + trailing_comma_safe = True for leaf in line.leaves: current_line.append(leaf, preformatted=True) comment_after = line.comments.get(id(leaf)) if comment_after: current_line.append(comment_after, preformatted=True) + lowest_depth = min(lowest_depth, leaf.bracket_depth) + if ( + leaf.bracket_depth == lowest_depth and # type: ignore + leaf.type == token.STAR or + leaf.type == token.DOUBLESTAR + ): + trailing_comma_safe = trailing_comma_safe and py36 leaf_priority = delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: normalize_prefix(current_line.leaves[0]) @@ -1290,7 +1310,8 @@ def delimiter_split(line: Line) -> Iterator[Line]: if current_line: if ( delimiter_priority == COMMA_PRIORITY and - current_line.leaves[-1].type != token.COMMA + current_line.leaves[-1].type != token.COMMA and + trailing_comma_safe ): current_line.append(Leaf(token.COMMA, ',')) normalize_prefix(current_line.leaves[0]) @@ -1325,6 +1346,31 @@ def normalize_prefix(leaf: Leaf) -> None: leaf.prefix = '' +def is_python36(node: Node) -> bool: + """Returns True if the current file is using Python 3.6+ features. + + Currently looking for: + - f-strings; and + - trailing commas after * or ** in function signatures. + """ + for n in node.pre_order(): + if n.type == token.STRING: + assert isinstance(n, Leaf) + if n.value[:2] in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}: + return True + + elif ( + n.type == syms.typedargslist and + n.children and + n.children[-1].type == token.COMMA + ): + for ch in n.children: + if ch.type == token.STAR or ch.type == token.DOUBLESTAR: + return True + + return False + + PYTHON_EXTENSIONS = {'.py'} BLACKLISTED_DIRECTORIES = { 'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'