"""Builds on top of nodes.py to track brackets.""" from dataclasses import dataclass, field import sys from typing import Dict, Iterable, List, Optional, Tuple, Union if sys.version_info < (3, 8): from typing_extensions import Final else: from typing import Final from blib2to3.pytree import Leaf, Node from blib2to3.pgen2 import token from black.nodes import syms, is_vararg, VARARGS_PARENTS, UNPACKING_PARENTS from black.nodes import BRACKET, OPENING_BRACKETS, CLOSING_BRACKETS from black.nodes import MATH_OPERATORS, COMPARATORS, LOGIC_OPERATORS # types LN = Union[Leaf, Node] Depth = int LeafID = int NodeType = int Priority = int COMPREHENSION_PRIORITY: Final = 20 COMMA_PRIORITY: Final = 18 TERNARY_PRIORITY: Final = 16 LOGIC_PRIORITY: Final = 14 STRING_PRIORITY: Final = 12 COMPARATOR_PRIORITY: Final = 10 MATH_PRIORITIES: Final = { token.VBAR: 9, token.CIRCUMFLEX: 8, token.AMPER: 7, token.LEFTSHIFT: 6, token.RIGHTSHIFT: 6, token.PLUS: 5, token.MINUS: 5, token.STAR: 4, token.SLASH: 4, token.DOUBLESLASH: 4, token.PERCENT: 4, token.AT: 4, token.TILDE: 3, token.DOUBLESTAR: 2, } DOT_PRIORITY: Final = 1 class BracketMatchError(KeyError): """Raised when an opening bracket is unable to be matched to a closing bracket.""" @dataclass class BracketTracker: """Keeps track of brackets on a line.""" depth: int = 0 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) delimiters: Dict[LeafID, Priority] = field(default_factory=dict) previous: Optional[Leaf] = None _for_loop_depths: List[int] = field(default_factory=list) _lambda_argument_depths: List[int] = field(default_factory=list) invisible: List[Leaf] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. All leaves receive an int `bracket_depth` field that stores how deep within brackets a given leaf is. 0 means there are no enclosing brackets that started on this line. If a leaf is itself a closing bracket, it receives an `opening_bracket` field that it forms a pair with. This is a one-directional link to avoid reference cycles. If a leaf is a delimiter (a token on which Black can split the line if needed) and it's on depth 0, its `id()` is stored in the tracker's `delimiters` field. """ if leaf.type == token.COMMENT: return self.maybe_decrement_after_for_loop_variable(leaf) self.maybe_decrement_after_lambda_arguments(leaf) if leaf.type in CLOSING_BRACKETS: self.depth -= 1 try: opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) except KeyError as e: raise BracketMatchError( "Unable to match a closing bracket to the following opening" f" bracket: {leaf}" ) from e leaf.opening_bracket = opening_bracket if not leaf.value: self.invisible.append(leaf) leaf.bracket_depth = self.depth if self.depth == 0: delim = is_split_before_delimiter(leaf, self.previous) if delim and self.previous is not None: self.delimiters[id(self.previous)] = delim else: delim = is_split_after_delimiter(leaf, self.previous) if delim: self.delimiters[id(leaf)] = delim if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 if not leaf.value: self.invisible.append(leaf) self.previous = leaf self.maybe_increment_lambda_arguments(leaf) self.maybe_increment_for_loop_variable(leaf) def any_open_brackets(self) -> bool: """Return True if there is an yet unmatched open bracket on the line.""" return bool(self.bracket_match) def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority: """Return the highest priority of a delimiter found on the line. Values are consistent with what `is_split_*_delimiter()` return. Raises ValueError on no delimiters. """ return max(v for k, v in self.delimiters.items() if k not in exclude) def delimiter_count_with_priority(self, priority: Priority = 0) -> int: """Return the number of delimiters with the given `priority`. If no `priority` is passed, defaults to max priority on the line. """ if not self.delimiters: return 0 priority = priority or self.max_delimiter_priority() return sum(1 for p in self.delimiters.values() if p == priority) def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool: """In a for loop, or comprehension, the variables are often unpacks. To avoid splitting on the comma in this situation, increase the depth of tokens between `for` and `in`. """ if leaf.type == token.NAME and leaf.value == "for": self.depth += 1 self._for_loop_depths.append(self.depth) return True return False def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool: """See `maybe_increment_for_loop_variable` above for explanation.""" if ( self._for_loop_depths and self._for_loop_depths[-1] == self.depth and leaf.type == token.NAME and leaf.value == "in" ): self.depth -= 1 self._for_loop_depths.pop() return True return False def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool: """In a lambda expression, there might be more than one argument. To avoid splitting on the comma in this situation, increase the depth of tokens between `lambda` and `:`. """ if leaf.type == token.NAME and leaf.value == "lambda": self.depth += 1 self._lambda_argument_depths.append(self.depth) return True return False def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool: """See `maybe_increment_lambda_arguments` above for explanation.""" if ( self._lambda_argument_depths and self._lambda_argument_depths[-1] == self.depth and leaf.type == token.COLON ): self.depth -= 1 self._lambda_argument_depths.pop() return True return False def get_open_lsqb(self) -> Optional[Leaf]: """Return the most recent opening square bracket (if any).""" return self.bracket_match.get((self.depth - 1, token.RSQB)) def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break after it. The delimiter priorities returned here are from those delimiters that would cause a line break after themselves. Higher numbers are higher priority. """ if leaf.type == token.COMMA: return COMMA_PRIORITY return 0 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority: """Return the priority of the `leaf` delimiter, given a line break before it. The delimiter priorities returned here are from those delimiters that would cause a line break before themselves. Higher numbers are higher priority. """ if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS): # * and ** might also be MATH_OPERATORS but in this case they are not. # Don't treat them as a delimiter. return 0 if ( leaf.type == token.DOT and leaf.parent and leaf.parent.type not in {syms.import_from, syms.dotted_name} and (previous is None or previous.type in CLOSING_BRACKETS) ): return DOT_PRIORITY if ( leaf.type in MATH_OPERATORS and leaf.parent and leaf.parent.type not in {syms.factor, syms.star_expr} ): return MATH_PRIORITIES[leaf.type] if leaf.type in COMPARATORS: return COMPARATOR_PRIORITY if ( leaf.type == token.STRING and previous is not None and previous.type == token.STRING ): return STRING_PRIORITY if leaf.type not in {token.NAME, token.ASYNC}: return 0 if ( leaf.value == "for" and leaf.parent and leaf.parent.type in {syms.comp_for, syms.old_comp_for} or leaf.type == token.ASYNC ): if ( not isinstance(leaf.prev_sibling, Leaf) or leaf.prev_sibling.value != "async" ): return COMPREHENSION_PRIORITY if ( leaf.value == "if" and leaf.parent and leaf.parent.type in {syms.comp_if, syms.old_comp_if} ): return COMPREHENSION_PRIORITY if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test: return TERNARY_PRIORITY if leaf.value == "is": return COMPARATOR_PRIORITY if ( leaf.value == "in" and leaf.parent and leaf.parent.type in {syms.comp_op, syms.comparison} and not ( previous is not None and previous.type == token.NAME and previous.value == "not" ) ): return COMPARATOR_PRIORITY if ( leaf.value == "not" and leaf.parent and leaf.parent.type == syms.comp_op and not ( previous is not None and previous.type == token.NAME and previous.value == "is" ) ): return COMPARATOR_PRIORITY if leaf.value in LOGIC_OPERATORS and leaf.parent: return LOGIC_PRIORITY return 0 def max_delimiter_priority_in_atom(node: LN) -> Priority: """Return maximum delimiter priority inside `node`. This is specific to atoms with contents contained in a pair of parentheses. If `node` isn't an atom or there are no enclosing parentheses, returns 0. """ if node.type != syms.atom: return 0 first = node.children[0] last = node.children[-1] if not (first.type == token.LPAR and last.type == token.RPAR): return 0 bt = BracketTracker() for c in node.children[1:-1]: if isinstance(c, Leaf): bt.mark(c) else: for leaf in c.leaves(): bt.mark(leaf) try: return bt.max_delimiter_priority() except ValueError: return 0