X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/09f5ee3a19f4274bb848324867bd8e68724cf851..3eab6d3131acd14b5900519d8447c8a1152e6d87:/black.py?ds=inline diff --git a/black.py b/black.py index 21e3743..7823ae0 100644 --- a/black.py +++ b/black.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - import asyncio import pickle from asyncio.base_events import BaseEventLoop @@ -26,6 +24,7 @@ from typing import ( List, Optional, Pattern, + Sequence, Set, Tuple, Type, @@ -43,8 +42,10 @@ from blib2to3 import pygram, pytree from blib2to3.pgen2 import driver, token from blib2to3.pgen2.parse import ParseError -__version__ = "18.4a2" + +__version__ = "18.4a6" DEFAULT_LINE_LENGTH = 88 + # types syms = pygram.python_symbols FileContent = str @@ -88,11 +89,11 @@ class FormatError(Exception): self.consumed = consumed def trim_prefix(self, leaf: Leaf) -> None: - leaf.prefix = leaf.prefix[self.consumed:] + leaf.prefix = leaf.prefix[self.consumed :] def leaf_from_consumed(self, leaf: Leaf) -> Leaf: """Returns a new Leaf from the consumed part of the prefix.""" - unformatted_prefix = leaf.prefix[:self.consumed] + unformatted_prefix = leaf.prefix[: self.consumed] return Leaf(token.NEWLINE, unformatted_prefix) @@ -193,6 +194,7 @@ def main( write_back = WriteBack.YES report = Report(check=check, quiet=quiet) if len(sources) == 0: + out("No paths given. Nothing to do 😴") ctx.exit(0) return @@ -244,7 +246,7 @@ def reformat_one( ) ): changed = Changed.YES - if write_back != WriteBack.DIFF and changed is not Changed.NO: + if write_back == WriteBack.YES and changed is not Changed.NO: write_cache(cache, [src], line_length) report.done(src, changed) except Exception as exc: @@ -311,7 +313,7 @@ async def schedule_formatting( if cancelled: await asyncio.gather(*cancelled, loop=loop, return_exceptions=True) - if write_back != WriteBack.DIFF and formatted: + if write_back == WriteBack.YES and formatted: write_cache(cache, formatted, line_length) @@ -409,9 +411,10 @@ def format_str(src_contents: str, line_length: int) -> FileContent: """ src_node = lib2to3_parse(src_contents) dst_contents = "" - lines = LineGenerator() - elt = EmptyLineTracker() + future_imports = get_future_imports(src_node) py36 = is_python36(src_node) + lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports) + elt = EmptyLineTracker() empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -553,19 +556,20 @@ COMPARATORS = { token.GREATEREQUAL, } MATH_OPERATORS = { + token.VBAR, + token.CIRCUMFLEX, + token.AMPER, + token.LEFTSHIFT, + token.RIGHTSHIFT, token.PLUS, token.MINUS, token.STAR, token.SLASH, - token.VBAR, - token.AMPER, + token.DOUBLESLASH, token.PERCENT, - token.CIRCUMFLEX, + token.AT, token.TILDE, - token.LEFTSHIFT, - token.RIGHTSHIFT, token.DOUBLESTAR, - token.DOUBLESLASH, } STARS = {token.STAR, token.DOUBLESTAR} VARARGS_PARENTS = { @@ -581,13 +585,61 @@ UNPACKING_PARENTS = { syms.listmaker, syms.testlist_gexp, } +TEST_DESCENDANTS = { + syms.test, + syms.lambdef, + syms.or_test, + syms.and_test, + syms.not_test, + syms.comparison, + syms.star_expr, + syms.expr, + syms.xor_expr, + syms.and_expr, + syms.shift_expr, + syms.arith_expr, + syms.trailer, + syms.term, + syms.power, +} +ASSIGNMENTS = { + "=", + "+=", + "-=", + "*=", + "@=", + "/=", + "%=", + "&=", + "|=", + "^=", + "<<=", + ">>=", + "**=", + "//=", +} COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 10 -TERNARY_PRIORITY = 7 -LOGIC_PRIORITY = 5 -STRING_PRIORITY = 4 -COMPARATOR_PRIORITY = 3 -MATH_PRIORITY = 1 +COMMA_PRIORITY = 18 +TERNARY_PRIORITY = 16 +LOGIC_PRIORITY = 14 +STRING_PRIORITY = 12 +COMPARATOR_PRIORITY = 10 +MATH_PRIORITIES = { + token.VBAR: 8, + token.CIRCUMFLEX: 7, + token.AMPER: 6, + token.LEFTSHIFT: 5, + token.RIGHTSHIFT: 5, + token.PLUS: 4, + token.MINUS: 4, + token.STAR: 3, + token.SLASH: 3, + token.DOUBLESLASH: 3, + token.PERCENT: 3, + token.AT: 3, + token.TILDE: 2, + token.DOUBLESTAR: 1, +} @dataclass @@ -598,8 +650,8 @@ class BracketTracker: bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) delimiters: Dict[LeafID, Priority] = Factory(dict) previous: Optional[Leaf] = None - _for_loop_variable: bool = False - _lambda_arguments: bool = False + _for_loop_variable: int = 0 + _lambda_arguments: int = 0 def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -661,7 +713,7 @@ class BracketTracker: """ if leaf.type == token.NAME and leaf.value == "for": self.depth += 1 - self._for_loop_variable = True + self._for_loop_variable += 1 return True return False @@ -670,7 +722,7 @@ class BracketTracker: """See `maybe_increment_for_loop_variable` above for explanation.""" if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in": self.depth -= 1 - self._for_loop_variable = False + self._for_loop_variable -= 1 return True return False @@ -683,7 +735,7 @@ class BracketTracker: """ if leaf.type == token.NAME and leaf.value == "lambda": self.depth += 1 - self._lambda_arguments = True + self._lambda_arguments += 1 return True return False @@ -692,11 +744,15 @@ class BracketTracker: """See `maybe_increment_lambda_arguments` above for explanation.""" if self._lambda_arguments and leaf.type == token.COLON: self.depth -= 1 - self._lambda_arguments = False + self._lambda_arguments -= 1 return True return False + def get_open_lsqb(self) -> Optional[Leaf]: + """Return the most recent opening square bracket (if any).""" + return self.bracket_match.get((self.depth - 1, token.RSQB)) + @dataclass class Line: @@ -722,14 +778,17 @@ class Line: if not has_value: return + if token.COLON == leaf.type and self.is_class_paren_empty: + del self.leaves[-2:] if self.leaves and not preformatted: # Note: at this point leaf.prefix should be empty except for # imports, for which we only preserve newlines. - leaf.prefix += whitespace(leaf) + leaf.prefix += whitespace( + leaf, complex_subscript=self.is_complex_subscript(leaf) + ) if self.inside_brackets or not preformatted: self.bracket_tracker.mark(leaf) self.maybe_remove_trailing_comma(leaf) - if not self.append_comment(leaf): self.leaves.append(leaf) @@ -817,6 +876,22 @@ class Line: and self.leaves[0].value == "yield" ) + @property + def is_class_paren_empty(self) -> bool: + """Is this a class with no base classes but using parentheses? + + Those are unnecessary and should be removed. + """ + return ( + bool(self) + and len(self.leaves) == 4 + and self.is_class + and self.leaves[2].type == token.LPAR + and self.leaves[2].value == "(" + and self.leaves[3].type == token.RPAR + and self.leaves[3].value == ")" + ) + def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: @@ -845,9 +920,14 @@ class Line: self.remove_trailing_comma() return True - # For parens let's check if it's safe to remove the comma. If the - # trailing one is the only one, we might mistakenly change a tuple - # into a different type by removing the comma. + # For parens let's check if it's safe to remove the comma. + # Imports are always safe. + if self.is_import: + self.remove_trailing_comma() + return True + + # Otheriwsse, if the trailing one is the only one, we might mistakenly + # change a tuple into a different type by removing the comma. depth = closing.bracket_depth + 1 commas = 0 opening = closing.opening_bracket @@ -858,7 +938,7 @@ class Line: else: return False - for leaf in self.leaves[_opening_index + 1:]: + for leaf in self.leaves[_opening_index + 1 :]: if leaf is closing: break @@ -897,17 +977,21 @@ class Line: self.comments.append((after, comment)) return True - def comments_after(self, leaf: Leaf) -> Iterator[Leaf]: - """Generate comments that should appear directly after `leaf`.""" - for _leaf_index, _leaf in enumerate(self.leaves): - if leaf is _leaf: - break + def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]: + """Generate comments that should appear directly after `leaf`. - else: - return + Provide a non-negative leaf `_index` to speed up the function. + """ + if _index == -1: + for _index, _leaf in enumerate(self.leaves): + if leaf is _leaf: + break + + else: + return for index, comment_after in self.comments: - if _leaf_index == index: + if _index == index: yield comment_after def remove_trailing_comma(self) -> None: @@ -919,6 +1003,25 @@ class Line: self.comments[i] = (comma_index - 1, comment) self.leaves.pop() + def is_complex_subscript(self, leaf: Leaf) -> bool: + """Return True iff `leaf` is part of a slice with non-trivial exprs.""" + open_lsqb = ( + leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb() + ) + if open_lsqb is None: + return False + + subscript_start = open_lsqb.next_sibling + if ( + isinstance(subscript_start, Node) + and subscript_start.type == syms.subscriptlist + ): + subscript_start = child_towards(subscript_start, leaf) + return ( + subscript_start is not None + and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()) + ) + def __str__(self) -> str: """Render the line.""" if not self: @@ -1075,6 +1178,7 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ current_line: Line = Factory(Line) + remove_u_prefix: bool = False def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]: """Generate a line. @@ -1142,6 +1246,7 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) if node.type == token.STRING: + normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) @@ -1173,14 +1278,13 @@ class LineGenerator(Visitor[Line]): """Visit a statement. This implementation is shared for `if`, `while`, `for`, `try`, `except`, - `def`, `with`, `class`, and `assert`. + `def`, `with`, `class`, `assert` and assignments. The relevant Python language `keywords` for a given statement will be NAME leaves within it. This methods puts those on a separate line. - `parens` holds pairs of nodes where invisible parentheses should be put. - Keys hold nodes after which opening parentheses should be put, values - hold nodes before which closing parentheses should be put. + `parens` holds a set of string leaf values immeditely after which + invisible parens should be put. """ normalize_invisible_parens(node, parens_after=parens) for child in node.children: @@ -1280,7 +1384,9 @@ class LineGenerator(Visitor[Line]): v = self.visit_stmt Ø: Set[str] = set() self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) - self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"}) + self.visit_if_stmt = partial( + v, keywords={"if", "else", "elif"}, parens={"if", "elif"} + ) self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"}) self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"}) self.visit_try_stmt = partial( @@ -1290,6 +1396,8 @@ class LineGenerator(Visitor[Line]): self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø) self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) + self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) + self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) self.visit_async_funcdef = self.visit_async_stmt self.visit_decorated = self.visit_decorators @@ -1302,8 +1410,12 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT} -def whitespace(leaf: Leaf) -> str: # noqa C901 - """Return whitespace prefix if needed for the given `leaf`.""" +def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901 + """Return whitespace prefix if needed for the given `leaf`. + + `complex_subscript` signals whether the given leaf is part of a subscription + which has non-trivial arguments, like arithmetic expressions or function calls. + """ NO = "" SPACE = " " DOUBLESPACE = " " @@ -1317,7 +1429,10 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return DOUBLESPACE assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}" - if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}: + if ( + t == token.COLON + and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop} + ): return NO prev = leaf.prev_sibling @@ -1327,7 +1442,13 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO if t == token.COLON: - return SPACE if prevp.type == token.COMMA else NO + if prevp.type == token.COLON: + return NO + + elif prevp.type != token.COMMA and not complex_subscript: + return NO + + return SPACE if prevp.type == token.EQUAL: if prevp.parent: @@ -1348,7 +1469,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 elif prevp.type == token.COLON: if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}: - return NO + return SPACE if complex_subscript else NO elif ( prevp.parent @@ -1454,7 +1575,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 if prev and prev.type == token.LPAR: return NO - elif p.type == syms.subscript: + elif p.type in {syms.subscript, syms.sliceop}: # indexing if not prev: assert p.parent is not None, "subscripts are always parented" @@ -1463,7 +1584,7 @@ def whitespace(leaf: Leaf) -> str: # noqa C901 return NO - else: + elif not complex_subscript: return NO elif p.type == syms.atom: @@ -1533,6 +1654,14 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None +def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]: + """Return the child of `ancestor` that contains `descendant`.""" + node: Optional[LN] = descendant + while node and node.parent != ancestor: + node = node.parent + return node + + def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int: """Return the priority of the `leaf` delimiter, given a line break after it. @@ -1565,7 +1694,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int: and leaf.parent and leaf.parent.type not in {syms.factor, syms.star_expr} ): - return MATH_PRIORITY + return MATH_PRIORITIES[leaf.type] if leaf.type in COMPARATORS: return COMPARATOR_PRIORITY @@ -1701,11 +1830,7 @@ def split_line( return line_str = str(line).strip("\n") - if ( - len(line_str) <= line_length - and "\n" not in line_str # multiline strings - and not line.contains_standalone_comments() - ): + if is_line_short_enough(line, line_length=line_length, line_str=line_str): yield line return @@ -1714,10 +1839,22 @@ def split_line( split_funcs = [left_hand_split] elif line.is_import: split_funcs = [explode_split] - elif line.inside_brackets: - split_funcs = [delimiter_split, standalone_comment_split, right_hand_split] else: - split_funcs = [right_hand_split] + + def rhs(line: Line, py36: bool = False) -> Iterator[Line]: + for omit in generate_trailers_to_omit(line, line_length): + lines = list(right_hand_split(line, py36, omit=omit)) + if is_line_short_enough(lines[0], line_length=line_length): + yield from lines + return + + # All splits failed, best effort split with no omits. + yield from right_hand_split(line, py36) + + if line.inside_brackets: + split_funcs = [delimiter_split, standalone_comment_split, rhs] + else: + split_funcs = [rhs] for split_func in split_funcs: # We are accumulating lines in `result` because we might want to abort # mission and return the original line in the end, or attempt a different @@ -1746,7 +1883,8 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. - Prefer RHS otherwise. + Prefer RHS otherwise. This is why this function is not symmetrical with + :func:`right_hand_split` which also handles optional parentheses. """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) @@ -1786,7 +1924,12 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]: def right_hand_split( line: Line, py36: bool = False, omit: Collection[LeafID] = () ) -> Iterator[Line]: - """Split line into many lines, starting with the last matching bracket pair.""" + """Split line into many lines, starting with the last matching bracket pair. + + If the split was by optional parentheses, attempt splitting without them, too. + `omit` is a collection of closing bracket IDs that shouldn't be considered for + this split. + """ head = Line(depth=line.depth) body = Line(depth=line.depth + 1, inside_brackets=True) tail = Line(depth=line.depth) @@ -1825,20 +1968,25 @@ def right_hand_split( bracket_split_succeeded_or_raise(head, body, tail) assert opening_bracket and closing_bracket if ( + # the opening bracket is an optional paren opening_bracket.type == token.LPAR and not opening_bracket.value + # the closing bracket is an optional paren and closing_bracket.type == token.RPAR and not closing_bracket.value + # there are no delimiters or standalone comments in the body + and not body.bracket_tracker.delimiters + and not line.contains_standalone_comments(0) + # and it's not an import (optional parens are the only thing we can split + # on in this case; attempting a split without them is a waste of time) + and not line.is_import ): - # These parens were optional. If there aren't any delimiters or standalone - # comments in the body, they were unnecessary and another split without - # them should be attempted. - if not ( - body.bracket_tracker.delimiters or line.contains_standalone_comments(0) - ): - omit = {id(closing_bracket), *omit} + omit = {id(closing_bracket), *omit} + try: yield from right_hand_split(line, py36=py36, omit=omit) return + except CannotSplit: + pass ensure_visible(opening_bracket) ensure_visible(closing_bracket) @@ -1923,10 +2071,10 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) lowest_depth = min(lowest_depth, leaf.bracket_depth) @@ -1970,10 +2118,10 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) current_line.append(leaf) - for leaf in line.leaves: + for index, leaf in enumerate(line.leaves): yield from append_to_line(leaf) - for comment_after in line.comments_after(leaf): + for comment_after in line.comments_after(leaf, index): yield from append_to_line(comment_after) if current_line: @@ -1983,15 +2131,17 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]: def explode_split( line: Line, py36: bool = False, omit: Collection[LeafID] = () ) -> Iterator[Line]: - """Split by RHS and immediately split contents by a delimiter.""" + """Split by rightmost bracket and immediately split contents by a delimiter.""" new_lines = list(right_hand_split(line, py36, omit)) if len(new_lines) != 3: yield from new_lines return yield new_lines[0] + try: yield from delimiter_split(new_lines[1], py36) + except CannotSplit: yield new_lines[1] @@ -2030,6 +2180,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: leaf.prefix = "" +def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: + """Make all string prefixes lowercase. + + If remove_u_prefix is given, also removes any u prefix from the string. + + Note: Mutates its argument. + """ + match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) + assert match is not None, f"failed to match string {leaf.value!r}" + orig_prefix = match.group(1) + new_prefix = orig_prefix.lower() + if remove_u_prefix: + new_prefix = new_prefix.replace("u", "") + leaf.value = f"{new_prefix}{match.group(2)}" + + def normalize_string_quotes(leaf: Leaf) -> None: """Prefer double quotes but only if it doesn't cause more escaping. @@ -2059,7 +2225,7 @@ def normalize_string_quotes(leaf: Leaf) -> None: unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}") escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}") escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}") - body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)] + body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)] if "r" in prefix.casefold(): if unescaped_new_quote.search(body): # There's at least one unescaped new_quote in this raw string @@ -2094,6 +2260,9 @@ def normalize_string_quotes(leaf: Leaf) -> None: def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: """Make existing optional parentheses invisible or create new ones. + `parens_after` is a set of string leaf values immeditely after which parens + should be put. + Standardizes on visible parentheses for single-element tuples, and keeps existing visible parentheses for other tuples and generator expressions. """ @@ -2101,17 +2270,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: for child in list(node.children): if check_lpar: if child.type == syms.atom: - if not ( - is_empty_tuple(child) - or is_one_tuple(child) - or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY - ): - first = child.children[0] - last = child.children[-1] - if first.type == token.LPAR and last.type == token.RPAR: - # make parentheses invisible - first.value = "" # type: ignore - last.value = "" # type: ignore + maybe_make_parens_invisible_in_atom(child) elif is_one_tuple(child): # wrap child in visible parentheses lpar = Leaf(token.LPAR, "(") @@ -2128,6 +2287,30 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = isinstance(child, Leaf) and child.value in parens_after +def maybe_make_parens_invisible_in_atom(node: LN) -> bool: + """If it's safe, make the parens in the atom `node` invisible, recusively.""" + if ( + node.type != syms.atom + or is_empty_tuple(node) + or is_one_tuple(node) + or is_yield(node) + or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY + ): + return False + + first = node.children[0] + last = node.children[-1] + if first.type == token.LPAR and last.type == token.RPAR: + # make parentheses invisible + first.value = "" # type: ignore + last.value = "" # type: ignore + if len(node.children) > 1: + maybe_make_parens_invisible_in_atom(node.children[1]) + return True + + return False + + def is_empty_tuple(node: LN) -> bool: """Return True if `node` holds an empty tuple.""" return ( @@ -2161,12 +2344,33 @@ def is_one_tuple(node: LN) -> bool: ) +def is_yield(node: LN) -> bool: + """Return True if `node` holds a `yield` or `yield from` expression.""" + if node.type == syms.yield_expr: + return True + + if node.type == token.NAME and node.value == "yield": # type: ignore + return True + + if node.type != syms.atom: + return False + + if len(node.children) != 3: + return False + + lpar, expr, rpar = node.children + if lpar.type == token.LPAR and rpar.type == token.RPAR: + return is_yield(expr) + + return False + + def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: """Return True if `leaf` is a star or double star in a vararg or kwarg. If `within` includes VARARGS_PARENTS, this applies to function signatures. - If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right - hand-side extended iterable unpacking (PEP 3132) and additional unpacking + If `within` includes UNPACKING_PARENTS, it applies to right hand-side + extended iterable unpacking (PEP 3132) and additional unpacking generalizations (PEP 448). """ if leaf.type not in STARS or not leaf.parent: @@ -2229,7 +2433,7 @@ def is_python36(node: Node) -> bool: Currently looking for: - f-strings; and - - trailing commas after * or ** in function signatures. + - trailing commas after * or ** in function signatures and calls. """ for n in node.pre_order(): if n.type == token.STRING: @@ -2238,7 +2442,7 @@ def is_python36(node: Node) -> bool: return True elif ( - n.type == syms.typedargslist + n.type in {syms.typedargslist, syms.arglist} and n.children and n.children[-1].type == token.COMMA ): @@ -2246,9 +2450,110 @@ def is_python36(node: Node) -> bool: if ch.type in STARS: return True + if ch.type == syms.argument: + for argch in ch.children: + if argch.type in STARS: + return True + return False +def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]: + """Generate sets of closing bracket IDs that should be omitted in a RHS. + + Brackets can be omitted if the entire trailer up to and including + a preceding closing bracket fits in one line. + + Yielded sets are cumulative (contain results of previous yields, too). First + set is empty. + """ + + omit: Set[LeafID] = set() + yield omit + + length = 4 * line.depth + opening_bracket = None + closing_bracket = None + optional_brackets: Set[LeafID] = set() + inner_brackets: Set[LeafID] = set() + for index, leaf in enumerate_reversed(line.leaves): + length += len(leaf.prefix) + len(leaf.value) + if length > line_length: + break + + comment: Optional[Leaf] + for comment in line.comments_after(leaf, index): + if "\n" in comment.prefix: + break # Oops, standalone comment! + + length += len(comment.value) + else: + comment = None + if comment is not None: + break # There was a standalone comment, we can't continue. + + optional_brackets.discard(id(leaf)) + if opening_bracket: + if leaf is opening_bracket: + opening_bracket = None + elif leaf.type in CLOSING_BRACKETS: + inner_brackets.add(id(leaf)) + elif leaf.type in CLOSING_BRACKETS: + if not leaf.value: + optional_brackets.add(id(opening_bracket)) + continue + + if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS: + # Empty brackets would fail a split so treat them as "inner" + # brackets (e.g. only add them to the `omit` set if another + # pair of brackets was good enough. + inner_brackets.add(id(leaf)) + continue + + opening_bracket = leaf.opening_bracket + if closing_bracket: + omit.add(id(closing_bracket)) + omit.update(inner_brackets) + inner_brackets.clear() + yield omit + closing_bracket = leaf + + +def get_future_imports(node: Node) -> Set[str]: + """Return a set of __future__ imports in the file.""" + imports = set() + for child in node.children: + if child.type != syms.simple_stmt: + break + first_child = child.children[0] + if isinstance(first_child, Leaf): + # Continue looking if we see a docstring; otherwise stop. + if ( + len(child.children) == 2 + and first_child.type == token.STRING + and child.children[1].type == token.NEWLINE + ): + continue + else: + break + elif first_child.type == syms.import_from: + module_name = first_child.children[1] + if not isinstance(module_name, Leaf) or module_name.value != "__future__": + break + for import_from_child in first_child.children[3:]: + if isinstance(import_from_child, Leaf): + if import_from_child.type == token.NAME: + imports.add(import_from_child.value) + else: + assert import_from_child.type == syms.import_as_names + for leaf in import_from_child.children: + if isinstance(leaf, Leaf) and leaf.type == token.NAME: + imports.add(leaf.value) + else: + break + return imports + + PYTHON_EXTENSIONS = {".py"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" @@ -2266,7 +2571,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]: yield from gen_python_files_in_dir(child) - elif child.suffix in PYTHON_EXTENSIONS: + elif child.is_file() and child.suffix in PYTHON_EXTENSIONS: yield child @@ -2491,6 +2796,28 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str: return regex.sub(replacement, regex.sub(replacement, original)) +def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: + """Like `reversed(enumerate(sequence))` if that were possible.""" + index = len(sequence) - 1 + for element in reversed(sequence): + yield (index, element) + index -= 1 + + +def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool: + """Return True if `line` is no longer than `line_length`. + + Uses the provided `line_str` rendering, if any, otherwise computes a new one. + """ + if not line_str: + line_str = str(line).strip("\n") + return ( + len(line_str) <= line_length + and "\n" not in line_str # multiline strings + and not line.contains_standalone_comments() + ) + + CACHE_DIR = Path(user_cache_dir("black", version=__version__))