]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Show __str__ in UnformattedLines
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a4"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by `format_file()` when the reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by `left_hand_split()`, `right_hand_split()`, and `delimiter_split()`.
48     """
49
50
51 class FormatError(Exception):
52     """Base exception for `# fmt: on` and `# fmt: off` handling.
53
54     It holds the number of bytes of the prefix consumed before the format
55     control comment appeared.
56     """
57
58     def __init__(self, consumed: int) -> None:
59         super().__init__(consumed)
60         self.consumed = consumed
61
62     def trim_prefix(self, leaf: Leaf) -> None:
63         leaf.prefix = leaf.prefix[self.consumed:]
64
65     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
66         """Returns a new Leaf from the consumed part of the prefix."""
67         unformatted_prefix = leaf.prefix[:self.consumed]
68         return Leaf(token.NEWLINE, unformatted_prefix)
69
70
71 class FormatOn(FormatError):
72     """Found a comment like `# fmt: on` in the file."""
73
74
75 class FormatOff(FormatError):
76     """Found a comment like `# fmt: off` in the file."""
77
78
79 @click.command()
80 @click.option(
81     '-l',
82     '--line-length',
83     type=int,
84     default=DEFAULT_LINE_LENGTH,
85     help='How many character per line to allow.',
86     show_default=True,
87 )
88 @click.option(
89     '--check',
90     is_flag=True,
91     help=(
92         "Don't write back the files, just return the status.  Return code 0 "
93         "means nothing would change.  Return code 1 means some files would be "
94         "reformatted.  Return code 123 means there was an internal error."
95     ),
96 )
97 @click.option(
98     '--fast/--safe',
99     is_flag=True,
100     help='If --fast given, skip temporary sanity checks. [default: --safe]',
101 )
102 @click.version_option(version=__version__)
103 @click.argument(
104     'src',
105     nargs=-1,
106     type=click.Path(
107         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
108     ),
109 )
110 @click.pass_context
111 def main(
112     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
113 ) -> None:
114     """The uncompromising code formatter."""
115     sources: List[Path] = []
116     for s in src:
117         p = Path(s)
118         if p.is_dir():
119             sources.extend(gen_python_files_in_dir(p))
120         elif p.is_file():
121             # if a file was explicitly given, we don't care about its extension
122             sources.append(p)
123         elif s == '-':
124             sources.append(Path('-'))
125         else:
126             err(f'invalid path: {s}')
127     if len(sources) == 0:
128         ctx.exit(0)
129     elif len(sources) == 1:
130         p = sources[0]
131         report = Report(check=check)
132         try:
133             if not p.is_file() and str(p) == '-':
134                 changed = format_stdin_to_stdout(
135                     line_length=line_length, fast=fast, write_back=not check
136                 )
137             else:
138                 changed = format_file_in_place(
139                     p, line_length=line_length, fast=fast, write_back=not check
140                 )
141             report.done(p, changed)
142         except Exception as exc:
143             report.failed(p, str(exc))
144         ctx.exit(report.return_code)
145     else:
146         loop = asyncio.get_event_loop()
147         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
148         return_code = 1
149         try:
150             return_code = loop.run_until_complete(
151                 schedule_formatting(
152                     sources, line_length, not check, fast, loop, executor
153                 )
154             )
155         finally:
156             loop.close()
157             ctx.exit(return_code)
158
159
160 async def schedule_formatting(
161     sources: List[Path],
162     line_length: int,
163     write_back: bool,
164     fast: bool,
165     loop: BaseEventLoop,
166     executor: Executor,
167 ) -> int:
168     """Run formatting of `sources` in parallel using the provided `executor`.
169
170     (Use ProcessPoolExecutors for actual parallelism.)
171
172     `line_length`, `write_back`, and `fast` options are passed to
173     :func:`format_file_in_place`.
174     """
175     tasks = {
176         src: loop.run_in_executor(
177             executor, format_file_in_place, src, line_length, fast, write_back
178         )
179         for src in sources
180     }
181     await asyncio.wait(tasks.values())
182     cancelled = []
183     report = Report()
184     for src, task in tasks.items():
185         if not task.done():
186             report.failed(src, 'timed out, cancelling')
187             task.cancel()
188             cancelled.append(task)
189         elif task.exception():
190             report.failed(src, str(task.exception()))
191         else:
192             report.done(src, task.result())
193     if cancelled:
194         await asyncio.wait(cancelled, timeout=2)
195     out('All done! ✨ 🍰 ✨')
196     click.echo(str(report))
197     return report.return_code
198
199
200 def format_file_in_place(
201     src: Path, line_length: int, fast: bool, write_back: bool = False
202 ) -> bool:
203     """Format file under `src` path. Return True if changed.
204
205     If `write_back` is True, write reformatted code back to stdout.
206     `line_length` and `fast` options are passed to :func:`format_file_contents`.
207     """
208     with tokenize.open(src) as src_buffer:
209         src_contents = src_buffer.read()
210     try:
211         contents = format_file_contents(
212             src_contents, line_length=line_length, fast=fast
213         )
214     except NothingChanged:
215         return False
216
217     if write_back:
218         with open(src, "w", encoding=src_buffer.encoding) as f:
219             f.write(contents)
220     return True
221
222
223 def format_stdin_to_stdout(
224     line_length: int, fast: bool, write_back: bool = False
225 ) -> bool:
226     """Format file on stdin. Return True if changed.
227
228     If `write_back` is True, write reformatted code back to stdout.
229     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
230     """
231     contents = sys.stdin.read()
232     try:
233         contents = format_file_contents(contents, line_length=line_length, fast=fast)
234         return True
235
236     except NothingChanged:
237         return False
238
239     finally:
240         if write_back:
241             sys.stdout.write(contents)
242
243
244 def format_file_contents(
245     src_contents: str, line_length: int, fast: bool
246 ) -> FileContent:
247     """Reformats a file and returns its contents and encoding.
248
249     If `fast` is False, additionally confirm that the reformatted code is
250     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
251     `line_length` is passed to :func:`format_str`.
252     """
253     if src_contents.strip() == '':
254         raise NothingChanged
255
256     dst_contents = format_str(src_contents, line_length=line_length)
257     if src_contents == dst_contents:
258         raise NothingChanged
259
260     if not fast:
261         assert_equivalent(src_contents, dst_contents)
262         assert_stable(src_contents, dst_contents, line_length=line_length)
263     return dst_contents
264
265
266 def format_str(src_contents: str, line_length: int) -> FileContent:
267     """Reformats a string and returns new contents.
268
269     `line_length` determines how many characters per line are allowed.
270     """
271     src_node = lib2to3_parse(src_contents)
272     dst_contents = ""
273     lines = LineGenerator()
274     elt = EmptyLineTracker()
275     py36 = is_python36(src_node)
276     empty_line = Line()
277     after = 0
278     for current_line in lines.visit(src_node):
279         for _ in range(after):
280             dst_contents += str(empty_line)
281         before, after = elt.maybe_empty_lines(current_line)
282         for _ in range(before):
283             dst_contents += str(empty_line)
284         for line in split_line(current_line, line_length=line_length, py36=py36):
285             dst_contents += str(line)
286     return dst_contents
287
288
289 GRAMMARS = [
290     pygram.python_grammar_no_print_statement_no_exec_statement,
291     pygram.python_grammar_no_print_statement,
292     pygram.python_grammar_no_exec_statement,
293     pygram.python_grammar,
294 ]
295
296
297 def lib2to3_parse(src_txt: str) -> Node:
298     """Given a string with source, return the lib2to3 Node."""
299     grammar = pygram.python_grammar_no_print_statement
300     if src_txt[-1] != '\n':
301         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
302         src_txt += nl
303     for grammar in GRAMMARS:
304         drv = driver.Driver(grammar, pytree.convert)
305         try:
306             result = drv.parse_string(src_txt, True)
307             break
308
309         except ParseError as pe:
310             lineno, column = pe.context[1]
311             lines = src_txt.splitlines()
312             try:
313                 faulty_line = lines[lineno - 1]
314             except IndexError:
315                 faulty_line = "<line number missing in source>"
316             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
317     else:
318         raise exc from None
319
320     if isinstance(result, Leaf):
321         result = Node(syms.file_input, [result])
322     return result
323
324
325 def lib2to3_unparse(node: Node) -> str:
326     """Given a lib2to3 node, return its string representation."""
327     code = str(node)
328     return code
329
330
331 T = TypeVar('T')
332
333
334 class Visitor(Generic[T]):
335     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
336
337     def visit(self, node: LN) -> Iterator[T]:
338         """Main method to visit `node` and its children.
339
340         It tries to find a `visit_*()` method for the given `node.type`, like
341         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
342         If no dedicated `visit_*()` method is found, chooses `visit_default()`
343         instead.
344
345         Then yields objects of type `T` from the selected visitor.
346         """
347         if node.type < 256:
348             name = token.tok_name[node.type]
349         else:
350             name = type_repr(node.type)
351         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
352
353     def visit_default(self, node: LN) -> Iterator[T]:
354         """Default `visit_*()` implementation. Recurses to children of `node`."""
355         if isinstance(node, Node):
356             for child in node.children:
357                 yield from self.visit(child)
358
359
360 @dataclass
361 class DebugVisitor(Visitor[T]):
362     tree_depth: int = 0
363
364     def visit_default(self, node: LN) -> Iterator[T]:
365         indent = ' ' * (2 * self.tree_depth)
366         if isinstance(node, Node):
367             _type = type_repr(node.type)
368             out(f'{indent}{_type}', fg='yellow')
369             self.tree_depth += 1
370             for child in node.children:
371                 yield from self.visit(child)
372
373             self.tree_depth -= 1
374             out(f'{indent}/{_type}', fg='yellow', bold=False)
375         else:
376             _type = token.tok_name.get(node.type, str(node.type))
377             out(f'{indent}{_type}', fg='blue', nl=False)
378             if node.prefix:
379                 # We don't have to handle prefixes for `Node` objects since
380                 # that delegates to the first child anyway.
381                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
382             out(f' {node.value!r}', fg='blue', bold=False)
383
384     @classmethod
385     def show(cls, code: str) -> None:
386         """Pretty-prints a given string of `code`.
387
388         Convenience method for debugging.
389         """
390         v: DebugVisitor[None] = DebugVisitor()
391         list(v.visit(lib2to3_parse(code)))
392
393
394 KEYWORDS = set(keyword.kwlist)
395 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
396 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
397 STATEMENT = {
398     syms.if_stmt,
399     syms.while_stmt,
400     syms.for_stmt,
401     syms.try_stmt,
402     syms.except_clause,
403     syms.with_stmt,
404     syms.funcdef,
405     syms.classdef,
406 }
407 STANDALONE_COMMENT = 153
408 LOGIC_OPERATORS = {'and', 'or'}
409 COMPARATORS = {
410     token.LESS,
411     token.GREATER,
412     token.EQEQUAL,
413     token.NOTEQUAL,
414     token.LESSEQUAL,
415     token.GREATEREQUAL,
416 }
417 MATH_OPERATORS = {
418     token.PLUS,
419     token.MINUS,
420     token.STAR,
421     token.SLASH,
422     token.VBAR,
423     token.AMPER,
424     token.PERCENT,
425     token.CIRCUMFLEX,
426     token.TILDE,
427     token.LEFTSHIFT,
428     token.RIGHTSHIFT,
429     token.DOUBLESTAR,
430     token.DOUBLESLASH,
431 }
432 COMPREHENSION_PRIORITY = 20
433 COMMA_PRIORITY = 10
434 LOGIC_PRIORITY = 5
435 STRING_PRIORITY = 4
436 COMPARATOR_PRIORITY = 3
437 MATH_PRIORITY = 1
438
439
440 @dataclass
441 class BracketTracker:
442     """Keeps track of brackets on a line."""
443
444     depth: int = 0
445     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
446     delimiters: Dict[LeafID, Priority] = Factory(dict)
447     previous: Optional[Leaf] = None
448
449     def mark(self, leaf: Leaf) -> None:
450         """Marks `leaf` with bracket-related metadata. Keeps track of delimiters.
451
452         All leaves receive an int `bracket_depth` field that stores how deep
453         within brackets a given leaf is. 0 means there are no enclosing brackets
454         that started on this line.
455
456         If a leaf is itself a closing bracket, it receives an `opening_bracket`
457         field that it forms a pair with. This is a one-directional link to
458         avoid reference cycles.
459
460         If a leaf is a delimiter (a token on which Black can split the line if
461         needed) and it's on depth 0, its `id()` is stored in the tracker's
462         `delimiters` field.
463         """
464         if leaf.type == token.COMMENT:
465             return
466
467         if leaf.type in CLOSING_BRACKETS:
468             self.depth -= 1
469             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
470             leaf.opening_bracket = opening_bracket
471         leaf.bracket_depth = self.depth
472         if self.depth == 0:
473             delim = is_delimiter(leaf)
474             if delim:
475                 self.delimiters[id(leaf)] = delim
476             elif self.previous is not None:
477                 if leaf.type == token.STRING and self.previous.type == token.STRING:
478                     self.delimiters[id(self.previous)] = STRING_PRIORITY
479                 elif (
480                     leaf.type == token.NAME
481                     and leaf.value == 'for'
482                     and leaf.parent
483                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
484                 ):
485                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
486                 elif (
487                     leaf.type == token.NAME
488                     and leaf.value == 'if'
489                     and leaf.parent
490                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
491                 ):
492                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
493                 elif (
494                     leaf.type == token.NAME
495                     and leaf.value in LOGIC_OPERATORS
496                     and leaf.parent
497                 ):
498                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
499         if leaf.type in OPENING_BRACKETS:
500             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
501             self.depth += 1
502         self.previous = leaf
503
504     def any_open_brackets(self) -> bool:
505         """Returns True if there is an yet unmatched open bracket on the line."""
506         return bool(self.bracket_match)
507
508     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
509         """Returns the highest priority of a delimiter found on the line.
510
511         Values are consistent with what `is_delimiter()` returns.
512         """
513         return max(v for k, v in self.delimiters.items() if k not in exclude)
514
515
516 @dataclass
517 class Line:
518     """Holds leaves and comments. Can be printed with `str(line)`."""
519
520     depth: int = 0
521     leaves: List[Leaf] = Factory(list)
522     comments: Dict[LeafID, Leaf] = Factory(dict)
523     bracket_tracker: BracketTracker = Factory(BracketTracker)
524     inside_brackets: bool = False
525     has_for: bool = False
526     _for_loop_variable: bool = False
527
528     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
529         """Add a new `leaf` to the end of the line.
530
531         Unless `preformatted` is True, the `leaf` will receive a new consistent
532         whitespace prefix and metadata applied by :class:`BracketTracker`.
533         Trailing commas are maybe removed, unpacked for loop variables are
534         demoted from being delimiters.
535
536         Inline comments are put aside.
537         """
538         has_value = leaf.value.strip()
539         if not has_value:
540             return
541
542         if self.leaves and not preformatted:
543             # Note: at this point leaf.prefix should be empty except for
544             # imports, for which we only preserve newlines.
545             leaf.prefix += whitespace(leaf)
546         if self.inside_brackets or not preformatted:
547             self.maybe_decrement_after_for_loop_variable(leaf)
548             self.bracket_tracker.mark(leaf)
549             self.maybe_remove_trailing_comma(leaf)
550             self.maybe_increment_for_loop_variable(leaf)
551             if self.maybe_adapt_standalone_comment(leaf):
552                 return
553
554         if not self.append_comment(leaf):
555             self.leaves.append(leaf)
556
557     @property
558     def is_comment(self) -> bool:
559         """Is this line a standalone comment?"""
560         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
561
562     @property
563     def is_decorator(self) -> bool:
564         """Is this line a decorator?"""
565         return bool(self) and self.leaves[0].type == token.AT
566
567     @property
568     def is_import(self) -> bool:
569         """Is this an import line?"""
570         return bool(self) and is_import(self.leaves[0])
571
572     @property
573     def is_class(self) -> bool:
574         """Is this a class definition?"""
575         return (
576             bool(self)
577             and self.leaves[0].type == token.NAME
578             and self.leaves[0].value == 'class'
579         )
580
581     @property
582     def is_def(self) -> bool:
583         """Is this a function definition? (Also returns True for async defs.)"""
584         try:
585             first_leaf = self.leaves[0]
586         except IndexError:
587             return False
588
589         try:
590             second_leaf: Optional[Leaf] = self.leaves[1]
591         except IndexError:
592             second_leaf = None
593         return (
594             (first_leaf.type == token.NAME and first_leaf.value == 'def')
595             or (
596                 first_leaf.type == token.ASYNC
597                 and second_leaf is not None
598                 and second_leaf.type == token.NAME
599                 and second_leaf.value == 'def'
600             )
601         )
602
603     @property
604     def is_flow_control(self) -> bool:
605         """Is this a flow control statement?
606
607         Those are `return`, `raise`, `break`, and `continue`.
608         """
609         return (
610             bool(self)
611             and self.leaves[0].type == token.NAME
612             and self.leaves[0].value in FLOW_CONTROL
613         )
614
615     @property
616     def is_yield(self) -> bool:
617         """Is this a yield statement?"""
618         return (
619             bool(self)
620             and self.leaves[0].type == token.NAME
621             and self.leaves[0].value == 'yield'
622         )
623
624     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
625         """Remove trailing comma if there is one and it's safe."""
626         if not (
627             self.leaves
628             and self.leaves[-1].type == token.COMMA
629             and closing.type in CLOSING_BRACKETS
630         ):
631             return False
632
633         if closing.type == token.RBRACE:
634             self.leaves.pop()
635             return True
636
637         if closing.type == token.RSQB:
638             comma = self.leaves[-1]
639             if comma.parent and comma.parent.type == syms.listmaker:
640                 self.leaves.pop()
641                 return True
642
643         # For parens let's check if it's safe to remove the comma.  If the
644         # trailing one is the only one, we might mistakenly change a tuple
645         # into a different type by removing the comma.
646         depth = closing.bracket_depth + 1
647         commas = 0
648         opening = closing.opening_bracket
649         for _opening_index, leaf in enumerate(self.leaves):
650             if leaf is opening:
651                 break
652
653         else:
654             return False
655
656         for leaf in self.leaves[_opening_index + 1:]:
657             if leaf is closing:
658                 break
659
660             bracket_depth = leaf.bracket_depth
661             if bracket_depth == depth and leaf.type == token.COMMA:
662                 commas += 1
663                 if leaf.parent and leaf.parent.type == syms.arglist:
664                     commas += 1
665                     break
666
667         if commas > 1:
668             self.leaves.pop()
669             return True
670
671         return False
672
673     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
674         """In a for loop, or comprehension, the variables are often unpacks.
675
676         To avoid splitting on the comma in this situation, we will increase
677         the depth of tokens between `for` and `in`.
678         """
679         if leaf.type == token.NAME and leaf.value == 'for':
680             self.has_for = True
681             self.bracket_tracker.depth += 1
682             self._for_loop_variable = True
683             return True
684
685         return False
686
687     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
688         """See `maybe_increment_for_loop_variable` above for explanation."""
689         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
690             self.bracket_tracker.depth -= 1
691             self._for_loop_variable = False
692             return True
693
694         return False
695
696     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
697         """Hack a standalone comment to act as a trailing comment for line splitting.
698
699         If this line has brackets and a standalone `comment`, we need to adapt
700         it to be able to still reformat the line.
701
702         This is not perfect, the line to which the standalone comment gets
703         appended will appear "too long" when splitting.
704         """
705         if not (
706             comment.type == STANDALONE_COMMENT
707             and self.bracket_tracker.any_open_brackets()
708         ):
709             return False
710
711         comment.type = token.COMMENT
712         comment.prefix = '\n' + '    ' * (self.depth + 1)
713         return self.append_comment(comment)
714
715     def append_comment(self, comment: Leaf) -> bool:
716         """Add an inline comment to the line."""
717         if comment.type != token.COMMENT:
718             return False
719
720         try:
721             after = id(self.last_non_delimiter())
722         except LookupError:
723             comment.type = STANDALONE_COMMENT
724             comment.prefix = ''
725             return False
726
727         else:
728             if after in self.comments:
729                 self.comments[after].value += str(comment)
730             else:
731                 self.comments[after] = comment
732             return True
733
734     def last_non_delimiter(self) -> Leaf:
735         """Returns the last non-delimiter on the line. Raises LookupError otherwise."""
736         for i in range(len(self.leaves)):
737             last = self.leaves[-i - 1]
738             if not is_delimiter(last):
739                 return last
740
741         raise LookupError("No non-delimiters found")
742
743     def __str__(self) -> str:
744         """Render the line."""
745         if not self:
746             return '\n'
747
748         indent = '    ' * self.depth
749         leaves = iter(self.leaves)
750         first = next(leaves)
751         res = f'{first.prefix}{indent}{first.value}'
752         for leaf in leaves:
753             res += str(leaf)
754         for comment in self.comments.values():
755             res += str(comment)
756         return res + '\n'
757
758     def __bool__(self) -> bool:
759         """Returns True if the line has leaves or comments."""
760         return bool(self.leaves or self.comments)
761
762
763 class UnformattedLines(Line):
764     """Just like :class:`Line` but stores lines which aren't reformatted."""
765
766     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
767         """Just add a new `leaf` to the end of the lines.
768
769         The `preformatted` argument is ignored.
770
771         Keeps track of indentation `depth`, which is useful when the user
772         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
773         """
774         try:
775             list(generate_comments(leaf))
776         except FormatOn as f_on:
777             self.leaves.append(f_on.leaf_from_consumed(leaf))
778             raise
779
780         self.leaves.append(leaf)
781         if leaf.type == token.INDENT:
782             self.depth += 1
783         elif leaf.type == token.DEDENT:
784             self.depth -= 1
785
786     def append_comment(self, comment: Leaf) -> bool:
787         """Not implemented in this class."""
788         raise NotImplementedError("Unformatted lines don't store comments separately.")
789
790     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
791         """Does nothing and returns False."""
792         return False
793
794     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
795         """Does nothing and returns False."""
796         return False
797
798     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
799         """Does nothing and returns False."""
800         return False
801
802     def __str__(self) -> str:
803         """Renders unformatted lines from leaves which were added with `append()`.
804
805         `depth` is not used for indentation in this case.
806         """
807         if not self:
808             return '\n'
809
810         res = ''
811         for leaf in self.leaves:
812             res += str(leaf)
813         return res
814
815
816 @dataclass
817 class EmptyLineTracker:
818     """Provides a stateful method that returns the number of potential extra
819     empty lines needed before and after the currently processed line.
820
821     Note: this tracker works on lines that haven't been split yet.  It assumes
822     the prefix of the first leaf consists of optional newlines.  Those newlines
823     are consumed by `maybe_empty_lines()` and included in the computation.
824     """
825     previous_line: Optional[Line] = None
826     previous_after: int = 0
827     previous_defs: List[int] = Factory(list)
828
829     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
830         """Returns the number of extra empty lines before and after the `current_line`.
831
832         This is for separating `def`, `async def` and `class` with extra empty lines
833         (two on module-level), as well as providing an extra empty line after flow
834         control keywords to make them more prominent.
835         """
836         if isinstance(current_line, UnformattedLines):
837             return 0, 0
838
839         before, after = self._maybe_empty_lines(current_line)
840         before -= self.previous_after
841         self.previous_after = after
842         self.previous_line = current_line
843         return before, after
844
845     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
846         max_allowed = 1
847         if current_line.depth == 0:
848             max_allowed = 2
849         if current_line.leaves:
850             # Consume the first leaf's extra newlines.
851             first_leaf = current_line.leaves[0]
852             before = first_leaf.prefix.count('\n')
853             before = min(before, max_allowed)
854             first_leaf.prefix = ''
855         else:
856             before = 0
857         depth = current_line.depth
858         while self.previous_defs and self.previous_defs[-1] >= depth:
859             self.previous_defs.pop()
860             before = 1 if depth else 2
861         is_decorator = current_line.is_decorator
862         if is_decorator or current_line.is_def or current_line.is_class:
863             if not is_decorator:
864                 self.previous_defs.append(depth)
865             if self.previous_line is None:
866                 # Don't insert empty lines before the first line in the file.
867                 return 0, 0
868
869             if self.previous_line and self.previous_line.is_decorator:
870                 # Don't insert empty lines between decorators.
871                 return 0, 0
872
873             newlines = 2
874             if current_line.depth:
875                 newlines -= 1
876             return newlines, 0
877
878         if current_line.is_flow_control:
879             return before, 1
880
881         if (
882             self.previous_line
883             and self.previous_line.is_import
884             and not current_line.is_import
885             and depth == self.previous_line.depth
886         ):
887             return (before or 1), 0
888
889         if (
890             self.previous_line
891             and self.previous_line.is_yield
892             and (not current_line.is_yield or depth != self.previous_line.depth)
893         ):
894             return (before or 1), 0
895
896         return before, 0
897
898
899 @dataclass
900 class LineGenerator(Visitor[Line]):
901     """Generates reformatted Line objects.  Empty lines are not emitted.
902
903     Note: destroys the tree it's visiting by mutating prefixes of its leaves
904     in ways that will no longer stringify to valid Python code on the tree.
905     """
906     current_line: Line = Factory(Line)
907
908     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
909         """Generate a line.
910
911         If the line is empty, only emit if it makes sense.
912         If the line is too long, split it first and then generate.
913
914         If any lines were generated, set up a new current_line.
915         """
916         if not self.current_line:
917             if self.current_line.__class__ == type:
918                 self.current_line.depth += indent
919             else:
920                 self.current_line = type(depth=self.current_line.depth + indent)
921             return  # Line is empty, don't emit. Creating a new one unnecessary.
922
923         complete_line = self.current_line
924         self.current_line = type(depth=complete_line.depth + indent)
925         yield complete_line
926
927     def visit(self, node: LN) -> Iterator[Line]:
928         """Main method to start the visit process. Yields :class:`Line` objects."""
929         if isinstance(self.current_line, UnformattedLines):
930             # File contained `# fmt: off`
931             yield from self.visit_unformatted(node)
932
933         else:
934             yield from super().visit(node)
935
936     def visit_default(self, node: LN) -> Iterator[Line]:
937         """Default `visit_*()` implementation. Recurses to children of `node`."""
938         if isinstance(node, Leaf):
939             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
940             try:
941                 for comment in generate_comments(node):
942                     if any_open_brackets:
943                         # any comment within brackets is subject to splitting
944                         self.current_line.append(comment)
945                     elif comment.type == token.COMMENT:
946                         # regular trailing comment
947                         self.current_line.append(comment)
948                         yield from self.line()
949
950                     else:
951                         # regular standalone comment
952                         yield from self.line()
953
954                         self.current_line.append(comment)
955                         yield from self.line()
956
957             except FormatOff as f_off:
958                 f_off.trim_prefix(node)
959                 yield from self.line(type=UnformattedLines)
960                 yield from self.visit(node)
961
962             except FormatOn as f_on:
963                 # This only happens here if somebody says "fmt: on" multiple
964                 # times in a row.
965                 f_on.trim_prefix(node)
966                 yield from self.visit_default(node)
967
968             else:
969                 normalize_prefix(node, inside_brackets=any_open_brackets)
970                 if node.type not in WHITESPACE:
971                     self.current_line.append(node)
972         yield from super().visit_default(node)
973
974     def visit_INDENT(self, node: Node) -> Iterator[Line]:
975         """Increases indentation level, maybe yields a line."""
976         # In blib2to3 INDENT never holds comments.
977         yield from self.line(+1)
978         yield from self.visit_default(node)
979
980     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
981         """Decreases indentation level, maybe yields a line."""
982         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
983         yield from self.line(-1)
984
985     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
986         """Visits a statement.
987
988         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
989         `def`, `with`, and `class`.
990
991         The relevant Python language `keywords` for a given statement will be NAME
992         leaves within it. This methods puts those on a separate line.
993         """
994         for child in node.children:
995             if child.type == token.NAME and child.value in keywords:  # type: ignore
996                 yield from self.line()
997
998             yield from self.visit(child)
999
1000     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1001         """Visits a statement without nested statements."""
1002         is_suite_like = node.parent and node.parent.type in STATEMENT
1003         if is_suite_like:
1004             yield from self.line(+1)
1005             yield from self.visit_default(node)
1006             yield from self.line(-1)
1007
1008         else:
1009             yield from self.line()
1010             yield from self.visit_default(node)
1011
1012     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1013         """Visits `async def`, `async for`, `async with`."""
1014         yield from self.line()
1015
1016         children = iter(node.children)
1017         for child in children:
1018             yield from self.visit(child)
1019
1020             if child.type == token.ASYNC:
1021                 break
1022
1023         internal_stmt = next(children)
1024         for child in internal_stmt.children:
1025             yield from self.visit(child)
1026
1027     def visit_decorators(self, node: Node) -> Iterator[Line]:
1028         """Visits decorators."""
1029         for child in node.children:
1030             yield from self.line()
1031             yield from self.visit(child)
1032
1033     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1034         """Semicolons are always removed.
1035
1036         Statements between them are put on separate lines.
1037         """
1038         yield from self.line()
1039
1040     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1041         """End of file.
1042
1043         Process outstanding comments and end with a newline.
1044         """
1045         yield from self.visit_default(leaf)
1046         yield from self.line()
1047
1048     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1049         """Used when file contained a `# fmt: off`."""
1050         if isinstance(node, Node):
1051             for child in node.children:
1052                 yield from self.visit(child)
1053
1054         else:
1055             try:
1056                 self.current_line.append(node)
1057             except FormatOn as f_on:
1058                 f_on.trim_prefix(node)
1059                 yield from self.line()
1060                 yield from self.visit(node)
1061
1062     def __attrs_post_init__(self) -> None:
1063         """You are in a twisty little maze of passages."""
1064         v = self.visit_stmt
1065         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
1066         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
1067         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
1068         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
1069         self.visit_except_clause = partial(v, keywords={'except'})
1070         self.visit_funcdef = partial(v, keywords={'def'})
1071         self.visit_with_stmt = partial(v, keywords={'with'})
1072         self.visit_classdef = partial(v, keywords={'class'})
1073         self.visit_async_funcdef = self.visit_async_stmt
1074         self.visit_decorated = self.visit_decorators
1075
1076
1077 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1078 OPENING_BRACKETS = set(BRACKET.keys())
1079 CLOSING_BRACKETS = set(BRACKET.values())
1080 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1081 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1082
1083
1084 def whitespace(leaf: Leaf) -> str:  # noqa C901
1085     """Return whitespace prefix if needed for the given `leaf`."""
1086     NO = ''
1087     SPACE = ' '
1088     DOUBLESPACE = '  '
1089     t = leaf.type
1090     p = leaf.parent
1091     v = leaf.value
1092     if t in ALWAYS_NO_SPACE:
1093         return NO
1094
1095     if t == token.COMMENT:
1096         return DOUBLESPACE
1097
1098     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1099     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1100         return NO
1101
1102     prev = leaf.prev_sibling
1103     if not prev:
1104         prevp = preceding_leaf(p)
1105         if not prevp or prevp.type in OPENING_BRACKETS:
1106             return NO
1107
1108         if t == token.COLON:
1109             return SPACE if prevp.type == token.COMMA else NO
1110
1111         if prevp.type == token.EQUAL:
1112             if prevp.parent:
1113                 if prevp.parent.type in {
1114                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1115                 }:
1116                     return NO
1117
1118                 elif prevp.parent.type == syms.typedargslist:
1119                     # A bit hacky: if the equal sign has whitespace, it means we
1120                     # previously found it's a typed argument.  So, we're using
1121                     # that, too.
1122                     return prevp.prefix
1123
1124         elif prevp.type == token.DOUBLESTAR:
1125             if prevp.parent and prevp.parent.type in {
1126                 syms.arglist,
1127                 syms.argument,
1128                 syms.dictsetmaker,
1129                 syms.parameters,
1130                 syms.typedargslist,
1131                 syms.varargslist,
1132             }:
1133                 return NO
1134
1135         elif prevp.type == token.COLON:
1136             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1137                 return NO
1138
1139         elif (
1140             prevp.parent
1141             and prevp.parent.type in {syms.factor, syms.star_expr}
1142             and prevp.type in MATH_OPERATORS
1143         ):
1144             return NO
1145
1146         elif (
1147             prevp.type == token.RIGHTSHIFT
1148             and prevp.parent
1149             and prevp.parent.type == syms.shift_expr
1150             and prevp.prev_sibling
1151             and prevp.prev_sibling.type == token.NAME
1152             and prevp.prev_sibling.value == 'print'  # type: ignore
1153         ):
1154             # Python 2 print chevron
1155             return NO
1156
1157     elif prev.type in OPENING_BRACKETS:
1158         return NO
1159
1160     if p.type in {syms.parameters, syms.arglist}:
1161         # untyped function signatures or calls
1162         if t == token.RPAR:
1163             return NO
1164
1165         if not prev or prev.type != token.COMMA:
1166             return NO
1167
1168     elif p.type == syms.varargslist:
1169         # lambdas
1170         if t == token.RPAR:
1171             return NO
1172
1173         if prev and prev.type != token.COMMA:
1174             return NO
1175
1176     elif p.type == syms.typedargslist:
1177         # typed function signatures
1178         if not prev:
1179             return NO
1180
1181         if t == token.EQUAL:
1182             if prev.type != syms.tname:
1183                 return NO
1184
1185         elif prev.type == token.EQUAL:
1186             # A bit hacky: if the equal sign has whitespace, it means we
1187             # previously found it's a typed argument.  So, we're using that, too.
1188             return prev.prefix
1189
1190         elif prev.type != token.COMMA:
1191             return NO
1192
1193     elif p.type == syms.tname:
1194         # type names
1195         if not prev:
1196             prevp = preceding_leaf(p)
1197             if not prevp or prevp.type != token.COMMA:
1198                 return NO
1199
1200     elif p.type == syms.trailer:
1201         # attributes and calls
1202         if t == token.LPAR or t == token.RPAR:
1203             return NO
1204
1205         if not prev:
1206             if t == token.DOT:
1207                 prevp = preceding_leaf(p)
1208                 if not prevp or prevp.type != token.NUMBER:
1209                     return NO
1210
1211             elif t == token.LSQB:
1212                 return NO
1213
1214         elif prev.type != token.COMMA:
1215             return NO
1216
1217     elif p.type == syms.argument:
1218         # single argument
1219         if t == token.EQUAL:
1220             return NO
1221
1222         if not prev:
1223             prevp = preceding_leaf(p)
1224             if not prevp or prevp.type == token.LPAR:
1225                 return NO
1226
1227         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1228             return NO
1229
1230     elif p.type == syms.decorator:
1231         # decorators
1232         return NO
1233
1234     elif p.type == syms.dotted_name:
1235         if prev:
1236             return NO
1237
1238         prevp = preceding_leaf(p)
1239         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1240             return NO
1241
1242     elif p.type == syms.classdef:
1243         if t == token.LPAR:
1244             return NO
1245
1246         if prev and prev.type == token.LPAR:
1247             return NO
1248
1249     elif p.type == syms.subscript:
1250         # indexing
1251         if not prev:
1252             assert p.parent is not None, "subscripts are always parented"
1253             if p.parent.type == syms.subscriptlist:
1254                 return SPACE
1255
1256             return NO
1257
1258         else:
1259             return NO
1260
1261     elif p.type == syms.atom:
1262         if prev and t == token.DOT:
1263             # dots, but not the first one.
1264             return NO
1265
1266     elif (
1267         p.type == syms.listmaker
1268         or p.type == syms.testlist_gexp
1269         or p.type == syms.subscriptlist
1270     ):
1271         # list interior, including unpacking
1272         if not prev:
1273             return NO
1274
1275     elif p.type == syms.dictsetmaker:
1276         # dict and set interior, including unpacking
1277         if not prev:
1278             return NO
1279
1280         if prev.type == token.DOUBLESTAR:
1281             return NO
1282
1283     elif p.type in {syms.factor, syms.star_expr}:
1284         # unary ops
1285         if not prev:
1286             prevp = preceding_leaf(p)
1287             if not prevp or prevp.type in OPENING_BRACKETS:
1288                 return NO
1289
1290             prevp_parent = prevp.parent
1291             assert prevp_parent is not None
1292             if prevp.type == token.COLON and prevp_parent.type in {
1293                 syms.subscript, syms.sliceop
1294             }:
1295                 return NO
1296
1297             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1298                 return NO
1299
1300         elif t == token.NAME or t == token.NUMBER:
1301             return NO
1302
1303     elif p.type == syms.import_from:
1304         if t == token.DOT:
1305             if prev and prev.type == token.DOT:
1306                 return NO
1307
1308         elif t == token.NAME:
1309             if v == 'import':
1310                 return SPACE
1311
1312             if prev and prev.type == token.DOT:
1313                 return NO
1314
1315     elif p.type == syms.sliceop:
1316         return NO
1317
1318     return SPACE
1319
1320
1321 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1322     """Returns the first leaf that precedes `node`, if any."""
1323     while node:
1324         res = node.prev_sibling
1325         if res:
1326             if isinstance(res, Leaf):
1327                 return res
1328
1329             try:
1330                 return list(res.leaves())[-1]
1331
1332             except IndexError:
1333                 return None
1334
1335         node = node.parent
1336     return None
1337
1338
1339 def is_delimiter(leaf: Leaf) -> int:
1340     """Returns the priority of the `leaf` delimiter. Returns 0 if not delimiter.
1341
1342     Higher numbers are higher priority.
1343     """
1344     if leaf.type == token.COMMA:
1345         return COMMA_PRIORITY
1346
1347     if leaf.type in COMPARATORS:
1348         return COMPARATOR_PRIORITY
1349
1350     if (
1351         leaf.type in MATH_OPERATORS
1352         and leaf.parent
1353         and leaf.parent.type not in {syms.factor, syms.star_expr}
1354     ):
1355         return MATH_PRIORITY
1356
1357     return 0
1358
1359
1360 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1361     """Cleans the prefix of the `leaf` and generates comments from it, if any.
1362
1363     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1364     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1365     move because it does away with modifying the grammar to include all the
1366     possible places in which comments can be placed.
1367
1368     The sad consequence for us though is that comments don't "belong" anywhere.
1369     This is why this function generates simple parentless Leaf objects for
1370     comments.  We simply don't know what the correct parent should be.
1371
1372     No matter though, we can live without this.  We really only need to
1373     differentiate between inline and standalone comments.  The latter don't
1374     share the line with any code.
1375
1376     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1377     are emitted with a fake STANDALONE_COMMENT token identifier.
1378     """
1379     p = leaf.prefix
1380     if not p:
1381         return
1382
1383     if '#' not in p:
1384         return
1385
1386     consumed = 0
1387     nlines = 0
1388     for index, line in enumerate(p.split('\n')):
1389         consumed += len(line) + 1  # adding the length of the split '\n'
1390         line = line.lstrip()
1391         if not line:
1392             nlines += 1
1393         if not line.startswith('#'):
1394             continue
1395
1396         if index == 0 and leaf.type != token.ENDMARKER:
1397             comment_type = token.COMMENT  # simple trailing comment
1398         else:
1399             comment_type = STANDALONE_COMMENT
1400         comment = make_comment(line)
1401         yield Leaf(comment_type, comment, prefix='\n' * nlines)
1402
1403         if comment in {'# fmt: on', '# yapf: enable'}:
1404             raise FormatOn(consumed)
1405
1406         if comment in {'# fmt: off', '# yapf: disable'}:
1407             raise FormatOff(consumed)
1408
1409         nlines = 0
1410
1411
1412 def make_comment(content: str) -> str:
1413     """Returns a consistently formatted comment from the given `content` string.
1414
1415     All comments (except for "##", "#!", "#:") should have a single space between
1416     the hash sign and the content.
1417
1418     If `content` didn't start with a hash sign, one is provided.
1419     """
1420     content = content.rstrip()
1421     if not content:
1422         return '#'
1423
1424     if content[0] == '#':
1425         content = content[1:]
1426     if content and content[0] not in ' !:#':
1427         content = ' ' + content
1428     return '#' + content
1429
1430
1431 def split_line(
1432     line: Line, line_length: int, inner: bool = False, py36: bool = False
1433 ) -> Iterator[Line]:
1434     """Splits a `line` into potentially many lines.
1435
1436     They should fit in the allotted `line_length` but might not be able to.
1437     `inner` signifies that there were a pair of brackets somewhere around the
1438     current `line`, possibly transitively. This means we can fallback to splitting
1439     by delimiters if the LHS/RHS don't yield any results.
1440
1441     If `py36` is True, splitting may generate syntax that is only compatible
1442     with Python 3.6 and later.
1443     """
1444     if isinstance(line, UnformattedLines):
1445         yield line
1446         return
1447
1448     line_str = str(line).strip('\n')
1449     if len(line_str) <= line_length and '\n' not in line_str:
1450         yield line
1451         return
1452
1453     if line.is_def:
1454         split_funcs = [left_hand_split]
1455     elif line.inside_brackets:
1456         split_funcs = [delimiter_split]
1457         if '\n' not in line_str:
1458             # Only attempt RHS if we don't have multiline strings or comments
1459             # on this line.
1460             split_funcs.append(right_hand_split)
1461     else:
1462         split_funcs = [right_hand_split]
1463     for split_func in split_funcs:
1464         # We are accumulating lines in `result` because we might want to abort
1465         # mission and return the original line in the end, or attempt a different
1466         # split altogether.
1467         result: List[Line] = []
1468         try:
1469             for l in split_func(line, py36=py36):
1470                 if str(l).strip('\n') == line_str:
1471                     raise CannotSplit("Split function returned an unchanged result")
1472
1473                 result.extend(
1474                     split_line(l, line_length=line_length, inner=True, py36=py36)
1475                 )
1476         except CannotSplit as cs:
1477             continue
1478
1479         else:
1480             yield from result
1481             break
1482
1483     else:
1484         yield line
1485
1486
1487 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1488     """Splits line into many lines, starting with the first matching bracket pair.
1489
1490     Note: this usually looks weird, only use this for function definitions.
1491     Prefer RHS otherwise.
1492     """
1493     head = Line(depth=line.depth)
1494     body = Line(depth=line.depth + 1, inside_brackets=True)
1495     tail = Line(depth=line.depth)
1496     tail_leaves: List[Leaf] = []
1497     body_leaves: List[Leaf] = []
1498     head_leaves: List[Leaf] = []
1499     current_leaves = head_leaves
1500     matching_bracket = None
1501     for leaf in line.leaves:
1502         if (
1503             current_leaves is body_leaves
1504             and leaf.type in CLOSING_BRACKETS
1505             and leaf.opening_bracket is matching_bracket
1506         ):
1507             current_leaves = tail_leaves if body_leaves else head_leaves
1508         current_leaves.append(leaf)
1509         if current_leaves is head_leaves:
1510             if leaf.type in OPENING_BRACKETS:
1511                 matching_bracket = leaf
1512                 current_leaves = body_leaves
1513     # Since body is a new indent level, remove spurious leading whitespace.
1514     if body_leaves:
1515         normalize_prefix(body_leaves[0], inside_brackets=True)
1516     # Build the new lines.
1517     for result, leaves in (
1518         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1519     ):
1520         for leaf in leaves:
1521             result.append(leaf, preformatted=True)
1522             comment_after = line.comments.get(id(leaf))
1523             if comment_after:
1524                 result.append(comment_after, preformatted=True)
1525     bracket_split_succeeded_or_raise(head, body, tail)
1526     for result in (head, body, tail):
1527         if result:
1528             yield result
1529
1530
1531 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1532     """Splits line into many lines, starting with the last matching bracket pair."""
1533     head = Line(depth=line.depth)
1534     body = Line(depth=line.depth + 1, inside_brackets=True)
1535     tail = Line(depth=line.depth)
1536     tail_leaves: List[Leaf] = []
1537     body_leaves: List[Leaf] = []
1538     head_leaves: List[Leaf] = []
1539     current_leaves = tail_leaves
1540     opening_bracket = None
1541     for leaf in reversed(line.leaves):
1542         if current_leaves is body_leaves:
1543             if leaf is opening_bracket:
1544                 current_leaves = head_leaves if body_leaves else tail_leaves
1545         current_leaves.append(leaf)
1546         if current_leaves is tail_leaves:
1547             if leaf.type in CLOSING_BRACKETS:
1548                 opening_bracket = leaf.opening_bracket
1549                 current_leaves = body_leaves
1550     tail_leaves.reverse()
1551     body_leaves.reverse()
1552     head_leaves.reverse()
1553     # Since body is a new indent level, remove spurious leading whitespace.
1554     if body_leaves:
1555         normalize_prefix(body_leaves[0], inside_brackets=True)
1556     # Build the new lines.
1557     for result, leaves in (
1558         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1559     ):
1560         for leaf in leaves:
1561             result.append(leaf, preformatted=True)
1562             comment_after = line.comments.get(id(leaf))
1563             if comment_after:
1564                 result.append(comment_after, preformatted=True)
1565     bracket_split_succeeded_or_raise(head, body, tail)
1566     for result in (head, body, tail):
1567         if result:
1568             yield result
1569
1570
1571 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1572     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1573
1574     Do nothing otherwise.
1575
1576     A left- or right-hand split is based on a pair of brackets. Content before
1577     (and including) the opening bracket is left on one line, content inside the
1578     brackets is put on a separate line, and finally content starting with and
1579     following the closing bracket is put on a separate line.
1580
1581     Those are called `head`, `body`, and `tail`, respectively. If the split
1582     produced the same line (all content in `head`) or ended up with an empty `body`
1583     and the `tail` is just the closing bracket, then it's considered failed.
1584     """
1585     tail_len = len(str(tail).strip())
1586     if not body:
1587         if tail_len == 0:
1588             raise CannotSplit("Splitting brackets produced the same line")
1589
1590         elif tail_len < 3:
1591             raise CannotSplit(
1592                 f"Splitting brackets on an empty body to save "
1593                 f"{tail_len} characters is not worth it"
1594             )
1595
1596
1597 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1598     """Splits according to delimiters of the highest priority.
1599
1600     This kind of split doesn't increase indentation.
1601     If `py36` is True, the split will add trailing commas also in function
1602     signatures that contain `*` and `**`.
1603     """
1604     try:
1605         last_leaf = line.leaves[-1]
1606     except IndexError:
1607         raise CannotSplit("Line empty")
1608
1609     delimiters = line.bracket_tracker.delimiters
1610     try:
1611         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1612             exclude={id(last_leaf)}
1613         )
1614     except ValueError:
1615         raise CannotSplit("No delimiters found")
1616
1617     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1618     lowest_depth = sys.maxsize
1619     trailing_comma_safe = True
1620     for leaf in line.leaves:
1621         current_line.append(leaf, preformatted=True)
1622         comment_after = line.comments.get(id(leaf))
1623         if comment_after:
1624             current_line.append(comment_after, preformatted=True)
1625         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1626         if (
1627             leaf.bracket_depth == lowest_depth
1628             and leaf.type == token.STAR
1629             or leaf.type == token.DOUBLESTAR
1630         ):
1631             trailing_comma_safe = trailing_comma_safe and py36
1632         leaf_priority = delimiters.get(id(leaf))
1633         if leaf_priority == delimiter_priority:
1634             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1635             yield current_line
1636
1637             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1638     if current_line:
1639         if (
1640             delimiter_priority == COMMA_PRIORITY
1641             and current_line.leaves[-1].type != token.COMMA
1642             and trailing_comma_safe
1643         ):
1644             current_line.append(Leaf(token.COMMA, ','))
1645         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1646         yield current_line
1647
1648
1649 def is_import(leaf: Leaf) -> bool:
1650     """Returns True if the given leaf starts an import statement."""
1651     p = leaf.parent
1652     t = leaf.type
1653     v = leaf.value
1654     return bool(
1655         t == token.NAME
1656         and (
1657             (v == 'import' and p and p.type == syms.import_name)
1658             or (v == 'from' and p and p.type == syms.import_from)
1659         )
1660     )
1661
1662
1663 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1664     """Leaves existing extra newlines if not `inside_brackets`.
1665
1666     Removes everything else.  Note: don't use backslashes for formatting or
1667     you'll lose your voting rights.
1668     """
1669     if not inside_brackets:
1670         spl = leaf.prefix.split('#')
1671         if '\\' not in spl[0]:
1672             nl_count = spl[-1].count('\n')
1673             if len(spl) > 1:
1674                 nl_count -= 1
1675             leaf.prefix = '\n' * nl_count
1676             return
1677
1678     leaf.prefix = ''
1679
1680
1681 def is_python36(node: Node) -> bool:
1682     """Returns True if the current file is using Python 3.6+ features.
1683
1684     Currently looking for:
1685     - f-strings; and
1686     - trailing commas after * or ** in function signatures.
1687     """
1688     for n in node.pre_order():
1689         if n.type == token.STRING:
1690             value_head = n.value[:2]  # type: ignore
1691             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1692                 return True
1693
1694         elif (
1695             n.type == syms.typedargslist
1696             and n.children
1697             and n.children[-1].type == token.COMMA
1698         ):
1699             for ch in n.children:
1700                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1701                     return True
1702
1703     return False
1704
1705
1706 PYTHON_EXTENSIONS = {'.py'}
1707 BLACKLISTED_DIRECTORIES = {
1708     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1709 }
1710
1711
1712 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1713     """Generates all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1714     and have one of the PYTHON_EXTENSIONS.
1715     """
1716     for child in path.iterdir():
1717         if child.is_dir():
1718             if child.name in BLACKLISTED_DIRECTORIES:
1719                 continue
1720
1721             yield from gen_python_files_in_dir(child)
1722
1723         elif child.suffix in PYTHON_EXTENSIONS:
1724             yield child
1725
1726
1727 @dataclass
1728 class Report:
1729     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1730     check: bool = False
1731     change_count: int = 0
1732     same_count: int = 0
1733     failure_count: int = 0
1734
1735     def done(self, src: Path, changed: bool) -> None:
1736         """Increment the counter for successful reformatting. Write out a message."""
1737         if changed:
1738             reformatted = 'would reformat' if self.check else 'reformatted'
1739             out(f'{reformatted} {src}')
1740             self.change_count += 1
1741         else:
1742             out(f'{src} already well formatted, good job.', bold=False)
1743             self.same_count += 1
1744
1745     def failed(self, src: Path, message: str) -> None:
1746         """Increment the counter for failed reformatting. Write out a message."""
1747         err(f'error: cannot format {src}: {message}')
1748         self.failure_count += 1
1749
1750     @property
1751     def return_code(self) -> int:
1752         """Which return code should the app use considering the current state."""
1753         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1754         # 126 we have special returncodes reserved by the shell.
1755         if self.failure_count:
1756             return 123
1757
1758         elif self.change_count and self.check:
1759             return 1
1760
1761         return 0
1762
1763     def __str__(self) -> str:
1764         """A color report of the current state.
1765
1766         Use `click.unstyle` to remove colors.
1767         """
1768         if self.check:
1769             reformatted = "would be reformatted"
1770             unchanged = "would be left unchanged"
1771             failed = "would fail to reformat"
1772         else:
1773             reformatted = "reformatted"
1774             unchanged = "left unchanged"
1775             failed = "failed to reformat"
1776         report = []
1777         if self.change_count:
1778             s = 's' if self.change_count > 1 else ''
1779             report.append(
1780                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1781             )
1782         if self.same_count:
1783             s = 's' if self.same_count > 1 else ''
1784             report.append(f'{self.same_count} file{s} {unchanged}')
1785         if self.failure_count:
1786             s = 's' if self.failure_count > 1 else ''
1787             report.append(
1788                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1789             )
1790         return ', '.join(report) + '.'
1791
1792
1793 def assert_equivalent(src: str, dst: str) -> None:
1794     """Raises AssertionError if `src` and `dst` aren't equivalent.
1795
1796     This is a temporary sanity check until Black becomes stable.
1797     """
1798
1799     import ast
1800     import traceback
1801
1802     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1803         """Simple visitor generating strings to compare ASTs by content."""
1804         yield f"{'  ' * depth}{node.__class__.__name__}("
1805
1806         for field in sorted(node._fields):
1807             try:
1808                 value = getattr(node, field)
1809             except AttributeError:
1810                 continue
1811
1812             yield f"{'  ' * (depth+1)}{field}="
1813
1814             if isinstance(value, list):
1815                 for item in value:
1816                     if isinstance(item, ast.AST):
1817                         yield from _v(item, depth + 2)
1818
1819             elif isinstance(value, ast.AST):
1820                 yield from _v(value, depth + 2)
1821
1822             else:
1823                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1824
1825         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1826
1827     try:
1828         src_ast = ast.parse(src)
1829     except Exception as exc:
1830         major, minor = sys.version_info[:2]
1831         raise AssertionError(
1832             f"cannot use --safe with this file; failed to parse source file "
1833             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1834             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1835         )
1836
1837     try:
1838         dst_ast = ast.parse(dst)
1839     except Exception as exc:
1840         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1841         raise AssertionError(
1842             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1843             f"Please report a bug on https://github.com/ambv/black/issues.  "
1844             f"This invalid output might be helpful: {log}"
1845         ) from None
1846
1847     src_ast_str = '\n'.join(_v(src_ast))
1848     dst_ast_str = '\n'.join(_v(dst_ast))
1849     if src_ast_str != dst_ast_str:
1850         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1851         raise AssertionError(
1852             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1853             f"the source.  "
1854             f"Please report a bug on https://github.com/ambv/black/issues.  "
1855             f"This diff might be helpful: {log}"
1856         ) from None
1857
1858
1859 def assert_stable(src: str, dst: str, line_length: int) -> None:
1860     """Raises AssertionError if `dst` reformats differently the second time.
1861
1862     This is a temporary sanity check until Black becomes stable.
1863     """
1864     newdst = format_str(dst, line_length=line_length)
1865     if dst != newdst:
1866         log = dump_to_file(
1867             diff(src, dst, 'source', 'first pass'),
1868             diff(dst, newdst, 'first pass', 'second pass'),
1869         )
1870         raise AssertionError(
1871             f"INTERNAL ERROR: Black produced different code on the second pass "
1872             f"of the formatter.  "
1873             f"Please report a bug on https://github.com/ambv/black/issues.  "
1874             f"This diff might be helpful: {log}"
1875         ) from None
1876
1877
1878 def dump_to_file(*output: str) -> str:
1879     """Dumps `output` to a temporary file. Returns path to the file."""
1880     import tempfile
1881
1882     with tempfile.NamedTemporaryFile(
1883         mode='w', prefix='blk_', suffix='.log', delete=False
1884     ) as f:
1885         for lines in output:
1886             f.write(lines)
1887             f.write('\n')
1888     return f.name
1889
1890
1891 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1892     """Returns a udiff string between strings `a` and `b`."""
1893     import difflib
1894
1895     a_lines = [line + '\n' for line in a.split('\n')]
1896     b_lines = [line + '\n' for line in b.split('\n')]
1897     return ''.join(
1898         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1899     )
1900
1901
1902 if __name__ == '__main__':
1903     main()