]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Any logo you like
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a4"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
48     """
49
50
51 class FormatError(Exception):
52     """Base fmt: on/off error.
53
54     It holds the number of bytes of the prefix consumed before the format
55     control comment appeared.
56     """
57
58     def __init__(self, consumed: int) -> None:
59         super().__init__(consumed)
60         self.consumed = consumed
61
62     def trim_prefix(self, leaf: Leaf) -> None:
63         leaf.prefix = leaf.prefix[self.consumed:]
64
65     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
66         """Returns a new Leaf from the consumed part of the prefix."""
67         unformatted_prefix = leaf.prefix[:self.consumed]
68         return Leaf(token.NEWLINE, unformatted_prefix)
69
70
71 class FormatOn(FormatError):
72     """Found a comment like `# fmt: on` in the file."""
73
74
75 class FormatOff(FormatError):
76     """Found a comment like `# fmt: off` in the file."""
77
78
79 @click.command()
80 @click.option(
81     '-l',
82     '--line-length',
83     type=int,
84     default=DEFAULT_LINE_LENGTH,
85     help='How many character per line to allow.',
86     show_default=True,
87 )
88 @click.option(
89     '--check',
90     is_flag=True,
91     help=(
92         "Don't write back the files, just return the status.  Return code 0 "
93         "means nothing would change.  Return code 1 means some files would be "
94         "reformatted.  Return code 123 means there was an internal error."
95     ),
96 )
97 @click.option(
98     '--fast/--safe',
99     is_flag=True,
100     help='If --fast given, skip temporary sanity checks. [default: --safe]',
101 )
102 @click.version_option(version=__version__)
103 @click.argument(
104     'src',
105     nargs=-1,
106     type=click.Path(
107         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
108     ),
109 )
110 @click.pass_context
111 def main(
112     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
113 ) -> None:
114     """The uncompromising code formatter."""
115     sources: List[Path] = []
116     for s in src:
117         p = Path(s)
118         if p.is_dir():
119             sources.extend(gen_python_files_in_dir(p))
120         elif p.is_file():
121             # if a file was explicitly given, we don't care about its extension
122             sources.append(p)
123         elif s == '-':
124             sources.append(Path('-'))
125         else:
126             err(f'invalid path: {s}')
127     if len(sources) == 0:
128         ctx.exit(0)
129     elif len(sources) == 1:
130         p = sources[0]
131         report = Report(check=check)
132         try:
133             if not p.is_file() and str(p) == '-':
134                 changed = format_stdin_to_stdout(
135                     line_length=line_length, fast=fast, write_back=not check
136                 )
137             else:
138                 changed = format_file_in_place(
139                     p, line_length=line_length, fast=fast, write_back=not check
140                 )
141             report.done(p, changed)
142         except Exception as exc:
143             report.failed(p, str(exc))
144         ctx.exit(report.return_code)
145     else:
146         loop = asyncio.get_event_loop()
147         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
148         return_code = 1
149         try:
150             return_code = loop.run_until_complete(
151                 schedule_formatting(
152                     sources, line_length, not check, fast, loop, executor
153                 )
154             )
155         finally:
156             loop.close()
157             ctx.exit(return_code)
158
159
160 async def schedule_formatting(
161     sources: List[Path],
162     line_length: int,
163     write_back: bool,
164     fast: bool,
165     loop: BaseEventLoop,
166     executor: Executor,
167 ) -> int:
168     tasks = {
169         src: loop.run_in_executor(
170             executor, format_file_in_place, src, line_length, fast, write_back
171         )
172         for src in sources
173     }
174     await asyncio.wait(tasks.values())
175     cancelled = []
176     report = Report()
177     for src, task in tasks.items():
178         if not task.done():
179             report.failed(src, 'timed out, cancelling')
180             task.cancel()
181             cancelled.append(task)
182         elif task.exception():
183             report.failed(src, str(task.exception()))
184         else:
185             report.done(src, task.result())
186     if cancelled:
187         await asyncio.wait(cancelled, timeout=2)
188     out('All done! ✨ 🍰 ✨')
189     click.echo(str(report))
190     return report.return_code
191
192
193 def format_file_in_place(
194     src: Path, line_length: int, fast: bool, write_back: bool = False
195 ) -> bool:
196     """Format the file and rewrite if changed. Return True if changed."""
197     with tokenize.open(src) as src_buffer:
198         src_contents = src_buffer.read()
199     try:
200         contents = format_file_contents(
201             src_contents, line_length=line_length, fast=fast
202         )
203     except NothingChanged:
204         return False
205
206     if write_back:
207         with open(src, "w", encoding=src_buffer.encoding) as f:
208             f.write(contents)
209     return True
210
211
212 def format_stdin_to_stdout(
213     line_length: int, fast: bool, write_back: bool = False
214 ) -> bool:
215     """Format file on stdin and pipe output to stdout. Return True if changed."""
216     contents = sys.stdin.read()
217     try:
218         contents = format_file_contents(contents, line_length=line_length, fast=fast)
219         return True
220
221     except NothingChanged:
222         return False
223
224     finally:
225         if write_back:
226             sys.stdout.write(contents)
227
228
229 def format_file_contents(
230     src_contents: str, line_length: int, fast: bool
231 ) -> FileContent:
232     """Reformats a file and returns its contents and encoding."""
233     if src_contents.strip() == '':
234         raise NothingChanged
235
236     dst_contents = format_str(src_contents, line_length=line_length)
237     if src_contents == dst_contents:
238         raise NothingChanged
239
240     if not fast:
241         assert_equivalent(src_contents, dst_contents)
242         assert_stable(src_contents, dst_contents, line_length=line_length)
243     return dst_contents
244
245
246 def format_str(src_contents: str, line_length: int) -> FileContent:
247     """Reformats a string and returns new contents."""
248     src_node = lib2to3_parse(src_contents)
249     dst_contents = ""
250     lines = LineGenerator()
251     elt = EmptyLineTracker()
252     py36 = is_python36(src_node)
253     empty_line = Line()
254     after = 0
255     for current_line in lines.visit(src_node):
256         for _ in range(after):
257             dst_contents += str(empty_line)
258         before, after = elt.maybe_empty_lines(current_line)
259         for _ in range(before):
260             dst_contents += str(empty_line)
261         for line in split_line(current_line, line_length=line_length, py36=py36):
262             dst_contents += str(line)
263     return dst_contents
264
265
266 GRAMMARS = [
267     pygram.python_grammar_no_print_statement_no_exec_statement,
268     pygram.python_grammar_no_print_statement,
269     pygram.python_grammar_no_exec_statement,
270     pygram.python_grammar,
271 ]
272
273
274 def lib2to3_parse(src_txt: str) -> Node:
275     """Given a string with source, return the lib2to3 Node."""
276     grammar = pygram.python_grammar_no_print_statement
277     if src_txt[-1] != '\n':
278         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
279         src_txt += nl
280     for grammar in GRAMMARS:
281         drv = driver.Driver(grammar, pytree.convert)
282         try:
283             result = drv.parse_string(src_txt, True)
284             break
285
286         except ParseError as pe:
287             lineno, column = pe.context[1]
288             lines = src_txt.splitlines()
289             try:
290                 faulty_line = lines[lineno - 1]
291             except IndexError:
292                 faulty_line = "<line number missing in source>"
293             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
294     else:
295         raise exc from None
296
297     if isinstance(result, Leaf):
298         result = Node(syms.file_input, [result])
299     return result
300
301
302 def lib2to3_unparse(node: Node) -> str:
303     """Given a lib2to3 node, return its string representation."""
304     code = str(node)
305     return code
306
307
308 T = TypeVar('T')
309
310
311 class Visitor(Generic[T]):
312     """Basic lib2to3 visitor that yields things on visiting."""
313
314     def visit(self, node: LN) -> Iterator[T]:
315         if node.type < 256:
316             name = token.tok_name[node.type]
317         else:
318             name = type_repr(node.type)
319         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
320
321     def visit_default(self, node: LN) -> Iterator[T]:
322         if isinstance(node, Node):
323             for child in node.children:
324                 yield from self.visit(child)
325
326
327 @dataclass
328 class DebugVisitor(Visitor[T]):
329     tree_depth: int = 0
330
331     def visit_default(self, node: LN) -> Iterator[T]:
332         indent = ' ' * (2 * self.tree_depth)
333         if isinstance(node, Node):
334             _type = type_repr(node.type)
335             out(f'{indent}{_type}', fg='yellow')
336             self.tree_depth += 1
337             for child in node.children:
338                 yield from self.visit(child)
339
340             self.tree_depth -= 1
341             out(f'{indent}/{_type}', fg='yellow', bold=False)
342         else:
343             _type = token.tok_name.get(node.type, str(node.type))
344             out(f'{indent}{_type}', fg='blue', nl=False)
345             if node.prefix:
346                 # We don't have to handle prefixes for `Node` objects since
347                 # that delegates to the first child anyway.
348                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
349             out(f' {node.value!r}', fg='blue', bold=False)
350
351     @classmethod
352     def show(cls, code: str) -> None:
353         """Pretty-prints a given string of `code`.
354
355         Convenience method for debugging.
356         """
357         v: DebugVisitor[None] = DebugVisitor()
358         list(v.visit(lib2to3_parse(code)))
359
360
361 KEYWORDS = set(keyword.kwlist)
362 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
363 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
364 STATEMENT = {
365     syms.if_stmt,
366     syms.while_stmt,
367     syms.for_stmt,
368     syms.try_stmt,
369     syms.except_clause,
370     syms.with_stmt,
371     syms.funcdef,
372     syms.classdef,
373 }
374 STANDALONE_COMMENT = 153
375 LOGIC_OPERATORS = {'and', 'or'}
376 COMPARATORS = {
377     token.LESS,
378     token.GREATER,
379     token.EQEQUAL,
380     token.NOTEQUAL,
381     token.LESSEQUAL,
382     token.GREATEREQUAL,
383 }
384 MATH_OPERATORS = {
385     token.PLUS,
386     token.MINUS,
387     token.STAR,
388     token.SLASH,
389     token.VBAR,
390     token.AMPER,
391     token.PERCENT,
392     token.CIRCUMFLEX,
393     token.TILDE,
394     token.LEFTSHIFT,
395     token.RIGHTSHIFT,
396     token.DOUBLESTAR,
397     token.DOUBLESLASH,
398 }
399 COMPREHENSION_PRIORITY = 20
400 COMMA_PRIORITY = 10
401 LOGIC_PRIORITY = 5
402 STRING_PRIORITY = 4
403 COMPARATOR_PRIORITY = 3
404 MATH_PRIORITY = 1
405
406
407 @dataclass
408 class BracketTracker:
409     depth: int = 0
410     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
411     delimiters: Dict[LeafID, Priority] = Factory(dict)
412     previous: Optional[Leaf] = None
413
414     def mark(self, leaf: Leaf) -> None:
415         if leaf.type == token.COMMENT:
416             return
417
418         if leaf.type in CLOSING_BRACKETS:
419             self.depth -= 1
420             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
421             leaf.opening_bracket = opening_bracket
422         leaf.bracket_depth = self.depth
423         if self.depth == 0:
424             delim = is_delimiter(leaf)
425             if delim:
426                 self.delimiters[id(leaf)] = delim
427             elif self.previous is not None:
428                 if leaf.type == token.STRING and self.previous.type == token.STRING:
429                     self.delimiters[id(self.previous)] = STRING_PRIORITY
430                 elif (
431                     leaf.type == token.NAME
432                     and leaf.value == 'for'
433                     and leaf.parent
434                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
435                 ):
436                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
437                 elif (
438                     leaf.type == token.NAME
439                     and leaf.value == 'if'
440                     and leaf.parent
441                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
442                 ):
443                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
444                 elif (
445                     leaf.type == token.NAME
446                     and leaf.value in LOGIC_OPERATORS
447                     and leaf.parent
448                 ):
449                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
450         if leaf.type in OPENING_BRACKETS:
451             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
452             self.depth += 1
453         self.previous = leaf
454
455     def any_open_brackets(self) -> bool:
456         """Returns True if there is an yet unmatched open bracket on the line."""
457         return bool(self.bracket_match)
458
459     def max_priority(self, exclude: Iterable[LeafID] = ()) -> int:
460         """Returns the highest priority of a delimiter found on the line.
461
462         Values are consistent with what `is_delimiter()` returns.
463         """
464         return max(v for k, v in self.delimiters.items() if k not in exclude)
465
466
467 @dataclass
468 class Line:
469     depth: int = 0
470     leaves: List[Leaf] = Factory(list)
471     comments: Dict[LeafID, Leaf] = Factory(dict)
472     bracket_tracker: BracketTracker = Factory(BracketTracker)
473     inside_brackets: bool = False
474     has_for: bool = False
475     _for_loop_variable: bool = False
476
477     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
478         has_value = leaf.value.strip()
479         if not has_value:
480             return
481
482         if self.leaves and not preformatted:
483             # Note: at this point leaf.prefix should be empty except for
484             # imports, for which we only preserve newlines.
485             leaf.prefix += whitespace(leaf)
486         if self.inside_brackets or not preformatted:
487             self.maybe_decrement_after_for_loop_variable(leaf)
488             self.bracket_tracker.mark(leaf)
489             self.maybe_remove_trailing_comma(leaf)
490             self.maybe_increment_for_loop_variable(leaf)
491             if self.maybe_adapt_standalone_comment(leaf):
492                 return
493
494         if not self.append_comment(leaf):
495             self.leaves.append(leaf)
496
497     @property
498     def is_comment(self) -> bool:
499         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
500
501     @property
502     def is_decorator(self) -> bool:
503         return bool(self) and self.leaves[0].type == token.AT
504
505     @property
506     def is_import(self) -> bool:
507         return bool(self) and is_import(self.leaves[0])
508
509     @property
510     def is_class(self) -> bool:
511         return (
512             bool(self)
513             and self.leaves[0].type == token.NAME
514             and self.leaves[0].value == 'class'
515         )
516
517     @property
518     def is_def(self) -> bool:
519         """Also returns True for async defs."""
520         try:
521             first_leaf = self.leaves[0]
522         except IndexError:
523             return False
524
525         try:
526             second_leaf: Optional[Leaf] = self.leaves[1]
527         except IndexError:
528             second_leaf = None
529         return (
530             (first_leaf.type == token.NAME and first_leaf.value == 'def')
531             or (
532                 first_leaf.type == token.ASYNC
533                 and second_leaf is not None
534                 and second_leaf.type == token.NAME
535                 and second_leaf.value == 'def'
536             )
537         )
538
539     @property
540     def is_flow_control(self) -> bool:
541         return (
542             bool(self)
543             and self.leaves[0].type == token.NAME
544             and self.leaves[0].value in FLOW_CONTROL
545         )
546
547     @property
548     def is_yield(self) -> bool:
549         return (
550             bool(self)
551             and self.leaves[0].type == token.NAME
552             and self.leaves[0].value == 'yield'
553         )
554
555     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
556         if not (
557             self.leaves
558             and self.leaves[-1].type == token.COMMA
559             and closing.type in CLOSING_BRACKETS
560         ):
561             return False
562
563         if closing.type == token.RBRACE:
564             self.leaves.pop()
565             return True
566
567         if closing.type == token.RSQB:
568             comma = self.leaves[-1]
569             if comma.parent and comma.parent.type == syms.listmaker:
570                 self.leaves.pop()
571                 return True
572
573         # For parens let's check if it's safe to remove the comma.  If the
574         # trailing one is the only one, we might mistakenly change a tuple
575         # into a different type by removing the comma.
576         depth = closing.bracket_depth + 1
577         commas = 0
578         opening = closing.opening_bracket
579         for _opening_index, leaf in enumerate(self.leaves):
580             if leaf is opening:
581                 break
582
583         else:
584             return False
585
586         for leaf in self.leaves[_opening_index + 1:]:
587             if leaf is closing:
588                 break
589
590             bracket_depth = leaf.bracket_depth
591             if bracket_depth == depth and leaf.type == token.COMMA:
592                 commas += 1
593                 if leaf.parent and leaf.parent.type == syms.arglist:
594                     commas += 1
595                     break
596
597         if commas > 1:
598             self.leaves.pop()
599             return True
600
601         return False
602
603     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
604         """In a for loop, or comprehension, the variables are often unpacks.
605
606         To avoid splitting on the comma in this situation, we will increase
607         the depth of tokens between `for` and `in`.
608         """
609         if leaf.type == token.NAME and leaf.value == 'for':
610             self.has_for = True
611             self.bracket_tracker.depth += 1
612             self._for_loop_variable = True
613             return True
614
615         return False
616
617     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
618         # See `maybe_increment_for_loop_variable` above for explanation.
619         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
620             self.bracket_tracker.depth -= 1
621             self._for_loop_variable = False
622             return True
623
624         return False
625
626     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
627         """Hack a standalone comment to act as a trailing comment for line splitting.
628
629         If this line has brackets and a standalone `comment`, we need to adapt
630         it to be able to still reformat the line.
631
632         This is not perfect, the line to which the standalone comment gets
633         appended will appear "too long" when splitting.
634         """
635         if not (
636             comment.type == STANDALONE_COMMENT
637             and self.bracket_tracker.any_open_brackets()
638         ):
639             return False
640
641         comment.type = token.COMMENT
642         comment.prefix = '\n' + '    ' * (self.depth + 1)
643         return self.append_comment(comment)
644
645     def append_comment(self, comment: Leaf) -> bool:
646         if comment.type != token.COMMENT:
647             return False
648
649         try:
650             after = id(self.last_non_delimiter())
651         except LookupError:
652             comment.type = STANDALONE_COMMENT
653             comment.prefix = ''
654             return False
655
656         else:
657             if after in self.comments:
658                 self.comments[after].value += str(comment)
659             else:
660                 self.comments[after] = comment
661             return True
662
663     def last_non_delimiter(self) -> Leaf:
664         for i in range(len(self.leaves)):
665             last = self.leaves[-i - 1]
666             if not is_delimiter(last):
667                 return last
668
669         raise LookupError("No non-delimiters found")
670
671     def __str__(self) -> str:
672         if not self:
673             return '\n'
674
675         indent = '    ' * self.depth
676         leaves = iter(self.leaves)
677         first = next(leaves)
678         res = f'{first.prefix}{indent}{first.value}'
679         for leaf in leaves:
680             res += str(leaf)
681         for comment in self.comments.values():
682             res += str(comment)
683         return res + '\n'
684
685     def __bool__(self) -> bool:
686         return bool(self.leaves or self.comments)
687
688
689 class UnformattedLines(Line):
690
691     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
692         try:
693             list(generate_comments(leaf))
694         except FormatOn as f_on:
695             self.leaves.append(f_on.leaf_from_consumed(leaf))
696             raise
697
698         self.leaves.append(leaf)
699         if leaf.type == token.INDENT:
700             self.depth += 1
701         elif leaf.type == token.DEDENT:
702             self.depth -= 1
703
704     def append_comment(self, comment: Leaf) -> bool:
705         raise NotImplementedError("Unformatted lines don't store comments separately.")
706
707     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
708         return False
709
710     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
711         return False
712
713     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
714         return False
715
716     def __str__(self) -> str:
717         if not self:
718             return '\n'
719
720         res = ''
721         for leaf in self.leaves:
722             res += str(leaf)
723         return res
724
725
726 @dataclass
727 class EmptyLineTracker:
728     """Provides a stateful method that returns the number of potential extra
729     empty lines needed before and after the currently processed line.
730
731     Note: this tracker works on lines that haven't been split yet.  It assumes
732     the prefix of the first leaf consists of optional newlines.  Those newlines
733     are consumed by `maybe_empty_lines()` and included in the computation.
734     """
735     previous_line: Optional[Line] = None
736     previous_after: int = 0
737     previous_defs: List[int] = Factory(list)
738
739     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
740         """Returns the number of extra empty lines before and after the `current_line`.
741
742         This is for separating `def`, `async def` and `class` with extra empty lines
743         (two on module-level), as well as providing an extra empty line after flow
744         control keywords to make them more prominent.
745         """
746         if isinstance(current_line, UnformattedLines):
747             return 0, 0
748
749         before, after = self._maybe_empty_lines(current_line)
750         before -= self.previous_after
751         self.previous_after = after
752         self.previous_line = current_line
753         return before, after
754
755     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
756         max_allowed = 1
757         if current_line.is_comment and current_line.depth == 0:
758             max_allowed = 2
759         if current_line.leaves:
760             # Consume the first leaf's extra newlines.
761             first_leaf = current_line.leaves[0]
762             before = first_leaf.prefix.count('\n')
763             before = min(before, max(before, max_allowed))
764             first_leaf.prefix = ''
765         else:
766             before = 0
767         depth = current_line.depth
768         while self.previous_defs and self.previous_defs[-1] >= depth:
769             self.previous_defs.pop()
770             before = 1 if depth else 2
771         is_decorator = current_line.is_decorator
772         if is_decorator or current_line.is_def or current_line.is_class:
773             if not is_decorator:
774                 self.previous_defs.append(depth)
775             if self.previous_line is None:
776                 # Don't insert empty lines before the first line in the file.
777                 return 0, 0
778
779             if self.previous_line and self.previous_line.is_decorator:
780                 # Don't insert empty lines between decorators.
781                 return 0, 0
782
783             newlines = 2
784             if current_line.depth:
785                 newlines -= 1
786             return newlines, 0
787
788         if current_line.is_flow_control:
789             return before, 1
790
791         if (
792             self.previous_line
793             and self.previous_line.is_import
794             and not current_line.is_import
795             and depth == self.previous_line.depth
796         ):
797             return (before or 1), 0
798
799         if (
800             self.previous_line
801             and self.previous_line.is_yield
802             and (not current_line.is_yield or depth != self.previous_line.depth)
803         ):
804             return (before or 1), 0
805
806         return before, 0
807
808
809 @dataclass
810 class LineGenerator(Visitor[Line]):
811     """Generates reformatted Line objects.  Empty lines are not emitted.
812
813     Note: destroys the tree it's visiting by mutating prefixes of its leaves
814     in ways that will no longer stringify to valid Python code on the tree.
815     """
816     current_line: Line = Factory(Line)
817
818     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
819         """Generate a line.
820
821         If the line is empty, only emit if it makes sense.
822         If the line is too long, split it first and then generate.
823
824         If any lines were generated, set up a new current_line.
825         """
826         if not self.current_line:
827             if self.current_line.__class__ == type:
828                 self.current_line.depth += indent
829             else:
830                 self.current_line = type(depth=self.current_line.depth + indent)
831             return  # Line is empty, don't emit. Creating a new one unnecessary.
832
833         complete_line = self.current_line
834         self.current_line = type(depth=complete_line.depth + indent)
835         yield complete_line
836
837     def visit(self, node: LN) -> Iterator[Line]:
838         """High-level entry point to the visitor."""
839         if isinstance(self.current_line, UnformattedLines):
840             # File contained `# fmt: off`
841             yield from self.visit_unformatted(node)
842
843         else:
844             yield from super().visit(node)
845
846     def visit_default(self, node: LN) -> Iterator[Line]:
847         if isinstance(node, Leaf):
848             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
849             try:
850                 for comment in generate_comments(node):
851                     if any_open_brackets:
852                         # any comment within brackets is subject to splitting
853                         self.current_line.append(comment)
854                     elif comment.type == token.COMMENT:
855                         # regular trailing comment
856                         self.current_line.append(comment)
857                         yield from self.line()
858
859                     else:
860                         # regular standalone comment
861                         yield from self.line()
862
863                         self.current_line.append(comment)
864                         yield from self.line()
865
866             except FormatOff as f_off:
867                 f_off.trim_prefix(node)
868                 yield from self.line(type=UnformattedLines)
869                 yield from self.visit(node)
870
871             except FormatOn as f_on:
872                 # This only happens here if somebody says "fmt: on" multiple
873                 # times in a row.
874                 f_on.trim_prefix(node)
875                 yield from self.visit_default(node)
876
877             else:
878                 normalize_prefix(node, inside_brackets=any_open_brackets)
879                 if node.type not in WHITESPACE:
880                     self.current_line.append(node)
881         yield from super().visit_default(node)
882
883     def visit_INDENT(self, node: Node) -> Iterator[Line]:
884         yield from self.line(+1)
885         yield from self.visit_default(node)
886
887     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
888         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
889         yield from self.line(-1)
890
891     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
892         """Visit a statement.
893
894         The relevant Python language keywords for this statement are NAME leaves
895         within it.
896         """
897         for child in node.children:
898             if child.type == token.NAME and child.value in keywords:  # type: ignore
899                 yield from self.line()
900
901             yield from self.visit(child)
902
903     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
904         """A statement without nested statements."""
905         is_suite_like = node.parent and node.parent.type in STATEMENT
906         if is_suite_like:
907             yield from self.line(+1)
908             yield from self.visit_default(node)
909             yield from self.line(-1)
910
911         else:
912             yield from self.line()
913             yield from self.visit_default(node)
914
915     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
916         yield from self.line()
917
918         children = iter(node.children)
919         for child in children:
920             yield from self.visit(child)
921
922             if child.type == token.ASYNC:
923                 break
924
925         internal_stmt = next(children)
926         for child in internal_stmt.children:
927             yield from self.visit(child)
928
929     def visit_decorators(self, node: Node) -> Iterator[Line]:
930         for child in node.children:
931             yield from self.line()
932             yield from self.visit(child)
933
934     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
935         yield from self.line()
936
937     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
938         yield from self.visit_default(leaf)
939         yield from self.line()
940
941     def visit_unformatted(self, node: LN) -> Iterator[Line]:
942         if isinstance(node, Node):
943             for child in node.children:
944                 yield from self.visit(child)
945
946         else:
947             try:
948                 self.current_line.append(node)
949             except FormatOn as f_on:
950                 f_on.trim_prefix(node)
951                 yield from self.line()
952                 yield from self.visit(node)
953
954     def __attrs_post_init__(self) -> None:
955         """You are in a twisty little maze of passages."""
956         v = self.visit_stmt
957         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
958         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
959         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
960         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
961         self.visit_except_clause = partial(v, keywords={'except'})
962         self.visit_funcdef = partial(v, keywords={'def'})
963         self.visit_with_stmt = partial(v, keywords={'with'})
964         self.visit_classdef = partial(v, keywords={'class'})
965         self.visit_async_funcdef = self.visit_async_stmt
966         self.visit_decorated = self.visit_decorators
967
968
969 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
970 OPENING_BRACKETS = set(BRACKET.keys())
971 CLOSING_BRACKETS = set(BRACKET.values())
972 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
973 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
974
975
976 def whitespace(leaf: Leaf) -> str:  # noqa C901
977     """Return whitespace prefix if needed for the given `leaf`."""
978     NO = ''
979     SPACE = ' '
980     DOUBLESPACE = '  '
981     t = leaf.type
982     p = leaf.parent
983     v = leaf.value
984     if t in ALWAYS_NO_SPACE:
985         return NO
986
987     if t == token.COMMENT:
988         return DOUBLESPACE
989
990     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
991     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
992         return NO
993
994     prev = leaf.prev_sibling
995     if not prev:
996         prevp = preceding_leaf(p)
997         if not prevp or prevp.type in OPENING_BRACKETS:
998             return NO
999
1000         if t == token.COLON:
1001             return SPACE if prevp.type == token.COMMA else NO
1002
1003         if prevp.type == token.EQUAL:
1004             if prevp.parent:
1005                 if prevp.parent.type in {
1006                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1007                 }:
1008                     return NO
1009
1010                 elif prevp.parent.type == syms.typedargslist:
1011                     # A bit hacky: if the equal sign has whitespace, it means we
1012                     # previously found it's a typed argument.  So, we're using
1013                     # that, too.
1014                     return prevp.prefix
1015
1016         elif prevp.type == token.DOUBLESTAR:
1017             if prevp.parent and prevp.parent.type in {
1018                 syms.arglist,
1019                 syms.argument,
1020                 syms.dictsetmaker,
1021                 syms.parameters,
1022                 syms.typedargslist,
1023                 syms.varargslist,
1024             }:
1025                 return NO
1026
1027         elif prevp.type == token.COLON:
1028             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1029                 return NO
1030
1031         elif (
1032             prevp.parent
1033             and prevp.parent.type in {syms.factor, syms.star_expr}
1034             and prevp.type in MATH_OPERATORS
1035         ):
1036             return NO
1037
1038         elif (
1039             prevp.type == token.RIGHTSHIFT
1040             and prevp.parent
1041             and prevp.parent.type == syms.shift_expr
1042             and prevp.prev_sibling
1043             and prevp.prev_sibling.type == token.NAME
1044             and prevp.prev_sibling.value == 'print'  # type: ignore
1045         ):
1046             # Python 2 print chevron
1047             return NO
1048
1049     elif prev.type in OPENING_BRACKETS:
1050         return NO
1051
1052     if p.type in {syms.parameters, syms.arglist}:
1053         # untyped function signatures or calls
1054         if t == token.RPAR:
1055             return NO
1056
1057         if not prev or prev.type != token.COMMA:
1058             return NO
1059
1060     elif p.type == syms.varargslist:
1061         # lambdas
1062         if t == token.RPAR:
1063             return NO
1064
1065         if prev and prev.type != token.COMMA:
1066             return NO
1067
1068     elif p.type == syms.typedargslist:
1069         # typed function signatures
1070         if not prev:
1071             return NO
1072
1073         if t == token.EQUAL:
1074             if prev.type != syms.tname:
1075                 return NO
1076
1077         elif prev.type == token.EQUAL:
1078             # A bit hacky: if the equal sign has whitespace, it means we
1079             # previously found it's a typed argument.  So, we're using that, too.
1080             return prev.prefix
1081
1082         elif prev.type != token.COMMA:
1083             return NO
1084
1085     elif p.type == syms.tname:
1086         # type names
1087         if not prev:
1088             prevp = preceding_leaf(p)
1089             if not prevp or prevp.type != token.COMMA:
1090                 return NO
1091
1092     elif p.type == syms.trailer:
1093         # attributes and calls
1094         if t == token.LPAR or t == token.RPAR:
1095             return NO
1096
1097         if not prev:
1098             if t == token.DOT:
1099                 prevp = preceding_leaf(p)
1100                 if not prevp or prevp.type != token.NUMBER:
1101                     return NO
1102
1103             elif t == token.LSQB:
1104                 return NO
1105
1106         elif prev.type != token.COMMA:
1107             return NO
1108
1109     elif p.type == syms.argument:
1110         # single argument
1111         if t == token.EQUAL:
1112             return NO
1113
1114         if not prev:
1115             prevp = preceding_leaf(p)
1116             if not prevp or prevp.type == token.LPAR:
1117                 return NO
1118
1119         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1120             return NO
1121
1122     elif p.type == syms.decorator:
1123         # decorators
1124         return NO
1125
1126     elif p.type == syms.dotted_name:
1127         if prev:
1128             return NO
1129
1130         prevp = preceding_leaf(p)
1131         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1132             return NO
1133
1134     elif p.type == syms.classdef:
1135         if t == token.LPAR:
1136             return NO
1137
1138         if prev and prev.type == token.LPAR:
1139             return NO
1140
1141     elif p.type == syms.subscript:
1142         # indexing
1143         if not prev:
1144             assert p.parent is not None, "subscripts are always parented"
1145             if p.parent.type == syms.subscriptlist:
1146                 return SPACE
1147
1148             return NO
1149
1150         else:
1151             return NO
1152
1153     elif p.type == syms.atom:
1154         if prev and t == token.DOT:
1155             # dots, but not the first one.
1156             return NO
1157
1158     elif (
1159         p.type == syms.listmaker
1160         or p.type == syms.testlist_gexp
1161         or p.type == syms.subscriptlist
1162     ):
1163         # list interior, including unpacking
1164         if not prev:
1165             return NO
1166
1167     elif p.type == syms.dictsetmaker:
1168         # dict and set interior, including unpacking
1169         if not prev:
1170             return NO
1171
1172         if prev.type == token.DOUBLESTAR:
1173             return NO
1174
1175     elif p.type in {syms.factor, syms.star_expr}:
1176         # unary ops
1177         if not prev:
1178             prevp = preceding_leaf(p)
1179             if not prevp or prevp.type in OPENING_BRACKETS:
1180                 return NO
1181
1182             prevp_parent = prevp.parent
1183             assert prevp_parent is not None
1184             if prevp.type == token.COLON and prevp_parent.type in {
1185                 syms.subscript, syms.sliceop
1186             }:
1187                 return NO
1188
1189             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1190                 return NO
1191
1192         elif t == token.NAME or t == token.NUMBER:
1193             return NO
1194
1195     elif p.type == syms.import_from:
1196         if t == token.DOT:
1197             if prev and prev.type == token.DOT:
1198                 return NO
1199
1200         elif t == token.NAME:
1201             if v == 'import':
1202                 return SPACE
1203
1204             if prev and prev.type == token.DOT:
1205                 return NO
1206
1207     elif p.type == syms.sliceop:
1208         return NO
1209
1210     return SPACE
1211
1212
1213 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1214     """Returns the first leaf that precedes `node`, if any."""
1215     while node:
1216         res = node.prev_sibling
1217         if res:
1218             if isinstance(res, Leaf):
1219                 return res
1220
1221             try:
1222                 return list(res.leaves())[-1]
1223
1224             except IndexError:
1225                 return None
1226
1227         node = node.parent
1228     return None
1229
1230
1231 def is_delimiter(leaf: Leaf) -> int:
1232     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1233
1234     Higher numbers are higher priority.
1235     """
1236     if leaf.type == token.COMMA:
1237         return COMMA_PRIORITY
1238
1239     if leaf.type in COMPARATORS:
1240         return COMPARATOR_PRIORITY
1241
1242     if (
1243         leaf.type in MATH_OPERATORS
1244         and leaf.parent
1245         and leaf.parent.type not in {syms.factor, syms.star_expr}
1246     ):
1247         return MATH_PRIORITY
1248
1249     return 0
1250
1251
1252 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1253     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1254
1255     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1256     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1257     move because it does away with modifying the grammar to include all the
1258     possible places in which comments can be placed.
1259
1260     The sad consequence for us though is that comments don't "belong" anywhere.
1261     This is why this function generates simple parentless Leaf objects for
1262     comments.  We simply don't know what the correct parent should be.
1263
1264     No matter though, we can live without this.  We really only need to
1265     differentiate between inline and standalone comments.  The latter don't
1266     share the line with any code.
1267
1268     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1269     are emitted with a fake STANDALONE_COMMENT token identifier.
1270     """
1271     p = leaf.prefix
1272     if not p:
1273         return
1274
1275     if '#' not in p:
1276         return
1277
1278     consumed = 0
1279     nlines = 0
1280     for index, line in enumerate(p.split('\n')):
1281         consumed += len(line) + 1  # adding the length of the split '\n'
1282         line = line.lstrip()
1283         if not line:
1284             nlines += 1
1285         if not line.startswith('#'):
1286             continue
1287
1288         if index == 0 and leaf.type != token.ENDMARKER:
1289             comment_type = token.COMMENT  # simple trailing comment
1290         else:
1291             comment_type = STANDALONE_COMMENT
1292         comment = make_comment(line)
1293         yield Leaf(comment_type, comment, prefix='\n' * nlines)
1294
1295         if comment in {'# fmt: on', '# yapf: enable'}:
1296             raise FormatOn(consumed)
1297
1298         if comment in {'# fmt: off', '# yapf: disable'}:
1299             raise FormatOff(consumed)
1300
1301         nlines = 0
1302
1303
1304 def make_comment(content: str) -> str:
1305     content = content.rstrip()
1306     if not content:
1307         return '#'
1308
1309     if content[0] == '#':
1310         content = content[1:]
1311     if content and content[0] not in ' !:#':
1312         content = ' ' + content
1313     return '#' + content
1314
1315
1316 def split_line(
1317     line: Line, line_length: int, inner: bool = False, py36: bool = False
1318 ) -> Iterator[Line]:
1319     """Splits a `line` into potentially many lines.
1320
1321     They should fit in the allotted `line_length` but might not be able to.
1322     `inner` signifies that there were a pair of brackets somewhere around the
1323     current `line`, possibly transitively. This means we can fallback to splitting
1324     by delimiters if the LHS/RHS don't yield any results.
1325
1326     If `py36` is True, splitting may generate syntax that is only compatible
1327     with Python 3.6 and later.
1328     """
1329     if isinstance(line, UnformattedLines):
1330         yield line
1331         return
1332
1333     line_str = str(line).strip('\n')
1334     if len(line_str) <= line_length and '\n' not in line_str:
1335         yield line
1336         return
1337
1338     if line.is_def:
1339         split_funcs = [left_hand_split]
1340     elif line.inside_brackets:
1341         split_funcs = [delimiter_split]
1342         if '\n' not in line_str:
1343             # Only attempt RHS if we don't have multiline strings or comments
1344             # on this line.
1345             split_funcs.append(right_hand_split)
1346     else:
1347         split_funcs = [right_hand_split]
1348     for split_func in split_funcs:
1349         # We are accumulating lines in `result` because we might want to abort
1350         # mission and return the original line in the end, or attempt a different
1351         # split altogether.
1352         result: List[Line] = []
1353         try:
1354             for l in split_func(line, py36=py36):
1355                 if str(l).strip('\n') == line_str:
1356                     raise CannotSplit("Split function returned an unchanged result")
1357
1358                 result.extend(
1359                     split_line(l, line_length=line_length, inner=True, py36=py36)
1360                 )
1361         except CannotSplit as cs:
1362             continue
1363
1364         else:
1365             yield from result
1366             break
1367
1368     else:
1369         yield line
1370
1371
1372 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1373     """Split line into many lines, starting with the first matching bracket pair.
1374
1375     Note: this usually looks weird, only use this for function definitions.
1376     Prefer RHS otherwise.
1377     """
1378     head = Line(depth=line.depth)
1379     body = Line(depth=line.depth + 1, inside_brackets=True)
1380     tail = Line(depth=line.depth)
1381     tail_leaves: List[Leaf] = []
1382     body_leaves: List[Leaf] = []
1383     head_leaves: List[Leaf] = []
1384     current_leaves = head_leaves
1385     matching_bracket = None
1386     for leaf in line.leaves:
1387         if (
1388             current_leaves is body_leaves
1389             and leaf.type in CLOSING_BRACKETS
1390             and leaf.opening_bracket is matching_bracket
1391         ):
1392             current_leaves = tail_leaves if body_leaves else head_leaves
1393         current_leaves.append(leaf)
1394         if current_leaves is head_leaves:
1395             if leaf.type in OPENING_BRACKETS:
1396                 matching_bracket = leaf
1397                 current_leaves = body_leaves
1398     # Since body is a new indent level, remove spurious leading whitespace.
1399     if body_leaves:
1400         normalize_prefix(body_leaves[0], inside_brackets=True)
1401     # Build the new lines.
1402     for result, leaves in (
1403         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1404     ):
1405         for leaf in leaves:
1406             result.append(leaf, preformatted=True)
1407             comment_after = line.comments.get(id(leaf))
1408             if comment_after:
1409                 result.append(comment_after, preformatted=True)
1410     split_succeeded_or_raise(head, body, tail)
1411     for result in (head, body, tail):
1412         if result:
1413             yield result
1414
1415
1416 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1417     """Split line into many lines, starting with the last matching bracket pair."""
1418     head = Line(depth=line.depth)
1419     body = Line(depth=line.depth + 1, inside_brackets=True)
1420     tail = Line(depth=line.depth)
1421     tail_leaves: List[Leaf] = []
1422     body_leaves: List[Leaf] = []
1423     head_leaves: List[Leaf] = []
1424     current_leaves = tail_leaves
1425     opening_bracket = None
1426     for leaf in reversed(line.leaves):
1427         if current_leaves is body_leaves:
1428             if leaf is opening_bracket:
1429                 current_leaves = head_leaves if body_leaves else tail_leaves
1430         current_leaves.append(leaf)
1431         if current_leaves is tail_leaves:
1432             if leaf.type in CLOSING_BRACKETS:
1433                 opening_bracket = leaf.opening_bracket
1434                 current_leaves = body_leaves
1435     tail_leaves.reverse()
1436     body_leaves.reverse()
1437     head_leaves.reverse()
1438     # Since body is a new indent level, remove spurious leading whitespace.
1439     if body_leaves:
1440         normalize_prefix(body_leaves[0], inside_brackets=True)
1441     # Build the new lines.
1442     for result, leaves in (
1443         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1444     ):
1445         for leaf in leaves:
1446             result.append(leaf, preformatted=True)
1447             comment_after = line.comments.get(id(leaf))
1448             if comment_after:
1449                 result.append(comment_after, preformatted=True)
1450     split_succeeded_or_raise(head, body, tail)
1451     for result in (head, body, tail):
1452         if result:
1453             yield result
1454
1455
1456 def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1457     tail_len = len(str(tail).strip())
1458     if not body:
1459         if tail_len == 0:
1460             raise CannotSplit("Splitting brackets produced the same line")
1461
1462         elif tail_len < 3:
1463             raise CannotSplit(
1464                 f"Splitting brackets on an empty body to save "
1465                 f"{tail_len} characters is not worth it"
1466             )
1467
1468
1469 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1470     """Split according to delimiters of the highest priority.
1471
1472     This kind of split doesn't increase indentation.
1473     If `py36` is True, the split will add trailing commas also in function
1474     signatures that contain * and **.
1475     """
1476     try:
1477         last_leaf = line.leaves[-1]
1478     except IndexError:
1479         raise CannotSplit("Line empty")
1480
1481     delimiters = line.bracket_tracker.delimiters
1482     try:
1483         delimiter_priority = line.bracket_tracker.max_priority(exclude={id(last_leaf)})
1484     except ValueError:
1485         raise CannotSplit("No delimiters found")
1486
1487     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1488     lowest_depth = sys.maxsize
1489     trailing_comma_safe = True
1490     for leaf in line.leaves:
1491         current_line.append(leaf, preformatted=True)
1492         comment_after = line.comments.get(id(leaf))
1493         if comment_after:
1494             current_line.append(comment_after, preformatted=True)
1495         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1496         if (
1497             leaf.bracket_depth == lowest_depth
1498             and leaf.type == token.STAR
1499             or leaf.type == token.DOUBLESTAR
1500         ):
1501             trailing_comma_safe = trailing_comma_safe and py36
1502         leaf_priority = delimiters.get(id(leaf))
1503         if leaf_priority == delimiter_priority:
1504             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1505             yield current_line
1506
1507             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1508     if current_line:
1509         if (
1510             delimiter_priority == COMMA_PRIORITY
1511             and current_line.leaves[-1].type != token.COMMA
1512             and trailing_comma_safe
1513         ):
1514             current_line.append(Leaf(token.COMMA, ','))
1515         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1516         yield current_line
1517
1518
1519 def is_import(leaf: Leaf) -> bool:
1520     """Returns True if the given leaf starts an import statement."""
1521     p = leaf.parent
1522     t = leaf.type
1523     v = leaf.value
1524     return bool(
1525         t == token.NAME
1526         and (
1527             (v == 'import' and p and p.type == syms.import_name)
1528             or (v == 'from' and p and p.type == syms.import_from)
1529         )
1530     )
1531
1532
1533 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1534     """Leave existing extra newlines if not `inside_brackets`.
1535
1536     Remove everything else.  Note: don't use backslashes for formatting or
1537     you'll lose your voting rights.
1538     """
1539     if not inside_brackets:
1540         spl = leaf.prefix.split('#')
1541         if '\\' not in spl[0]:
1542             nl_count = spl[-1].count('\n')
1543             if len(spl) > 1:
1544                 nl_count -= 1
1545             leaf.prefix = '\n' * nl_count
1546             return
1547
1548     leaf.prefix = ''
1549
1550
1551 def is_python36(node: Node) -> bool:
1552     """Returns True if the current file is using Python 3.6+ features.
1553
1554     Currently looking for:
1555     - f-strings; and
1556     - trailing commas after * or ** in function signatures.
1557     """
1558     for n in node.pre_order():
1559         if n.type == token.STRING:
1560             value_head = n.value[:2]  # type: ignore
1561             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1562                 return True
1563
1564         elif (
1565             n.type == syms.typedargslist
1566             and n.children
1567             and n.children[-1].type == token.COMMA
1568         ):
1569             for ch in n.children:
1570                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1571                     return True
1572
1573     return False
1574
1575
1576 PYTHON_EXTENSIONS = {'.py'}
1577 BLACKLISTED_DIRECTORIES = {
1578     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1579 }
1580
1581
1582 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1583     for child in path.iterdir():
1584         if child.is_dir():
1585             if child.name in BLACKLISTED_DIRECTORIES:
1586                 continue
1587
1588             yield from gen_python_files_in_dir(child)
1589
1590         elif child.suffix in PYTHON_EXTENSIONS:
1591             yield child
1592
1593
1594 @dataclass
1595 class Report:
1596     """Provides a reformatting counter."""
1597     check: bool = False
1598     change_count: int = 0
1599     same_count: int = 0
1600     failure_count: int = 0
1601
1602     def done(self, src: Path, changed: bool) -> None:
1603         """Increment the counter for successful reformatting. Write out a message."""
1604         if changed:
1605             reformatted = 'would reformat' if self.check else 'reformatted'
1606             out(f'{reformatted} {src}')
1607             self.change_count += 1
1608         else:
1609             out(f'{src} already well formatted, good job.', bold=False)
1610             self.same_count += 1
1611
1612     def failed(self, src: Path, message: str) -> None:
1613         """Increment the counter for failed reformatting. Write out a message."""
1614         err(f'error: cannot format {src}: {message}')
1615         self.failure_count += 1
1616
1617     @property
1618     def return_code(self) -> int:
1619         """Which return code should the app use considering the current state."""
1620         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1621         # 126 we have special returncodes reserved by the shell.
1622         if self.failure_count:
1623             return 123
1624
1625         elif self.change_count and self.check:
1626             return 1
1627
1628         return 0
1629
1630     def __str__(self) -> str:
1631         """A color report of the current state.
1632
1633         Use `click.unstyle` to remove colors.
1634         """
1635         if self.check:
1636             reformatted = "would be reformatted"
1637             unchanged = "would be left unchanged"
1638             failed = "would fail to reformat"
1639         else:
1640             reformatted = "reformatted"
1641             unchanged = "left unchanged"
1642             failed = "failed to reformat"
1643         report = []
1644         if self.change_count:
1645             s = 's' if self.change_count > 1 else ''
1646             report.append(
1647                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1648             )
1649         if self.same_count:
1650             s = 's' if self.same_count > 1 else ''
1651             report.append(f'{self.same_count} file{s} {unchanged}')
1652         if self.failure_count:
1653             s = 's' if self.failure_count > 1 else ''
1654             report.append(
1655                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1656             )
1657         return ', '.join(report) + '.'
1658
1659
1660 def assert_equivalent(src: str, dst: str) -> None:
1661     """Raises AssertionError if `src` and `dst` aren't equivalent.
1662
1663     This is a temporary sanity check until Black becomes stable.
1664     """
1665
1666     import ast
1667     import traceback
1668
1669     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1670         """Simple visitor generating strings to compare ASTs by content."""
1671         yield f"{'  ' * depth}{node.__class__.__name__}("
1672
1673         for field in sorted(node._fields):
1674             try:
1675                 value = getattr(node, field)
1676             except AttributeError:
1677                 continue
1678
1679             yield f"{'  ' * (depth+1)}{field}="
1680
1681             if isinstance(value, list):
1682                 for item in value:
1683                     if isinstance(item, ast.AST):
1684                         yield from _v(item, depth + 2)
1685
1686             elif isinstance(value, ast.AST):
1687                 yield from _v(value, depth + 2)
1688
1689             else:
1690                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1691
1692         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1693
1694     try:
1695         src_ast = ast.parse(src)
1696     except Exception as exc:
1697         major, minor = sys.version_info[:2]
1698         raise AssertionError(
1699             f"cannot use --safe with this file; failed to parse source file "
1700             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1701             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1702         )
1703
1704     try:
1705         dst_ast = ast.parse(dst)
1706     except Exception as exc:
1707         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1708         raise AssertionError(
1709             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1710             f"Please report a bug on https://github.com/ambv/black/issues.  "
1711             f"This invalid output might be helpful: {log}"
1712         ) from None
1713
1714     src_ast_str = '\n'.join(_v(src_ast))
1715     dst_ast_str = '\n'.join(_v(dst_ast))
1716     if src_ast_str != dst_ast_str:
1717         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1718         raise AssertionError(
1719             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1720             f"the source.  "
1721             f"Please report a bug on https://github.com/ambv/black/issues.  "
1722             f"This diff might be helpful: {log}"
1723         ) from None
1724
1725
1726 def assert_stable(src: str, dst: str, line_length: int) -> None:
1727     """Raises AssertionError if `dst` reformats differently the second time.
1728
1729     This is a temporary sanity check until Black becomes stable.
1730     """
1731     newdst = format_str(dst, line_length=line_length)
1732     if dst != newdst:
1733         log = dump_to_file(
1734             diff(src, dst, 'source', 'first pass'),
1735             diff(dst, newdst, 'first pass', 'second pass'),
1736         )
1737         raise AssertionError(
1738             f"INTERNAL ERROR: Black produced different code on the second pass "
1739             f"of the formatter.  "
1740             f"Please report a bug on https://github.com/ambv/black/issues.  "
1741             f"This diff might be helpful: {log}"
1742         ) from None
1743
1744
1745 def dump_to_file(*output: str) -> str:
1746     """Dumps `output` to a temporary file. Returns path to the file."""
1747     import tempfile
1748
1749     with tempfile.NamedTemporaryFile(
1750         mode='w', prefix='blk_', suffix='.log', delete=False
1751     ) as f:
1752         for lines in output:
1753             f.write(lines)
1754             f.write('\n')
1755     return f.name
1756
1757
1758 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1759     """Returns a udiff string between strings `a` and `b`."""
1760     import difflib
1761
1762     a_lines = [line + '\n' for line in a.split('\n')]
1763     b_lines = [line + '\n' for line in b.split('\n')]
1764     return ''.join(
1765         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1766     )
1767
1768
1769 if __name__ == '__main__':
1770     main()