]> git.madduck.net Git - etc/vim.git/blob - src/black/brackets.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Document black-jupyter hook (#3650)
[etc/vim.git] / src / black / brackets.py
1 """Builds on top of nodes.py to track brackets."""
2
3 import sys
4 from dataclasses import dataclass, field
5 from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
6
7 if sys.version_info < (3, 8):
8     from typing_extensions import Final
9 else:
10     from typing import Final
11
12 from black.nodes import (
13     BRACKET,
14     CLOSING_BRACKETS,
15     COMPARATORS,
16     LOGIC_OPERATORS,
17     MATH_OPERATORS,
18     OPENING_BRACKETS,
19     UNPACKING_PARENTS,
20     VARARGS_PARENTS,
21     is_vararg,
22     syms,
23 )
24 from blib2to3.pgen2 import token
25 from blib2to3.pytree import Leaf, Node
26
27 # types
28 LN = Union[Leaf, Node]
29 Depth = int
30 LeafID = int
31 NodeType = int
32 Priority = int
33
34
35 COMPREHENSION_PRIORITY: Final = 20
36 COMMA_PRIORITY: Final = 18
37 TERNARY_PRIORITY: Final = 16
38 LOGIC_PRIORITY: Final = 14
39 STRING_PRIORITY: Final = 12
40 COMPARATOR_PRIORITY: Final = 10
41 MATH_PRIORITIES: Final = {
42     token.VBAR: 9,
43     token.CIRCUMFLEX: 8,
44     token.AMPER: 7,
45     token.LEFTSHIFT: 6,
46     token.RIGHTSHIFT: 6,
47     token.PLUS: 5,
48     token.MINUS: 5,
49     token.STAR: 4,
50     token.SLASH: 4,
51     token.DOUBLESLASH: 4,
52     token.PERCENT: 4,
53     token.AT: 4,
54     token.TILDE: 3,
55     token.DOUBLESTAR: 2,
56 }
57 DOT_PRIORITY: Final = 1
58
59
60 class BracketMatchError(Exception):
61     """Raised when an opening bracket is unable to be matched to a closing bracket."""
62
63
64 @dataclass
65 class BracketTracker:
66     """Keeps track of brackets on a line."""
67
68     depth: int = 0
69     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
70     delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
71     previous: Optional[Leaf] = None
72     _for_loop_depths: List[int] = field(default_factory=list)
73     _lambda_argument_depths: List[int] = field(default_factory=list)
74     invisible: List[Leaf] = field(default_factory=list)
75
76     def mark(self, leaf: Leaf) -> None:
77         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
78
79         All leaves receive an int `bracket_depth` field that stores how deep
80         within brackets a given leaf is. 0 means there are no enclosing brackets
81         that started on this line.
82
83         If a leaf is itself a closing bracket and there is a matching opening
84         bracket earlier, it receives an `opening_bracket` field with which it forms a
85         pair. This is a one-directional link to avoid reference cycles. Closing
86         bracket without opening happens on lines continued from previous
87         breaks, e.g. `) -> "ReturnType":` as part of a funcdef where we place
88         the return type annotation on its own line of the previous closing RPAR.
89
90         If a leaf is a delimiter (a token on which Black can split the line if
91         needed) and it's on depth 0, its `id()` is stored in the tracker's
92         `delimiters` field.
93         """
94         if leaf.type == token.COMMENT:
95             return
96
97         if (
98             self.depth == 0
99             and leaf.type in CLOSING_BRACKETS
100             and (self.depth, leaf.type) not in self.bracket_match
101         ):
102             return
103
104         self.maybe_decrement_after_for_loop_variable(leaf)
105         self.maybe_decrement_after_lambda_arguments(leaf)
106         if leaf.type in CLOSING_BRACKETS:
107             self.depth -= 1
108             try:
109                 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
110             except KeyError as e:
111                 raise BracketMatchError(
112                     "Unable to match a closing bracket to the following opening"
113                     f" bracket: {leaf}"
114                 ) from e
115             leaf.opening_bracket = opening_bracket
116             if not leaf.value:
117                 self.invisible.append(leaf)
118         leaf.bracket_depth = self.depth
119         if self.depth == 0:
120             delim = is_split_before_delimiter(leaf, self.previous)
121             if delim and self.previous is not None:
122                 self.delimiters[id(self.previous)] = delim
123             else:
124                 delim = is_split_after_delimiter(leaf, self.previous)
125                 if delim:
126                     self.delimiters[id(leaf)] = delim
127         if leaf.type in OPENING_BRACKETS:
128             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
129             self.depth += 1
130             if not leaf.value:
131                 self.invisible.append(leaf)
132         self.previous = leaf
133         self.maybe_increment_lambda_arguments(leaf)
134         self.maybe_increment_for_loop_variable(leaf)
135
136     def any_open_brackets(self) -> bool:
137         """Return True if there is an yet unmatched open bracket on the line."""
138         return bool(self.bracket_match)
139
140     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
141         """Return the highest priority of a delimiter found on the line.
142
143         Values are consistent with what `is_split_*_delimiter()` return.
144         Raises ValueError on no delimiters.
145         """
146         return max(v for k, v in self.delimiters.items() if k not in exclude)
147
148     def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
149         """Return the number of delimiters with the given `priority`.
150
151         If no `priority` is passed, defaults to max priority on the line.
152         """
153         if not self.delimiters:
154             return 0
155
156         priority = priority or self.max_delimiter_priority()
157         return sum(1 for p in self.delimiters.values() if p == priority)
158
159     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
160         """In a for loop, or comprehension, the variables are often unpacks.
161
162         To avoid splitting on the comma in this situation, increase the depth of
163         tokens between `for` and `in`.
164         """
165         if leaf.type == token.NAME and leaf.value == "for":
166             self.depth += 1
167             self._for_loop_depths.append(self.depth)
168             return True
169
170         return False
171
172     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
173         """See `maybe_increment_for_loop_variable` above for explanation."""
174         if (
175             self._for_loop_depths
176             and self._for_loop_depths[-1] == self.depth
177             and leaf.type == token.NAME
178             and leaf.value == "in"
179         ):
180             self.depth -= 1
181             self._for_loop_depths.pop()
182             return True
183
184         return False
185
186     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
187         """In a lambda expression, there might be more than one argument.
188
189         To avoid splitting on the comma in this situation, increase the depth of
190         tokens between `lambda` and `:`.
191         """
192         if leaf.type == token.NAME and leaf.value == "lambda":
193             self.depth += 1
194             self._lambda_argument_depths.append(self.depth)
195             return True
196
197         return False
198
199     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
200         """See `maybe_increment_lambda_arguments` above for explanation."""
201         if (
202             self._lambda_argument_depths
203             and self._lambda_argument_depths[-1] == self.depth
204             and leaf.type == token.COLON
205         ):
206             self.depth -= 1
207             self._lambda_argument_depths.pop()
208             return True
209
210         return False
211
212     def get_open_lsqb(self) -> Optional[Leaf]:
213         """Return the most recent opening square bracket (if any)."""
214         return self.bracket_match.get((self.depth - 1, token.RSQB))
215
216
217 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
218     """Return the priority of the `leaf` delimiter, given a line break after it.
219
220     The delimiter priorities returned here are from those delimiters that would
221     cause a line break after themselves.
222
223     Higher numbers are higher priority.
224     """
225     if leaf.type == token.COMMA:
226         return COMMA_PRIORITY
227
228     return 0
229
230
231 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
232     """Return the priority of the `leaf` delimiter, given a line break before it.
233
234     The delimiter priorities returned here are from those delimiters that would
235     cause a line break before themselves.
236
237     Higher numbers are higher priority.
238     """
239     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
240         # * and ** might also be MATH_OPERATORS but in this case they are not.
241         # Don't treat them as a delimiter.
242         return 0
243
244     if (
245         leaf.type == token.DOT
246         and leaf.parent
247         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
248         and (previous is None or previous.type in CLOSING_BRACKETS)
249     ):
250         return DOT_PRIORITY
251
252     if (
253         leaf.type in MATH_OPERATORS
254         and leaf.parent
255         and leaf.parent.type not in {syms.factor, syms.star_expr}
256     ):
257         return MATH_PRIORITIES[leaf.type]
258
259     if leaf.type in COMPARATORS:
260         return COMPARATOR_PRIORITY
261
262     if (
263         leaf.type == token.STRING
264         and previous is not None
265         and previous.type == token.STRING
266     ):
267         return STRING_PRIORITY
268
269     if leaf.type not in {token.NAME, token.ASYNC}:
270         return 0
271
272     if (
273         leaf.value == "for"
274         and leaf.parent
275         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
276         or leaf.type == token.ASYNC
277     ):
278         if (
279             not isinstance(leaf.prev_sibling, Leaf)
280             or leaf.prev_sibling.value != "async"
281         ):
282             return COMPREHENSION_PRIORITY
283
284     if (
285         leaf.value == "if"
286         and leaf.parent
287         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
288     ):
289         return COMPREHENSION_PRIORITY
290
291     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
292         return TERNARY_PRIORITY
293
294     if leaf.value == "is":
295         return COMPARATOR_PRIORITY
296
297     if (
298         leaf.value == "in"
299         and leaf.parent
300         and leaf.parent.type in {syms.comp_op, syms.comparison}
301         and not (
302             previous is not None
303             and previous.type == token.NAME
304             and previous.value == "not"
305         )
306     ):
307         return COMPARATOR_PRIORITY
308
309     if (
310         leaf.value == "not"
311         and leaf.parent
312         and leaf.parent.type == syms.comp_op
313         and not (
314             previous is not None
315             and previous.type == token.NAME
316             and previous.value == "is"
317         )
318     ):
319         return COMPARATOR_PRIORITY
320
321     if leaf.value in LOGIC_OPERATORS and leaf.parent:
322         return LOGIC_PRIORITY
323
324     return 0
325
326
327 def max_delimiter_priority_in_atom(node: LN) -> Priority:
328     """Return maximum delimiter priority inside `node`.
329
330     This is specific to atoms with contents contained in a pair of parentheses.
331     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
332     """
333     if node.type != syms.atom:
334         return 0
335
336     first = node.children[0]
337     last = node.children[-1]
338     if not (first.type == token.LPAR and last.type == token.RPAR):
339         return 0
340
341     bt = BracketTracker()
342     for c in node.children[1:-1]:
343         if isinstance(c, Leaf):
344             bt.mark(c)
345         else:
346             for leaf in c.leaves():
347                 bt.mark(leaf)
348     try:
349         return bt.max_delimiter_priority()
350
351     except ValueError:
352         return 0
353
354
355 def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]:
356     """Return leaves that are inside matching brackets.
357
358     The input `leaves` can have non-matching brackets at the head or tail parts.
359     Matching brackets are included.
360     """
361     try:
362         # Start with the first opening bracket and ignore closing brackets before.
363         start_index = next(
364             i for i, l in enumerate(leaves) if l.type in OPENING_BRACKETS
365         )
366     except StopIteration:
367         return set()
368     bracket_stack = []
369     ids = set()
370     for i in range(start_index, len(leaves)):
371         leaf = leaves[i]
372         if leaf.type in OPENING_BRACKETS:
373             bracket_stack.append((BRACKET[leaf.type], i))
374         if leaf.type in CLOSING_BRACKETS:
375             if bracket_stack and leaf.type == bracket_stack[-1][0]:
376                 _, start = bracket_stack.pop()
377                 for j in range(start, i + 1):
378                     ids.add(id(leaves[j]))
379             else:
380                 break
381     return ids