]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Graceful shutdown in case of cancellation
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial, wraps
7 import keyword
8 import logging
9 import os
10 from pathlib import Path
11 import tokenize
12 import signal
13 import sys
14 from typing import (
15     Callable,
16     Dict,
17     Generic,
18     Iterable,
19     Iterator,
20     List,
21     Optional,
22     Set,
23     Tuple,
24     Type,
25     TypeVar,
26     Union,
27 )
28
29 from attr import dataclass, Factory
30 import click
31
32 # lib2to3 fork
33 from blib2to3.pytree import Node, Leaf, type_repr
34 from blib2to3 import pygram, pytree
35 from blib2to3.pgen2 import driver, token
36 from blib2to3.pgen2.parse import ParseError
37
38 __version__ = "18.3a4"
39 DEFAULT_LINE_LENGTH = 88
40 # types
41 syms = pygram.python_symbols
42 FileContent = str
43 Encoding = str
44 Depth = int
45 NodeType = int
46 LeafID = int
47 Priority = int
48 Index = int
49 LN = Union[Leaf, Node]
50 SplitFunc = Callable[['Line', bool], Iterator['Line']]
51 out = partial(click.secho, bold=True, err=True)
52 err = partial(click.secho, fg='red', err=True)
53
54
55 class NothingChanged(UserWarning):
56     """Raised by :func:`format_file` when reformatted code is the same as source."""
57
58
59 class CannotSplit(Exception):
60     """A readable split that fits the allotted line length is impossible.
61
62     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
63     :func:`delimiter_split`.
64     """
65
66
67 class FormatError(Exception):
68     """Base exception for `# fmt: on` and `# fmt: off` handling.
69
70     It holds the number of bytes of the prefix consumed before the format
71     control comment appeared.
72     """
73
74     def __init__(self, consumed: int) -> None:
75         super().__init__(consumed)
76         self.consumed = consumed
77
78     def trim_prefix(self, leaf: Leaf) -> None:
79         leaf.prefix = leaf.prefix[self.consumed:]
80
81     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
82         """Returns a new Leaf from the consumed part of the prefix."""
83         unformatted_prefix = leaf.prefix[:self.consumed]
84         return Leaf(token.NEWLINE, unformatted_prefix)
85
86
87 class FormatOn(FormatError):
88     """Found a comment like `# fmt: on` in the file."""
89
90
91 class FormatOff(FormatError):
92     """Found a comment like `# fmt: off` in the file."""
93
94
95 @click.command()
96 @click.option(
97     '-l',
98     '--line-length',
99     type=int,
100     default=DEFAULT_LINE_LENGTH,
101     help='How many character per line to allow.',
102     show_default=True,
103 )
104 @click.option(
105     '--check',
106     is_flag=True,
107     help=(
108         "Don't write back the files, just return the status.  Return code 0 "
109         "means nothing would change.  Return code 1 means some files would be "
110         "reformatted.  Return code 123 means there was an internal error."
111     ),
112 )
113 @click.option(
114     '--fast/--safe',
115     is_flag=True,
116     help='If --fast given, skip temporary sanity checks. [default: --safe]',
117 )
118 @click.version_option(version=__version__)
119 @click.argument(
120     'src',
121     nargs=-1,
122     type=click.Path(
123         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
124     ),
125 )
126 @click.pass_context
127 def main(
128     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
129 ) -> None:
130     """The uncompromising code formatter."""
131     sources: List[Path] = []
132     for s in src:
133         p = Path(s)
134         if p.is_dir():
135             sources.extend(gen_python_files_in_dir(p))
136         elif p.is_file():
137             # if a file was explicitly given, we don't care about its extension
138             sources.append(p)
139         elif s == '-':
140             sources.append(Path('-'))
141         else:
142             err(f'invalid path: {s}')
143     if len(sources) == 0:
144         ctx.exit(0)
145     elif len(sources) == 1:
146         p = sources[0]
147         report = Report(check=check)
148         try:
149             if not p.is_file() and str(p) == '-':
150                 changed = format_stdin_to_stdout(
151                     line_length=line_length, fast=fast, write_back=not check
152                 )
153             else:
154                 changed = format_file_in_place(
155                     p, line_length=line_length, fast=fast, write_back=not check
156                 )
157             report.done(p, changed)
158         except Exception as exc:
159             report.failed(p, str(exc))
160         ctx.exit(report.return_code)
161     else:
162         loop = asyncio.get_event_loop()
163         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
164         return_code = 1
165         try:
166             return_code = loop.run_until_complete(
167                 schedule_formatting(
168                     sources, line_length, not check, fast, loop, executor
169                 )
170             )
171         finally:
172             shutdown(loop)
173             ctx.exit(return_code)
174
175
176 async def schedule_formatting(
177     sources: List[Path],
178     line_length: int,
179     write_back: bool,
180     fast: bool,
181     loop: BaseEventLoop,
182     executor: Executor,
183 ) -> int:
184     """Run formatting of `sources` in parallel using the provided `executor`.
185
186     (Use ProcessPoolExecutors for actual parallelism.)
187
188     `line_length`, `write_back`, and `fast` options are passed to
189     :func:`format_file_in_place`.
190     """
191     tasks = {
192         src: loop.run_in_executor(
193             executor, format_file_in_place, src, line_length, fast, write_back
194         )
195         for src in sources
196     }
197     _task_values = list(tasks.values())
198     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
199     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
200     await asyncio.wait(tasks.values())
201     cancelled = []
202     report = Report(check=not write_back)
203     for src, task in tasks.items():
204         if not task.done():
205             report.failed(src, 'timed out, cancelling')
206             task.cancel()
207             cancelled.append(task)
208         elif task.cancelled():
209             cancelled.append(task)
210         elif task.exception():
211             report.failed(src, str(task.exception()))
212         else:
213             report.done(src, task.result())
214     if cancelled:
215         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
216     else:
217         out('All done! ✨ 🍰 ✨')
218     click.echo(str(report))
219     return report.return_code
220
221
222 def format_file_in_place(
223     src: Path, line_length: int, fast: bool, write_back: bool = False
224 ) -> bool:
225     """Format file under `src` path. Return True if changed.
226
227     If `write_back` is True, write reformatted code back to stdout.
228     `line_length` and `fast` options are passed to :func:`format_file_contents`.
229     """
230     with tokenize.open(src) as src_buffer:
231         src_contents = src_buffer.read()
232     try:
233         contents = format_file_contents(
234             src_contents, line_length=line_length, fast=fast
235         )
236     except NothingChanged:
237         return False
238
239     if write_back:
240         with open(src, "w", encoding=src_buffer.encoding) as f:
241             f.write(contents)
242     return True
243
244
245 def format_stdin_to_stdout(
246     line_length: int, fast: bool, write_back: bool = False
247 ) -> bool:
248     """Format file on stdin. Return True if changed.
249
250     If `write_back` is True, write reformatted code back to stdout.
251     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
252     """
253     contents = sys.stdin.read()
254     try:
255         contents = format_file_contents(contents, line_length=line_length, fast=fast)
256         return True
257
258     except NothingChanged:
259         return False
260
261     finally:
262         if write_back:
263             sys.stdout.write(contents)
264
265
266 def format_file_contents(
267     src_contents: str, line_length: int, fast: bool
268 ) -> FileContent:
269     """Reformat contents a file and return new contents.
270
271     If `fast` is False, additionally confirm that the reformatted code is
272     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
273     `line_length` is passed to :func:`format_str`.
274     """
275     if src_contents.strip() == '':
276         raise NothingChanged
277
278     dst_contents = format_str(src_contents, line_length=line_length)
279     if src_contents == dst_contents:
280         raise NothingChanged
281
282     if not fast:
283         assert_equivalent(src_contents, dst_contents)
284         assert_stable(src_contents, dst_contents, line_length=line_length)
285     return dst_contents
286
287
288 def format_str(src_contents: str, line_length: int) -> FileContent:
289     """Reformat a string and return new contents.
290
291     `line_length` determines how many characters per line are allowed.
292     """
293     src_node = lib2to3_parse(src_contents)
294     dst_contents = ""
295     lines = LineGenerator()
296     elt = EmptyLineTracker()
297     py36 = is_python36(src_node)
298     empty_line = Line()
299     after = 0
300     for current_line in lines.visit(src_node):
301         for _ in range(after):
302             dst_contents += str(empty_line)
303         before, after = elt.maybe_empty_lines(current_line)
304         for _ in range(before):
305             dst_contents += str(empty_line)
306         for line in split_line(current_line, line_length=line_length, py36=py36):
307             dst_contents += str(line)
308     return dst_contents
309
310
311 GRAMMARS = [
312     pygram.python_grammar_no_print_statement_no_exec_statement,
313     pygram.python_grammar_no_print_statement,
314     pygram.python_grammar_no_exec_statement,
315     pygram.python_grammar,
316 ]
317
318
319 def lib2to3_parse(src_txt: str) -> Node:
320     """Given a string with source, return the lib2to3 Node."""
321     grammar = pygram.python_grammar_no_print_statement
322     if src_txt[-1] != '\n':
323         nl = '\r\n' if '\r\n' in src_txt[:1024] else '\n'
324         src_txt += nl
325     for grammar in GRAMMARS:
326         drv = driver.Driver(grammar, pytree.convert)
327         try:
328             result = drv.parse_string(src_txt, True)
329             break
330
331         except ParseError as pe:
332             lineno, column = pe.context[1]
333             lines = src_txt.splitlines()
334             try:
335                 faulty_line = lines[lineno - 1]
336             except IndexError:
337                 faulty_line = "<line number missing in source>"
338             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
339     else:
340         raise exc from None
341
342     if isinstance(result, Leaf):
343         result = Node(syms.file_input, [result])
344     return result
345
346
347 def lib2to3_unparse(node: Node) -> str:
348     """Given a lib2to3 node, return its string representation."""
349     code = str(node)
350     return code
351
352
353 T = TypeVar('T')
354
355
356 class Visitor(Generic[T]):
357     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
358
359     def visit(self, node: LN) -> Iterator[T]:
360         """Main method to visit `node` and its children.
361
362         It tries to find a `visit_*()` method for the given `node.type`, like
363         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
364         If no dedicated `visit_*()` method is found, chooses `visit_default()`
365         instead.
366
367         Then yields objects of type `T` from the selected visitor.
368         """
369         if node.type < 256:
370             name = token.tok_name[node.type]
371         else:
372             name = type_repr(node.type)
373         yield from getattr(self, f'visit_{name}', self.visit_default)(node)
374
375     def visit_default(self, node: LN) -> Iterator[T]:
376         """Default `visit_*()` implementation. Recurses to children of `node`."""
377         if isinstance(node, Node):
378             for child in node.children:
379                 yield from self.visit(child)
380
381
382 @dataclass
383 class DebugVisitor(Visitor[T]):
384     tree_depth: int = 0
385
386     def visit_default(self, node: LN) -> Iterator[T]:
387         indent = ' ' * (2 * self.tree_depth)
388         if isinstance(node, Node):
389             _type = type_repr(node.type)
390             out(f'{indent}{_type}', fg='yellow')
391             self.tree_depth += 1
392             for child in node.children:
393                 yield from self.visit(child)
394
395             self.tree_depth -= 1
396             out(f'{indent}/{_type}', fg='yellow', bold=False)
397         else:
398             _type = token.tok_name.get(node.type, str(node.type))
399             out(f'{indent}{_type}', fg='blue', nl=False)
400             if node.prefix:
401                 # We don't have to handle prefixes for `Node` objects since
402                 # that delegates to the first child anyway.
403                 out(f' {node.prefix!r}', fg='green', bold=False, nl=False)
404             out(f' {node.value!r}', fg='blue', bold=False)
405
406     @classmethod
407     def show(cls, code: str) -> None:
408         """Pretty-print the lib2to3 AST of a given string of `code`.
409
410         Convenience method for debugging.
411         """
412         v: DebugVisitor[None] = DebugVisitor()
413         list(v.visit(lib2to3_parse(code)))
414
415
416 KEYWORDS = set(keyword.kwlist)
417 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
418 FLOW_CONTROL = {'return', 'raise', 'break', 'continue'}
419 STATEMENT = {
420     syms.if_stmt,
421     syms.while_stmt,
422     syms.for_stmt,
423     syms.try_stmt,
424     syms.except_clause,
425     syms.with_stmt,
426     syms.funcdef,
427     syms.classdef,
428 }
429 STANDALONE_COMMENT = 153
430 LOGIC_OPERATORS = {'and', 'or'}
431 COMPARATORS = {
432     token.LESS,
433     token.GREATER,
434     token.EQEQUAL,
435     token.NOTEQUAL,
436     token.LESSEQUAL,
437     token.GREATEREQUAL,
438 }
439 MATH_OPERATORS = {
440     token.PLUS,
441     token.MINUS,
442     token.STAR,
443     token.SLASH,
444     token.VBAR,
445     token.AMPER,
446     token.PERCENT,
447     token.CIRCUMFLEX,
448     token.TILDE,
449     token.LEFTSHIFT,
450     token.RIGHTSHIFT,
451     token.DOUBLESTAR,
452     token.DOUBLESLASH,
453 }
454 COMPREHENSION_PRIORITY = 20
455 COMMA_PRIORITY = 10
456 LOGIC_PRIORITY = 5
457 STRING_PRIORITY = 4
458 COMPARATOR_PRIORITY = 3
459 MATH_PRIORITY = 1
460
461
462 @dataclass
463 class BracketTracker:
464     """Keeps track of brackets on a line."""
465
466     depth: int = 0
467     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
468     delimiters: Dict[LeafID, Priority] = Factory(dict)
469     previous: Optional[Leaf] = None
470
471     def mark(self, leaf: Leaf) -> None:
472         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
473
474         All leaves receive an int `bracket_depth` field that stores how deep
475         within brackets a given leaf is. 0 means there are no enclosing brackets
476         that started on this line.
477
478         If a leaf is itself a closing bracket, it receives an `opening_bracket`
479         field that it forms a pair with. This is a one-directional link to
480         avoid reference cycles.
481
482         If a leaf is a delimiter (a token on which Black can split the line if
483         needed) and it's on depth 0, its `id()` is stored in the tracker's
484         `delimiters` field.
485         """
486         if leaf.type == token.COMMENT:
487             return
488
489         if leaf.type in CLOSING_BRACKETS:
490             self.depth -= 1
491             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
492             leaf.opening_bracket = opening_bracket
493         leaf.bracket_depth = self.depth
494         if self.depth == 0:
495             delim = is_delimiter(leaf)
496             if delim:
497                 self.delimiters[id(leaf)] = delim
498             elif self.previous is not None:
499                 if leaf.type == token.STRING and self.previous.type == token.STRING:
500                     self.delimiters[id(self.previous)] = STRING_PRIORITY
501                 elif (
502                     leaf.type == token.NAME
503                     and leaf.value == 'for'
504                     and leaf.parent
505                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
506                 ):
507                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
508                 elif (
509                     leaf.type == token.NAME
510                     and leaf.value == 'if'
511                     and leaf.parent
512                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
513                 ):
514                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
515                 elif (
516                     leaf.type == token.NAME
517                     and leaf.value in LOGIC_OPERATORS
518                     and leaf.parent
519                 ):
520                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
521         if leaf.type in OPENING_BRACKETS:
522             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
523             self.depth += 1
524         self.previous = leaf
525
526     def any_open_brackets(self) -> bool:
527         """Return True if there is an yet unmatched open bracket on the line."""
528         return bool(self.bracket_match)
529
530     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
531         """Return the highest priority of a delimiter found on the line.
532
533         Values are consistent with what `is_delimiter()` returns.
534         """
535         return max(v for k, v in self.delimiters.items() if k not in exclude)
536
537
538 @dataclass
539 class Line:
540     """Holds leaves and comments. Can be printed with `str(line)`."""
541
542     depth: int = 0
543     leaves: List[Leaf] = Factory(list)
544     comments: List[Tuple[Index, Leaf]] = Factory(list)
545     bracket_tracker: BracketTracker = Factory(BracketTracker)
546     inside_brackets: bool = False
547     has_for: bool = False
548     _for_loop_variable: bool = False
549
550     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
551         """Add a new `leaf` to the end of the line.
552
553         Unless `preformatted` is True, the `leaf` will receive a new consistent
554         whitespace prefix and metadata applied by :class:`BracketTracker`.
555         Trailing commas are maybe removed, unpacked for loop variables are
556         demoted from being delimiters.
557
558         Inline comments are put aside.
559         """
560         has_value = leaf.value.strip()
561         if not has_value:
562             return
563
564         if self.leaves and not preformatted:
565             # Note: at this point leaf.prefix should be empty except for
566             # imports, for which we only preserve newlines.
567             leaf.prefix += whitespace(leaf)
568         if self.inside_brackets or not preformatted:
569             self.maybe_decrement_after_for_loop_variable(leaf)
570             self.bracket_tracker.mark(leaf)
571             self.maybe_remove_trailing_comma(leaf)
572             self.maybe_increment_for_loop_variable(leaf)
573
574         if not self.append_comment(leaf):
575             self.leaves.append(leaf)
576
577     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
578         """Like :func:`append()` but disallow invalid standalone comment structure.
579
580         Raises ValueError when any `leaf` is appended after a standalone comment
581         or when a standalone comment is not the first leaf on the line.
582         """
583         if self.bracket_tracker.depth == 0:
584             if self.is_comment:
585                 raise ValueError("cannot append to standalone comments")
586
587             if self.leaves and leaf.type == STANDALONE_COMMENT:
588                 raise ValueError(
589                     "cannot append standalone comments to a populated line"
590                 )
591
592         self.append(leaf, preformatted=preformatted)
593
594     @property
595     def is_comment(self) -> bool:
596         """Is this line a standalone comment?"""
597         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
598
599     @property
600     def is_decorator(self) -> bool:
601         """Is this line a decorator?"""
602         return bool(self) and self.leaves[0].type == token.AT
603
604     @property
605     def is_import(self) -> bool:
606         """Is this an import line?"""
607         return bool(self) and is_import(self.leaves[0])
608
609     @property
610     def is_class(self) -> bool:
611         """Is this line a class definition?"""
612         return (
613             bool(self)
614             and self.leaves[0].type == token.NAME
615             and self.leaves[0].value == 'class'
616         )
617
618     @property
619     def is_def(self) -> bool:
620         """Is this a function definition? (Also returns True for async defs.)"""
621         try:
622             first_leaf = self.leaves[0]
623         except IndexError:
624             return False
625
626         try:
627             second_leaf: Optional[Leaf] = self.leaves[1]
628         except IndexError:
629             second_leaf = None
630         return (
631             (first_leaf.type == token.NAME and first_leaf.value == 'def')
632             or (
633                 first_leaf.type == token.ASYNC
634                 and second_leaf is not None
635                 and second_leaf.type == token.NAME
636                 and second_leaf.value == 'def'
637             )
638         )
639
640     @property
641     def is_flow_control(self) -> bool:
642         """Is this line a flow control statement?
643
644         Those are `return`, `raise`, `break`, and `continue`.
645         """
646         return (
647             bool(self)
648             and self.leaves[0].type == token.NAME
649             and self.leaves[0].value in FLOW_CONTROL
650         )
651
652     @property
653     def is_yield(self) -> bool:
654         """Is this line a yield statement?"""
655         return (
656             bool(self)
657             and self.leaves[0].type == token.NAME
658             and self.leaves[0].value == 'yield'
659         )
660
661     @property
662     def contains_standalone_comments(self) -> bool:
663         """If so, needs to be split before emitting."""
664         for leaf in self.leaves:
665             if leaf.type == STANDALONE_COMMENT:
666                 return True
667
668         return False
669
670     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
671         """Remove trailing comma if there is one and it's safe."""
672         if not (
673             self.leaves
674             and self.leaves[-1].type == token.COMMA
675             and closing.type in CLOSING_BRACKETS
676         ):
677             return False
678
679         if closing.type == token.RBRACE:
680             self.remove_trailing_comma()
681             return True
682
683         if closing.type == token.RSQB:
684             comma = self.leaves[-1]
685             if comma.parent and comma.parent.type == syms.listmaker:
686                 self.remove_trailing_comma()
687                 return True
688
689         # For parens let's check if it's safe to remove the comma.  If the
690         # trailing one is the only one, we might mistakenly change a tuple
691         # into a different type by removing the comma.
692         depth = closing.bracket_depth + 1
693         commas = 0
694         opening = closing.opening_bracket
695         for _opening_index, leaf in enumerate(self.leaves):
696             if leaf is opening:
697                 break
698
699         else:
700             return False
701
702         for leaf in self.leaves[_opening_index + 1:]:
703             if leaf is closing:
704                 break
705
706             bracket_depth = leaf.bracket_depth
707             if bracket_depth == depth and leaf.type == token.COMMA:
708                 commas += 1
709                 if leaf.parent and leaf.parent.type == syms.arglist:
710                     commas += 1
711                     break
712
713         if commas > 1:
714             self.remove_trailing_comma()
715             return True
716
717         return False
718
719     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
720         """In a for loop, or comprehension, the variables are often unpacks.
721
722         To avoid splitting on the comma in this situation, increase the depth of
723         tokens between `for` and `in`.
724         """
725         if leaf.type == token.NAME and leaf.value == 'for':
726             self.has_for = True
727             self.bracket_tracker.depth += 1
728             self._for_loop_variable = True
729             return True
730
731         return False
732
733     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
734         """See `maybe_increment_for_loop_variable` above for explanation."""
735         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
736             self.bracket_tracker.depth -= 1
737             self._for_loop_variable = False
738             return True
739
740         return False
741
742     def append_comment(self, comment: Leaf) -> bool:
743         """Add an inline or standalone comment to the line."""
744         if (
745             comment.type == STANDALONE_COMMENT
746             and self.bracket_tracker.any_open_brackets()
747         ):
748             comment.prefix = ''
749             return False
750
751         if comment.type != token.COMMENT:
752             return False
753
754         after = len(self.leaves) - 1
755         if after == -1:
756             comment.type = STANDALONE_COMMENT
757             comment.prefix = ''
758             return False
759
760         else:
761             self.comments.append((after, comment))
762             return True
763
764     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
765         """Generate comments that should appear directly after `leaf`."""
766         for _leaf_index, _leaf in enumerate(self.leaves):
767             if leaf is _leaf:
768                 break
769
770         else:
771             return
772
773         for index, comment_after in self.comments:
774             if _leaf_index == index:
775                 yield comment_after
776
777     def remove_trailing_comma(self) -> None:
778         """Remove the trailing comma and moves the comments attached to it."""
779         comma_index = len(self.leaves) - 1
780         for i in range(len(self.comments)):
781             comment_index, comment = self.comments[i]
782             if comment_index == comma_index:
783                 self.comments[i] = (comma_index - 1, comment)
784         self.leaves.pop()
785
786     def __str__(self) -> str:
787         """Render the line."""
788         if not self:
789             return '\n'
790
791         indent = '    ' * self.depth
792         leaves = iter(self.leaves)
793         first = next(leaves)
794         res = f'{first.prefix}{indent}{first.value}'
795         for leaf in leaves:
796             res += str(leaf)
797         for _, comment in self.comments:
798             res += str(comment)
799         return res + '\n'
800
801     def __bool__(self) -> bool:
802         """Return True if the line has leaves or comments."""
803         return bool(self.leaves or self.comments)
804
805
806 class UnformattedLines(Line):
807     """Just like :class:`Line` but stores lines which aren't reformatted."""
808
809     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
810         """Just add a new `leaf` to the end of the lines.
811
812         The `preformatted` argument is ignored.
813
814         Keeps track of indentation `depth`, which is useful when the user
815         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
816         """
817         try:
818             list(generate_comments(leaf))
819         except FormatOn as f_on:
820             self.leaves.append(f_on.leaf_from_consumed(leaf))
821             raise
822
823         self.leaves.append(leaf)
824         if leaf.type == token.INDENT:
825             self.depth += 1
826         elif leaf.type == token.DEDENT:
827             self.depth -= 1
828
829     def __str__(self) -> str:
830         """Render unformatted lines from leaves which were added with `append()`.
831
832         `depth` is not used for indentation in this case.
833         """
834         if not self:
835             return '\n'
836
837         res = ''
838         for leaf in self.leaves:
839             res += str(leaf)
840         return res
841
842     def append_comment(self, comment: Leaf) -> bool:
843         """Not implemented in this class. Raises `NotImplementedError`."""
844         raise NotImplementedError("Unformatted lines don't store comments separately.")
845
846     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
847         """Does nothing and returns False."""
848         return False
849
850     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
851         """Does nothing and returns False."""
852         return False
853
854
855 @dataclass
856 class EmptyLineTracker:
857     """Provides a stateful method that returns the number of potential extra
858     empty lines needed before and after the currently processed line.
859
860     Note: this tracker works on lines that haven't been split yet.  It assumes
861     the prefix of the first leaf consists of optional newlines.  Those newlines
862     are consumed by `maybe_empty_lines()` and included in the computation.
863     """
864     previous_line: Optional[Line] = None
865     previous_after: int = 0
866     previous_defs: List[int] = Factory(list)
867
868     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
869         """Return the number of extra empty lines before and after the `current_line`.
870
871         This is for separating `def`, `async def` and `class` with extra empty
872         lines (two on module-level), as well as providing an extra empty line
873         after flow control keywords to make them more prominent.
874         """
875         if isinstance(current_line, UnformattedLines):
876             return 0, 0
877
878         before, after = self._maybe_empty_lines(current_line)
879         before -= self.previous_after
880         self.previous_after = after
881         self.previous_line = current_line
882         return before, after
883
884     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
885         max_allowed = 1
886         if current_line.depth == 0:
887             max_allowed = 2
888         if current_line.leaves:
889             # Consume the first leaf's extra newlines.
890             first_leaf = current_line.leaves[0]
891             before = first_leaf.prefix.count('\n')
892             before = min(before, max_allowed)
893             first_leaf.prefix = ''
894         else:
895             before = 0
896         depth = current_line.depth
897         while self.previous_defs and self.previous_defs[-1] >= depth:
898             self.previous_defs.pop()
899             before = 1 if depth else 2
900         is_decorator = current_line.is_decorator
901         if is_decorator or current_line.is_def or current_line.is_class:
902             if not is_decorator:
903                 self.previous_defs.append(depth)
904             if self.previous_line is None:
905                 # Don't insert empty lines before the first line in the file.
906                 return 0, 0
907
908             if self.previous_line and self.previous_line.is_decorator:
909                 # Don't insert empty lines between decorators.
910                 return 0, 0
911
912             newlines = 2
913             if current_line.depth:
914                 newlines -= 1
915             return newlines, 0
916
917         if current_line.is_flow_control:
918             return before, 1
919
920         if (
921             self.previous_line
922             and self.previous_line.is_import
923             and not current_line.is_import
924             and depth == self.previous_line.depth
925         ):
926             return (before or 1), 0
927
928         if (
929             self.previous_line
930             and self.previous_line.is_yield
931             and (not current_line.is_yield or depth != self.previous_line.depth)
932         ):
933             return (before or 1), 0
934
935         return before, 0
936
937
938 @dataclass
939 class LineGenerator(Visitor[Line]):
940     """Generates reformatted Line objects.  Empty lines are not emitted.
941
942     Note: destroys the tree it's visiting by mutating prefixes of its leaves
943     in ways that will no longer stringify to valid Python code on the tree.
944     """
945     current_line: Line = Factory(Line)
946
947     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
948         """Generate a line.
949
950         If the line is empty, only emit if it makes sense.
951         If the line is too long, split it first and then generate.
952
953         If any lines were generated, set up a new current_line.
954         """
955         if not self.current_line:
956             if self.current_line.__class__ == type:
957                 self.current_line.depth += indent
958             else:
959                 self.current_line = type(depth=self.current_line.depth + indent)
960             return  # Line is empty, don't emit. Creating a new one unnecessary.
961
962         complete_line = self.current_line
963         self.current_line = type(depth=complete_line.depth + indent)
964         yield complete_line
965
966     def visit(self, node: LN) -> Iterator[Line]:
967         """Main method to visit `node` and its children.
968
969         Yields :class:`Line` objects.
970         """
971         if isinstance(self.current_line, UnformattedLines):
972             # File contained `# fmt: off`
973             yield from self.visit_unformatted(node)
974
975         else:
976             yield from super().visit(node)
977
978     def visit_default(self, node: LN) -> Iterator[Line]:
979         """Default `visit_*()` implementation. Recurses to children of `node`."""
980         if isinstance(node, Leaf):
981             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
982             try:
983                 for comment in generate_comments(node):
984                     if any_open_brackets:
985                         # any comment within brackets is subject to splitting
986                         self.current_line.append(comment)
987                     elif comment.type == token.COMMENT:
988                         # regular trailing comment
989                         self.current_line.append(comment)
990                         yield from self.line()
991
992                     else:
993                         # regular standalone comment
994                         yield from self.line()
995
996                         self.current_line.append(comment)
997                         yield from self.line()
998
999             except FormatOff as f_off:
1000                 f_off.trim_prefix(node)
1001                 yield from self.line(type=UnformattedLines)
1002                 yield from self.visit(node)
1003
1004             except FormatOn as f_on:
1005                 # This only happens here if somebody says "fmt: on" multiple
1006                 # times in a row.
1007                 f_on.trim_prefix(node)
1008                 yield from self.visit_default(node)
1009
1010             else:
1011                 normalize_prefix(node, inside_brackets=any_open_brackets)
1012                 if node.type not in WHITESPACE:
1013                     self.current_line.append(node)
1014         yield from super().visit_default(node)
1015
1016     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1017         """Increase indentation level, maybe yield a line."""
1018         # In blib2to3 INDENT never holds comments.
1019         yield from self.line(+1)
1020         yield from self.visit_default(node)
1021
1022     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1023         """Decrease indentation level, maybe yield a line."""
1024         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1025         yield from self.line(-1)
1026
1027     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1028         """Visit a statement.
1029
1030         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1031         `def`, `with`, and `class`.
1032
1033         The relevant Python language `keywords` for a given statement will be NAME
1034         leaves within it. This methods puts those on a separate line.
1035         """
1036         for child in node.children:
1037             if child.type == token.NAME and child.value in keywords:  # type: ignore
1038                 yield from self.line()
1039
1040             yield from self.visit(child)
1041
1042     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1043         """Visit a statement without nested statements."""
1044         is_suite_like = node.parent and node.parent.type in STATEMENT
1045         if is_suite_like:
1046             yield from self.line(+1)
1047             yield from self.visit_default(node)
1048             yield from self.line(-1)
1049
1050         else:
1051             yield from self.line()
1052             yield from self.visit_default(node)
1053
1054     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1055         """Visit `async def`, `async for`, `async with`."""
1056         yield from self.line()
1057
1058         children = iter(node.children)
1059         for child in children:
1060             yield from self.visit(child)
1061
1062             if child.type == token.ASYNC:
1063                 break
1064
1065         internal_stmt = next(children)
1066         for child in internal_stmt.children:
1067             yield from self.visit(child)
1068
1069     def visit_decorators(self, node: Node) -> Iterator[Line]:
1070         """Visit decorators."""
1071         for child in node.children:
1072             yield from self.line()
1073             yield from self.visit(child)
1074
1075     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1076         """Remove a semicolon and put the other statement on a separate line."""
1077         yield from self.line()
1078
1079     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1080         """End of file. Process outstanding comments and end with a newline."""
1081         yield from self.visit_default(leaf)
1082         yield from self.line()
1083
1084     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1085         """Used when file contained a `# fmt: off`."""
1086         if isinstance(node, Node):
1087             for child in node.children:
1088                 yield from self.visit(child)
1089
1090         else:
1091             try:
1092                 self.current_line.append(node)
1093             except FormatOn as f_on:
1094                 f_on.trim_prefix(node)
1095                 yield from self.line()
1096                 yield from self.visit(node)
1097
1098     def __attrs_post_init__(self) -> None:
1099         """You are in a twisty little maze of passages."""
1100         v = self.visit_stmt
1101         self.visit_if_stmt = partial(v, keywords={'if', 'else', 'elif'})
1102         self.visit_while_stmt = partial(v, keywords={'while', 'else'})
1103         self.visit_for_stmt = partial(v, keywords={'for', 'else'})
1104         self.visit_try_stmt = partial(v, keywords={'try', 'except', 'else', 'finally'})
1105         self.visit_except_clause = partial(v, keywords={'except'})
1106         self.visit_funcdef = partial(v, keywords={'def'})
1107         self.visit_with_stmt = partial(v, keywords={'with'})
1108         self.visit_classdef = partial(v, keywords={'class'})
1109         self.visit_async_funcdef = self.visit_async_stmt
1110         self.visit_decorated = self.visit_decorators
1111
1112
1113 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1114 OPENING_BRACKETS = set(BRACKET.keys())
1115 CLOSING_BRACKETS = set(BRACKET.values())
1116 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1117 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1118
1119
1120 def whitespace(leaf: Leaf) -> str:  # noqa C901
1121     """Return whitespace prefix if needed for the given `leaf`."""
1122     NO = ''
1123     SPACE = ' '
1124     DOUBLESPACE = '  '
1125     t = leaf.type
1126     p = leaf.parent
1127     v = leaf.value
1128     if t in ALWAYS_NO_SPACE:
1129         return NO
1130
1131     if t == token.COMMENT:
1132         return DOUBLESPACE
1133
1134     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1135     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1136         return NO
1137
1138     prev = leaf.prev_sibling
1139     if not prev:
1140         prevp = preceding_leaf(p)
1141         if not prevp or prevp.type in OPENING_BRACKETS:
1142             return NO
1143
1144         if t == token.COLON:
1145             return SPACE if prevp.type == token.COMMA else NO
1146
1147         if prevp.type == token.EQUAL:
1148             if prevp.parent:
1149                 if prevp.parent.type in {
1150                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1151                 }:
1152                     return NO
1153
1154                 elif prevp.parent.type == syms.typedargslist:
1155                     # A bit hacky: if the equal sign has whitespace, it means we
1156                     # previously found it's a typed argument.  So, we're using
1157                     # that, too.
1158                     return prevp.prefix
1159
1160         elif prevp.type == token.DOUBLESTAR:
1161             if prevp.parent and prevp.parent.type in {
1162                 syms.arglist,
1163                 syms.argument,
1164                 syms.dictsetmaker,
1165                 syms.parameters,
1166                 syms.typedargslist,
1167                 syms.varargslist,
1168             }:
1169                 return NO
1170
1171         elif prevp.type == token.COLON:
1172             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1173                 return NO
1174
1175         elif (
1176             prevp.parent
1177             and prevp.parent.type in {syms.factor, syms.star_expr}
1178             and prevp.type in MATH_OPERATORS
1179         ):
1180             return NO
1181
1182         elif (
1183             prevp.type == token.RIGHTSHIFT
1184             and prevp.parent
1185             and prevp.parent.type == syms.shift_expr
1186             and prevp.prev_sibling
1187             and prevp.prev_sibling.type == token.NAME
1188             and prevp.prev_sibling.value == 'print'  # type: ignore
1189         ):
1190             # Python 2 print chevron
1191             return NO
1192
1193     elif prev.type in OPENING_BRACKETS:
1194         return NO
1195
1196     if p.type in {syms.parameters, syms.arglist}:
1197         # untyped function signatures or calls
1198         if t == token.RPAR:
1199             return NO
1200
1201         if not prev or prev.type != token.COMMA:
1202             return NO
1203
1204     elif p.type == syms.varargslist:
1205         # lambdas
1206         if t == token.RPAR:
1207             return NO
1208
1209         if prev and prev.type != token.COMMA:
1210             return NO
1211
1212     elif p.type == syms.typedargslist:
1213         # typed function signatures
1214         if not prev:
1215             return NO
1216
1217         if t == token.EQUAL:
1218             if prev.type != syms.tname:
1219                 return NO
1220
1221         elif prev.type == token.EQUAL:
1222             # A bit hacky: if the equal sign has whitespace, it means we
1223             # previously found it's a typed argument.  So, we're using that, too.
1224             return prev.prefix
1225
1226         elif prev.type != token.COMMA:
1227             return NO
1228
1229     elif p.type == syms.tname:
1230         # type names
1231         if not prev:
1232             prevp = preceding_leaf(p)
1233             if not prevp or prevp.type != token.COMMA:
1234                 return NO
1235
1236     elif p.type == syms.trailer:
1237         # attributes and calls
1238         if t == token.LPAR or t == token.RPAR:
1239             return NO
1240
1241         if not prev:
1242             if t == token.DOT:
1243                 prevp = preceding_leaf(p)
1244                 if not prevp or prevp.type != token.NUMBER:
1245                     return NO
1246
1247             elif t == token.LSQB:
1248                 return NO
1249
1250         elif prev.type != token.COMMA:
1251             return NO
1252
1253     elif p.type == syms.argument:
1254         # single argument
1255         if t == token.EQUAL:
1256             return NO
1257
1258         if not prev:
1259             prevp = preceding_leaf(p)
1260             if not prevp or prevp.type == token.LPAR:
1261                 return NO
1262
1263         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1264             return NO
1265
1266     elif p.type == syms.decorator:
1267         # decorators
1268         return NO
1269
1270     elif p.type == syms.dotted_name:
1271         if prev:
1272             return NO
1273
1274         prevp = preceding_leaf(p)
1275         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1276             return NO
1277
1278     elif p.type == syms.classdef:
1279         if t == token.LPAR:
1280             return NO
1281
1282         if prev and prev.type == token.LPAR:
1283             return NO
1284
1285     elif p.type == syms.subscript:
1286         # indexing
1287         if not prev:
1288             assert p.parent is not None, "subscripts are always parented"
1289             if p.parent.type == syms.subscriptlist:
1290                 return SPACE
1291
1292             return NO
1293
1294         else:
1295             return NO
1296
1297     elif p.type == syms.atom:
1298         if prev and t == token.DOT:
1299             # dots, but not the first one.
1300             return NO
1301
1302     elif (
1303         p.type == syms.listmaker
1304         or p.type == syms.testlist_gexp
1305         or p.type == syms.subscriptlist
1306     ):
1307         # list interior, including unpacking
1308         if not prev:
1309             return NO
1310
1311     elif p.type == syms.dictsetmaker:
1312         # dict and set interior, including unpacking
1313         if not prev:
1314             return NO
1315
1316         if prev.type == token.DOUBLESTAR:
1317             return NO
1318
1319     elif p.type in {syms.factor, syms.star_expr}:
1320         # unary ops
1321         if not prev:
1322             prevp = preceding_leaf(p)
1323             if not prevp or prevp.type in OPENING_BRACKETS:
1324                 return NO
1325
1326             prevp_parent = prevp.parent
1327             assert prevp_parent is not None
1328             if prevp.type == token.COLON and prevp_parent.type in {
1329                 syms.subscript, syms.sliceop
1330             }:
1331                 return NO
1332
1333             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1334                 return NO
1335
1336         elif t == token.NAME or t == token.NUMBER:
1337             return NO
1338
1339     elif p.type == syms.import_from:
1340         if t == token.DOT:
1341             if prev and prev.type == token.DOT:
1342                 return NO
1343
1344         elif t == token.NAME:
1345             if v == 'import':
1346                 return SPACE
1347
1348             if prev and prev.type == token.DOT:
1349                 return NO
1350
1351     elif p.type == syms.sliceop:
1352         return NO
1353
1354     return SPACE
1355
1356
1357 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1358     """Return the first leaf that precedes `node`, if any."""
1359     while node:
1360         res = node.prev_sibling
1361         if res:
1362             if isinstance(res, Leaf):
1363                 return res
1364
1365             try:
1366                 return list(res.leaves())[-1]
1367
1368             except IndexError:
1369                 return None
1370
1371         node = node.parent
1372     return None
1373
1374
1375 def is_delimiter(leaf: Leaf) -> int:
1376     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1377
1378     Higher numbers are higher priority.
1379     """
1380     if leaf.type == token.COMMA:
1381         return COMMA_PRIORITY
1382
1383     if leaf.type in COMPARATORS:
1384         return COMPARATOR_PRIORITY
1385
1386     if (
1387         leaf.type in MATH_OPERATORS
1388         and leaf.parent
1389         and leaf.parent.type not in {syms.factor, syms.star_expr}
1390     ):
1391         return MATH_PRIORITY
1392
1393     return 0
1394
1395
1396 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1397     """Clean the prefix of the `leaf` and generate comments from it, if any.
1398
1399     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1400     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1401     move because it does away with modifying the grammar to include all the
1402     possible places in which comments can be placed.
1403
1404     The sad consequence for us though is that comments don't "belong" anywhere.
1405     This is why this function generates simple parentless Leaf objects for
1406     comments.  We simply don't know what the correct parent should be.
1407
1408     No matter though, we can live without this.  We really only need to
1409     differentiate between inline and standalone comments.  The latter don't
1410     share the line with any code.
1411
1412     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1413     are emitted with a fake STANDALONE_COMMENT token identifier.
1414     """
1415     p = leaf.prefix
1416     if not p:
1417         return
1418
1419     if '#' not in p:
1420         return
1421
1422     consumed = 0
1423     nlines = 0
1424     for index, line in enumerate(p.split('\n')):
1425         consumed += len(line) + 1  # adding the length of the split '\n'
1426         line = line.lstrip()
1427         if not line:
1428             nlines += 1
1429         if not line.startswith('#'):
1430             continue
1431
1432         if index == 0 and leaf.type != token.ENDMARKER:
1433             comment_type = token.COMMENT  # simple trailing comment
1434         else:
1435             comment_type = STANDALONE_COMMENT
1436         comment = make_comment(line)
1437         yield Leaf(comment_type, comment, prefix='\n' * nlines)
1438
1439         if comment in {'# fmt: on', '# yapf: enable'}:
1440             raise FormatOn(consumed)
1441
1442         if comment in {'# fmt: off', '# yapf: disable'}:
1443             raise FormatOff(consumed)
1444
1445         nlines = 0
1446
1447
1448 def make_comment(content: str) -> str:
1449     """Return a consistently formatted comment from the given `content` string.
1450
1451     All comments (except for "##", "#!", "#:") should have a single space between
1452     the hash sign and the content.
1453
1454     If `content` didn't start with a hash sign, one is provided.
1455     """
1456     content = content.rstrip()
1457     if not content:
1458         return '#'
1459
1460     if content[0] == '#':
1461         content = content[1:]
1462     if content and content[0] not in ' !:#':
1463         content = ' ' + content
1464     return '#' + content
1465
1466
1467 def split_line(
1468     line: Line, line_length: int, inner: bool = False, py36: bool = False
1469 ) -> Iterator[Line]:
1470     """Split a `line` into potentially many lines.
1471
1472     They should fit in the allotted `line_length` but might not be able to.
1473     `inner` signifies that there were a pair of brackets somewhere around the
1474     current `line`, possibly transitively. This means we can fallback to splitting
1475     by delimiters if the LHS/RHS don't yield any results.
1476
1477     If `py36` is True, splitting may generate syntax that is only compatible
1478     with Python 3.6 and later.
1479     """
1480     if isinstance(line, UnformattedLines) or line.is_comment:
1481         yield line
1482         return
1483
1484     line_str = str(line).strip('\n')
1485     if (
1486         len(line_str) <= line_length
1487         and '\n' not in line_str  # multiline strings
1488         and not line.contains_standalone_comments
1489     ):
1490         yield line
1491         return
1492
1493     split_funcs: List[SplitFunc]
1494     if line.is_def:
1495         split_funcs = [left_hand_split]
1496     elif line.inside_brackets:
1497         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1498     else:
1499         split_funcs = [right_hand_split]
1500     for split_func in split_funcs:
1501         # We are accumulating lines in `result` because we might want to abort
1502         # mission and return the original line in the end, or attempt a different
1503         # split altogether.
1504         result: List[Line] = []
1505         try:
1506             for l in split_func(line, py36):
1507                 if str(l).strip('\n') == line_str:
1508                     raise CannotSplit("Split function returned an unchanged result")
1509
1510                 result.extend(
1511                     split_line(l, line_length=line_length, inner=True, py36=py36)
1512                 )
1513         except CannotSplit as cs:
1514             continue
1515
1516         else:
1517             yield from result
1518             break
1519
1520     else:
1521         yield line
1522
1523
1524 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1525     """Split line into many lines, starting with the first matching bracket pair.
1526
1527     Note: this usually looks weird, only use this for function definitions.
1528     Prefer RHS otherwise.
1529     """
1530     head = Line(depth=line.depth)
1531     body = Line(depth=line.depth + 1, inside_brackets=True)
1532     tail = Line(depth=line.depth)
1533     tail_leaves: List[Leaf] = []
1534     body_leaves: List[Leaf] = []
1535     head_leaves: List[Leaf] = []
1536     current_leaves = head_leaves
1537     matching_bracket = None
1538     for leaf in line.leaves:
1539         if (
1540             current_leaves is body_leaves
1541             and leaf.type in CLOSING_BRACKETS
1542             and leaf.opening_bracket is matching_bracket
1543         ):
1544             current_leaves = tail_leaves if body_leaves else head_leaves
1545         current_leaves.append(leaf)
1546         if current_leaves is head_leaves:
1547             if leaf.type in OPENING_BRACKETS:
1548                 matching_bracket = leaf
1549                 current_leaves = body_leaves
1550     # Since body is a new indent level, remove spurious leading whitespace.
1551     if body_leaves:
1552         normalize_prefix(body_leaves[0], inside_brackets=True)
1553     # Build the new lines.
1554     for result, leaves in (
1555         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1556     ):
1557         for leaf in leaves:
1558             result.append(leaf, preformatted=True)
1559             for comment_after in line.comments_after(leaf):
1560                 result.append(comment_after, preformatted=True)
1561     bracket_split_succeeded_or_raise(head, body, tail)
1562     for result in (head, body, tail):
1563         if result:
1564             yield result
1565
1566
1567 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1568     """Split line into many lines, starting with the last matching bracket pair."""
1569     head = Line(depth=line.depth)
1570     body = Line(depth=line.depth + 1, inside_brackets=True)
1571     tail = Line(depth=line.depth)
1572     tail_leaves: List[Leaf] = []
1573     body_leaves: List[Leaf] = []
1574     head_leaves: List[Leaf] = []
1575     current_leaves = tail_leaves
1576     opening_bracket = None
1577     for leaf in reversed(line.leaves):
1578         if current_leaves is body_leaves:
1579             if leaf is opening_bracket:
1580                 current_leaves = head_leaves if body_leaves else tail_leaves
1581         current_leaves.append(leaf)
1582         if current_leaves is tail_leaves:
1583             if leaf.type in CLOSING_BRACKETS:
1584                 opening_bracket = leaf.opening_bracket
1585                 current_leaves = body_leaves
1586     tail_leaves.reverse()
1587     body_leaves.reverse()
1588     head_leaves.reverse()
1589     # Since body is a new indent level, remove spurious leading whitespace.
1590     if body_leaves:
1591         normalize_prefix(body_leaves[0], inside_brackets=True)
1592     # Build the new lines.
1593     for result, leaves in (
1594         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1595     ):
1596         for leaf in leaves:
1597             result.append(leaf, preformatted=True)
1598             for comment_after in line.comments_after(leaf):
1599                 result.append(comment_after, preformatted=True)
1600     bracket_split_succeeded_or_raise(head, body, tail)
1601     for result in (head, body, tail):
1602         if result:
1603             yield result
1604
1605
1606 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1607     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1608
1609     Do nothing otherwise.
1610
1611     A left- or right-hand split is based on a pair of brackets. Content before
1612     (and including) the opening bracket is left on one line, content inside the
1613     brackets is put on a separate line, and finally content starting with and
1614     following the closing bracket is put on a separate line.
1615
1616     Those are called `head`, `body`, and `tail`, respectively. If the split
1617     produced the same line (all content in `head`) or ended up with an empty `body`
1618     and the `tail` is just the closing bracket, then it's considered failed.
1619     """
1620     tail_len = len(str(tail).strip())
1621     if not body:
1622         if tail_len == 0:
1623             raise CannotSplit("Splitting brackets produced the same line")
1624
1625         elif tail_len < 3:
1626             raise CannotSplit(
1627                 f"Splitting brackets on an empty body to save "
1628                 f"{tail_len} characters is not worth it"
1629             )
1630
1631
1632 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1633     """Normalize prefix of the first leaf in every line returned by `split_func`.
1634
1635     This is a decorator over relevant split functions.
1636     """
1637
1638     @wraps(split_func)
1639     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1640         for l in split_func(line, py36):
1641             normalize_prefix(l.leaves[0], inside_brackets=True)
1642             yield l
1643
1644     return split_wrapper
1645
1646
1647 @dont_increase_indentation
1648 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1649     """Split according to delimiters of the highest priority.
1650
1651     If `py36` is True, the split will add trailing commas also in function
1652     signatures that contain `*` and `**`.
1653     """
1654     try:
1655         last_leaf = line.leaves[-1]
1656     except IndexError:
1657         raise CannotSplit("Line empty")
1658
1659     delimiters = line.bracket_tracker.delimiters
1660     try:
1661         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1662             exclude={id(last_leaf)}
1663         )
1664     except ValueError:
1665         raise CannotSplit("No delimiters found")
1666
1667     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1668     lowest_depth = sys.maxsize
1669     trailing_comma_safe = True
1670
1671     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1672         """Append `leaf` to current line or to new line if appending impossible."""
1673         nonlocal current_line
1674         try:
1675             current_line.append_safe(leaf, preformatted=True)
1676         except ValueError as ve:
1677             yield current_line
1678
1679             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1680             current_line.append(leaf)
1681
1682     for leaf in line.leaves:
1683         yield from append_to_line(leaf)
1684
1685         for comment_after in line.comments_after(leaf):
1686             yield from append_to_line(comment_after)
1687
1688         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1689         if (
1690             leaf.bracket_depth == lowest_depth
1691             and leaf.type == token.STAR
1692             or leaf.type == token.DOUBLESTAR
1693         ):
1694             trailing_comma_safe = trailing_comma_safe and py36
1695         leaf_priority = delimiters.get(id(leaf))
1696         if leaf_priority == delimiter_priority:
1697             yield current_line
1698
1699             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1700     if current_line:
1701         if (
1702             delimiter_priority == COMMA_PRIORITY
1703             and current_line.leaves[-1].type != token.COMMA
1704             and trailing_comma_safe
1705         ):
1706             current_line.append(Leaf(token.COMMA, ','))
1707         yield current_line
1708
1709
1710 @dont_increase_indentation
1711 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1712     """Split standalone comments from the rest of the line."""
1713     for leaf in line.leaves:
1714         if leaf.type == STANDALONE_COMMENT:
1715             if leaf.bracket_depth == 0:
1716                 break
1717
1718     else:
1719         raise CannotSplit("Line does not have any standalone comments")
1720
1721     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1722
1723     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1724         """Append `leaf` to current line or to new line if appending impossible."""
1725         nonlocal current_line
1726         try:
1727             current_line.append_safe(leaf, preformatted=True)
1728         except ValueError as ve:
1729             yield current_line
1730
1731             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1732             current_line.append(leaf)
1733
1734     for leaf in line.leaves:
1735         yield from append_to_line(leaf)
1736
1737         for comment_after in line.comments_after(leaf):
1738             yield from append_to_line(comment_after)
1739
1740     if current_line:
1741         yield current_line
1742
1743
1744 def is_import(leaf: Leaf) -> bool:
1745     """Return True if the given leaf starts an import statement."""
1746     p = leaf.parent
1747     t = leaf.type
1748     v = leaf.value
1749     return bool(
1750         t == token.NAME
1751         and (
1752             (v == 'import' and p and p.type == syms.import_name)
1753             or (v == 'from' and p and p.type == syms.import_from)
1754         )
1755     )
1756
1757
1758 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1759     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1760     else.
1761
1762     Note: don't use backslashes for formatting or you'll lose your voting rights.
1763     """
1764     if not inside_brackets:
1765         spl = leaf.prefix.split('#')
1766         if '\\' not in spl[0]:
1767             nl_count = spl[-1].count('\n')
1768             if len(spl) > 1:
1769                 nl_count -= 1
1770             leaf.prefix = '\n' * nl_count
1771             return
1772
1773     leaf.prefix = ''
1774
1775
1776 def is_python36(node: Node) -> bool:
1777     """Return True if the current file is using Python 3.6+ features.
1778
1779     Currently looking for:
1780     - f-strings; and
1781     - trailing commas after * or ** in function signatures.
1782     """
1783     for n in node.pre_order():
1784         if n.type == token.STRING:
1785             value_head = n.value[:2]  # type: ignore
1786             if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
1787                 return True
1788
1789         elif (
1790             n.type == syms.typedargslist
1791             and n.children
1792             and n.children[-1].type == token.COMMA
1793         ):
1794             for ch in n.children:
1795                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1796                     return True
1797
1798     return False
1799
1800
1801 PYTHON_EXTENSIONS = {'.py'}
1802 BLACKLISTED_DIRECTORIES = {
1803     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
1804 }
1805
1806
1807 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1808     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1809     and have one of the PYTHON_EXTENSIONS.
1810     """
1811     for child in path.iterdir():
1812         if child.is_dir():
1813             if child.name in BLACKLISTED_DIRECTORIES:
1814                 continue
1815
1816             yield from gen_python_files_in_dir(child)
1817
1818         elif child.suffix in PYTHON_EXTENSIONS:
1819             yield child
1820
1821
1822 @dataclass
1823 class Report:
1824     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1825     check: bool = False
1826     change_count: int = 0
1827     same_count: int = 0
1828     failure_count: int = 0
1829
1830     def done(self, src: Path, changed: bool) -> None:
1831         """Increment the counter for successful reformatting. Write out a message."""
1832         if changed:
1833             reformatted = 'would reformat' if self.check else 'reformatted'
1834             out(f'{reformatted} {src}')
1835             self.change_count += 1
1836         else:
1837             out(f'{src} already well formatted, good job.', bold=False)
1838             self.same_count += 1
1839
1840     def failed(self, src: Path, message: str) -> None:
1841         """Increment the counter for failed reformatting. Write out a message."""
1842         err(f'error: cannot format {src}: {message}')
1843         self.failure_count += 1
1844
1845     @property
1846     def return_code(self) -> int:
1847         """Return the exit code that the app should use.
1848
1849         This considers the current state of changed files and failures:
1850         - if there were any failures, return 123;
1851         - if any files were changed and --check is being used, return 1;
1852         - otherwise return 0.
1853         """
1854         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1855         # 126 we have special returncodes reserved by the shell.
1856         if self.failure_count:
1857             return 123
1858
1859         elif self.change_count and self.check:
1860             return 1
1861
1862         return 0
1863
1864     def __str__(self) -> str:
1865         """Render a color report of the current state.
1866
1867         Use `click.unstyle` to remove colors.
1868         """
1869         if self.check:
1870             reformatted = "would be reformatted"
1871             unchanged = "would be left unchanged"
1872             failed = "would fail to reformat"
1873         else:
1874             reformatted = "reformatted"
1875             unchanged = "left unchanged"
1876             failed = "failed to reformat"
1877         report = []
1878         if self.change_count:
1879             s = 's' if self.change_count > 1 else ''
1880             report.append(
1881                 click.style(f'{self.change_count} file{s} {reformatted}', bold=True)
1882             )
1883         if self.same_count:
1884             s = 's' if self.same_count > 1 else ''
1885             report.append(f'{self.same_count} file{s} {unchanged}')
1886         if self.failure_count:
1887             s = 's' if self.failure_count > 1 else ''
1888             report.append(
1889                 click.style(f'{self.failure_count} file{s} {failed}', fg='red')
1890             )
1891         return ', '.join(report) + '.'
1892
1893
1894 def assert_equivalent(src: str, dst: str) -> None:
1895     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1896
1897     import ast
1898     import traceback
1899
1900     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1901         """Simple visitor generating strings to compare ASTs by content."""
1902         yield f"{'  ' * depth}{node.__class__.__name__}("
1903
1904         for field in sorted(node._fields):
1905             try:
1906                 value = getattr(node, field)
1907             except AttributeError:
1908                 continue
1909
1910             yield f"{'  ' * (depth+1)}{field}="
1911
1912             if isinstance(value, list):
1913                 for item in value:
1914                     if isinstance(item, ast.AST):
1915                         yield from _v(item, depth + 2)
1916
1917             elif isinstance(value, ast.AST):
1918                 yield from _v(value, depth + 2)
1919
1920             else:
1921                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1922
1923         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1924
1925     try:
1926         src_ast = ast.parse(src)
1927     except Exception as exc:
1928         major, minor = sys.version_info[:2]
1929         raise AssertionError(
1930             f"cannot use --safe with this file; failed to parse source file "
1931             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1932             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1933         )
1934
1935     try:
1936         dst_ast = ast.parse(dst)
1937     except Exception as exc:
1938         log = dump_to_file(''.join(traceback.format_tb(exc.__traceback__)), dst)
1939         raise AssertionError(
1940             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1941             f"Please report a bug on https://github.com/ambv/black/issues.  "
1942             f"This invalid output might be helpful: {log}"
1943         ) from None
1944
1945     src_ast_str = '\n'.join(_v(src_ast))
1946     dst_ast_str = '\n'.join(_v(dst_ast))
1947     if src_ast_str != dst_ast_str:
1948         log = dump_to_file(diff(src_ast_str, dst_ast_str, 'src', 'dst'))
1949         raise AssertionError(
1950             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1951             f"the source.  "
1952             f"Please report a bug on https://github.com/ambv/black/issues.  "
1953             f"This diff might be helpful: {log}"
1954         ) from None
1955
1956
1957 def assert_stable(src: str, dst: str, line_length: int) -> None:
1958     """Raise AssertionError if `dst` reformats differently the second time."""
1959     newdst = format_str(dst, line_length=line_length)
1960     if dst != newdst:
1961         log = dump_to_file(
1962             diff(src, dst, 'source', 'first pass'),
1963             diff(dst, newdst, 'first pass', 'second pass'),
1964         )
1965         raise AssertionError(
1966             f"INTERNAL ERROR: Black produced different code on the second pass "
1967             f"of the formatter.  "
1968             f"Please report a bug on https://github.com/ambv/black/issues.  "
1969             f"This diff might be helpful: {log}"
1970         ) from None
1971
1972
1973 def dump_to_file(*output: str) -> str:
1974     """Dump `output` to a temporary file. Return path to the file."""
1975     import tempfile
1976
1977     with tempfile.NamedTemporaryFile(
1978         mode='w', prefix='blk_', suffix='.log', delete=False
1979     ) as f:
1980         for lines in output:
1981             f.write(lines)
1982             f.write('\n')
1983     return f.name
1984
1985
1986 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
1987     """Return a unified diff string between strings `a` and `b`."""
1988     import difflib
1989
1990     a_lines = [line + '\n' for line in a.split('\n')]
1991     b_lines = [line + '\n' for line in b.split('\n')]
1992     return ''.join(
1993         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
1994     )
1995
1996
1997 def cancel(tasks: List[asyncio.Task]) -> None:
1998     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
1999     err("Aborted!")
2000     for task in tasks:
2001         task.cancel()
2002
2003
2004 def shutdown(loop: BaseEventLoop) -> None:
2005     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2006     try:
2007         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2008         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2009         if not to_cancel:
2010             return
2011
2012         for task in to_cancel:
2013             task.cancel()
2014         loop.run_until_complete(
2015             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2016         )
2017     finally:
2018         # `concurrent.futures.Future` objects cannot be cancelled once they
2019         # are already running. There might be some when the `shutdown()` happened.
2020         # Silence their logger's spew about the event loop being closed.
2021         cf_logger = logging.getLogger("concurrent.futures")
2022         cf_logger.setLevel(logging.CRITICAL)
2023         loop.close()
2024
2025
2026 if __name__ == '__main__':
2027     main()