]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove the test-specific .flake8 file
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial
7 import keyword
8 import os
9 from pathlib import Path
10 import tokenize
11 import sys
12 from typing import (
13     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
14 )
15
16 from attr import dataclass, Factory
17 import click
18
19 # lib2to3 fork
20 from blib2to3.pytree import Node, Leaf, type_repr
21 from blib2to3 import pygram, pytree
22 from blib2to3.pgen2 import driver, token
23 from blib2to3.pgen2.parse import ParseError
24
25 __version__ = "18.3a4"
26 DEFAULT_LINE_LENGTH = 88
27 # types
28 syms = pygram.python_symbols
29 FileContent = str
30 Encoding = str
31 Depth = int
32 NodeType = int
33 LeafID = int
34 Priority = int
35 LN = Union[Leaf, Node]
36 out = partial(click.secho, bold=True, err=True)
37 err = partial(click.secho, fg='red', err=True)
38
39
40 class NothingChanged(UserWarning):
41     """Raised by :func:`format_file` when reformatted code is the same as source."""
42
43
44 class CannotSplit(Exception):
45     """A readable split that fits the allotted line length is impossible.
46
47     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
48     :func:`delimiter_split`.
49     """
50
51
52 class FormatError(Exception):
53     """Base exception for `# fmt: on` and `# fmt: off` handling.
54
55     It holds the number of bytes of the prefix consumed before the format
56     control comment appeared.
57     """
58
59     def __init__(self, consumed: int) -> None:
60         super().__init__(consumed)
61         self.consumed = consumed
62
63     def trim_prefix(self, leaf: Leaf) -> None:
64         leaf.prefix = leaf.prefix[self.consumed:]
65
66     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
67         """Returns a new Leaf from the consumed part of the prefix."""
68         unformatted_prefix = leaf.prefix[:self.consumed]
69         return Leaf(token.NEWLINE, unformatted_prefix)
70
71
72 class FormatOn(FormatError):
73     """Found a comment like `# fmt: on` in the file."""
74
75
76 class FormatOff(FormatError):
77     """Found a comment like `# fmt: off` in the file."""
78
79
80 @click.command()
81 @click.option(
82     '-l',
83     '--line-length',
84     type=int,
85     default=DEFAULT_LINE_LENGTH,
86     help='How many character per line to allow.',
87     show_default=True,
88 )
89 @click.option(
90     '--check',
91     is_flag=True,
92     help=(
93         "Don't write back the files, just return the status.  Return code 0 "
94         "means nothing would change.  Return code 1 means some files would be "
95         "reformatted.  Return code 123 means there was an internal error."
96     ),
97 )
98 @click.option(
99     '--fast/--safe',
100     is_flag=True,
101     help='If --fast given, skip temporary sanity checks. [default: --safe]',
102 )
103 @click.version_option(version=__version__)
104 @click.argument(
105     'src',
106     nargs=-1,
107     type=click.Path(
108         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
109     ),
110 )
111 @click.pass_context
112 def main(
113     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
114 ) -> None:
115     """The uncompromising code formatter."""
116     sources: List[Path] = []
117     for s in src:
118         p = Path(s)
119         if p.is_dir():
120             sources.extend(gen_python_files_in_dir(p))
121         elif p.is_file():
122             # if a file was explicitly given, we don't care about its extension
123             sources.append(p)
124         elif s == '-':
125             sources.append(Path('-'))
126         else:
127             err(f'invalid path: {s}')
128     if len(sources) == 0:
129         ctx.exit(0)
130     elif len(sources) == 1:
131         p = sources[0]
132         report = Report(check=check)
133         try:
134             if not p.is_file() and str(p) == '-':
135                 changed = format_stdin_to_stdout(
136                     line_length=line_length, fast=fast, write_back=not check
137                 )
138             else:
139                 changed = format_file_in_place(
140                     p, line_length=line_length, fast=fast, write_back=not check
141                 )
142             report.done(p, changed)
143         except Exception as exc:
144             report.failed(p, str(exc))
145         ctx.exit(report.return_code)
146     else:
147         loop = asyncio.get_event_loop()
148         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
149         return_code = 1
150         try:
151             return_code = loop.run_until_complete(
152                 schedule_formatting(
153                     sources, line_length, not check, fast, loop, executor
154                 )
155             )
156         finally:
157             loop.close()
158             ctx.exit(return_code)
159
160
161 async def schedule_formatting(
162     sources: List[Path],
163     line_length: int,
164     write_back: bool,
165     fast: bool,
166     loop: BaseEventLoop,
167     executor: Executor,
168 ) -> int:
169     """Run formatting of `sources` in parallel using the provided `executor`.
170
171     (Use ProcessPoolExecutors for actual parallelism.)
172
173     `line_length`, `write_back`, and `fast` options are passed to
174     :func:`format_file_in_place`.
175     """
176     tasks = {
177         src: loop.run_in_executor(
178             executor, format_file_in_place, src, line_length, fast, write_back
179         )
180         for src in sources
181     }
182     await asyncio.wait(tasks.values())
183     cancelled = []
184     report = Report(check=not write_back)
185     for src, task in tasks.items():
186         if not task.done():
187             report.failed(src, 'timed out, cancelling')
188             task.cancel()
189             cancelled.append(task)
190         elif task.exception():
191             report.failed(src, str(task.exception()))
192         else:
193             report.done(src, task.result())
194     if cancelled:
195         await asyncio.wait(cancelled, timeout=2)
196     out('All done! ✨ 🍰 ✨')
197     click.echo(str(report))
198     return report.return_code
199
200
201 def format_file_in_place(
202     src: Path, line_length: int, fast: bool, write_back: bool = False
203 ) -> bool:
204     """Format file under `src` path. Return True if changed.
205
206     If `write_back` is True, write reformatted code back to stdout.
207     `line_length` and `fast` options are passed to :func:`format_file_contents`.
208     """
209     with tokenize.open(src) as src_buffer:
210         src_contents = src_buffer.read()
211     try:
212         contents = format_file_contents(
213             src_contents, line_length=line_length, fast=fast
214         )
215     except NothingChanged:
216         return False
217
218     if write_back:
219         with open(src, "w", encoding=src_buffer.encoding) as f:
220             f.write(contents)
221     return True
222
223
224 def format_stdin_to_stdout(
225     line_length: int, fast: bool, write_back: bool = False
226 ) -> bool:
227     """Format file on stdin. Return True if changed.
228
229     If `write_back` is True, write reformatted code back to stdout.
230     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
231     """
232     contents = sys.stdin.read()
233     try:
234         contents = format_file_contents(contents, line_length=line_length, fast=fast)
235         return True
236
237     except NothingChanged:
238         return False
239
240     finally:
241         if write_back:
242             sys.stdout.write(contents)
243
244
245 def format_file_contents(
246     src_contents: str, line_length: int, fast: bool
247 ) -> FileContent:
248     """Reformat contents a file and return new contents.
249
250     If `fast` is False, additionally confirm that the reformatted code is
251     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
252     `line_length` is passed to :func:`format_str`.
253     """
254     if src_contents.strip() == '':
255         raise NothingChanged
256
257     dst_contents = format_str(src_contents, line_length=line_length)
258     if src_contents == dst_contents:
259         raise NothingChanged
260
261     if not fast:
262         assert_equivalent(src_contents, dst_contents)
263         assert_stable(src_contents, dst_contents, line_length=line_length)
264     return dst_contents
265
266
267 def format_str(src_contents: str, line_length: int) -> FileContent:
268     """Reformat a string and return new contents.
269
270     `line_length` determines how many characters per line are allowed.
271     """
272     src_node = lib2to3_parse(src_contents)
273     dst_contents = ""
274     lines = LineGenerator()
275     elt = EmptyLineTracker()
276     py36 = is_python36(src_node)
277     empty_line = Line()
278     after = 0
279     for current_line in lines.visit(src_node):
280         for _ in range(after):
281             dst_contents += str(empty_line)
282         before, after = elt.maybe_empty_lines(current_line)
283         for _ in range(before):
284             dst_contents += str(empty_line)
285         for line in split_line(current_line, line_length=line_length, py36=py36):
286             dst_contents += str(line)
287     return dst_contents
288
289
290 GRAMMARS = [
291     pygram.python_grammar_no_print_statement_no_exec_statement,
292     pygram.python_grammar_no_print_statement,
293     pygram.python_grammar_no_exec_statement,
294     pygram.python_grammar,
295 ]
296
297
298 def lib2to3_parse(src_txt: str) -> Node:
299     """Given a string with source, return the lib2to3 Node."""
300     grammar = pygram.python_grammar_no_print_statement
301     if src_txt[-1] != '\n':
302         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
303         src_txt += nl
304     for grammar in GRAMMARS:
305         drv = driver.Driver(grammar, pytree.convert)
306         try:
307             result = drv.parse_string(src_txt, True)
308             break
309
310         except ParseError as pe:
311             lineno, column = pe.context[1]
312             lines = src_txt.splitlines()
313             try:
314                 faulty_line = lines[lineno - 1]
315             except IndexError:
316                 faulty_line = "<line number missing in source>"
317             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
318     else:
319         raise exc from None
320
321     if isinstance(result, Leaf):
322         result = Node(syms.file_input, [result])
323     return result
324
325
326 def lib2to3_unparse(node: Node) -> str:
327     """Given a lib2to3 node, return its string representation."""
328     code = str(node)
329     return code
330
331
332 T = TypeVar('T')
333
334
335 class Visitor(Generic[T]):
336     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
337
338     def visit(self, node: LN) -> Iterator[T]:
339         """Main method to visit `node` and its children.
340
341         It tries to find a `visit_*()` method for the given `node.type`, like
342         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
343         If no dedicated `visit_*()` method is found, chooses `visit_default()`
344         instead.
345
346         Then yields objects of type `T` from the selected visitor.
347         """
348         if node.type < 256:
349             name = token.tok_name[node.type]
350         else:
351             name = type_repr(node.type)
352         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
353
354     def visit_default(self, node: LN) -> Iterator[T]:
355         """Default `visit_*()` implementation. Recurses to children of `node`."""
356         if isinstance(node, Node):
357             for child in node.children:
358                 yield from self.visit(child)
359
360
361 @dataclass
362 class DebugVisitor(Visitor[T]):
363     tree_depth: int = 0
364
365     def visit_default(self, node: LN) -> Iterator[T]:
366         indent = ' ' * (2 * self.tree_depth)
367         if isinstance(node, Node):
368             _type = type_repr(node.type)
369             out(f'{indent}{_type}', fg='yellow')
370             self.tree_depth += 1
371             for child in node.children:
372                 yield from self.visit(child)
373
374             self.tree_depth -= 1
375             out(f'{indent}/{_type}', fg='yellow', bold=False)
376         else:
377             _type = token.tok_name.get(node.type, str(node.type))
378             out(f'{indent}{_type}', fg='blue', nl=False)
379             if node.prefix:
380                 # We don't have to handle prefixes for `Node` objects since
381                 # that delegates to the first child anyway.
382                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
383             out(f' {node.value!r}', fg='blue', bold=False)
384
385     @classmethod
386     def show(cls, code: str) -> None:
387         """Pretty-print the lib2to3 AST of a given string of `code`.
388
389         Convenience method for debugging.
390         """
391         v: DebugVisitor[None] = DebugVisitor()
392         list(v.visit(lib2to3_parse(code)))
393
394
395 KEYWORDS = set(keyword.kwlist)
396 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
397 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
398 STATEMENT = {
399     syms.if_stmt,
400     syms.while_stmt,
401     syms.for_stmt,
402     syms.try_stmt,
403     syms.except_clause,
404     syms.with_stmt,
405     syms.funcdef,
406     syms.classdef,
407 }
408 STANDALONE_COMMENT = 153
409 LOGIC_OPERATORS = {'and', 'or'}
410 COMPARATORS = {
411     token.LESS,
412     token.GREATER,
413     token.EQEQUAL,
414     token.NOTEQUAL,
415     token.LESSEQUAL,
416     token.GREATEREQUAL,
417 }
418 MATH_OPERATORS = {
419     token.PLUS,
420     token.MINUS,
421     token.STAR,
422     token.SLASH,
423     token.VBAR,
424     token.AMPER,
425     token.PERCENT,
426     token.CIRCUMFLEX,
427     token.TILDE,
428     token.LEFTSHIFT,
429     token.RIGHTSHIFT,
430     token.DOUBLESTAR,
431     token.DOUBLESLASH,
432 }
433 COMPREHENSION_PRIORITY = 20
434 COMMA_PRIORITY = 10
435 LOGIC_PRIORITY = 5
436 STRING_PRIORITY = 4
437 COMPARATOR_PRIORITY = 3
438 MATH_PRIORITY = 1
439
440
441 @dataclass
442 class BracketTracker:
443     """Keeps track of brackets on a line."""
444
445     depth: int = 0
446     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
447     delimiters: Dict[LeafID, Priority] = Factory(dict)
448     previous: Optional[Leaf] = None
449
450     def mark(self, leaf: Leaf) -> None:
451         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
452
453         All leaves receive an int `bracket_depth` field that stores how deep
454         within brackets a given leaf is. 0 means there are no enclosing brackets
455         that started on this line.
456
457         If a leaf is itself a closing bracket, it receives an `opening_bracket`
458         field that it forms a pair with. This is a one-directional link to
459         avoid reference cycles.
460
461         If a leaf is a delimiter (a token on which Black can split the line if
462         needed) and it's on depth 0, its `id()` is stored in the tracker's
463         `delimiters` field.
464         """
465         if leaf.type == token.COMMENT:
466             return
467
468         if leaf.type in CLOSING_BRACKETS:
469             self.depth -= 1
470             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
471             leaf.opening_bracket = opening_bracket
472         leaf.bracket_depth = self.depth
473         if self.depth == 0:
474             delim = is_delimiter(leaf)
475             if delim:
476                 self.delimiters[id(leaf)] = delim
477             elif self.previous is not None:
478                 if leaf.type == token.STRING and self.previous.type == token.STRING:
479                     self.delimiters[id(self.previous)] = STRING_PRIORITY
480                 elif (
481                     leaf.type == token.NAME
482                     and leaf.value == 'for'
483                     and leaf.parent
484                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
485                 ):
486                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
487                 elif (
488                     leaf.type == token.NAME
489                     and leaf.value == 'if'
490                     and leaf.parent
491                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
492                 ):
493                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
494                 elif (
495                     leaf.type == token.NAME
496                     and leaf.value in LOGIC_OPERATORS
497                     and leaf.parent
498                 ):
499                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
500         if leaf.type in OPENING_BRACKETS:
501             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
502             self.depth += 1
503         self.previous = leaf
504
505     def any_open_brackets(self) -> bool:
506         """Return True if there is an yet unmatched open bracket on the line."""
507         return bool(self.bracket_match)
508
509     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
510         """Return the highest priority of a delimiter found on the line.
511
512         Values are consistent with what `is_delimiter()` returns.
513         """
514         return max(v for k, v in self.delimiters.items() if k not in exclude)
515
516
517 @dataclass
518 class Line:
519     """Holds leaves and comments. Can be printed with `str(line)`."""
520
521     depth: int = 0
522     leaves: List[Leaf] = Factory(list)
523     comments: Dict[LeafID, Leaf] = Factory(dict)
524     bracket_tracker: BracketTracker = Factory(BracketTracker)
525     inside_brackets: bool = False
526     has_for: bool = False
527     _for_loop_variable: bool = False
528
529     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
530         """Add a new `leaf` to the end of the line.
531
532         Unless `preformatted` is True, the `leaf` will receive a new consistent
533         whitespace prefix and metadata applied by :class:`BracketTracker`.
534         Trailing commas are maybe removed, unpacked for loop variables are
535         demoted from being delimiters.
536
537         Inline comments are put aside.
538         """
539         has_value = leaf.value.strip()
540         if not has_value:
541             return
542
543         if self.leaves and not preformatted:
544             # Note: at this point leaf.prefix should be empty except for
545             # imports, for which we only preserve newlines.
546             leaf.prefix += whitespace(leaf)
547         if self.inside_brackets or not preformatted:
548             self.maybe_decrement_after_for_loop_variable(leaf)
549             self.bracket_tracker.mark(leaf)
550             self.maybe_remove_trailing_comma(leaf)
551             self.maybe_increment_for_loop_variable(leaf)
552             if self.maybe_adapt_standalone_comment(leaf):
553                 return
554
555         if not self.append_comment(leaf):
556             self.leaves.append(leaf)
557
558     @property
559     def is_comment(self) -> bool:
560         """Is this line a standalone comment?"""
561         return bool(self) and self.leaves[0].type == STANDALONE_COMMENT
562
563     @property
564     def is_decorator(self) -> bool:
565         """Is this line a decorator?"""
566         return bool(self) and self.leaves[0].type == token.AT
567
568     @property
569     def is_import(self) -> bool:
570         """Is this an import line?"""
571         return bool(self) and is_import(self.leaves[0])
572
573     @property
574     def is_class(self) -> bool:
575         """Is this line a class definition?"""
576         return (
577             bool(self)
578             and self.leaves[0].type == token.NAME
579             and self.leaves[0].value == 'class'
580         )
581
582     @property
583     def is_def(self) -> bool:
584         """Is this a function definition? (Also returns True for async defs.)"""
585         try:
586             first_leaf = self.leaves[0]
587         except IndexError:
588             return False
589
590         try:
591             second_leaf: Optional[Leaf] = self.leaves[1]
592         except IndexError:
593             second_leaf = None
594         return (
595             (first_leaf.type == token.NAME and first_leaf.value == 'def')
596             or (
597                 first_leaf.type == token.ASYNC
598                 and second_leaf is not None
599                 and second_leaf.type == token.NAME
600                 and second_leaf.value == 'def'
601             )
602         )
603
604     @property
605     def is_flow_control(self) -> bool:
606         """Is this line a flow control statement?
607
608         Those are `return`, `raise`, `break`, and `continue`.
609         """
610         return (
611             bool(self)
612             and self.leaves[0].type == token.NAME
613             and self.leaves[0].value in FLOW_CONTROL
614         )
615
616     @property
617     def is_yield(self) -> bool:
618         """Is this line a yield statement?"""
619         return (
620             bool(self)
621             and self.leaves[0].type == token.NAME
622             and self.leaves[0].value == 'yield'
623         )
624
625     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
626         """Remove trailing comma if there is one and it's safe."""
627         if not (
628             self.leaves
629             and self.leaves[-1].type == token.COMMA
630             and closing.type in CLOSING_BRACKETS
631         ):
632             return False
633
634         if closing.type == token.RBRACE:
635             self.leaves.pop()
636             return True
637
638         if closing.type == token.RSQB:
639             comma = self.leaves[-1]
640             if comma.parent and comma.parent.type == syms.listmaker:
641                 self.leaves.pop()
642                 return True
643
644         # For parens let's check if it's safe to remove the comma.  If the
645         # trailing one is the only one, we might mistakenly change a tuple
646         # into a different type by removing the comma.
647         depth = closing.bracket_depth + 1
648         commas = 0
649         opening = closing.opening_bracket
650         for _opening_index, leaf in enumerate(self.leaves):
651             if leaf is opening:
652                 break
653
654         else:
655             return False
656
657         for leaf in self.leaves[_opening_index + 1:]:
658             if leaf is closing:
659                 break
660
661             bracket_depth = leaf.bracket_depth
662             if bracket_depth == depth and leaf.type == token.COMMA:
663                 commas += 1
664                 if leaf.parent and leaf.parent.type == syms.arglist:
665                     commas += 1
666                     break
667
668         if commas > 1:
669             self.leaves.pop()
670             return True
671
672         return False
673
674     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
675         """In a for loop, or comprehension, the variables are often unpacks.
676
677         To avoid splitting on the comma in this situation, increase the depth of
678         tokens between `for` and `in`.
679         """
680         if leaf.type == token.NAME and leaf.value == 'for':
681             self.has_for = True
682             self.bracket_tracker.depth += 1
683             self._for_loop_variable = True
684             return True
685
686         return False
687
688     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
689         """See `maybe_increment_for_loop_variable` above for explanation."""
690         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
691             self.bracket_tracker.depth -= 1
692             self._for_loop_variable = False
693             return True
694
695         return False
696
697     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
698         """Hack a standalone comment to act as a trailing comment for line splitting.
699
700         If this line has brackets and a standalone `comment`, we need to adapt
701         it to be able to still reformat the line.
702
703         This is not perfect, the line to which the standalone comment gets
704         appended will appear "too long" when splitting.
705         """
706         if not (
707             comment.type == STANDALONE_COMMENT
708             and self.bracket_tracker.any_open_brackets()
709         ):
710             return False
711
712         comment.type = token.COMMENT
713         comment.prefix = '\n' + '    ' * (self.depth + 1)
714         return self.append_comment(comment)
715
716     def append_comment(self, comment: Leaf) -> bool:
717         """Add an inline comment to the line."""
718         if comment.type != token.COMMENT:
719             return False
720
721         try:
722             after = id(self.last_non_delimiter())
723         except LookupError:
724             comment.type = STANDALONE_COMMENT
725             comment.prefix = ''
726             return False
727
728         else:
729             if after in self.comments:
730                 self.comments[after].value += str(comment)
731             else:
732                 self.comments[after] = comment
733             return True
734
735     def last_non_delimiter(self) -> Leaf:
736         """Return the last non-delimiter on the line. Raise LookupError otherwise."""
737         for i in range(len(self.leaves)):
738             last = self.leaves[-i - 1]
739             if not is_delimiter(last):
740                 return last
741
742         raise LookupError("No non-delimiters found")
743
744     def __str__(self) -> str:
745         """Render the line."""
746         if not self:
747             return '\n'
748
749         indent = '    ' * self.depth
750         leaves = iter(self.leaves)
751         first = next(leaves)
752         res = f'{first.prefix}{indent}{first.value}'
753         for leaf in leaves:
754             res += str(leaf)
755         for comment in self.comments.values():
756             res += str(comment)
757         return res + '\n'
758
759     def __bool__(self) -> bool:
760         """Return True if the line has leaves or comments."""
761         return bool(self.leaves or self.comments)
762
763
764 class UnformattedLines(Line):
765     """Just like :class:`Line` but stores lines which aren't reformatted."""
766
767     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
768         """Just add a new `leaf` to the end of the lines.
769
770         The `preformatted` argument is ignored.
771
772         Keeps track of indentation `depth`, which is useful when the user
773         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
774         """
775         try:
776             list(generate_comments(leaf))
777         except FormatOn as f_on:
778             self.leaves.append(f_on.leaf_from_consumed(leaf))
779             raise
780
781         self.leaves.append(leaf)
782         if leaf.type == token.INDENT:
783             self.depth += 1
784         elif leaf.type == token.DEDENT:
785             self.depth -= 1
786
787     def __str__(self) -> str:
788         """Render unformatted lines from leaves which were added with `append()`.
789
790         `depth` is not used for indentation in this case.
791         """
792         if not self:
793             return '\n'
794
795         res = ''
796         for leaf in self.leaves:
797             res += str(leaf)
798         return res
799
800     def append_comment(self, comment: Leaf) -> bool:
801         """Not implemented in this class. Raises `NotImplementedError`."""
802         raise NotImplementedError("Unformatted lines don't store comments separately.")
803
804     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
805         """Does nothing and returns False."""
806         return False
807
808     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
809         """Does nothing and returns False."""
810         return False
811
812     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
813         """Does nothing and returns False."""
814         return False
815
816
817 @dataclass
818 class EmptyLineTracker:
819     """Provides a stateful method that returns the number of potential extra
820     empty lines needed before and after the currently processed line.
821
822     Note: this tracker works on lines that haven't been split yet.  It assumes
823     the prefix of the first leaf consists of optional newlines.  Those newlines
824     are consumed by `maybe_empty_lines()` and included in the computation.
825     """
826     previous_line: Optional[Line] = None
827     previous_after: int = 0
828     previous_defs: List[int] = Factory(list)
829
830     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
831         """Return the number of extra empty lines before and after the `current_line`.
832
833         This is for separating `def`, `async def` and `class` with extra empty
834         lines (two on module-level), as well as providing an extra empty line
835         after flow control keywords to make them more prominent.
836         """
837         if isinstance(current_line, UnformattedLines):
838             return 0, 0
839
840         before, after = self._maybe_empty_lines(current_line)
841         before -= self.previous_after
842         self.previous_after = after
843         self.previous_line = current_line
844         return before, after
845
846     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
847         max_allowed = 1
848         if current_line.depth == 0:
849             max_allowed = 2
850         if current_line.leaves:
851             # Consume the first leaf's extra newlines.
852             first_leaf = current_line.leaves[0]
853             before = first_leaf.prefix.count('\n')
854             before = min(before, max_allowed)
855             first_leaf.prefix = ''
856         else:
857             before = 0
858         depth = current_line.depth
859         while self.previous_defs and self.previous_defs[-1] >= depth:
860             self.previous_defs.pop()
861             before = 1 if depth else 2
862         is_decorator = current_line.is_decorator
863         if is_decorator or current_line.is_def or current_line.is_class:
864             if not is_decorator:
865                 self.previous_defs.append(depth)
866             if self.previous_line is None:
867                 # Don't insert empty lines before the first line in the file.
868                 return 0, 0
869
870             if self.previous_line and self.previous_line.is_decorator:
871                 # Don't insert empty lines between decorators.
872                 return 0, 0
873
874             newlines = 2
875             if current_line.depth:
876                 newlines -= 1
877             return newlines, 0
878
879         if current_line.is_flow_control:
880             return before, 1
881
882         if (
883             self.previous_line
884             and self.previous_line.is_import
885             and not current_line.is_import
886             and depth == self.previous_line.depth
887         ):
888             return (before or 1), 0
889
890         if (
891             self.previous_line
892             and self.previous_line.is_yield
893             and (not current_line.is_yield or depth != self.previous_line.depth)
894         ):
895             return (before or 1), 0
896
897         return before, 0
898
899
900 @dataclass
901 class LineGenerator(Visitor[Line]):
902     """Generates reformatted Line objects.  Empty lines are not emitted.
903
904     Note: destroys the tree it's visiting by mutating prefixes of its leaves
905     in ways that will no longer stringify to valid Python code on the tree.
906     """
907     current_line: Line = Factory(Line)
908
909     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
910         """Generate a line.
911
912         If the line is empty, only emit if it makes sense.
913         If the line is too long, split it first and then generate.
914
915         If any lines were generated, set up a new current_line.
916         """
917         if not self.current_line:
918             if self.current_line.__class__ == type:
919                 self.current_line.depth += indent
920             else:
921                 self.current_line = type(depth=self.current_line.depth + indent)
922             return  # Line is empty, don't emit. Creating a new one unnecessary.
923
924         complete_line = self.current_line
925         self.current_line = type(depth=complete_line.depth + indent)
926         yield complete_line
927
928     def visit(self, node: LN) -> Iterator[Line]:
929         """Main method to visit `node` and its children.
930
931         Yields :class:`Line` objects.
932         """
933         if isinstance(self.current_line, UnformattedLines):
934             # File contained `# fmt: off`
935             yield from self.visit_unformatted(node)
936
937         else:
938             yield from super().visit(node)
939
940     def visit_default(self, node: LN) -> Iterator[Line]:
941         """Default `visit_*()` implementation. Recurses to children of `node`."""
942         if isinstance(node, Leaf):
943             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
944             try:
945                 for comment in generate_comments(node):
946                     if any_open_brackets:
947                         # any comment within brackets is subject to splitting
948                         self.current_line.append(comment)
949                     elif comment.type == token.COMMENT:
950                         # regular trailing comment
951                         self.current_line.append(comment)
952                         yield from self.line()
953
954                     else:
955                         # regular standalone comment
956                         yield from self.line()
957
958                         self.current_line.append(comment)
959                         yield from self.line()
960
961             except FormatOff as f_off:
962                 f_off.trim_prefix(node)
963                 yield from self.line(type=UnformattedLines)
964                 yield from self.visit(node)
965
966             except FormatOn as f_on:
967                 # This only happens here if somebody says "fmt: on" multiple
968                 # times in a row.
969                 f_on.trim_prefix(node)
970                 yield from self.visit_default(node)
971
972             else:
973                 normalize_prefix(node, inside_brackets=any_open_brackets)
974                 if node.type not in WHITESPACE:
975                     self.current_line.append(node)
976         yield from super().visit_default(node)
977
978     def visit_INDENT(self, node: Node) -> Iterator[Line]:
979         """Increase indentation level, maybe yield a line."""
980         # In blib2to3 INDENT never holds comments.
981         yield from self.line(+1)
982         yield from self.visit_default(node)
983
984     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
985         """Decrease indentation level, maybe yield a line."""
986         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
987         yield from self.line(-1)
988
989     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
990         """Visit a statement.
991
992         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
993         `def`, `with`, and `class`.
994
995         The relevant Python language `keywords` for a given statement will be NAME
996         leaves within it. This methods puts those on a separate line.
997         """
998         for child in node.children:
999             if child.type == token.NAME and child.value in keywords:  # type: ignore
1000                 yield from self.line()
1001
1002             yield from self.visit(child)
1003
1004     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1005         """Visit a statement without nested statements."""
1006         is_suite_like = node.parent and node.parent.type in STATEMENT
1007         if is_suite_like:
1008             yield from self.line(+1)
1009             yield from self.visit_default(node)
1010             yield from self.line(-1)
1011
1012         else:
1013             yield from self.line()
1014             yield from self.visit_default(node)
1015
1016     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1017         """Visit `async def`, `async for`, `async with`."""
1018         yield from self.line()
1019
1020         children = iter(node.children)
1021         for child in children:
1022             yield from self.visit(child)
1023
1024             if child.type == token.ASYNC:
1025                 break
1026
1027         internal_stmt = next(children)
1028         for child in internal_stmt.children:
1029             yield from self.visit(child)
1030
1031     def visit_decorators(self, node: Node) -> Iterator[Line]:
1032         """Visit decorators."""
1033         for child in node.children:
1034             yield from self.line()
1035             yield from self.visit(child)
1036
1037     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1038         """Remove a semicolon and put the other statement on a separate line."""
1039         yield from self.line()
1040
1041     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1042         """End of file. Process outstanding comments and end with a newline."""
1043         yield from self.visit_default(leaf)
1044         yield from self.line()
1045
1046     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1047         """Used when file contained a `# fmt: off`."""
1048         if isinstance(node, Node):
1049             for child in node.children:
1050                 yield from self.visit(child)
1051
1052         else:
1053             try:
1054                 self.current_line.append(node)
1055             except FormatOn as f_on:
1056                 f_on.trim_prefix(node)
1057                 yield from self.line()
1058                 yield from self.visit(node)
1059
1060     def __attrs_post_init__(self) -> None:
1061         """You are in a twisty little maze of passages."""
1062         v = self.visit_stmt
1063         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
1064         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
1065         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
1066         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
1067         self.visit_except_clause = partial(v, keywords={'except'})
1068         self.visit_funcdef = partial(v, keywords={'def'})
1069         self.visit_with_stmt = partial(v, keywords={'with'})
1070         self.visit_classdef = partial(v, keywords={'class'})
1071         self.visit_async_funcdef = self.visit_async_stmt
1072         self.visit_decorated = self.visit_decorators
1073
1074
1075 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1076 OPENING_BRACKETS = set(BRACKET.keys())
1077 CLOSING_BRACKETS = set(BRACKET.values())
1078 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1079 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1080
1081
1082 def whitespace(leaf: Leaf) -> str:  # noqa C901
1083     """Return whitespace prefix if needed for the given `leaf`."""
1084     NO = ''
1085     SPACE = ' '
1086     DOUBLESPACE = '  '
1087     t = leaf.type
1088     p = leaf.parent
1089     v = leaf.value
1090     if t in ALWAYS_NO_SPACE:
1091         return NO
1092
1093     if t == token.COMMENT:
1094         return DOUBLESPACE
1095
1096     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1097     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1098         return NO
1099
1100     prev = leaf.prev_sibling
1101     if not prev:
1102         prevp = preceding_leaf(p)
1103         if not prevp or prevp.type in OPENING_BRACKETS:
1104             return NO
1105
1106         if t == token.COLON:
1107             return SPACE if prevp.type == token.COMMA else NO
1108
1109         if prevp.type == token.EQUAL:
1110             if prevp.parent:
1111                 if prevp.parent.type in {
1112                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1113                 }:
1114                     return NO
1115
1116                 elif prevp.parent.type == syms.typedargslist:
1117                     # A bit hacky: if the equal sign has whitespace, it means we
1118                     # previously found it's a typed argument.  So, we're using
1119                     # that, too.
1120                     return prevp.prefix
1121
1122         elif prevp.type == token.DOUBLESTAR:
1123             if prevp.parent and prevp.parent.type in {
1124                 syms.arglist,
1125                 syms.argument,
1126                 syms.dictsetmaker,
1127                 syms.parameters,
1128                 syms.typedargslist,
1129                 syms.varargslist,
1130             }:
1131                 return NO
1132
1133         elif prevp.type == token.COLON:
1134             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1135                 return NO
1136
1137         elif (
1138             prevp.parent
1139             and prevp.parent.type in {syms.factor, syms.star_expr}
1140             and prevp.type in MATH_OPERATORS
1141         ):
1142             return NO
1143
1144         elif (
1145             prevp.type == token.RIGHTSHIFT
1146             and prevp.parent
1147             and prevp.parent.type == syms.shift_expr
1148             and prevp.prev_sibling
1149             and prevp.prev_sibling.type == token.NAME
1150             and prevp.prev_sibling.value == 'print'  # type: ignore
1151         ):
1152             # Python 2 print chevron
1153             return NO
1154
1155     elif prev.type in OPENING_BRACKETS:
1156         return NO
1157
1158     if p.type in {syms.parameters, syms.arglist}:
1159         # untyped function signatures or calls
1160         if t == token.RPAR:
1161             return NO
1162
1163         if not prev or prev.type != token.COMMA:
1164             return NO
1165
1166     elif p.type == syms.varargslist:
1167         # lambdas
1168         if t == token.RPAR:
1169             return NO
1170
1171         if prev and prev.type != token.COMMA:
1172             return NO
1173
1174     elif p.type == syms.typedargslist:
1175         # typed function signatures
1176         if not prev:
1177             return NO
1178
1179         if t == token.EQUAL:
1180             if prev.type != syms.tname:
1181                 return NO
1182
1183         elif prev.type == token.EQUAL:
1184             # A bit hacky: if the equal sign has whitespace, it means we
1185             # previously found it's a typed argument.  So, we're using that, too.
1186             return prev.prefix
1187
1188         elif prev.type != token.COMMA:
1189             return NO
1190
1191     elif p.type == syms.tname:
1192         # type names
1193         if not prev:
1194             prevp = preceding_leaf(p)
1195             if not prevp or prevp.type != token.COMMA:
1196                 return NO
1197
1198     elif p.type == syms.trailer:
1199         # attributes and calls
1200         if t == token.LPAR or t == token.RPAR:
1201             return NO
1202
1203         if not prev:
1204             if t == token.DOT:
1205                 prevp = preceding_leaf(p)
1206                 if not prevp or prevp.type != token.NUMBER:
1207                     return NO
1208
1209             elif t == token.LSQB:
1210                 return NO
1211
1212         elif prev.type != token.COMMA:
1213             return NO
1214
1215     elif p.type == syms.argument:
1216         # single argument
1217         if t == token.EQUAL:
1218             return NO
1219
1220         if not prev:
1221             prevp = preceding_leaf(p)
1222             if not prevp or prevp.type == token.LPAR:
1223                 return NO
1224
1225         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1226             return NO
1227
1228     elif p.type == syms.decorator:
1229         # decorators
1230         return NO
1231
1232     elif p.type == syms.dotted_name:
1233         if prev:
1234             return NO
1235
1236         prevp = preceding_leaf(p)
1237         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1238             return NO
1239
1240     elif p.type == syms.classdef:
1241         if t == token.LPAR:
1242             return NO
1243
1244         if prev and prev.type == token.LPAR:
1245             return NO
1246
1247     elif p.type == syms.subscript:
1248         # indexing
1249         if not prev:
1250             assert p.parent is not None, "subscripts are always parented"
1251             if p.parent.type == syms.subscriptlist:
1252                 return SPACE
1253
1254             return NO
1255
1256         else:
1257             return NO
1258
1259     elif p.type == syms.atom:
1260         if prev and t == token.DOT:
1261             # dots, but not the first one.
1262             return NO
1263
1264     elif (
1265         p.type == syms.listmaker
1266         or p.type == syms.testlist_gexp
1267         or p.type == syms.subscriptlist
1268     ):
1269         # list interior, including unpacking
1270         if not prev:
1271             return NO
1272
1273     elif p.type == syms.dictsetmaker:
1274         # dict and set interior, including unpacking
1275         if not prev:
1276             return NO
1277
1278         if prev.type == token.DOUBLESTAR:
1279             return NO
1280
1281     elif p.type in {syms.factor, syms.star_expr}:
1282         # unary ops
1283         if not prev:
1284             prevp = preceding_leaf(p)
1285             if not prevp or prevp.type in OPENING_BRACKETS:
1286                 return NO
1287
1288             prevp_parent = prevp.parent
1289             assert prevp_parent is not None
1290             if prevp.type == token.COLON and prevp_parent.type in {
1291                 syms.subscript, syms.sliceop
1292             }:
1293                 return NO
1294
1295             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1296                 return NO
1297
1298         elif t == token.NAME or t == token.NUMBER:
1299             return NO
1300
1301     elif p.type == syms.import_from:
1302         if t == token.DOT:
1303             if prev and prev.type == token.DOT:
1304                 return NO
1305
1306         elif t == token.NAME:
1307             if v == 'import':
1308                 return SPACE
1309
1310             if prev and prev.type == token.DOT:
1311                 return NO
1312
1313     elif p.type == syms.sliceop:
1314         return NO
1315
1316     return SPACE
1317
1318
1319 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1320     """Return the first leaf that precedes `node`, if any."""
1321     while node:
1322         res = node.prev_sibling
1323         if res:
1324             if isinstance(res, Leaf):
1325                 return res
1326
1327             try:
1328                 return list(res.leaves())[-1]
1329
1330             except IndexError:
1331                 return None
1332
1333         node = node.parent
1334     return None
1335
1336
1337 def is_delimiter(leaf: Leaf) -> int:
1338     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1339
1340     Higher numbers are higher priority.
1341     """
1342     if leaf.type == token.COMMA:
1343         return COMMA_PRIORITY
1344
1345     if leaf.type in COMPARATORS:
1346         return COMPARATOR_PRIORITY
1347
1348     if (
1349         leaf.type in MATH_OPERATORS
1350         and leaf.parent
1351         and leaf.parent.type not in {syms.factor, syms.star_expr}
1352     ):
1353         return MATH_PRIORITY
1354
1355     return 0
1356
1357
1358 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1359     """Clean the prefix of the `leaf` and generate comments from it, if any.
1360
1361     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1362     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1363     move because it does away with modifying the grammar to include all the
1364     possible places in which comments can be placed.
1365
1366     The sad consequence for us though is that comments don't "belong" anywhere.
1367     This is why this function generates simple parentless Leaf objects for
1368     comments.  We simply don't know what the correct parent should be.
1369
1370     No matter though, we can live without this.  We really only need to
1371     differentiate between inline and standalone comments.  The latter don't
1372     share the line with any code.
1373
1374     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1375     are emitted with a fake STANDALONE_COMMENT token identifier.
1376     """
1377     p = leaf.prefix
1378     if not p:
1379         return
1380
1381     if '#' not in p:
1382         return
1383
1384     consumed = 0
1385     nlines = 0
1386     for index, line in enumerate(p.split('\n')):
1387         consumed += len(line) + 1  # adding the length of the split '\n'
1388         line = line.lstrip()
1389         if not line:
1390             nlines += 1
1391         if not line.startswith('#'):
1392             continue
1393
1394         if index == 0 and leaf.type != token.ENDMARKER:
1395             comment_type = token.COMMENT  # simple trailing comment
1396         else:
1397             comment_type = STANDALONE_COMMENT
1398         comment = make_comment(line)
1399         yield Leaf(comment_type, comment, prefix='\n' * nlines)
1400
1401         if comment in {'# fmt: on', '# yapf: enable'}:
1402             raise FormatOn(consumed)
1403
1404         if comment in {'# fmt: off', '# yapf: disable'}:
1405             raise FormatOff(consumed)
1406
1407         nlines = 0
1408
1409
1410 def make_comment(content: str) -> str:
1411     """Return a consistently formatted comment from the given `content` string.
1412
1413     All comments (except for "##", "#!", "#:") should have a single space between
1414     the hash sign and the content.
1415
1416     If `content` didn't start with a hash sign, one is provided.
1417     """
1418     content = content.rstrip()
1419     if not content:
1420         return '#'
1421
1422     if content[0] == '#':
1423         content = content[1:]
1424     if content and content[0] not in ' !:#':
1425         content = ' ' + content
1426     return '#' + content
1427
1428
1429 def split_line(
1430     line: Line, line_length: int, inner: bool = False, py36: bool = False
1431 ) -> Iterator[Line]:
1432     """Split a `line` into potentially many lines.
1433
1434     They should fit in the allotted `line_length` but might not be able to.
1435     `inner` signifies that there were a pair of brackets somewhere around the
1436     current `line`, possibly transitively. This means we can fallback to splitting
1437     by delimiters if the LHS/RHS don't yield any results.
1438
1439     If `py36` is True, splitting may generate syntax that is only compatible
1440     with Python 3.6 and later.
1441     """
1442     if isinstance(line, UnformattedLines):
1443         yield line
1444         return
1445
1446     line_str = str(line).strip('\n')
1447     if len(line_str) <= line_length and '\n' not in line_str:
1448         yield line
1449         return
1450
1451     if line.is_def:
1452         split_funcs = [left_hand_split]
1453     elif line.inside_brackets:
1454         split_funcs = [delimiter_split]
1455         if '\n' not in line_str:
1456             # Only attempt RHS if we don't have multiline strings or comments
1457             # on this line.
1458             split_funcs.append(right_hand_split)
1459     else:
1460         split_funcs = [right_hand_split]
1461     for split_func in split_funcs:
1462         # We are accumulating lines in `result` because we might want to abort
1463         # mission and return the original line in the end, or attempt a different
1464         # split altogether.
1465         result: List[Line] = []
1466         try:
1467             for l in split_func(line, py36=py36):
1468                 if str(l).strip('\n') == line_str:
1469                     raise CannotSplit("Split function returned an unchanged result")
1470
1471                 result.extend(
1472                     split_line(l, line_length=line_length, inner=True, py36=py36)
1473                 )
1474         except CannotSplit as cs:
1475             continue
1476
1477         else:
1478             yield from result
1479             break
1480
1481     else:
1482         yield line
1483
1484
1485 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1486     """Split line into many lines, starting with the first matching bracket pair.
1487
1488     Note: this usually looks weird, only use this for function definitions.
1489     Prefer RHS otherwise.
1490     """
1491     head = Line(depth=line.depth)
1492     body = Line(depth=line.depth + 1, inside_brackets=True)
1493     tail = Line(depth=line.depth)
1494     tail_leaves: List[Leaf] = []
1495     body_leaves: List[Leaf] = []
1496     head_leaves: List[Leaf] = []
1497     current_leaves = head_leaves
1498     matching_bracket = None
1499     for leaf in line.leaves:
1500         if (
1501             current_leaves is body_leaves
1502             and leaf.type in CLOSING_BRACKETS
1503             and leaf.opening_bracket is matching_bracket
1504         ):
1505             current_leaves = tail_leaves if body_leaves else head_leaves
1506         current_leaves.append(leaf)
1507         if current_leaves is head_leaves:
1508             if leaf.type in OPENING_BRACKETS:
1509                 matching_bracket = leaf
1510                 current_leaves = body_leaves
1511     # Since body is a new indent level, remove spurious leading whitespace.
1512     if body_leaves:
1513         normalize_prefix(body_leaves[0], inside_brackets=True)
1514     # Build the new lines.
1515     for result, leaves in (
1516         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1517     ):
1518         for leaf in leaves:
1519             result.append(leaf, preformatted=True)
1520             comment_after = line.comments.get(id(leaf))
1521             if comment_after:
1522                 result.append(comment_after, preformatted=True)
1523     bracket_split_succeeded_or_raise(head, body, tail)
1524     for result in (head, body, tail):
1525         if result:
1526             yield result
1527
1528
1529 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1530     """Split line into many lines, starting with the last matching bracket pair."""
1531     head = Line(depth=line.depth)
1532     body = Line(depth=line.depth + 1, inside_brackets=True)
1533     tail = Line(depth=line.depth)
1534     tail_leaves: List[Leaf] = []
1535     body_leaves: List[Leaf] = []
1536     head_leaves: List[Leaf] = []
1537     current_leaves = tail_leaves
1538     opening_bracket = None
1539     for leaf in reversed(line.leaves):
1540         if current_leaves is body_leaves:
1541             if leaf is opening_bracket:
1542                 current_leaves = head_leaves if body_leaves else tail_leaves
1543         current_leaves.append(leaf)
1544         if current_leaves is tail_leaves:
1545             if leaf.type in CLOSING_BRACKETS:
1546                 opening_bracket = leaf.opening_bracket
1547                 current_leaves = body_leaves
1548     tail_leaves.reverse()
1549     body_leaves.reverse()
1550     head_leaves.reverse()
1551     # Since body is a new indent level, remove spurious leading whitespace.
1552     if body_leaves:
1553         normalize_prefix(body_leaves[0], inside_brackets=True)
1554     # Build the new lines.
1555     for result, leaves in (
1556         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1557     ):
1558         for leaf in leaves:
1559             result.append(leaf, preformatted=True)
1560             comment_after = line.comments.get(id(leaf))
1561             if comment_after:
1562                 result.append(comment_after, preformatted=True)
1563     bracket_split_succeeded_or_raise(head, body, tail)
1564     for result in (head, body, tail):
1565         if result:
1566             yield result
1567
1568
1569 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1570     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1571
1572     Do nothing otherwise.
1573
1574     A left- or right-hand split is based on a pair of brackets. Content before
1575     (and including) the opening bracket is left on one line, content inside the
1576     brackets is put on a separate line, and finally content starting with and
1577     following the closing bracket is put on a separate line.
1578
1579     Those are called `head`, `body`, and `tail`, respectively. If the split
1580     produced the same line (all content in `head`) or ended up with an empty `body`
1581     and the `tail` is just the closing bracket, then it's considered failed.
1582     """
1583     tail_len = len(str(tail).strip())
1584     if not body:
1585         if tail_len == 0:
1586             raise CannotSplit("Splitting brackets produced the same line")
1587
1588         elif tail_len < 3:
1589             raise CannotSplit(
1590                 f"Splitting brackets on an empty body to save "
1591                 f"{tail_len} characters is not worth it"
1592             )
1593
1594
1595 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1596     """Split according to delimiters of the highest priority.
1597
1598     This kind of split doesn't increase indentation.
1599     If `py36` is True, the split will add trailing commas also in function
1600     signatures that contain `*` and `**`.
1601     """
1602     try:
1603         last_leaf = line.leaves[-1]
1604     except IndexError:
1605         raise CannotSplit("Line empty")
1606
1607     delimiters = line.bracket_tracker.delimiters
1608     try:
1609         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1610             exclude={id(last_leaf)}
1611         )
1612     except ValueError:
1613         raise CannotSplit("No delimiters found")
1614
1615     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1616     lowest_depth = sys.maxsize
1617     trailing_comma_safe = True
1618     for leaf in line.leaves:
1619         current_line.append(leaf, preformatted=True)
1620         comment_after = line.comments.get(id(leaf))
1621         if comment_after:
1622             current_line.append(comment_after, preformatted=True)
1623         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1624         if (
1625             leaf.bracket_depth == lowest_depth
1626             and leaf.type == token.STAR
1627             or leaf.type == token.DOUBLESTAR
1628         ):
1629             trailing_comma_safe = trailing_comma_safe and py36
1630         leaf_priority = delimiters.get(id(leaf))
1631         if leaf_priority == delimiter_priority:
1632             normalize_prefix(current_line.leaves[0], inside_brackets=True)
1633             yield current_line
1634
1635             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1636     if current_line:
1637         if (
1638             delimiter_priority == COMMA_PRIORITY
1639             and current_line.leaves[-1].type != token.COMMA
1640             and trailing_comma_safe
1641         ):
1642             current_line.append(Leaf(token.COMMA, ','))
1643         normalize_prefix(current_line.leaves[0], inside_brackets=True)
1644         yield current_line
1645
1646
1647 def is_import(leaf: Leaf) -> bool:
1648     """Return True if the given leaf starts an import statement."""
1649     p = leaf.parent
1650     t = leaf.type
1651     v = leaf.value
1652     return bool(
1653         t == token.NAME
1654         and (
1655             (v == 'import' and p and p.type == syms.import_name)
1656             or (v == 'from' and p and p.type == syms.import_from)
1657         )
1658     )
1659
1660
1661 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1662     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1663     else.
1664
1665     Note: don't use backslashes for formatting or you'll lose your voting rights.
1666     """
1667     if not inside_brackets:
1668         spl = leaf.prefix.split('#')
1669         if '\\' not in spl[0]:
1670             nl_count = spl[-1].count('\n')
1671             if len(spl) > 1:
1672                 nl_count -= 1
1673             leaf.prefix = '\n' * nl_count
1674             return
1675
1676     leaf.prefix = ''
1677
1678
1679 def is_python36(node: Node) -> bool:
1680     """Return True if the current file is using Python 3.6+ features.
1681
1682     Currently looking for:
1683     - f-strings; and
1684     - trailing commas after * or ** in function signatures.
1685     """
1686     for n in node.pre_order():
1687         if n.type == token.STRING:
1688             value_head = n.value[:2]  # type: ignore
1689             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1690                 return True
1691
1692         elif (
1693             n.type == syms.typedargslist
1694             and n.children
1695             and n.children[-1].type == token.COMMA
1696         ):
1697             for ch in n.children:
1698                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1699                     return True
1700
1701     return False
1702
1703
1704 PYTHON_EXTENSIONS = {'.py'}
1705 BLACKLISTED_DIRECTORIES = {
1706     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1707 }
1708
1709
1710 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1711     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1712     and have one of the PYTHON_EXTENSIONS.
1713     """
1714     for child in path.iterdir():
1715         if child.is_dir():
1716             if child.name in BLACKLISTED_DIRECTORIES:
1717                 continue
1718
1719             yield from gen_python_files_in_dir(child)
1720
1721         elif child.suffix in PYTHON_EXTENSIONS:
1722             yield child
1723
1724
1725 @dataclass
1726 class Report:
1727     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1728     check: bool = False
1729     change_count: int = 0
1730     same_count: int = 0
1731     failure_count: int = 0
1732
1733     def done(self, src: Path, changed: bool) -> None:
1734         """Increment the counter for successful reformatting. Write out a message."""
1735         if changed:
1736             reformatted = 'would reformat' if self.check else 'reformatted'
1737             out(f'{reformatted} {src}')
1738             self.change_count += 1
1739         else:
1740             out(f'{src} already well formatted, good job.', bold=False)
1741             self.same_count += 1
1742
1743     def failed(self, src: Path, message: str) -> None:
1744         """Increment the counter for failed reformatting. Write out a message."""
1745         err(f'error: cannot format {src}: {message}')
1746         self.failure_count += 1
1747
1748     @property
1749     def return_code(self) -> int:
1750         """Return the exit code that the app should use.
1751
1752         This considers the current state of changed files and failures:
1753         - if there were any failures, return 123;
1754         - if any files were changed and --check is being used, return 1;
1755         - otherwise return 0.
1756         """
1757         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1758         # 126 we have special returncodes reserved by the shell.
1759         if self.failure_count:
1760             return 123
1761
1762         elif self.change_count and self.check:
1763             return 1
1764
1765         return 0
1766
1767     def __str__(self) -> str:
1768         """Render a color report of the current state.
1769
1770         Use `click.unstyle` to remove colors.
1771         """
1772         if self.check:
1773             reformatted = "would be reformatted"
1774             unchanged = "would be left unchanged"
1775             failed = "would fail to reformat"
1776         else:
1777             reformatted = "reformatted"
1778             unchanged = "left unchanged"
1779             failed = "failed to reformat"
1780         report = []
1781         if self.change_count:
1782             s = 's' if self.change_count > 1 else ''
1783             report.append(
1784                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1785             )
1786         if self.same_count:
1787             s = 's' if self.same_count > 1 else ''
1788             report.append(f'{self.same_count} file{s} {unchanged}')
1789         if self.failure_count:
1790             s = 's' if self.failure_count > 1 else ''
1791             report.append(
1792                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1793             )
1794         return ', '.join(report) + '.'
1795
1796
1797 def assert_equivalent(src: str, dst: str) -> None:
1798     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1799
1800     import ast
1801     import traceback
1802
1803     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1804         """Simple visitor generating strings to compare ASTs by content."""
1805         yield f"{'  ' * depth}{node.__class__.__name__}("
1806
1807         for field in sorted(node._fields):
1808             try:
1809                 value = getattr(node, field)
1810             except AttributeError:
1811                 continue
1812
1813             yield f"{'  ' * (depth+1)}{field}="
1814
1815             if isinstance(value, list):
1816                 for item in value:
1817                     if isinstance(item, ast.AST):
1818                         yield from _v(item, depth + 2)
1819
1820             elif isinstance(value, ast.AST):
1821                 yield from _v(value, depth + 2)
1822
1823             else:
1824                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1825
1826         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1827
1828     try:
1829         src_ast = ast.parse(src)
1830     except Exception as exc:
1831         major, minor = sys.version_info[:2]
1832         raise AssertionError(
1833             f"cannot use --safe with this file; failed to parse source file "
1834             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1835             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1836         )
1837
1838     try:
1839         dst_ast = ast.parse(dst)
1840     except Exception as exc:
1841         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1842         raise AssertionError(
1843             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1844             f"Please report a bug on https://github.com/ambv/black/issues.  "
1845             f"This invalid output might be helpful: {log}"
1846         ) from None
1847
1848     src_ast_str = '\n'.join(_v(src_ast))
1849     dst_ast_str = '\n'.join(_v(dst_ast))
1850     if src_ast_str != dst_ast_str:
1851         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1852         raise AssertionError(
1853             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1854             f"the source.  "
1855             f"Please report a bug on https://github.com/ambv/black/issues.  "
1856             f"This diff might be helpful: {log}"
1857         ) from None
1858
1859
1860 def assert_stable(src: str, dst: str, line_length: int) -> None:
1861     """Raise AssertionError if `dst` reformats differently the second time."""
1862     newdst = format_str(dst, line_length=line_length)
1863     if dst != newdst:
1864         log = dump_to_file(
1865             diff(src, dst, 'source', 'first pass'),
1866             diff(dst, newdst, 'first pass', 'second pass'),
1867         )
1868         raise AssertionError(
1869             f"INTERNAL ERROR: Black produced different code on the second pass "
1870             f"of the formatter.  "
1871             f"Please report a bug on https://github.com/ambv/black/issues.  "
1872             f"This diff might be helpful: {log}"
1873         ) from None
1874
1875
1876 def dump_to_file(*output: str) -> str:
1877     """Dump `output` to a temporary file. Return path to the file."""
1878     import tempfile
1879
1880     with tempfile.NamedTemporaryFile(
1881         mode='w', prefix='blk_', suffix='.log', delete=False
1882     ) as f:
1883         for lines in output:
1884             f.write(lines)
1885             f.write('\n')
1886     return f.name
1887
1888
1889 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1890     """Return a unified diff string between strings `a` and `b`."""
1891     import difflib
1892
1893     a_lines = [line + '\n' for line in a.split('\n')]
1894     b_lines = [line + '\n' for line in b.split('\n')]
1895     return ''.join(
1896         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1897     )
1898
1899
1900 if __name__ == '__main__':
1901     main()