]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove standalone comment hacks
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial, wraps
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Callable,
14     Dict,
15     Generic,
16     Iterable,
17     Iterator,
18     List,
19     Optional,
20     Set,
21     Tuple,
22     Type,
23     TypeVar,
24     Union,
25 )
26
27 from attr import dataclass, Factory
28 import click
29
30 # lib2to3 fork
31 from blib2to3.pytree import Node, Leaf, type_repr
32 from blib2to3 import pygram, pytree
33 from blib2to3.pgen2 import driver, token
34 from blib2to3.pgen2.parse import ParseError
35
36 __version__ = "18.3a4"
37 DEFAULT_LINE_LENGTH = 88
38 # types
39 syms = pygram.python_symbols
40 FileContent = str
41 Encoding = str
42 Depth = int
43 NodeType = int
44 LeafID = int
45 Priority = int
46 Index = int
47 LN = Union[Leaf, Node]
48 SplitFunc = Callable[['Line', bool], Iterator['Line']]
49 out = partial(click.secho, bold=True, err=True)
50 err = partial(click.secho, fg='red', err=True)
51
52
53 class NothingChanged(UserWarning):
54     """Raised by :func:`format_file` when reformatted code is the same as source."""
55
56
57 class CannotSplit(Exception):
58     """A readable split that fits the allotted line length is impossible.
59
60     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
61     :func:`delimiter_split`.
62     """
63
64
65 class FormatError(Exception):
66     """Base exception for `# fmt: on` and `# fmt: off` handling.
67
68     It holds the number of bytes of the prefix consumed before the format
69     control comment appeared.
70     """
71
72     def __init__(self, consumed: int) -> None:
73         super().__init__(consumed)
74         self.consumed = consumed
75
76     def trim_prefix(self, leaf: Leaf) -> None:
77         leaf.prefix = leaf.prefix[self.consumed:]
78
79     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
80         """Returns a new Leaf from the consumed part of the prefix."""
81         unformatted_prefix = leaf.prefix[:self.consumed]
82         return Leaf(token.NEWLINE, unformatted_prefix)
83
84
85 class FormatOn(FormatError):
86     """Found a comment like `# fmt: on` in the file."""
87
88
89 class FormatOff(FormatError):
90     """Found a comment like `# fmt: off` in the file."""
91
92
93 @click.command()
94 @click.option(
95     '-l',
96     '--line-length',
97     type=int,
98     default=DEFAULT_LINE_LENGTH,
99     help='How many character per line to allow.',
100     show_default=True,
101 )
102 @click.option(
103     '--check',
104     is_flag=True,
105     help=(
106         "Don't write back the files, just return the status.  Return code 0 "
107         "means nothing would change.  Return code 1 means some files would be "
108         "reformatted.  Return code 123 means there was an internal error."
109     ),
110 )
111 @click.option(
112     '--fast/--safe',
113     is_flag=True,
114     help='If --fast given, skip temporary sanity checks. [default: --safe]',
115 )
116 @click.version_option(version=__version__)
117 @click.argument(
118     'src',
119     nargs=-1,
120     type=click.Path(
121         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
122     ),
123 )
124 @click.pass_context
125 def main(
126     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
127 ) -> None:
128     """The uncompromising code formatter."""
129     sources: List[Path] = []
130     for s in src:
131         p = Path(s)
132         if p.is_dir():
133             sources.extend(gen_python_files_in_dir(p))
134         elif p.is_file():
135             # if a file was explicitly given, we don't care about its extension
136             sources.append(p)
137         elif s == '-':
138             sources.append(Path('-'))
139         else:
140             err(f'invalid path: {s}')
141     if len(sources) == 0:
142         ctx.exit(0)
143     elif len(sources) == 1:
144         p = sources[0]
145         report = Report(check=check)
146         try:
147             if not p.is_file() and str(p) == '-':
148                 changed = format_stdin_to_stdout(
149                     line_length=line_length, fast=fast, write_back=not check
150                 )
151             else:
152                 changed = format_file_in_place(
153                     p, line_length=line_length, fast=fast, write_back=not check
154                 )
155             report.done(p, changed)
156         except Exception as exc:
157             report.failed(p, str(exc))
158         ctx.exit(report.return_code)
159     else:
160         loop = asyncio.get_event_loop()
161         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
162         return_code = 1
163         try:
164             return_code = loop.run_until_complete(
165                 schedule_formatting(
166                     sources, line_length, not check, fast, loop, executor
167                 )
168             )
169         finally:
170             loop.close()
171             ctx.exit(return_code)
172
173
174 async def schedule_formatting(
175     sources: List[Path],
176     line_length: int,
177     write_back: bool,
178     fast: bool,
179     loop: BaseEventLoop,
180     executor: Executor,
181 ) -> int:
182     """Run formatting of `sources` in parallel using the provided `executor`.
183
184     (Use ProcessPoolExecutors for actual parallelism.)
185
186     `line_length`, `write_back`, and `fast` options are passed to
187     :func:`format_file_in_place`.
188     """
189     tasks = {
190         src: loop.run_in_executor(
191             executor, format_file_in_place, src, line_length, fast, write_back
192         )
193         for src in sources
194     }
195     await asyncio.wait(tasks.values())
196     cancelled = []
197     report = Report(check=not write_back)
198     for src, task in tasks.items():
199         if not task.done():
200             report.failed(src, 'timed out, cancelling')
201             task.cancel()
202             cancelled.append(task)
203         elif task.exception():
204             report.failed(src, str(task.exception()))
205         else:
206             report.done(src, task.result())
207     if cancelled:
208         await asyncio.wait(cancelled, timeout=2)
209     out('All done! ✨ 🍰 ✨')
210     click.echo(str(report))
211     return report.return_code
212
213
214 def format_file_in_place(
215     src: Path, line_length: int, fast: bool, write_back: bool = False
216 ) -> bool:
217     """Format file under `src` path. Return True if changed.
218
219     If `write_back` is True, write reformatted code back to stdout.
220     `line_length` and `fast` options are passed to :func:`format_file_contents`.
221     """
222     with tokenize.open(src) as src_buffer:
223         src_contents = src_buffer.read()
224     try:
225         contents = format_file_contents(
226             src_contents, line_length=line_length, fast=fast
227         )
228     except NothingChanged:
229         return False
230
231     if write_back:
232         with open(src, "w", encoding=src_buffer.encoding) as f:
233             f.write(contents)
234     return True
235
236
237 def format_stdin_to_stdout(
238     line_length: int, fast: bool, write_back: bool = False
239 ) -> bool:
240     """Format file on stdin. Return True if changed.
241
242     If `write_back` is True, write reformatted code back to stdout.
243     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
244     """
245     contents = sys.stdin.read()
246     try:
247         contents = format_file_contents(contents, line_length=line_length, fast=fast)
248         return True
249
250     except NothingChanged:
251         return False
252
253     finally:
254         if write_back:
255             sys.stdout.write(contents)
256
257
258 def format_file_contents(
259     src_contents: str, line_length: int, fast: bool
260 ) -> FileContent:
261     """Reformat contents a file and return new contents.
262
263     If `fast` is False, additionally confirm that the reformatted code is
264     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
265     `line_length` is passed to :func:`format_str`.
266     """
267     if src_contents.strip() == '':
268         raise NothingChanged
269
270     dst_contents = format_str(src_contents, line_length=line_length)
271     if src_contents == dst_contents:
272         raise NothingChanged
273
274     if not fast:
275         assert_equivalent(src_contents, dst_contents)
276         assert_stable(src_contents, dst_contents, line_length=line_length)
277     return dst_contents
278
279
280 def format_str(src_contents: str, line_length: int) -> FileContent:
281     """Reformat a string and return new contents.
282
283     `line_length` determines how many characters per line are allowed.
284     """
285     src_node = lib2to3_parse(src_contents)
286     dst_contents = ""
287     lines = LineGenerator()
288     elt = EmptyLineTracker()
289     py36 = is_python36(src_node)
290     empty_line = Line()
291     after = 0
292     for current_line in lines.visit(src_node):
293         for _ in range(after):
294             dst_contents += str(empty_line)
295         before, after = elt.maybe_empty_lines(current_line)
296         for _ in range(before):
297             dst_contents += str(empty_line)
298         for line in split_line(current_line, line_length=line_length, py36=py36):
299             dst_contents += str(line)
300     return dst_contents
301
302
303 GRAMMARS = [
304     pygram.python_grammar_no_print_statement_no_exec_statement,
305     pygram.python_grammar_no_print_statement,
306     pygram.python_grammar_no_exec_statement,
307     pygram.python_grammar,
308 ]
309
310
311 def lib2to3_parse(src_txt: str) -> Node:
312     """Given a string with source, return the lib2to3 Node."""
313     grammar = pygram.python_grammar_no_print_statement
314     if src_txt[-1] != '\n':
315         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
316         src_txt += nl
317     for grammar in GRAMMARS:
318         drv = driver.Driver(grammar, pytree.convert)
319         try:
320             result = drv.parse_string(src_txt, True)
321             break
322
323         except ParseError as pe:
324             lineno, column = pe.context[1]
325             lines = src_txt.splitlines()
326             try:
327                 faulty_line = lines[lineno - 1]
328             except IndexError:
329                 faulty_line = "<line number missing in source>"
330             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
331     else:
332         raise exc from None
333
334     if isinstance(result, Leaf):
335         result = Node(syms.file_input, [result])
336     return result
337
338
339 def lib2to3_unparse(node: Node) -> str:
340     """Given a lib2to3 node, return its string representation."""
341     code = str(node)
342     return code
343
344
345 T = TypeVar('T')
346
347
348 class Visitor(Generic[T]):
349     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
350
351     def visit(self, node: LN) -> Iterator[T]:
352         """Main method to visit `node` and its children.
353
354         It tries to find a `visit_*()` method for the given `node.type`, like
355         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
356         If no dedicated `visit_*()` method is found, chooses `visit_default()`
357         instead.
358
359         Then yields objects of type `T` from the selected visitor.
360         """
361         if node.type < 256:
362             name = token.tok_name[node.type]
363         else:
364             name = type_repr(node.type)
365         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
366
367     def visit_default(self, node: LN) -> Iterator[T]:
368         """Default `visit_*()` implementation. Recurses to children of `node`."""
369         if isinstance(node, Node):
370             for child in node.children:
371                 yield from self.visit(child)
372
373
374 @dataclass
375 class DebugVisitor(Visitor[T]):
376     tree_depth: int = 0
377
378     def visit_default(self, node: LN) -> Iterator[T]:
379         indent = ' ' * (2 * self.tree_depth)
380         if isinstance(node, Node):
381             _type = type_repr(node.type)
382             out(f'{indent}{_type}', fg='yellow')
383             self.tree_depth += 1
384             for child in node.children:
385                 yield from self.visit(child)
386
387             self.tree_depth -= 1
388             out(f'{indent}/{_type}', fg='yellow', bold=False)
389         else:
390             _type = token.tok_name.get(node.type, str(node.type))
391             out(f'{indent}{_type}', fg='blue', nl=False)
392             if node.prefix:
393                 # We don't have to handle prefixes for `Node` objects since
394                 # that delegates to the first child anyway.
395                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
396             out(f' {node.value!r}', fg='blue', bold=False)
397
398     @classmethod
399     def show(cls, code: str) -> None:
400         """Pretty-print the lib2to3 AST of a given string of `code`.
401
402         Convenience method for debugging.
403         """
404         v: DebugVisitor[None] = DebugVisitor()
405         list(v.visit(lib2to3_parse(code)))
406
407
408 KEYWORDS = set(keyword.kwlist)
409 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
410 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
411 STATEMENT = {
412     syms.if_stmt,
413     syms.while_stmt,
414     syms.for_stmt,
415     syms.try_stmt,
416     syms.except_clause,
417     syms.with_stmt,
418     syms.funcdef,
419     syms.classdef,
420 }
421 STANDALONE_COMMENT = 153
422 LOGIC_OPERATORS = {'and', 'or'}
423 COMPARATORS = {
424     token.LESS,
425     token.GREATER,
426     token.EQEQUAL,
427     token.NOTEQUAL,
428     token.LESSEQUAL,
429     token.GREATEREQUAL,
430 }
431 MATH_OPERATORS = {
432     token.PLUS,
433     token.MINUS,
434     token.STAR,
435     token.SLASH,
436     token.VBAR,
437     token.AMPER,
438     token.PERCENT,
439     token.CIRCUMFLEX,
440     token.TILDE,
441     token.LEFTSHIFT,
442     token.RIGHTSHIFT,
443     token.DOUBLESTAR,
444     token.DOUBLESLASH,
445 }
446 COMPREHENSION_PRIORITY = 20
447 COMMA_PRIORITY = 10
448 LOGIC_PRIORITY = 5
449 STRING_PRIORITY = 4
450 COMPARATOR_PRIORITY = 3
451 MATH_PRIORITY = 1
452
453
454 @dataclass
455 class BracketTracker:
456     """Keeps track of brackets on a line."""
457
458     depth: int = 0
459     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
460     delimiters: Dict[LeafID, Priority] = Factory(dict)
461     previous: Optional[Leaf] = None
462
463     def mark(self, leaf: Leaf) -> None:
464         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
465
466         All leaves receive an int `bracket_depth` field that stores how deep
467         within brackets a given leaf is. 0 means there are no enclosing brackets
468         that started on this line.
469
470         If a leaf is itself a closing bracket, it receives an `opening_bracket`
471         field that it forms a pair with. This is a one-directional link to
472         avoid reference cycles.
473
474         If a leaf is a delimiter (a token on which Black can split the line if
475         needed) and it's on depth 0, its `id()` is stored in the tracker's
476         `delimiters` field.
477         """
478         if leaf.type == token.COMMENT:
479             return
480
481         if leaf.type in CLOSING_BRACKETS:
482             self.depth -= 1
483             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
484             leaf.opening_bracket = opening_bracket
485         leaf.bracket_depth = self.depth
486         if self.depth == 0:
487             delim = is_delimiter(leaf)
488             if delim:
489                 self.delimiters[id(leaf)] = delim
490             elif self.previous is not None:
491                 if leaf.type == token.STRING and self.previous.type == token.STRING:
492                     self.delimiters[id(self.previous)] = STRING_PRIORITY
493                 elif (
494                     leaf.type == token.NAME
495                     and leaf.value == 'for'
496                     and leaf.parent
497                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
498                 ):
499                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
500                 elif (
501                     leaf.type == token.NAME
502                     and leaf.value == 'if'
503                     and leaf.parent
504                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
505                 ):
506                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
507                 elif (
508                     leaf.type == token.NAME
509                     and leaf.value in LOGIC_OPERATORS
510                     and leaf.parent
511                 ):
512                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
513         if leaf.type in OPENING_BRACKETS:
514             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
515             self.depth += 1
516         self.previous = leaf
517
518     def any_open_brackets(self) -> bool:
519         """Return True if there is an yet unmatched open bracket on the line."""
520         return bool(self.bracket_match)
521
522     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
523         """Return the highest priority of a delimiter found on the line.
524
525         Values are consistent with what `is_delimiter()` returns.
526         """
527         return max(v for k, v in self.delimiters.items() if k not in exclude)
528
529
530 @dataclass
531 class Line:
532     """Holds leaves and comments. Can be printed with `str(line)`."""
533
534     depth: int = 0
535     leaves: List[Leaf] = Factory(list)
536     comments: List[Tuple[Index, Leaf]] = Factory(list)
537     bracket_tracker: BracketTracker = Factory(BracketTracker)
538     inside_brackets: bool = False
539     has_for: bool = False
540     _for_loop_variable: bool = False
541
542     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
543         """Add a new `leaf` to the end of the line.
544
545         Unless `preformatted` is True, the `leaf` will receive a new consistent
546         whitespace prefix and metadata applied by :class:`BracketTracker`.
547         Trailing commas are maybe removed, unpacked for loop variables are
548         demoted from being delimiters.
549
550         Inline comments are put aside.
551         """
552         has_value = leaf.value.strip()
553         if not has_value:
554             return
555
556         if self.leaves and not preformatted:
557             # Note: at this point leaf.prefix should be empty except for
558             # imports, for which we only preserve newlines.
559             leaf.prefix += whitespace(leaf)
560         if self.inside_brackets or not preformatted:
561             self.maybe_decrement_after_for_loop_variable(leaf)
562             self.bracket_tracker.mark(leaf)
563             self.maybe_remove_trailing_comma(leaf)
564             self.maybe_increment_for_loop_variable(leaf)
565
566         if not self.append_comment(leaf):
567             self.leaves.append(leaf)
568
569     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
570         """Like :func:`append()` but disallow invalid standalone comment structure.
571
572         Raises ValueError when any `leaf` is appended after a standalone comment
573         or when a standalone comment is not the first leaf on the line.
574         """
575         if self.bracket_tracker.depth == 0:
576             if self.is_comment:
577                 raise ValueError("cannot append to standalone comments")
578
579             if self.leaves and leaf.type == STANDALONE_COMMENT:
580                 raise ValueError(
581                     "cannot append standalone comments to a populated line"
582                 )
583
584         self.append(leaf, preformatted=preformatted)
585
586     @property
587     def is_comment(self) -> bool:
588         """Is this line a standalone comment?"""
589         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
590
591     @property
592     def is_decorator(self) -> bool:
593         """Is this line a decorator?"""
594         return bool(self) and self.leaves[0].type == token.AT
595
596     @property
597     def is_import(self) -> bool:
598         """Is this an import line?"""
599         return bool(self) and is_import(self.leaves[0])
600
601     @property
602     def is_class(self) -> bool:
603         """Is this line a class definition?"""
604         return (
605             bool(self)
606             and self.leaves[0].type == token.NAME
607             and self.leaves[0].value == 'class'
608         )
609
610     @property
611     def is_def(self) -> bool:
612         """Is this a function definition? (Also returns True for async defs.)"""
613         try:
614             first_leaf = self.leaves[0]
615         except IndexError:
616             return False
617
618         try:
619             second_leaf: Optional[Leaf] = self.leaves[1]
620         except IndexError:
621             second_leaf = None
622         return (
623             (first_leaf.type == token.NAME and first_leaf.value == 'def')
624             or (
625                 first_leaf.type == token.ASYNC
626                 and second_leaf is not None
627                 and second_leaf.type == token.NAME
628                 and second_leaf.value == 'def'
629             )
630         )
631
632     @property
633     def is_flow_control(self) -> bool:
634         """Is this line a flow control statement?
635
636         Those are `return`, `raise`, `break`, and `continue`.
637         """
638         return (
639             bool(self)
640             and self.leaves[0].type == token.NAME
641             and self.leaves[0].value in FLOW_CONTROL
642         )
643
644     @property
645     def is_yield(self) -> bool:
646         """Is this line a yield statement?"""
647         return (
648             bool(self)
649             and self.leaves[0].type == token.NAME
650             and self.leaves[0].value == 'yield'
651         )
652
653     @property
654     def contains_standalone_comments(self) -> bool:
655         """If so, needs to be split before emitting."""
656         for leaf in self.leaves:
657             if leaf.type == STANDALONE_COMMENT:
658                 return True
659
660         return False
661
662     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
663         """Remove trailing comma if there is one and it's safe."""
664         if not (
665             self.leaves
666             and self.leaves[-1].type == token.COMMA
667             and closing.type in CLOSING_BRACKETS
668         ):
669             return False
670
671         if closing.type == token.RBRACE:
672             self.remove_trailing_comma()
673             return True
674
675         if closing.type == token.RSQB:
676             comma = self.leaves[-1]
677             if comma.parent and comma.parent.type == syms.listmaker:
678                 self.remove_trailing_comma()
679                 return True
680
681         # For parens let's check if it's safe to remove the comma.  If the
682         # trailing one is the only one, we might mistakenly change a tuple
683         # into a different type by removing the comma.
684         depth = closing.bracket_depth + 1
685         commas = 0
686         opening = closing.opening_bracket
687         for _opening_index, leaf in enumerate(self.leaves):
688             if leaf is opening:
689                 break
690
691         else:
692             return False
693
694         for leaf in self.leaves[_opening_index + 1:]:
695             if leaf is closing:
696                 break
697
698             bracket_depth = leaf.bracket_depth
699             if bracket_depth == depth and leaf.type == token.COMMA:
700                 commas += 1
701                 if leaf.parent and leaf.parent.type == syms.arglist:
702                     commas += 1
703                     break
704
705         if commas > 1:
706             self.remove_trailing_comma()
707             return True
708
709         return False
710
711     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
712         """In a for loop, or comprehension, the variables are often unpacks.
713
714         To avoid splitting on the comma in this situation, increase the depth of
715         tokens between `for` and `in`.
716         """
717         if leaf.type == token.NAME and leaf.value == 'for':
718             self.has_for = True
719             self.bracket_tracker.depth += 1
720             self._for_loop_variable = True
721             return True
722
723         return False
724
725     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
726         """See `maybe_increment_for_loop_variable` above for explanation."""
727         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
728             self.bracket_tracker.depth -= 1
729             self._for_loop_variable = False
730             return True
731
732         return False
733
734     def append_comment(self, comment: Leaf) -> bool:
735         """Add an inline or standalone comment to the line."""
736         if (
737             comment.type == STANDALONE_COMMENT
738             and self.bracket_tracker.any_open_brackets()
739         ):
740             comment.prefix = ''
741             return False
742
743         if comment.type != token.COMMENT:
744             return False
745
746         after = len(self.leaves) - 1
747         if after == -1:
748             comment.type = STANDALONE_COMMENT
749             comment.prefix = ''
750             return False
751
752         else:
753             self.comments.append((after, comment))
754             return True
755
756     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
757         """Generate comments that should appear directly after `leaf`."""
758         for _leaf_index, _leaf in enumerate(self.leaves):
759             if leaf is _leaf:
760                 break
761
762         else:
763             return
764
765         for index, comment_after in self.comments:
766             if _leaf_index == index:
767                 yield comment_after
768
769     def remove_trailing_comma(self) -> None:
770         """Remove the trailing comma and moves the comments attached to it."""
771         comma_index = len(self.leaves) - 1
772         for i in range(len(self.comments)):
773             comment_index, comment = self.comments[i]
774             if comment_index == comma_index:
775                 self.comments[i] = (comma_index - 1, comment)
776         self.leaves.pop()
777
778     def __str__(self) -> str:
779         """Render the line."""
780         if not self:
781             return '\n'
782
783         indent = '    ' * self.depth
784         leaves = iter(self.leaves)
785         first = next(leaves)
786         res = f'{first.prefix}{indent}{first.value}'
787         for leaf in leaves:
788             res += str(leaf)
789         for _, comment in self.comments:
790             res += str(comment)
791         return res + '\n'
792
793     def __bool__(self) -> bool:
794         """Return True if the line has leaves or comments."""
795         return bool(self.leaves or self.comments)
796
797
798 class UnformattedLines(Line):
799     """Just like :class:`Line` but stores lines which aren't reformatted."""
800
801     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
802         """Just add a new `leaf` to the end of the lines.
803
804         The `preformatted` argument is ignored.
805
806         Keeps track of indentation `depth`, which is useful when the user
807         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
808         """
809         try:
810             list(generate_comments(leaf))
811         except FormatOn as f_on:
812             self.leaves.append(f_on.leaf_from_consumed(leaf))
813             raise
814
815         self.leaves.append(leaf)
816         if leaf.type == token.INDENT:
817             self.depth += 1
818         elif leaf.type == token.DEDENT:
819             self.depth -= 1
820
821     def __str__(self) -> str:
822         """Render unformatted lines from leaves which were added with `append()`.
823
824         `depth` is not used for indentation in this case.
825         """
826         if not self:
827             return '\n'
828
829         res = ''
830         for leaf in self.leaves:
831             res += str(leaf)
832         return res
833
834     def append_comment(self, comment: Leaf) -> bool:
835         """Not implemented in this class. Raises `NotImplementedError`."""
836         raise NotImplementedError("Unformatted lines don't store comments separately.")
837
838     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
839         """Does nothing and returns False."""
840         return False
841
842     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
843         """Does nothing and returns False."""
844         return False
845
846
847 @dataclass
848 class EmptyLineTracker:
849     """Provides a stateful method that returns the number of potential extra
850     empty lines needed before and after the currently processed line.
851
852     Note: this tracker works on lines that haven't been split yet.  It assumes
853     the prefix of the first leaf consists of optional newlines.  Those newlines
854     are consumed by `maybe_empty_lines()` and included in the computation.
855     """
856     previous_line: Optional[Line] = None
857     previous_after: int = 0
858     previous_defs: List[int] = Factory(list)
859
860     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
861         """Return the number of extra empty lines before and after the `current_line`.
862
863         This is for separating `def`, `async def` and `class` with extra empty
864         lines (two on module-level), as well as providing an extra empty line
865         after flow control keywords to make them more prominent.
866         """
867         if isinstance(current_line, UnformattedLines):
868             return 0, 0
869
870         before, after = self._maybe_empty_lines(current_line)
871         before -= self.previous_after
872         self.previous_after = after
873         self.previous_line = current_line
874         return before, after
875
876     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
877         max_allowed = 1
878         if current_line.depth == 0:
879             max_allowed = 2
880         if current_line.leaves:
881             # Consume the first leaf's extra newlines.
882             first_leaf = current_line.leaves[0]
883             before = first_leaf.prefix.count('\n')
884             before = min(before, max_allowed)
885             first_leaf.prefix = ''
886         else:
887             before = 0
888         depth = current_line.depth
889         while self.previous_defs and self.previous_defs[-1] >= depth:
890             self.previous_defs.pop()
891             before = 1 if depth else 2
892         is_decorator = current_line.is_decorator
893         if is_decorator or current_line.is_def or current_line.is_class:
894             if not is_decorator:
895                 self.previous_defs.append(depth)
896             if self.previous_line is None:
897                 # Don't insert empty lines before the first line in the file.
898                 return 0, 0
899
900             if self.previous_line and self.previous_line.is_decorator:
901                 # Don't insert empty lines between decorators.
902                 return 0, 0
903
904             newlines = 2
905             if current_line.depth:
906                 newlines -= 1
907             return newlines, 0
908
909         if current_line.is_flow_control:
910             return before, 1
911
912         if (
913             self.previous_line
914             and self.previous_line.is_import
915             and not current_line.is_import
916             and depth == self.previous_line.depth
917         ):
918             return (before or 1), 0
919
920         if (
921             self.previous_line
922             and self.previous_line.is_yield
923             and (not current_line.is_yield or depth != self.previous_line.depth)
924         ):
925             return (before or 1), 0
926
927         return before, 0
928
929
930 @dataclass
931 class LineGenerator(Visitor[Line]):
932     """Generates reformatted Line objects.  Empty lines are not emitted.
933
934     Note: destroys the tree it's visiting by mutating prefixes of its leaves
935     in ways that will no longer stringify to valid Python code on the tree.
936     """
937     current_line: Line = Factory(Line)
938
939     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
940         """Generate a line.
941
942         If the line is empty, only emit if it makes sense.
943         If the line is too long, split it first and then generate.
944
945         If any lines were generated, set up a new current_line.
946         """
947         if not self.current_line:
948             if self.current_line.__class__ == type:
949                 self.current_line.depth += indent
950             else:
951                 self.current_line = type(depth=self.current_line.depth + indent)
952             return  # Line is empty, don't emit. Creating a new one unnecessary.
953
954         complete_line = self.current_line
955         self.current_line = type(depth=complete_line.depth + indent)
956         yield complete_line
957
958     def visit(self, node: LN) -> Iterator[Line]:
959         """Main method to visit `node` and its children.
960
961         Yields :class:`Line` objects.
962         """
963         if isinstance(self.current_line, UnformattedLines):
964             # File contained `# fmt: off`
965             yield from self.visit_unformatted(node)
966
967         else:
968             yield from super().visit(node)
969
970     def visit_default(self, node: LN) -> Iterator[Line]:
971         """Default `visit_*()` implementation. Recurses to children of `node`."""
972         if isinstance(node, Leaf):
973             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
974             try:
975                 for comment in generate_comments(node):
976                     if any_open_brackets:
977                         # any comment within brackets is subject to splitting
978                         self.current_line.append(comment)
979                     elif comment.type == token.COMMENT:
980                         # regular trailing comment
981                         self.current_line.append(comment)
982                         yield from self.line()
983
984                     else:
985                         # regular standalone comment
986                         yield from self.line()
987
988                         self.current_line.append(comment)
989                         yield from self.line()
990
991             except FormatOff as f_off:
992                 f_off.trim_prefix(node)
993                 yield from self.line(type=UnformattedLines)
994                 yield from self.visit(node)
995
996             except FormatOn as f_on:
997                 # This only happens here if somebody says "fmt: on" multiple
998                 # times in a row.
999                 f_on.trim_prefix(node)
1000                 yield from self.visit_default(node)
1001
1002             else:
1003                 normalize_prefix(node, inside_brackets=any_open_brackets)
1004                 if node.type not in WHITESPACE:
1005                     self.current_line.append(node)
1006         yield from super().visit_default(node)
1007
1008     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1009         """Increase indentation level, maybe yield a line."""
1010         # In blib2to3 INDENT never holds comments.
1011         yield from self.line(+1)
1012         yield from self.visit_default(node)
1013
1014     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1015         """Decrease indentation level, maybe yield a line."""
1016         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1017         yield from self.line(-1)
1018
1019     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1020         """Visit a statement.
1021
1022         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1023         `def`, `with`, and `class`.
1024
1025         The relevant Python language `keywords` for a given statement will be NAME
1026         leaves within it. This methods puts those on a separate line.
1027         """
1028         for child in node.children:
1029             if child.type == token.NAME and child.value in keywords:  # type: ignore
1030                 yield from self.line()
1031
1032             yield from self.visit(child)
1033
1034     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1035         """Visit a statement without nested statements."""
1036         is_suite_like = node.parent and node.parent.type in STATEMENT
1037         if is_suite_like:
1038             yield from self.line(+1)
1039             yield from self.visit_default(node)
1040             yield from self.line(-1)
1041
1042         else:
1043             yield from self.line()
1044             yield from self.visit_default(node)
1045
1046     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1047         """Visit `async def`, `async for`, `async with`."""
1048         yield from self.line()
1049
1050         children = iter(node.children)
1051         for child in children:
1052             yield from self.visit(child)
1053
1054             if child.type == token.ASYNC:
1055                 break
1056
1057         internal_stmt = next(children)
1058         for child in internal_stmt.children:
1059             yield from self.visit(child)
1060
1061     def visit_decorators(self, node: Node) -> Iterator[Line]:
1062         """Visit decorators."""
1063         for child in node.children:
1064             yield from self.line()
1065             yield from self.visit(child)
1066
1067     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1068         """Remove a semicolon and put the other statement on a separate line."""
1069         yield from self.line()
1070
1071     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1072         """End of file. Process outstanding comments and end with a newline."""
1073         yield from self.visit_default(leaf)
1074         yield from self.line()
1075
1076     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1077         """Used when file contained a `# fmt: off`."""
1078         if isinstance(node, Node):
1079             for child in node.children:
1080                 yield from self.visit(child)
1081
1082         else:
1083             try:
1084                 self.current_line.append(node)
1085             except FormatOn as f_on:
1086                 f_on.trim_prefix(node)
1087                 yield from self.line()
1088                 yield from self.visit(node)
1089
1090     def __attrs_post_init__(self) -> None:
1091         """You are in a twisty little maze of passages."""
1092         v = self.visit_stmt
1093         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
1094         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
1095         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
1096         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
1097         self.visit_except_clause = partial(v, keywords={'except'})
1098         self.visit_funcdef = partial(v, keywords={'def'})
1099         self.visit_with_stmt = partial(v, keywords={'with'})
1100         self.visit_classdef = partial(v, keywords={'class'})
1101         self.visit_async_funcdef = self.visit_async_stmt
1102         self.visit_decorated = self.visit_decorators
1103
1104
1105 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1106 OPENING_BRACKETS = set(BRACKET.keys())
1107 CLOSING_BRACKETS = set(BRACKET.values())
1108 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1109 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1110
1111
1112 def whitespace(leaf: Leaf) -> str:  # noqa C901
1113     """Return whitespace prefix if needed for the given `leaf`."""
1114     NO = ''
1115     SPACE = ' '
1116     DOUBLESPACE = '  '
1117     t = leaf.type
1118     p = leaf.parent
1119     v = leaf.value
1120     if t in ALWAYS_NO_SPACE:
1121         return NO
1122
1123     if t == token.COMMENT:
1124         return DOUBLESPACE
1125
1126     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1127     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1128         return NO
1129
1130     prev = leaf.prev_sibling
1131     if not prev:
1132         prevp = preceding_leaf(p)
1133         if not prevp or prevp.type in OPENING_BRACKETS:
1134             return NO
1135
1136         if t == token.COLON:
1137             return SPACE if prevp.type == token.COMMA else NO
1138
1139         if prevp.type == token.EQUAL:
1140             if prevp.parent:
1141                 if prevp.parent.type in {
1142                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1143                 }:
1144                     return NO
1145
1146                 elif prevp.parent.type == syms.typedargslist:
1147                     # A bit hacky: if the equal sign has whitespace, it means we
1148                     # previously found it's a typed argument.  So, we're using
1149                     # that, too.
1150                     return prevp.prefix
1151
1152         elif prevp.type == token.DOUBLESTAR:
1153             if prevp.parent and prevp.parent.type in {
1154                 syms.arglist,
1155                 syms.argument,
1156                 syms.dictsetmaker,
1157                 syms.parameters,
1158                 syms.typedargslist,
1159                 syms.varargslist,
1160             }:
1161                 return NO
1162
1163         elif prevp.type == token.COLON:
1164             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1165                 return NO
1166
1167         elif (
1168             prevp.parent
1169             and prevp.parent.type in {syms.factor, syms.star_expr}
1170             and prevp.type in MATH_OPERATORS
1171         ):
1172             return NO
1173
1174         elif (
1175             prevp.type == token.RIGHTSHIFT
1176             and prevp.parent
1177             and prevp.parent.type == syms.shift_expr
1178             and prevp.prev_sibling
1179             and prevp.prev_sibling.type == token.NAME
1180             and prevp.prev_sibling.value == 'print'  # type: ignore
1181         ):
1182             # Python 2 print chevron
1183             return NO
1184
1185     elif prev.type in OPENING_BRACKETS:
1186         return NO
1187
1188     if p.type in {syms.parameters, syms.arglist}:
1189         # untyped function signatures or calls
1190         if t == token.RPAR:
1191             return NO
1192
1193         if not prev or prev.type != token.COMMA:
1194             return NO
1195
1196     elif p.type == syms.varargslist:
1197         # lambdas
1198         if t == token.RPAR:
1199             return NO
1200
1201         if prev and prev.type != token.COMMA:
1202             return NO
1203
1204     elif p.type == syms.typedargslist:
1205         # typed function signatures
1206         if not prev:
1207             return NO
1208
1209         if t == token.EQUAL:
1210             if prev.type != syms.tname:
1211                 return NO
1212
1213         elif prev.type == token.EQUAL:
1214             # A bit hacky: if the equal sign has whitespace, it means we
1215             # previously found it's a typed argument.  So, we're using that, too.
1216             return prev.prefix
1217
1218         elif prev.type != token.COMMA:
1219             return NO
1220
1221     elif p.type == syms.tname:
1222         # type names
1223         if not prev:
1224             prevp = preceding_leaf(p)
1225             if not prevp or prevp.type != token.COMMA:
1226                 return NO
1227
1228     elif p.type == syms.trailer:
1229         # attributes and calls
1230         if t == token.LPAR or t == token.RPAR:
1231             return NO
1232
1233         if not prev:
1234             if t == token.DOT:
1235                 prevp = preceding_leaf(p)
1236                 if not prevp or prevp.type != token.NUMBER:
1237                     return NO
1238
1239             elif t == token.LSQB:
1240                 return NO
1241
1242         elif prev.type != token.COMMA:
1243             return NO
1244
1245     elif p.type == syms.argument:
1246         # single argument
1247         if t == token.EQUAL:
1248             return NO
1249
1250         if not prev:
1251             prevp = preceding_leaf(p)
1252             if not prevp or prevp.type == token.LPAR:
1253                 return NO
1254
1255         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1256             return NO
1257
1258     elif p.type == syms.decorator:
1259         # decorators
1260         return NO
1261
1262     elif p.type == syms.dotted_name:
1263         if prev:
1264             return NO
1265
1266         prevp = preceding_leaf(p)
1267         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1268             return NO
1269
1270     elif p.type == syms.classdef:
1271         if t == token.LPAR:
1272             return NO
1273
1274         if prev and prev.type == token.LPAR:
1275             return NO
1276
1277     elif p.type == syms.subscript:
1278         # indexing
1279         if not prev:
1280             assert p.parent is not None, "subscripts are always parented"
1281             if p.parent.type == syms.subscriptlist:
1282                 return SPACE
1283
1284             return NO
1285
1286         else:
1287             return NO
1288
1289     elif p.type == syms.atom:
1290         if prev and t == token.DOT:
1291             # dots, but not the first one.
1292             return NO
1293
1294     elif (
1295         p.type == syms.listmaker
1296         or p.type == syms.testlist_gexp
1297         or p.type == syms.subscriptlist
1298     ):
1299         # list interior, including unpacking
1300         if not prev:
1301             return NO
1302
1303     elif p.type == syms.dictsetmaker:
1304         # dict and set interior, including unpacking
1305         if not prev:
1306             return NO
1307
1308         if prev.type == token.DOUBLESTAR:
1309             return NO
1310
1311     elif p.type in {syms.factor, syms.star_expr}:
1312         # unary ops
1313         if not prev:
1314             prevp = preceding_leaf(p)
1315             if not prevp or prevp.type in OPENING_BRACKETS:
1316                 return NO
1317
1318             prevp_parent = prevp.parent
1319             assert prevp_parent is not None
1320             if prevp.type == token.COLON and prevp_parent.type in {
1321                 syms.subscript, syms.sliceop
1322             }:
1323                 return NO
1324
1325             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1326                 return NO
1327
1328         elif t == token.NAME or t == token.NUMBER:
1329             return NO
1330
1331     elif p.type == syms.import_from:
1332         if t == token.DOT:
1333             if prev and prev.type == token.DOT:
1334                 return NO
1335
1336         elif t == token.NAME:
1337             if v == 'import':
1338                 return SPACE
1339
1340             if prev and prev.type == token.DOT:
1341                 return NO
1342
1343     elif p.type == syms.sliceop:
1344         return NO
1345
1346     return SPACE
1347
1348
1349 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1350     """Return the first leaf that precedes `node`, if any."""
1351     while node:
1352         res = node.prev_sibling
1353         if res:
1354             if isinstance(res, Leaf):
1355                 return res
1356
1357             try:
1358                 return list(res.leaves())[-1]
1359
1360             except IndexError:
1361                 return None
1362
1363         node = node.parent
1364     return None
1365
1366
1367 def is_delimiter(leaf: Leaf) -> int:
1368     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1369
1370     Higher numbers are higher priority.
1371     """
1372     if leaf.type == token.COMMA:
1373         return COMMA_PRIORITY
1374
1375     if leaf.type in COMPARATORS:
1376         return COMPARATOR_PRIORITY
1377
1378     if (
1379         leaf.type in MATH_OPERATORS
1380         and leaf.parent
1381         and leaf.parent.type not in {syms.factor, syms.star_expr}
1382     ):
1383         return MATH_PRIORITY
1384
1385     return 0
1386
1387
1388 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1389     """Clean the prefix of the `leaf` and generate comments from it, if any.
1390
1391     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1392     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1393     move because it does away with modifying the grammar to include all the
1394     possible places in which comments can be placed.
1395
1396     The sad consequence for us though is that comments don't "belong" anywhere.
1397     This is why this function generates simple parentless Leaf objects for
1398     comments.  We simply don't know what the correct parent should be.
1399
1400     No matter though, we can live without this.  We really only need to
1401     differentiate between inline and standalone comments.  The latter don't
1402     share the line with any code.
1403
1404     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1405     are emitted with a fake STANDALONE_COMMENT token identifier.
1406     """
1407     p = leaf.prefix
1408     if not p:
1409         return
1410
1411     if '#' not in p:
1412         return
1413
1414     consumed = 0
1415     nlines = 0
1416     for index, line in enumerate(p.split('\n')):
1417         consumed += len(line) + 1  # adding the length of the split '\n'
1418         line = line.lstrip()
1419         if not line:
1420             nlines += 1
1421         if not line.startswith('#'):
1422             continue
1423
1424         if index == 0 and leaf.type != token.ENDMARKER:
1425             comment_type = token.COMMENT  # simple trailing comment
1426         else:
1427             comment_type = STANDALONE_COMMENT
1428         comment = make_comment(line)
1429         yield Leaf(comment_type, comment, prefix='\n' * nlines)
1430
1431         if comment in {'# fmt: on', '# yapf: enable'}:
1432             raise FormatOn(consumed)
1433
1434         if comment in {'# fmt: off', '# yapf: disable'}:
1435             raise FormatOff(consumed)
1436
1437         nlines = 0
1438
1439
1440 def make_comment(content: str) -> str:
1441     """Return a consistently formatted comment from the given `content` string.
1442
1443     All comments (except for "##", "#!", "#:") should have a single space between
1444     the hash sign and the content.
1445
1446     If `content` didn't start with a hash sign, one is provided.
1447     """
1448     content = content.rstrip()
1449     if not content:
1450         return '#'
1451
1452     if content[0] == '#':
1453         content = content[1:]
1454     if content and content[0] not in ' !:#':
1455         content = ' ' + content
1456     return '#' + content
1457
1458
1459 def split_line(
1460     line: Line, line_length: int, inner: bool = False, py36: bool = False
1461 ) -> Iterator[Line]:
1462     """Split a `line` into potentially many lines.
1463
1464     They should fit in the allotted `line_length` but might not be able to.
1465     `inner` signifies that there were a pair of brackets somewhere around the
1466     current `line`, possibly transitively. This means we can fallback to splitting
1467     by delimiters if the LHS/RHS don't yield any results.
1468
1469     If `py36` is True, splitting may generate syntax that is only compatible
1470     with Python 3.6 and later.
1471     """
1472     if isinstance(line, UnformattedLines) or line.is_comment:
1473         yield line
1474         return
1475
1476     line_str = str(line).strip('\n')
1477     if (
1478         len(line_str) <= line_length
1479         and '\n' not in line_str  # multiline strings
1480         and not line.contains_standalone_comments
1481     ):
1482         yield line
1483         return
1484
1485     split_funcs: List[SplitFunc]
1486     if line.is_def:
1487         split_funcs = [left_hand_split]
1488     elif line.inside_brackets:
1489         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1490     else:
1491         split_funcs = [right_hand_split]
1492     for split_func in split_funcs:
1493         # We are accumulating lines in `result` because we might want to abort
1494         # mission and return the original line in the end, or attempt a different
1495         # split altogether.
1496         result: List[Line] = []
1497         try:
1498             for l in split_func(line, py36):
1499                 if str(l).strip('\n') == line_str:
1500                     raise CannotSplit("Split function returned an unchanged result")
1501
1502                 result.extend(
1503                     split_line(l, line_length=line_length, inner=True, py36=py36)
1504                 )
1505         except CannotSplit as cs:
1506             continue
1507
1508         else:
1509             yield from result
1510             break
1511
1512     else:
1513         yield line
1514
1515
1516 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1517     """Split line into many lines, starting with the first matching bracket pair.
1518
1519     Note: this usually looks weird, only use this for function definitions.
1520     Prefer RHS otherwise.
1521     """
1522     head = Line(depth=line.depth)
1523     body = Line(depth=line.depth + 1, inside_brackets=True)
1524     tail = Line(depth=line.depth)
1525     tail_leaves: List[Leaf] = []
1526     body_leaves: List[Leaf] = []
1527     head_leaves: List[Leaf] = []
1528     current_leaves = head_leaves
1529     matching_bracket = None
1530     for leaf in line.leaves:
1531         if (
1532             current_leaves is body_leaves
1533             and leaf.type in CLOSING_BRACKETS
1534             and leaf.opening_bracket is matching_bracket
1535         ):
1536             current_leaves = tail_leaves if body_leaves else head_leaves
1537         current_leaves.append(leaf)
1538         if current_leaves is head_leaves:
1539             if leaf.type in OPENING_BRACKETS:
1540                 matching_bracket = leaf
1541                 current_leaves = body_leaves
1542     # Since body is a new indent level, remove spurious leading whitespace.
1543     if body_leaves:
1544         normalize_prefix(body_leaves[0], inside_brackets=True)
1545     # Build the new lines.
1546     for result, leaves in (
1547         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1548     ):
1549         for leaf in leaves:
1550             result.append(leaf, preformatted=True)
1551             for comment_after in line.comments_after(leaf):
1552                 result.append(comment_after, preformatted=True)
1553     bracket_split_succeeded_or_raise(head, body, tail)
1554     for result in (head, body, tail):
1555         if result:
1556             yield result
1557
1558
1559 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1560     """Split line into many lines, starting with the last matching bracket pair."""
1561     head = Line(depth=line.depth)
1562     body = Line(depth=line.depth + 1, inside_brackets=True)
1563     tail = Line(depth=line.depth)
1564     tail_leaves: List[Leaf] = []
1565     body_leaves: List[Leaf] = []
1566     head_leaves: List[Leaf] = []
1567     current_leaves = tail_leaves
1568     opening_bracket = None
1569     for leaf in reversed(line.leaves):
1570         if current_leaves is body_leaves:
1571             if leaf is opening_bracket:
1572                 current_leaves = head_leaves if body_leaves else tail_leaves
1573         current_leaves.append(leaf)
1574         if current_leaves is tail_leaves:
1575             if leaf.type in CLOSING_BRACKETS:
1576                 opening_bracket = leaf.opening_bracket
1577                 current_leaves = body_leaves
1578     tail_leaves.reverse()
1579     body_leaves.reverse()
1580     head_leaves.reverse()
1581     # Since body is a new indent level, remove spurious leading whitespace.
1582     if body_leaves:
1583         normalize_prefix(body_leaves[0], inside_brackets=True)
1584     # Build the new lines.
1585     for result, leaves in (
1586         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1587     ):
1588         for leaf in leaves:
1589             result.append(leaf, preformatted=True)
1590             for comment_after in line.comments_after(leaf):
1591                 result.append(comment_after, preformatted=True)
1592     bracket_split_succeeded_or_raise(head, body, tail)
1593     for result in (head, body, tail):
1594         if result:
1595             yield result
1596
1597
1598 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1599     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1600
1601     Do nothing otherwise.
1602
1603     A left- or right-hand split is based on a pair of brackets. Content before
1604     (and including) the opening bracket is left on one line, content inside the
1605     brackets is put on a separate line, and finally content starting with and
1606     following the closing bracket is put on a separate line.
1607
1608     Those are called `head`, `body`, and `tail`, respectively. If the split
1609     produced the same line (all content in `head`) or ended up with an empty `body`
1610     and the `tail` is just the closing bracket, then it's considered failed.
1611     """
1612     tail_len = len(str(tail).strip())
1613     if not body:
1614         if tail_len == 0:
1615             raise CannotSplit("Splitting brackets produced the same line")
1616
1617         elif tail_len < 3:
1618             raise CannotSplit(
1619                 f"Splitting brackets on an empty body to save "
1620                 f"{tail_len} characters is not worth it"
1621             )
1622
1623
1624 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1625     """Normalize prefix of the first leaf in every line returned by `split_func`.
1626
1627     This is a decorator over relevant split functions.
1628     """
1629
1630     @wraps(split_func)
1631     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1632         for l in split_func(line, py36):
1633             normalize_prefix(l.leaves[0], inside_brackets=True)
1634             yield l
1635
1636     return split_wrapper
1637
1638
1639 @dont_increase_indentation
1640 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1641     """Split according to delimiters of the highest priority.
1642
1643     If `py36` is True, the split will add trailing commas also in function
1644     signatures that contain `*` and `**`.
1645     """
1646     try:
1647         last_leaf = line.leaves[-1]
1648     except IndexError:
1649         raise CannotSplit("Line empty")
1650
1651     delimiters = line.bracket_tracker.delimiters
1652     try:
1653         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1654             exclude={id(last_leaf)}
1655         )
1656     except ValueError:
1657         raise CannotSplit("No delimiters found")
1658
1659     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1660     lowest_depth = sys.maxsize
1661     trailing_comma_safe = True
1662
1663     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1664         """Append `leaf` to current line or to new line if appending impossible."""
1665         nonlocal current_line
1666         try:
1667             current_line.append_safe(leaf, preformatted=True)
1668         except ValueError as ve:
1669             yield current_line
1670
1671             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1672             current_line.append(leaf)
1673
1674     for leaf in line.leaves:
1675         yield from append_to_line(leaf)
1676
1677         for comment_after in line.comments_after(leaf):
1678             yield from append_to_line(comment_after)
1679
1680         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1681         if (
1682             leaf.bracket_depth == lowest_depth
1683             and leaf.type == token.STAR
1684             or leaf.type == token.DOUBLESTAR
1685         ):
1686             trailing_comma_safe = trailing_comma_safe and py36
1687         leaf_priority = delimiters.get(id(leaf))
1688         if leaf_priority == delimiter_priority:
1689             yield current_line
1690
1691             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1692     if current_line:
1693         if (
1694             delimiter_priority == COMMA_PRIORITY
1695             and current_line.leaves[-1].type != token.COMMA
1696             and trailing_comma_safe
1697         ):
1698             current_line.append(Leaf(token.COMMA, ','))
1699         yield current_line
1700
1701
1702 @dont_increase_indentation
1703 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1704     """Split standalone comments from the rest of the line."""
1705     for leaf in line.leaves:
1706         if leaf.type == STANDALONE_COMMENT:
1707             if leaf.bracket_depth == 0:
1708                 break
1709
1710     else:
1711         raise CannotSplit("Line does not have any standalone comments")
1712
1713     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1714
1715     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1716         """Append `leaf` to current line or to new line if appending impossible."""
1717         nonlocal current_line
1718         try:
1719             current_line.append_safe(leaf, preformatted=True)
1720         except ValueError as ve:
1721             yield current_line
1722
1723             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1724             current_line.append(leaf)
1725
1726     for leaf in line.leaves:
1727         yield from append_to_line(leaf)
1728
1729         for comment_after in line.comments_after(leaf):
1730             yield from append_to_line(comment_after)
1731
1732     if current_line:
1733         yield current_line
1734
1735
1736 def is_import(leaf: Leaf) -> bool:
1737     """Return True if the given leaf starts an import statement."""
1738     p = leaf.parent
1739     t = leaf.type
1740     v = leaf.value
1741     return bool(
1742         t == token.NAME
1743         and (
1744             (v == 'import' and p and p.type == syms.import_name)
1745             or (v == 'from' and p and p.type == syms.import_from)
1746         )
1747     )
1748
1749
1750 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1751     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1752     else.
1753
1754     Note: don't use backslashes for formatting or you'll lose your voting rights.
1755     """
1756     if not inside_brackets:
1757         spl = leaf.prefix.split('#')
1758         if '\\' not in spl[0]:
1759             nl_count = spl[-1].count('\n')
1760             if len(spl) > 1:
1761                 nl_count -= 1
1762             leaf.prefix = '\n' * nl_count
1763             return
1764
1765     leaf.prefix = ''
1766
1767
1768 def is_python36(node: Node) -> bool:
1769     """Return True if the current file is using Python 3.6+ features.
1770
1771     Currently looking for:
1772     - f-strings; and
1773     - trailing commas after * or ** in function signatures.
1774     """
1775     for n in node.pre_order():
1776         if n.type == token.STRING:
1777             value_head = n.value[:2]  # type: ignore
1778             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1779                 return True
1780
1781         elif (
1782             n.type == syms.typedargslist
1783             and n.children
1784             and n.children[-1].type == token.COMMA
1785         ):
1786             for ch in n.children:
1787                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1788                     return True
1789
1790     return False
1791
1792
1793 PYTHON_EXTENSIONS = {'.py'}
1794 BLACKLISTED_DIRECTORIES = {
1795     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1796 }
1797
1798
1799 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1800     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1801     and have one of the PYTHON_EXTENSIONS.
1802     """
1803     for child in path.iterdir():
1804         if child.is_dir():
1805             if child.name in BLACKLISTED_DIRECTORIES:
1806                 continue
1807
1808             yield from gen_python_files_in_dir(child)
1809
1810         elif child.suffix in PYTHON_EXTENSIONS:
1811             yield child
1812
1813
1814 @dataclass
1815 class Report:
1816     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1817     check: bool = False
1818     change_count: int = 0
1819     same_count: int = 0
1820     failure_count: int = 0
1821
1822     def done(self, src: Path, changed: bool) -> None:
1823         """Increment the counter for successful reformatting. Write out a message."""
1824         if changed:
1825             reformatted = 'would reformat' if self.check else 'reformatted'
1826             out(f'{reformatted} {src}')
1827             self.change_count += 1
1828         else:
1829             out(f'{src} already well formatted, good job.', bold=False)
1830             self.same_count += 1
1831
1832     def failed(self, src: Path, message: str) -> None:
1833         """Increment the counter for failed reformatting. Write out a message."""
1834         err(f'error: cannot format {src}: {message}')
1835         self.failure_count += 1
1836
1837     @property
1838     def return_code(self) -> int:
1839         """Return the exit code that the app should use.
1840
1841         This considers the current state of changed files and failures:
1842         - if there were any failures, return 123;
1843         - if any files were changed and --check is being used, return 1;
1844         - otherwise return 0.
1845         """
1846         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1847         # 126 we have special returncodes reserved by the shell.
1848         if self.failure_count:
1849             return 123
1850
1851         elif self.change_count and self.check:
1852             return 1
1853
1854         return 0
1855
1856     def __str__(self) -> str:
1857         """Render a color report of the current state.
1858
1859         Use `click.unstyle` to remove colors.
1860         """
1861         if self.check:
1862             reformatted = "would be reformatted"
1863             unchanged = "would be left unchanged"
1864             failed = "would fail to reformat"
1865         else:
1866             reformatted = "reformatted"
1867             unchanged = "left unchanged"
1868             failed = "failed to reformat"
1869         report = []
1870         if self.change_count:
1871             s = 's' if self.change_count > 1 else ''
1872             report.append(
1873                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1874             )
1875         if self.same_count:
1876             s = 's' if self.same_count > 1 else ''
1877             report.append(f'{self.same_count} file{s} {unchanged}')
1878         if self.failure_count:
1879             s = 's' if self.failure_count > 1 else ''
1880             report.append(
1881                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1882             )
1883         return ', '.join(report) + '.'
1884
1885
1886 def assert_equivalent(src: str, dst: str) -> None:
1887     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1888
1889     import ast
1890     import traceback
1891
1892     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1893         """Simple visitor generating strings to compare ASTs by content."""
1894         yield f"{'  ' * depth}{node.__class__.__name__}("
1895
1896         for field in sorted(node._fields):
1897             try:
1898                 value = getattr(node, field)
1899             except AttributeError:
1900                 continue
1901
1902             yield f"{'  ' * (depth+1)}{field}="
1903
1904             if isinstance(value, list):
1905                 for item in value:
1906                     if isinstance(item, ast.AST):
1907                         yield from _v(item, depth + 2)
1908
1909             elif isinstance(value, ast.AST):
1910                 yield from _v(value, depth + 2)
1911
1912             else:
1913                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1914
1915         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1916
1917     try:
1918         src_ast = ast.parse(src)
1919     except Exception as exc:
1920         major, minor = sys.version_info[:2]
1921         raise AssertionError(
1922             f"cannot use --safe with this file; failed to parse source file "
1923             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1924             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1925         )
1926
1927     try:
1928         dst_ast = ast.parse(dst)
1929     except Exception as exc:
1930         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1931         raise AssertionError(
1932             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1933             f"Please report a bug on https://github.com/ambv/black/issues.  "
1934             f"This invalid output might be helpful: {log}"
1935         ) from None
1936
1937     src_ast_str = '\n'.join(_v(src_ast))
1938     dst_ast_str = '\n'.join(_v(dst_ast))
1939     if src_ast_str != dst_ast_str:
1940         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1941         raise AssertionError(
1942             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1943             f"the source.  "
1944             f"Please report a bug on https://github.com/ambv/black/issues.  "
1945             f"This diff might be helpful: {log}"
1946         ) from None
1947
1948
1949 def assert_stable(src: str, dst: str, line_length: int) -> None:
1950     """Raise AssertionError if `dst` reformats differently the second time."""
1951     newdst = format_str(dst, line_length=line_length)
1952     if dst != newdst:
1953         log = dump_to_file(
1954             diff(src, dst, 'source', 'first pass'),
1955             diff(dst, newdst, 'first pass', 'second pass'),
1956         )
1957         raise AssertionError(
1958             f"INTERNAL ERROR: Black produced different code on the second pass "
1959             f"of the formatter.  "
1960             f"Please report a bug on https://github.com/ambv/black/issues.  "
1961             f"This diff might be helpful: {log}"
1962         ) from None
1963
1964
1965 def dump_to_file(*output: str) -> str:
1966     """Dump `output` to a temporary file. Return path to the file."""
1967     import tempfile
1968
1969     with tempfile.NamedTemporaryFile(
1970         mode='w', prefix='blk_', suffix='.log', delete=False
1971     ) as f:
1972         for lines in output:
1973             f.write(lines)
1974             f.write('\n')
1975     return f.name
1976
1977
1978 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1979     """Return a unified diff string between strings `a` and `b`."""
1980     import difflib
1981
1982     a_lines = [line + '\n' for line in a.split('\n')]
1983     b_lines = [line + '\n' for line in b.split('\n')]
1984     return ''.join(
1985         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1986     )
1987
1988
1989 if __name__ == '__main__':
1990     main()