]> git.madduck.net Git - etc/vim.git/blob - src/black/brackets.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove `$`, `>>>` and other prompt prefixes when code copied from the… (#3884)
[etc/vim.git] / src / black / brackets.py
1 """Builds on top of nodes.py to track brackets."""
2
3 from dataclasses import dataclass, field
4 from typing import Dict, Final, Iterable, List, Optional, Sequence, Set, Tuple, Union
5
6 from black.nodes import (
7     BRACKET,
8     CLOSING_BRACKETS,
9     COMPARATORS,
10     LOGIC_OPERATORS,
11     MATH_OPERATORS,
12     OPENING_BRACKETS,
13     UNPACKING_PARENTS,
14     VARARGS_PARENTS,
15     is_vararg,
16     syms,
17 )
18 from blib2to3.pgen2 import token
19 from blib2to3.pytree import Leaf, Node
20
21 # types
22 LN = Union[Leaf, Node]
23 Depth = int
24 LeafID = int
25 NodeType = int
26 Priority = int
27
28
29 COMPREHENSION_PRIORITY: Final = 20
30 COMMA_PRIORITY: Final = 18
31 TERNARY_PRIORITY: Final = 16
32 LOGIC_PRIORITY: Final = 14
33 STRING_PRIORITY: Final = 12
34 COMPARATOR_PRIORITY: Final = 10
35 MATH_PRIORITIES: Final = {
36     token.VBAR: 9,
37     token.CIRCUMFLEX: 8,
38     token.AMPER: 7,
39     token.LEFTSHIFT: 6,
40     token.RIGHTSHIFT: 6,
41     token.PLUS: 5,
42     token.MINUS: 5,
43     token.STAR: 4,
44     token.SLASH: 4,
45     token.DOUBLESLASH: 4,
46     token.PERCENT: 4,
47     token.AT: 4,
48     token.TILDE: 3,
49     token.DOUBLESTAR: 2,
50 }
51 DOT_PRIORITY: Final = 1
52
53
54 class BracketMatchError(Exception):
55     """Raised when an opening bracket is unable to be matched to a closing bracket."""
56
57
58 @dataclass
59 class BracketTracker:
60     """Keeps track of brackets on a line."""
61
62     depth: int = 0
63     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
64     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
65     previous: Optional[Leaf] = None
66     _for_loop_depths: List[int] = field(default_factory=list)
67     _lambda_argument_depths: List[int] = field(default_factory=list)
68     invisible: List[Leaf] = field(default_factory=list)
69
70     def mark(self, leaf: Leaf) -> None:
71         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
72
73         All leaves receive an int `bracket_depth` field that stores how deep
74         within brackets a given leaf is. 0 means there are no enclosing brackets
75         that started on this line.
76
77         If a leaf is itself a closing bracket and there is a matching opening
78         bracket earlier, it receives an `opening_bracket` field with which it forms a
79         pair. This is a one-directional link to avoid reference cycles. Closing
80         bracket without opening happens on lines continued from previous
81         breaks, e.g. `) -> "ReturnType":` as part of a funcdef where we place
82         the return type annotation on its own line of the previous closing RPAR.
83
84         If a leaf is a delimiter (a token on which Black can split the line if
85         needed) and it's on depth 0, its `id()` is stored in the tracker's
86         `delimiters` field.
87         """
88         if leaf.type == token.COMMENT:
89             return
90
91         if (
92             self.depth == 0
93             and leaf.type in CLOSING_BRACKETS
94             and (self.depth, leaf.type) not in self.bracket_match
95         ):
96             return
97
98         self.maybe_decrement_after_for_loop_variable(leaf)
99         self.maybe_decrement_after_lambda_arguments(leaf)
100         if leaf.type in CLOSING_BRACKETS:
101             self.depth -= 1
102             try:
103                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
104             except KeyError as e:
105                 raise BracketMatchError(
106                     "Unable to match a closing bracket to the following opening"
107                     f" bracket: {leaf}"
108                 ) from e
109             leaf.opening_bracket = opening_bracket
110             if not leaf.value:
111                 self.invisible.append(leaf)
112         leaf.bracket_depth = self.depth
113         if self.depth == 0:
114             delim = is_split_before_delimiter(leaf, self.previous)
115             if delim and self.previous is not None:
116                 self.delimiters[id(self.previous)] = delim
117             else:
118                 delim = is_split_after_delimiter(leaf, self.previous)
119                 if delim:
120                     self.delimiters[id(leaf)] = delim
121         if leaf.type in OPENING_BRACKETS:
122             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
123             self.depth += 1
124             if not leaf.value:
125                 self.invisible.append(leaf)
126         self.previous = leaf
127         self.maybe_increment_lambda_arguments(leaf)
128         self.maybe_increment_for_loop_variable(leaf)
129
130     def any_open_brackets(self) -> bool:
131         """Return True if there is an yet unmatched open bracket on the line."""
132         return bool(self.bracket_match)
133
134     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
135         """Return the highest priority of a delimiter found on the line.
136
137         Values are consistent with what `is_split_*_delimiter()` return.
138         Raises ValueError on no delimiters.
139         """
140         return max(v for k, v in self.delimiters.items() if k not in exclude)
141
142     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
143         """Return the number of delimiters with the given `priority`.
144
145         If no `priority` is passed, defaults to max priority on the line.
146         """
147         if not self.delimiters:
148             return 0
149
150         priority = priority or self.max_delimiter_priority()
151         return sum(1 for p in self.delimiters.values() if p == priority)
152
153     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
154         """In a for loop, or comprehension, the variables are often unpacks.
155
156         To avoid splitting on the comma in this situation, increase the depth of
157         tokens between `for` and `in`.
158         """
159         if leaf.type == token.NAME and leaf.value == "for":
160             self.depth += 1
161             self._for_loop_depths.append(self.depth)
162             return True
163
164         return False
165
166     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
167         """See `maybe_increment_for_loop_variable` above for explanation."""
168         if (
169             self._for_loop_depths
170             and self._for_loop_depths[-1] == self.depth
171             and leaf.type == token.NAME
172             and leaf.value == "in"
173         ):
174             self.depth -= 1
175             self._for_loop_depths.pop()
176             return True
177
178         return False
179
180     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
181         """In a lambda expression, there might be more than one argument.
182
183         To avoid splitting on the comma in this situation, increase the depth of
184         tokens between `lambda` and `:`.
185         """
186         if leaf.type == token.NAME and leaf.value == "lambda":
187             self.depth += 1
188             self._lambda_argument_depths.append(self.depth)
189             return True
190
191         return False
192
193     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
194         """See `maybe_increment_lambda_arguments` above for explanation."""
195         if (
196             self._lambda_argument_depths
197             and self._lambda_argument_depths[-1] == self.depth
198             and leaf.type == token.COLON
199         ):
200             self.depth -= 1
201             self._lambda_argument_depths.pop()
202             return True
203
204         return False
205
206     def get_open_lsqb(self) -> Optional[Leaf]:
207         """Return the most recent opening square bracket (if any)."""
208         return self.bracket_match.get((self.depth - 1, token.RSQB))
209
210
211 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
212     """Return the priority of the `leaf` delimiter, given a line break after it.
213
214     The delimiter priorities returned here are from those delimiters that would
215     cause a line break after themselves.
216
217     Higher numbers are higher priority.
218     """
219     if leaf.type == token.COMMA:
220         return COMMA_PRIORITY
221
222     return 0
223
224
225 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
226     """Return the priority of the `leaf` delimiter, given a line break before it.
227
228     The delimiter priorities returned here are from those delimiters that would
229     cause a line break before themselves.
230
231     Higher numbers are higher priority.
232     """
233     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
234         # * and ** might also be MATH_OPERATORS but in this case they are not.
235         # Don't treat them as a delimiter.
236         return 0
237
238     if (
239         leaf.type == token.DOT
240         and leaf.parent
241         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
242         and (previous is None or previous.type in CLOSING_BRACKETS)
243     ):
244         return DOT_PRIORITY
245
246     if (
247         leaf.type in MATH_OPERATORS
248         and leaf.parent
249         and leaf.parent.type not in {syms.factor, syms.star_expr}
250     ):
251         return MATH_PRIORITIES[leaf.type]
252
253     if leaf.type in COMPARATORS:
254         return COMPARATOR_PRIORITY
255
256     if (
257         leaf.type == token.STRING
258         and previous is not None
259         and previous.type == token.STRING
260     ):
261         return STRING_PRIORITY
262
263     if leaf.type not in {token.NAME, token.ASYNC}:
264         return 0
265
266     if (
267         leaf.value == "for"
268         and leaf.parent
269         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
270         or leaf.type == token.ASYNC
271     ):
272         if (
273             not isinstance(leaf.prev_sibling, Leaf)
274             or leaf.prev_sibling.value != "async"
275         ):
276             return COMPREHENSION_PRIORITY
277
278     if (
279         leaf.value == "if"
280         and leaf.parent
281         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
282     ):
283         return COMPREHENSION_PRIORITY
284
285     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
286         return TERNARY_PRIORITY
287
288     if leaf.value == "is":
289         return COMPARATOR_PRIORITY
290
291     if (
292         leaf.value == "in"
293         and leaf.parent
294         and leaf.parent.type in {syms.comp_op, syms.comparison}
295         and not (
296             previous is not None
297             and previous.type == token.NAME
298             and previous.value == "not"
299         )
300     ):
301         return COMPARATOR_PRIORITY
302
303     if (
304         leaf.value == "not"
305         and leaf.parent
306         and leaf.parent.type == syms.comp_op
307         and not (
308             previous is not None
309             and previous.type == token.NAME
310             and previous.value == "is"
311         )
312     ):
313         return COMPARATOR_PRIORITY
314
315     if leaf.value in LOGIC_OPERATORS and leaf.parent:
316         return LOGIC_PRIORITY
317
318     return 0
319
320
321 def max_delimiter_priority_in_atom(node: LN) -> Priority:
322     """Return maximum delimiter priority inside `node`.
323
324     This is specific to atoms with contents contained in a pair of parentheses.
325     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
326     """
327     if node.type != syms.atom:
328         return 0
329
330     first = node.children[0]
331     last = node.children[-1]
332     if not (first.type == token.LPAR and last.type == token.RPAR):
333         return 0
334
335     bt = BracketTracker()
336     for c in node.children[1:-1]:
337         if isinstance(c, Leaf):
338             bt.mark(c)
339         else:
340             for leaf in c.leaves():
341                 bt.mark(leaf)
342     try:
343         return bt.max_delimiter_priority()
344
345     except ValueError:
346         return 0
347
348
349 def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]:
350     """Return leaves that are inside matching brackets.
351
352     The input `leaves` can have non-matching brackets at the head or tail parts.
353     Matching brackets are included.
354     """
355     try:
356         # Start with the first opening bracket and ignore closing brackets before.
357         start_index = next(
358             i for i, l in enumerate(leaves) if l.type in OPENING_BRACKETS
359         )
360     except StopIteration:
361         return set()
362     bracket_stack = []
363     ids = set()
364     for i in range(start_index, len(leaves)):
365         leaf = leaves[i]
366         if leaf.type in OPENING_BRACKETS:
367             bracket_stack.append((BRACKET[leaf.type], i))
368         if leaf.type in CLOSING_BRACKETS:
369             if bracket_stack and leaf.type == bracket_stack[-1][0]:
370                 _, start = bracket_stack.pop()
371                 for j in range(start, i + 1):
372                     ids.add(id(leaves[j]))
373             else:
374                 break
375     return ids