X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/1747c388bba0c87f75a6239d56e3b51f7455e93d..f24635635eef3a998ea02bfcbc663d3dbe129851:/black.py?ds=inline diff --git a/black.py b/black.py index 151dc8c..f335ebd 100644 --- a/black.py +++ b/black.py @@ -24,6 +24,7 @@ from typing import ( List, Optional, Pattern, + Sequence, Set, Tuple, Type, @@ -41,6 +42,7 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError + __version__ = "18.4a6" DEFAULT_LINE_LENGTH = 88 @@ -237,11 +239,8 @@ def reformat_one( src = src.resolve() if src in cache and cache[src] == get_cache_info(src): changed = Changed.CACHED - if ( - changed is not Changed.CACHED - and format_file_in_place( - src, line_length=line_length, fast=fast, write_back=write_back - ) + if changed is not Changed.CACHED and format_file_in_place( + src, line_length=line_length, fast=fast, write_back=write_back ): changed = Changed.YES if write_back == WriteBack.YES and changed is not Changed.NO: @@ -283,32 +282,29 @@ async def schedule_formatting( manager = Manager() lock = manager.Lock() tasks = { - src: loop.run_in_executor( + loop.run_in_executor( executor, format_file_in_place, src, line_length, fast, write_back, lock - ) - for src in sources + ): src + for src in sorted(sources) } - _task_values = list(tasks.values()) + pending: Iterable[asyncio.Task] = tasks.keys() try: - loop.add_signal_handler(signal.SIGINT, cancel, _task_values) - loop.add_signal_handler(signal.SIGTERM, cancel, _task_values) + loop.add_signal_handler(signal.SIGINT, cancel, pending) + loop.add_signal_handler(signal.SIGTERM, cancel, pending) except NotImplementedError: # There are no good alternatives for these on Windows pass - await asyncio.wait(_task_values) - for src, task in tasks.items(): - if not task.done(): - report.failed(src, "timed out, cancelling") - task.cancel() - cancelled.append(task) - elif task.cancelled(): - cancelled.append(task) - elif task.exception(): - report.failed(src, str(task.exception())) - else: - formatted.append(src) - report.done(src, Changed.YES if task.result() else Changed.NO) - + while pending: + done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + src = tasks.pop(task) + if task.cancelled(): + cancelled.append(task) + elif task.exception(): + report.failed(src, str(task.exception())) + else: + formatted.append(src) + report.done(src, Changed.YES if task.result() else Changed.NO) if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) if write_back == WriteBack.YES and formatted: @@ -327,12 +323,13 @@ def format_file_in_place( If `write_back` is True, write reformatted code back to stdout. `line_length` and `fast` options are passed to :func:`format_file_contents`. """ + is_pyi = src.suffix == ".pyi" with tokenize.open(src) as src_buffer: src_contents = src_buffer.read() try: dst_contents = format_file_contents( - src_contents, line_length=line_length, fast=fast + src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi ) except NothingChanged: return False @@ -381,7 +378,7 @@ def format_stdin_to_stdout( def format_file_contents( - src_contents: str, line_length: int, fast: bool + src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False ) -> FileContent: """Reformat contents a file and return new contents. @@ -392,26 +389,33 @@ def format_file_contents( if src_contents.strip() == "": raise NothingChanged - dst_contents = format_str(src_contents, line_length=line_length) + dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi) if src_contents == dst_contents: raise NothingChanged if not fast: assert_equivalent(src_contents, dst_contents) - assert_stable(src_contents, dst_contents, line_length=line_length) + assert_stable( + src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi + ) return dst_contents -def format_str(src_contents: str, line_length: int) -> FileContent: +def format_str( + src_contents: str, line_length: int, *, is_pyi: bool = False +) -> FileContent: """Reformat a string and return new contents. `line_length` determines how many characters per line are allowed. """ src_node = lib2to3_parse(src_contents) dst_contents = "" - lines = LineGenerator() - elt = EmptyLineTracker() + future_imports = get_future_imports(src_node) + elt = EmptyLineTracker(is_pyi=is_pyi) py36 = is_python36(src_node) + lines = LineGenerator( + remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi + ) empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -553,19 +557,20 @@ COMPARATORS = { token.GREATEREQUAL, } MATH_OPERATORS = { + token.VBAR, + token.CIRCUMFLEX, + token.AMPER, + token.LEFTSHIFT, + token.RIGHTSHIFT, token.PLUS, token.MINUS, token.STAR, token.SLASH, - token.VBAR, - token.AMPER, + token.DOUBLESLASH, token.PERCENT, - token.CIRCUMFLEX, + token.AT, token.TILDE, - token.LEFTSHIFT, - token.RIGHTSHIFT, token.DOUBLESTAR, - token.DOUBLESLASH, } STARS = {token.STAR, token.DOUBLESTAR} VARARGS_PARENTS = { @@ -598,13 +603,44 @@ TEST_DESCENDANTS = { syms.term, syms.power, } +ASSIGNMENTS = { + "=", + "+=", + "-=", + "*=", + "@=", + "/=", + "%=", + "&=", + "|=", + "^=", + "<<=", + ">>=", + "**=", + "//=", +} COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 10 -TERNARY_PRIORITY = 7 -LOGIC_PRIORITY = 5 -STRING_PRIORITY = 4 -COMPARATOR_PRIORITY = 3 -MATH_PRIORITY = 1 +COMMA_PRIORITY = 18 +TERNARY_PRIORITY = 16 +LOGIC_PRIORITY = 14 +STRING_PRIORITY = 12 +COMPARATOR_PRIORITY = 10 +MATH_PRIORITIES = { + token.VBAR: 8, + token.CIRCUMFLEX: 7, + token.AMPER: 6, + token.LEFTSHIFT: 5, + token.RIGHTSHIFT: 5, + token.PLUS: 4, + token.MINUS: 4, + token.STAR: 3, + token.SLASH: 3, + token.DOUBLESLASH: 3, + token.PERCENT: 3, + token.AT: 3, + token.TILDE: 2, + token.DOUBLESTAR: 1, +} @dataclass @@ -615,8 +651,8 @@ class BracketTracker: bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) delimiters: Dict[LeafID, Priority] = Factory(dict) previous: Optional[Leaf] = None - _for_loop_variable: bool = False - _lambda_arguments: bool = False + _for_loop_variable: int = 0 + _lambda_arguments: int = 0 def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -670,6 +706,17 @@ class BracketTracker: """ return max(v for k, v in self.delimiters.items() if k not in exclude) + def delimiter_count_with_priority(self, priority: int = 0) -> int: + """Return the number of delimiters with the given `priority`. + + If no `priority` is passed, defaults to max priority on the line. + """ + if not self.delimiters: + return 0 + + priority = priority or self.max_delimiter_priority() + return sum(1 for p in self.delimiters.values() if p == priority) + def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: """In a for loop, or comprehension, the variables are often unpacks. @@ -678,7 +725,7 @@ class BracketTracker: """ if leaf.type == token.NAME and leaf.value == "for": self.depth += 1 - self._for_loop_variable = True + self._for_loop_variable += 1 return True return False @@ -687,7 +734,7 @@ class BracketTracker: """See `maybe_increment_for_loop_variable` above for explanation.""" if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": self.depth -= 1 - self._for_loop_variable = False + self._for_loop_variable -= 1 return True return False @@ -700,7 +747,7 @@ class BracketTracker: """ if leaf.type == token.NAME and leaf.value == "lambda": self.depth += 1 - self._lambda_arguments = True + self._lambda_arguments += 1 return True return False @@ -709,7 +756,7 @@ class BracketTracker: """See `maybe_increment_lambda_arguments` above for explanation.""" if self._lambda_arguments and leaf.type == token.COLON: self.depth -= 1 - self._lambda_arguments = False + self._lambda_arguments -= 1 return True return False @@ -798,6 +845,13 @@ class Line: and self.leaves[0].value == "class" ) + @property + def is_stub_class(self) -> bool: + """Is this line a class definition with a body consisting only of "..."?""" + return self.is_class and self.leaves[-3:] == [ + Leaf(token.DOT, ".") for _ in range(3) + ] + @property def is_def(self) -> bool: """Is this a function definition? (Also returns True for async defs.)""" @@ -810,14 +864,11 @@ class Line: second_leaf: Optional[Leaf] = self.leaves[1] except IndexError: second_leaf = None - return ( - (first_leaf.type == token.NAME and first_leaf.value == "def") - or ( - first_leaf.type == token.ASYNC - and second_leaf is not None - and second_leaf.type == token.NAME - and second_leaf.value == "def" - ) + return (first_leaf.type == token.NAME and first_leaf.value == "def") or ( + first_leaf.type == token.ASYNC + and second_leaf is not None + and second_leaf.type == token.NAME + and second_leaf.value == "def" ) @property @@ -942,17 +993,21 @@ class Line: self.comments.append((after, comment)) return True - def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: - """Generate comments that should appear directly after `leaf`.""" - for _leaf_index, _leaf in enumerate(self.leaves): - if leaf is _leaf: - break + def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`. - else: - return + Provide a non-negative leaf `_index` to speed up the function. + """ + if _index == -1: + for _index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return for index, comment_after in self.comments: - if _leaf_index == index: + if _index == index: yield comment_after def remove_trailing_comma(self) -> None: @@ -1060,6 +1115,7 @@ class EmptyLineTracker: the prefix of the first leaf consists of optional newlines. Those newlines are consumed by `maybe_empty_lines()` and included in the computation. """ + is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 previous_defs: List[int] = Factory(list) @@ -1083,7 +1139,7 @@ class EmptyLineTracker: def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: max_allowed = 1 if current_line.depth == 0: - max_allowed = 2 + max_allowed = 1 if self.is_pyi else 2 if current_line.leaves: # Consume the first leaf's extra newlines. first_leaf = current_line.leaves[0] @@ -1095,7 +1151,10 @@ class EmptyLineTracker: depth = current_line.depth while self.previous_defs and self.previous_defs[-1] >= depth: self.previous_defs.pop() - before = 1 if depth else 2 + if self.is_pyi: + before = 0 if depth else 1 + else: + before = 1 if depth else 2 is_decorator = current_line.is_decorator if is_decorator or current_line.is_def or current_line.is_class: if not is_decorator: @@ -1114,8 +1173,19 @@ class EmptyLineTracker: ): return 0, 0 - newlines = 2 - if current_line.depth: + if self.is_pyi: + if self.previous_line.depth > current_line.depth: + newlines = 1 + elif current_line.is_class or self.previous_line.is_class: + if current_line.is_stub_class and self.previous_line.is_stub_class: + newlines = 0 + else: + newlines = 1 + else: + newlines = 0 + else: + newlines = 2 + if current_line.depth and newlines: newlines -= 1 return newlines, 0 @@ -1137,7 +1207,9 @@ class LineGenerator(Visitor[Line]): Note: destroys the tree it's visiting by mutating prefixes of its leaves in ways that will no longer stringify to valid Python code on the tree. """ + is_pyi: bool = False current_line: Line = Factory(Line) + remove_u_prefix: bool = False def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]: """Generate a line. @@ -1205,6 +1277,7 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) if node.type == token.STRING: + normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) @@ -1236,14 +1309,13 @@ class LineGenerator(Visitor[Line]): """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, - `def`, `with`, `class`, and `assert`. + `def`, `with`, `class`, `assert` and assignments. The relevant Python language `keywords` for a given statement will be NAME leaves within it. This methods puts those on a separate line. - `parens` holds pairs of nodes where invisible parentheses should be put. - Keys hold nodes after which opening parentheses should be put, values - hold nodes before which closing parentheses should be put. + `parens` holds a set of string leaf values immeditely after which + invisible parens should be put. """ normalize_invisible_parens(node, parens_after=parens) for child in node.children: @@ -1252,16 +1324,27 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) + def visit_suite(self, node: Node) -> Iterator[Line]: + """Visit a suite.""" + if self.is_pyi and is_stub_suite(node): + yield from self.visit(node.children[2]) + else: + yield from self.visit_default(node) + def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - yield from self.line(+1) - yield from self.visit_default(node) - yield from self.line(-1) + if self.is_pyi and is_stub_body(node): + yield from self.visit_default(node) + else: + yield from self.line(+1) + yield from self.visit_default(node) + yield from self.line(-1) else: - yield from self.line() + if not self.is_pyi or not node.parent or not is_stub_suite(node.parent): + yield from self.line() yield from self.visit_default(node) def visit_async_stmt(self, node: Node) -> Iterator[Line]: @@ -1343,7 +1426,9 @@ class LineGenerator(Visitor[Line]): v = self.visit_stmt Ø: Set[str] = set() self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) - self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"}) + self.visit_if_stmt = partial( + v, keywords={"if", "else", "elif"}, parens={"if", "elif"} + ) self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"}) self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"}) self.visit_try_stmt = partial( @@ -1353,6 +1438,8 @@ class LineGenerator(Visitor[Line]): self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø) self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) + self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) + self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) self.visit_async_funcdef = self.visit_async_stmt self.visit_decorated = self.visit_decorators @@ -1384,10 +1471,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if ( - t == token.COLON - and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop} - ): + if t == token.COLON and p.type not in { + syms.subscript, syms.subscriptlist, syms.sliceop + }: return NO prev = leaf.prev_sibling @@ -1561,10 +1647,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901 prevp_parent = prevp.parent assert prevp_parent is not None - if ( - prevp.type == token.COLON - and prevp_parent.type in {syms.subscript, syms.sliceop} - ): + if prevp.type == token.COLON and prevp_parent.type in { + syms.subscript, syms.sliceop + }: return NO elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument: @@ -1649,7 +1734,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: and leaf.parent and leaf.parent.type not in {syms.factor, syms.star_expr} ): - return MATH_PRIORITY + return MATH_PRIORITIES[leaf.type] if leaf.type in COMPARATORS: return COMPARATOR_PRIORITY @@ -1785,11 +1870,7 @@ def split_line( return line_str = str(line).strip("\n") - if ( - len(line_str) <= line_length - and "\n" not in line_str # multiline strings - and not line.contains_standalone_comments() - ): + if is_line_short_enough(line, line_length=line_length, line_str=line_str): yield line return @@ -1798,10 +1879,22 @@ def split_line( split_funcs = [left_hand_split] elif line.is_import: split_funcs = [explode_split] - elif line.inside_brackets: - split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: - split_funcs = [right_hand_split] + + def rhs(line: Line, py36: bool = False) -> Iterator[Line]: + for omit in generate_trailers_to_omit(line, line_length): + lines = list(right_hand_split(line, py36, omit=omit)) + if is_line_short_enough(lines[0], line_length=line_length): + yield from lines + return + + # All splits failed, best effort split with no omits. + yield from right_hand_split(line, py36) + + if line.inside_brackets: + split_funcs = [delimiter_split, standalone_comment_split, rhs] + else: + split_funcs = [rhs] for split_func in split_funcs: # We are accumulating lines in `result` because we might want to abort # mission and return the original line in the end, or attempt a different @@ -1874,6 +1967,10 @@ def right_hand_split( """Split line into many lines, starting with the last matching bracket pair. If the split was by optional parentheses, attempt splitting without them, too. + `omit` is a collection of closing bracket IDs that shouldn't be considered for + this split. + + Note: running this function modifies `bracket_depth` on the leaves of `line`. """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1900,8 +1997,9 @@ def right_hand_split( # Since body is a new indent level, remove spurious leading whitespace. if body_leaves: normalize_prefix(body_leaves[0], inside_brackets=True) - elif not head_leaves: - # No `head` and no `body` means the split failed. `tail` has all content. + if not head_leaves: + # No `head` means the split failed. Either `tail` has all content or + # the matching `opening_bracket` wasn't available on `line` anymore. raise CannotSplit("No brackets found") # Build the new lines. @@ -1919,19 +2017,27 @@ def right_hand_split( # the closing bracket is an optional paren and closing_bracket.type == token.RPAR and not closing_bracket.value - # there are no delimiters or standalone comments in the body - and not body.bracket_tracker.delimiters + # there are no standalone comments in the body and not line.contains_standalone_comments(0) # and it's not an import (optional parens are the only thing we can split # on in this case; attempting a split without them is a waste of time) and not line.is_import ): omit = {id(closing_bracket), *omit} - try: - yield from right_hand_split(line, py36=py36, omit=omit) - return - except CannotSplit: - pass + delimiter_count = body.bracket_tracker.delimiter_count_with_priority() + if ( + delimiter_count == 0 + or delimiter_count == 1 + and ( + body.leaves[0].type in OPENING_BRACKETS + or body.leaves[-1].type in CLOSING_BRACKETS + ) + ): + try: + yield from right_hand_split(line, py36=py36, omit=omit) + return + except CannotSplit: + pass ensure_visible(opening_bracket) ensure_visible(closing_bracket) @@ -1993,11 +2099,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: except IndexError: raise CannotSplit("Line empty") - delimiters = line.bracket_tracker.delimiters + bt = line.bracket_tracker try: - delimiter_priority = line.bracket_tracker.max_delimiter_priority( - exclude={id(last_leaf)} - ) + delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)}) except ValueError: raise CannotSplit("No delimiters found") @@ -2016,19 +2120,18 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) lowest_depth = min(lowest_depth, leaf.bracket_depth) - if ( - leaf.bracket_depth == lowest_depth - and is_vararg(leaf, within=VARARGS_PARENTS) + if leaf.bracket_depth == lowest_depth and is_vararg( + leaf, within=VARARGS_PARENTS ): trailing_comma_safe = trailing_comma_safe and py36 - leaf_priority = delimiters.get(id(leaf)) + leaf_priority = bt.delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: yield current_line @@ -2063,10 +2166,10 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) if current_line: @@ -2125,6 +2228,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: leaf.prefix = "" +def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: + """Make all string prefixes lowercase. + + If remove_u_prefix is given, also removes any u prefix from the string. + + Note: Mutates its argument. + """ + match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) + assert match is not None, f"failed to match string {leaf.value!r}" + orig_prefix = match.group(1) + new_prefix = orig_prefix.lower() + if remove_u_prefix: + new_prefix = new_prefix.replace("u", "") + leaf.value = f"{new_prefix}{match.group(2)}" + + def normalize_string_quotes(leaf: Leaf) -> None: """Prefer double quotes but only if it doesn't cause more escaping. @@ -2189,6 +2308,9 @@ def normalize_string_quotes(leaf: Leaf) -> None: def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: """Make existing optional parentheses invisible or create new ones. + `parens_after` is a set of string leaf values immeditely after which parens + should be put. + Standardizes on visible parentheses for single-element tuples, and keeps existing visible parentheses for other tuples and generator expressions. """ @@ -2219,6 +2341,7 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool: node.type != syms.atom or is_empty_tuple(node) or is_one_tuple(node) + or is_yield(node) or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY ): return False @@ -2269,12 +2392,33 @@ def is_one_tuple(node: LN) -> bool: ) +def is_yield(node: LN) -> bool: + """Return True if `node` holds a `yield` or `yield from` expression.""" + if node.type == syms.yield_expr: + return True + + if node.type == token.NAME and node.value == "yield": # type: ignore + return True + + if node.type != syms.atom: + return False + + if len(node.children) != 3: + return False + + lpar, expr, rpar = node.children + if lpar.type == token.LPAR and rpar.type == token.RPAR: + return is_yield(expr) + + return False + + def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: """Return True if `leaf` is a star or double star in a vararg or kwarg. If `within` includes VARARGS_PARENTS, this applies to function signatures. - If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right - hand-side extended iterable unpacking (PEP 3132) and additional unpacking + If `within` includes UNPACKING_PARENTS, it applies to right hand-side + extended iterable unpacking (PEP 3132) and additional unpacking generalizations (PEP 448). """ if leaf.type not in STARS or not leaf.parent: @@ -2292,6 +2436,35 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: return p.type in within +def is_stub_suite(node: Node) -> bool: + """Return True if `node` is a suite with a stub body.""" + if ( + len(node.children) != 4 + or node.children[0].type != token.NEWLINE + or node.children[1].type != token.INDENT + or node.children[3].type != token.DEDENT + ): + return False + + return is_stub_body(node.children[2]) + + +def is_stub_body(node: LN) -> bool: + """Return True if `node` is a simple statement containing an ellipsis.""" + if not isinstance(node, Node) or node.type != syms.simple_stmt: + return False + + if len(node.children) != 2: + return False + + child = node.children[0] + return ( + child.type == syms.atom + and len(child.children) == 3 + and all(leaf == Leaf(token.DOT, ".") for leaf in child.children) + ) + + def max_delimiter_priority_in_atom(node: LN) -> int: """Return maximum delimiter priority inside `node`. @@ -2362,7 +2535,103 @@ def is_python36(node: Node) -> bool: return False -PYTHON_EXTENSIONS = {".py"} +def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]: + """Generate sets of closing bracket IDs that should be omitted in a RHS. + + Brackets can be omitted if the entire trailer up to and including + a preceding closing bracket fits in one line. + + Yielded sets are cumulative (contain results of previous yields, too). First + set is empty. + """ + + omit: Set[LeafID] = set() + yield omit + + length = 4 * line.depth + opening_bracket = None + closing_bracket = None + optional_brackets: Set[LeafID] = set() + inner_brackets: Set[LeafID] = set() + for index, leaf in enumerate_reversed(line.leaves): + length += len(leaf.prefix) + len(leaf.value) + if length > line_length: + break + + comment: Optional[Leaf] + for comment in line.comments_after(leaf, index): + if "\n" in comment.prefix: + break # Oops, standalone comment! + + length += len(comment.value) + else: + comment = None + if comment is not None: + break # There was a standalone comment, we can't continue. + + optional_brackets.discard(id(leaf)) + if opening_bracket: + if leaf is opening_bracket: + opening_bracket = None + elif leaf.type in CLOSING_BRACKETS: + inner_brackets.add(id(leaf)) + elif leaf.type in CLOSING_BRACKETS: + if not leaf.value: + optional_brackets.add(id(opening_bracket)) + continue + + if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS: + # Empty brackets would fail a split so treat them as "inner" + # brackets (e.g. only add them to the `omit` set if another + # pair of brackets was good enough. + inner_brackets.add(id(leaf)) + continue + + opening_bracket = leaf.opening_bracket + if closing_bracket: + omit.add(id(closing_bracket)) + omit.update(inner_brackets) + inner_brackets.clear() + yield omit + closing_bracket = leaf + + +def get_future_imports(node: Node) -> Set[str]: + """Return a set of __future__ imports in the file.""" + imports = set() + for child in node.children: + if child.type != syms.simple_stmt: + break + first_child = child.children[0] + if isinstance(first_child, Leaf): + # Continue looking if we see a docstring; otherwise stop. + if ( + len(child.children) == 2 + and first_child.type == token.STRING + and child.children[1].type == token.NEWLINE + ): + continue + else: + break + elif first_child.type == syms.import_from: + module_name = first_child.children[1] + if not isinstance(module_name, Leaf) or module_name.value != "__future__": + break + for import_from_child in first_child.children[3:]: + if isinstance(import_from_child, Leaf): + if import_from_child.type == token.NAME: + imports.add(import_from_child.value) + else: + assert import_from_child.type == syms.import_as_names + for leaf in import_from_child.children: + if isinstance(leaf, Leaf) and leaf.type == token.NAME: + imports.add(leaf.value) + else: + break + return imports + + +PYTHON_EXTENSIONS = {".py", ".pyi"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" } @@ -2379,7 +2648,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: yield from gen_python_files_in_dir(child) - elif child.suffix in PYTHON_EXTENSIONS: + elif child.is_file() and child.suffix in PYTHON_EXTENSIONS: yield child @@ -2525,9 +2794,9 @@ def assert_equivalent(src: str, dst: str) -> None: ) from None -def assert_stable(src: str, dst: str, line_length: int) -> None: +def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, line_length=line_length) + newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi) if dst != newdst: log = dump_to_file( diff(src, dst, "source", "first pass"), @@ -2566,7 +2835,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) -def cancel(tasks: List[asyncio.Task]) -> None: +def cancel(tasks: Iterable[asyncio.Task]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: @@ -2604,6 +2873,28 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: return regex.sub(replacement, regex.sub(replacement, original)) +def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: + """Like `reversed(enumerate(sequence))` if that were possible.""" + index = len(sequence) - 1 + for element in reversed(sequence): + yield (index, element) + index -= 1 + + +def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool: + """Return True if `line` is no longer than `line_length`. + + Uses the provided `line_str` rendering, if any, otherwise computes a new one. + """ + if not line_str: + line_str = str(line).strip("\n") + return ( + len(line_str) <= line_length + and "\n" not in line_str # multiline strings + and not line.contains_standalone_comments() + ) + + CACHE_DIR = Path(user_cache_dir("black", version=__version__))