]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Normalize string quotes (#75)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial, wraps
7 import keyword
8 import logging
9 import os
10 from pathlib import Path
11 import tokenize
12 import signal
13 import sys
14 from typing import (
15     Callable,
16     Dict,
17     Generic,
18     Iterable,
19     Iterator,
20     List,
21     Optional,
22     Set,
23     Tuple,
24     Type,
25     TypeVar,
26     Union,
27 )
28
29 from attr import dataclass, Factory
30 import click
31
32 # lib2to3 fork
33 from blib2to3.pytree import Node, Leaf, type_repr
34 from blib2to3 import pygram, pytree
35 from blib2to3.pgen2 import driver, token
36 from blib2to3.pgen2.parse import ParseError
37
38 __version__ = "18.3a4"
39 DEFAULT_LINE_LENGTH = 88
40 # types
41 syms = pygram.python_symbols
42 FileContent = str
43 Encoding = str
44 Depth = int
45 NodeType = int
46 LeafID = int
47 Priority = int
48 Index = int
49 LN = Union[Leaf, Node]
50 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
51 out = partial(click.secho, bold=True, err=True)
52 err = partial(click.secho, fg="red", err=True)
53
54
55 class NothingChanged(UserWarning):
56     """Raised by :func:`format_file` when reformatted code is the same as source."""
57
58
59 class CannotSplit(Exception):
60     """A readable split that fits the allotted line length is impossible.
61
62     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
63     :func:`delimiter_split`.
64     """
65
66
67 class FormatError(Exception):
68     """Base exception for `# fmt: on` and `# fmt: off` handling.
69
70     It holds the number of bytes of the prefix consumed before the format
71     control comment appeared.
72     """
73
74     def __init__(self, consumed: int) -> None:
75         super().__init__(consumed)
76         self.consumed = consumed
77
78     def trim_prefix(self, leaf: Leaf) -> None:
79         leaf.prefix = leaf.prefix[self.consumed:]
80
81     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
82         """Returns a new Leaf from the consumed part of the prefix."""
83         unformatted_prefix = leaf.prefix[:self.consumed]
84         return Leaf(token.NEWLINE, unformatted_prefix)
85
86
87 class FormatOn(FormatError):
88     """Found a comment like `# fmt: on` in the file."""
89
90
91 class FormatOff(FormatError):
92     """Found a comment like `# fmt: off` in the file."""
93
94
95 @click.command()
96 @click.option(
97     "-l",
98     "--line-length",
99     type=int,
100     default=DEFAULT_LINE_LENGTH,
101     help="How many character per line to allow.",
102     show_default=True,
103 )
104 @click.option(
105     "--check",
106     is_flag=True,
107     help=(
108         "Don't write back the files, just return the status.  Return code 0 "
109         "means nothing would change.  Return code 1 means some files would be "
110         "reformatted.  Return code 123 means there was an internal error."
111     ),
112 )
113 @click.option(
114     "--fast/--safe",
115     is_flag=True,
116     help="If --fast given, skip temporary sanity checks. [default: --safe]",
117 )
118 @click.version_option(version=__version__)
119 @click.argument(
120     "src",
121     nargs=-1,
122     type=click.Path(
123         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
124     ),
125 )
126 @click.pass_context
127 def main(
128     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
129 ) -> None:
130     """The uncompromising code formatter."""
131     sources: List[Path] = []
132     for s in src:
133         p = Path(s)
134         if p.is_dir():
135             sources.extend(gen_python_files_in_dir(p))
136         elif p.is_file():
137             # if a file was explicitly given, we don't care about its extension
138             sources.append(p)
139         elif s == "-":
140             sources.append(Path("-"))
141         else:
142             err(f"invalid path: {s}")
143     if len(sources) == 0:
144         ctx.exit(0)
145     elif len(sources) == 1:
146         p = sources[0]
147         report = Report(check=check)
148         try:
149             if not p.is_file() and str(p) == "-":
150                 changed = format_stdin_to_stdout(
151                     line_length=line_length, fast=fast, write_back=not check
152                 )
153             else:
154                 changed = format_file_in_place(
155                     p, line_length=line_length, fast=fast, write_back=not check
156                 )
157             report.done(p, changed)
158         except Exception as exc:
159             report.failed(p, str(exc))
160         ctx.exit(report.return_code)
161     else:
162         loop = asyncio.get_event_loop()
163         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
164         return_code = 1
165         try:
166             return_code = loop.run_until_complete(
167                 schedule_formatting(
168                     sources, line_length, not check, fast, loop, executor
169                 )
170             )
171         finally:
172             shutdown(loop)
173             ctx.exit(return_code)
174
175
176 async def schedule_formatting(
177     sources: List[Path],
178     line_length: int,
179     write_back: bool,
180     fast: bool,
181     loop: BaseEventLoop,
182     executor: Executor,
183 ) -> int:
184     """Run formatting of `sources` in parallel using the provided `executor`.
185
186     (Use ProcessPoolExecutors for actual parallelism.)
187
188     `line_length`, `write_back`, and `fast` options are passed to
189     :func:`format_file_in_place`.
190     """
191     tasks = {
192         src: loop.run_in_executor(
193             executor, format_file_in_place, src, line_length, fast, write_back
194         )
195         for src in sources
196     }
197     _task_values = list(tasks.values())
198     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
199     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
200     await asyncio.wait(tasks.values())
201     cancelled = []
202     report = Report(check=not write_back)
203     for src, task in tasks.items():
204         if not task.done():
205             report.failed(src, "timed out, cancelling")
206             task.cancel()
207             cancelled.append(task)
208         elif task.cancelled():
209             cancelled.append(task)
210         elif task.exception():
211             report.failed(src, str(task.exception()))
212         else:
213             report.done(src, task.result())
214     if cancelled:
215         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
216     else:
217         out("All done! ✨ 🍰 ✨")
218     click.echo(str(report))
219     return report.return_code
220
221
222 def format_file_in_place(
223     src: Path, line_length: int, fast: bool, write_back: bool = False
224 ) -> bool:
225     """Format file under `src` path. Return True if changed.
226
227     If `write_back` is True, write reformatted code back to stdout.
228     `line_length` and `fast` options are passed to :func:`format_file_contents`.
229     """
230     with tokenize.open(src) as src_buffer:
231         src_contents = src_buffer.read()
232     try:
233         contents = format_file_contents(
234             src_contents, line_length=line_length, fast=fast
235         )
236     except NothingChanged:
237         return False
238
239     if write_back:
240         with open(src, "w", encoding=src_buffer.encoding) as f:
241             f.write(contents)
242     return True
243
244
245 def format_stdin_to_stdout(
246     line_length: int, fast: bool, write_back: bool = False
247 ) -> bool:
248     """Format file on stdin. Return True if changed.
249
250     If `write_back` is True, write reformatted code back to stdout.
251     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
252     """
253     contents = sys.stdin.read()
254     try:
255         contents = format_file_contents(contents, line_length=line_length, fast=fast)
256         return True
257
258     except NothingChanged:
259         return False
260
261     finally:
262         if write_back:
263             sys.stdout.write(contents)
264
265
266 def format_file_contents(
267     src_contents: str, line_length: int, fast: bool
268 ) -> FileContent:
269     """Reformat contents a file and return new contents.
270
271     If `fast` is False, additionally confirm that the reformatted code is
272     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
273     `line_length` is passed to :func:`format_str`.
274     """
275     if src_contents.strip() == "":
276         raise NothingChanged
277
278     dst_contents = format_str(src_contents, line_length=line_length)
279     if src_contents == dst_contents:
280         raise NothingChanged
281
282     if not fast:
283         assert_equivalent(src_contents, dst_contents)
284         assert_stable(src_contents, dst_contents, line_length=line_length)
285     return dst_contents
286
287
288 def format_str(src_contents: str, line_length: int) -> FileContent:
289     """Reformat a string and return new contents.
290
291     `line_length` determines how many characters per line are allowed.
292     """
293     src_node = lib2to3_parse(src_contents)
294     dst_contents = ""
295     lines = LineGenerator()
296     elt = EmptyLineTracker()
297     py36 = is_python36(src_node)
298     empty_line = Line()
299     after = 0
300     for current_line in lines.visit(src_node):
301         for _ in range(after):
302             dst_contents += str(empty_line)
303         before, after = elt.maybe_empty_lines(current_line)
304         for _ in range(before):
305             dst_contents += str(empty_line)
306         for line in split_line(current_line, line_length=line_length, py36=py36):
307             dst_contents += str(line)
308     return dst_contents
309
310
311 GRAMMARS = [
312     pygram.python_grammar_no_print_statement_no_exec_statement,
313     pygram.python_grammar_no_print_statement,
314     pygram.python_grammar_no_exec_statement,
315     pygram.python_grammar,
316 ]
317
318
319 def lib2to3_parse(src_txt: str) -> Node:
320     """Given a string with source, return the lib2to3 Node."""
321     grammar = pygram.python_grammar_no_print_statement
322     if src_txt[-1] != "\n":
323         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
324         src_txt += nl
325     for grammar in GRAMMARS:
326         drv = driver.Driver(grammar, pytree.convert)
327         try:
328             result = drv.parse_string(src_txt, True)
329             break
330
331         except ParseError as pe:
332             lineno, column = pe.context[1]
333             lines = src_txt.splitlines()
334             try:
335                 faulty_line = lines[lineno - 1]
336             except IndexError:
337                 faulty_line = "<line number missing in source>"
338             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
339     else:
340         raise exc from None
341
342     if isinstance(result, Leaf):
343         result = Node(syms.file_input, [result])
344     return result
345
346
347 def lib2to3_unparse(node: Node) -> str:
348     """Given a lib2to3 node, return its string representation."""
349     code = str(node)
350     return code
351
352
353 T = TypeVar("T")
354
355
356 class Visitor(Generic[T]):
357     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
358
359     def visit(self, node: LN) -> Iterator[T]:
360         """Main method to visit `node` and its children.
361
362         It tries to find a `visit_*()` method for the given `node.type`, like
363         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
364         If no dedicated `visit_*()` method is found, chooses `visit_default()`
365         instead.
366
367         Then yields objects of type `T` from the selected visitor.
368         """
369         if node.type < 256:
370             name = token.tok_name[node.type]
371         else:
372             name = type_repr(node.type)
373         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
374
375     def visit_default(self, node: LN) -> Iterator[T]:
376         """Default `visit_*()` implementation. Recurses to children of `node`."""
377         if isinstance(node, Node):
378             for child in node.children:
379                 yield from self.visit(child)
380
381
382 @dataclass
383 class DebugVisitor(Visitor[T]):
384     tree_depth: int = 0
385
386     def visit_default(self, node: LN) -> Iterator[T]:
387         indent = " " * (2 * self.tree_depth)
388         if isinstance(node, Node):
389             _type = type_repr(node.type)
390             out(f"{indent}{_type}", fg="yellow")
391             self.tree_depth += 1
392             for child in node.children:
393                 yield from self.visit(child)
394
395             self.tree_depth -= 1
396             out(f"{indent}/{_type}", fg="yellow", bold=False)
397         else:
398             _type = token.tok_name.get(node.type, str(node.type))
399             out(f"{indent}{_type}", fg="blue", nl=False)
400             if node.prefix:
401                 # We don't have to handle prefixes for `Node` objects since
402                 # that delegates to the first child anyway.
403                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
404             out(f" {node.value!r}", fg="blue", bold=False)
405
406     @classmethod
407     def show(cls, code: str) -> None:
408         """Pretty-print the lib2to3 AST of a given string of `code`.
409
410         Convenience method for debugging.
411         """
412         v: DebugVisitor[None] = DebugVisitor()
413         list(v.visit(lib2to3_parse(code)))
414
415
416 KEYWORDS = set(keyword.kwlist)
417 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
418 FLOW_CONTROL = {"return", "raise", "break", "continue"}
419 STATEMENT = {
420     syms.if_stmt,
421     syms.while_stmt,
422     syms.for_stmt,
423     syms.try_stmt,
424     syms.except_clause,
425     syms.with_stmt,
426     syms.funcdef,
427     syms.classdef,
428 }
429 STANDALONE_COMMENT = 153
430 LOGIC_OPERATORS = {"and", "or"}
431 COMPARATORS = {
432     token.LESS,
433     token.GREATER,
434     token.EQEQUAL,
435     token.NOTEQUAL,
436     token.LESSEQUAL,
437     token.GREATEREQUAL,
438 }
439 MATH_OPERATORS = {
440     token.PLUS,
441     token.MINUS,
442     token.STAR,
443     token.SLASH,
444     token.VBAR,
445     token.AMPER,
446     token.PERCENT,
447     token.CIRCUMFLEX,
448     token.TILDE,
449     token.LEFTSHIFT,
450     token.RIGHTSHIFT,
451     token.DOUBLESTAR,
452     token.DOUBLESLASH,
453 }
454 COMPREHENSION_PRIORITY = 20
455 COMMA_PRIORITY = 10
456 LOGIC_PRIORITY = 5
457 STRING_PRIORITY = 4
458 COMPARATOR_PRIORITY = 3
459 MATH_PRIORITY = 1
460
461
462 @dataclass
463 class BracketTracker:
464     """Keeps track of brackets on a line."""
465
466     depth: int = 0
467     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
468     delimiters: Dict[LeafID, Priority] = Factory(dict)
469     previous: Optional[Leaf] = None
470
471     def mark(self, leaf: Leaf) -> None:
472         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
473
474         All leaves receive an int `bracket_depth` field that stores how deep
475         within brackets a given leaf is. 0 means there are no enclosing brackets
476         that started on this line.
477
478         If a leaf is itself a closing bracket, it receives an `opening_bracket`
479         field that it forms a pair with. This is a one-directional link to
480         avoid reference cycles.
481
482         If a leaf is a delimiter (a token on which Black can split the line if
483         needed) and it's on depth 0, its `id()` is stored in the tracker's
484         `delimiters` field.
485         """
486         if leaf.type == token.COMMENT:
487             return
488
489         if leaf.type in CLOSING_BRACKETS:
490             self.depth -= 1
491             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
492             leaf.opening_bracket = opening_bracket
493         leaf.bracket_depth = self.depth
494         if self.depth == 0:
495             delim = is_delimiter(leaf)
496             if delim:
497                 self.delimiters[id(leaf)] = delim
498             elif self.previous is not None:
499                 if leaf.type == token.STRING and self.previous.type == token.STRING:
500                     self.delimiters[id(self.previous)] = STRING_PRIORITY
501                 elif (
502                     leaf.type == token.NAME
503                     and leaf.value == "for"
504                     and leaf.parent
505                     and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
506                 ):
507                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
508                 elif (
509                     leaf.type == token.NAME
510                     and leaf.value == "if"
511                     and leaf.parent
512                     and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
513                 ):
514                     self.delimiters[id(self.previous)] = COMPREHENSION_PRIORITY
515                 elif (
516                     leaf.type == token.NAME
517                     and leaf.value in LOGIC_OPERATORS
518                     and leaf.parent
519                 ):
520                     self.delimiters[id(self.previous)] = LOGIC_PRIORITY
521         if leaf.type in OPENING_BRACKETS:
522             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
523             self.depth += 1
524         self.previous = leaf
525
526     def any_open_brackets(self) -> bool:
527         """Return True if there is an yet unmatched open bracket on the line."""
528         return bool(self.bracket_match)
529
530     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
531         """Return the highest priority of a delimiter found on the line.
532
533         Values are consistent with what `is_delimiter()` returns.
534         """
535         return max(v for k, v in self.delimiters.items() if k not in exclude)
536
537
538 @dataclass
539 class Line:
540     """Holds leaves and comments. Can be printed with `str(line)`."""
541
542     depth: int = 0
543     leaves: List[Leaf] = Factory(list)
544     comments: List[Tuple[Index, Leaf]] = Factory(list)
545     bracket_tracker: BracketTracker = Factory(BracketTracker)
546     inside_brackets: bool = False
547     has_for: bool = False
548     _for_loop_variable: bool = False
549
550     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
551         """Add a new `leaf` to the end of the line.
552
553         Unless `preformatted` is True, the `leaf` will receive a new consistent
554         whitespace prefix and metadata applied by :class:`BracketTracker`.
555         Trailing commas are maybe removed, unpacked for loop variables are
556         demoted from being delimiters.
557
558         Inline comments are put aside.
559         """
560         has_value = leaf.value.strip()
561         if not has_value:
562             return
563
564         if self.leaves and not preformatted:
565             # Note: at this point leaf.prefix should be empty except for
566             # imports, for which we only preserve newlines.
567             leaf.prefix += whitespace(leaf)
568         if self.inside_brackets or not preformatted:
569             self.maybe_decrement_after_for_loop_variable(leaf)
570             self.bracket_tracker.mark(leaf)
571             self.maybe_remove_trailing_comma(leaf)
572             self.maybe_increment_for_loop_variable(leaf)
573
574         if not self.append_comment(leaf):
575             self.leaves.append(leaf)
576
577     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
578         """Like :func:`append()` but disallow invalid standalone comment structure.
579
580         Raises ValueError when any `leaf` is appended after a standalone comment
581         or when a standalone comment is not the first leaf on the line.
582         """
583         if self.bracket_tracker.depth == 0:
584             if self.is_comment:
585                 raise ValueError("cannot append to standalone comments")
586
587             if self.leaves and leaf.type == STANDALONE_COMMENT:
588                 raise ValueError(
589                     "cannot append standalone comments to a populated line"
590                 )
591
592         self.append(leaf, preformatted=preformatted)
593
594     @property
595     def is_comment(self) -> bool:
596         """Is this line a standalone comment?"""
597         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
598
599     @property
600     def is_decorator(self) -> bool:
601         """Is this line a decorator?"""
602         return bool(self) and self.leaves[0].type == token.AT
603
604     @property
605     def is_import(self) -> bool:
606         """Is this an import line?"""
607         return bool(self) and is_import(self.leaves[0])
608
609     @property
610     def is_class(self) -> bool:
611         """Is this line a class definition?"""
612         return (
613             bool(self)
614             and self.leaves[0].type == token.NAME
615             and self.leaves[0].value == "class"
616         )
617
618     @property
619     def is_def(self) -> bool:
620         """Is this a function definition? (Also returns True for async defs.)"""
621         try:
622             first_leaf = self.leaves[0]
623         except IndexError:
624             return False
625
626         try:
627             second_leaf: Optional[Leaf] = self.leaves[1]
628         except IndexError:
629             second_leaf = None
630         return (
631             (first_leaf.type == token.NAME and first_leaf.value == "def")
632             or (
633                 first_leaf.type == token.ASYNC
634                 and second_leaf is not None
635                 and second_leaf.type == token.NAME
636                 and second_leaf.value == "def"
637             )
638         )
639
640     @property
641     def is_flow_control(self) -> bool:
642         """Is this line a flow control statement?
643
644         Those are `return`, `raise`, `break`, and `continue`.
645         """
646         return (
647             bool(self)
648             and self.leaves[0].type == token.NAME
649             and self.leaves[0].value in FLOW_CONTROL
650         )
651
652     @property
653     def is_yield(self) -> bool:
654         """Is this line a yield statement?"""
655         return (
656             bool(self)
657             and self.leaves[0].type == token.NAME
658             and self.leaves[0].value == "yield"
659         )
660
661     @property
662     def contains_standalone_comments(self) -> bool:
663         """If so, needs to be split before emitting."""
664         for leaf in self.leaves:
665             if leaf.type == STANDALONE_COMMENT:
666                 return True
667
668         return False
669
670     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
671         """Remove trailing comma if there is one and it's safe."""
672         if not (
673             self.leaves
674             and self.leaves[-1].type == token.COMMA
675             and closing.type in CLOSING_BRACKETS
676         ):
677             return False
678
679         if closing.type == token.RBRACE:
680             self.remove_trailing_comma()
681             return True
682
683         if closing.type == token.RSQB:
684             comma = self.leaves[-1]
685             if comma.parent and comma.parent.type == syms.listmaker:
686                 self.remove_trailing_comma()
687                 return True
688
689         # For parens let's check if it's safe to remove the comma.  If the
690         # trailing one is the only one, we might mistakenly change a tuple
691         # into a different type by removing the comma.
692         depth = closing.bracket_depth + 1
693         commas = 0
694         opening = closing.opening_bracket
695         for _opening_index, leaf in enumerate(self.leaves):
696             if leaf is opening:
697                 break
698
699         else:
700             return False
701
702         for leaf in self.leaves[_opening_index + 1:]:
703             if leaf is closing:
704                 break
705
706             bracket_depth = leaf.bracket_depth
707             if bracket_depth == depth and leaf.type == token.COMMA:
708                 commas += 1
709                 if leaf.parent and leaf.parent.type == syms.arglist:
710                     commas += 1
711                     break
712
713         if commas > 1:
714             self.remove_trailing_comma()
715             return True
716
717         return False
718
719     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
720         """In a for loop, or comprehension, the variables are often unpacks.
721
722         To avoid splitting on the comma in this situation, increase the depth of
723         tokens between `for` and `in`.
724         """
725         if leaf.type == token.NAME and leaf.value == "for":
726             self.has_for = True
727             self.bracket_tracker.depth += 1
728             self._for_loop_variable = True
729             return True
730
731         return False
732
733     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
734         """See `maybe_increment_for_loop_variable` above for explanation."""
735         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
736             self.bracket_tracker.depth -= 1
737             self._for_loop_variable = False
738             return True
739
740         return False
741
742     def append_comment(self, comment: Leaf) -> bool:
743         """Add an inline or standalone comment to the line."""
744         if (
745             comment.type == STANDALONE_COMMENT
746             and self.bracket_tracker.any_open_brackets()
747         ):
748             comment.prefix = ""
749             return False
750
751         if comment.type != token.COMMENT:
752             return False
753
754         after = len(self.leaves) - 1
755         if after == -1:
756             comment.type = STANDALONE_COMMENT
757             comment.prefix = ""
758             return False
759
760         else:
761             self.comments.append((after, comment))
762             return True
763
764     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
765         """Generate comments that should appear directly after `leaf`."""
766         for _leaf_index, _leaf in enumerate(self.leaves):
767             if leaf is _leaf:
768                 break
769
770         else:
771             return
772
773         for index, comment_after in self.comments:
774             if _leaf_index == index:
775                 yield comment_after
776
777     def remove_trailing_comma(self) -> None:
778         """Remove the trailing comma and moves the comments attached to it."""
779         comma_index = len(self.leaves) - 1
780         for i in range(len(self.comments)):
781             comment_index, comment = self.comments[i]
782             if comment_index == comma_index:
783                 self.comments[i] = (comma_index - 1, comment)
784         self.leaves.pop()
785
786     def __str__(self) -> str:
787         """Render the line."""
788         if not self:
789             return "\n"
790
791         indent = "    " * self.depth
792         leaves = iter(self.leaves)
793         first = next(leaves)
794         res = f"{first.prefix}{indent}{first.value}"
795         for leaf in leaves:
796             res += str(leaf)
797         for _, comment in self.comments:
798             res += str(comment)
799         return res + "\n"
800
801     def __bool__(self) -> bool:
802         """Return True if the line has leaves or comments."""
803         return bool(self.leaves or self.comments)
804
805
806 class UnformattedLines(Line):
807     """Just like :class:`Line` but stores lines which aren't reformatted."""
808
809     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
810         """Just add a new `leaf` to the end of the lines.
811
812         The `preformatted` argument is ignored.
813
814         Keeps track of indentation `depth`, which is useful when the user
815         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
816         """
817         try:
818             list(generate_comments(leaf))
819         except FormatOn as f_on:
820             self.leaves.append(f_on.leaf_from_consumed(leaf))
821             raise
822
823         self.leaves.append(leaf)
824         if leaf.type == token.INDENT:
825             self.depth += 1
826         elif leaf.type == token.DEDENT:
827             self.depth -= 1
828
829     def __str__(self) -> str:
830         """Render unformatted lines from leaves which were added with `append()`.
831
832         `depth` is not used for indentation in this case.
833         """
834         if not self:
835             return "\n"
836
837         res = ""
838         for leaf in self.leaves:
839             res += str(leaf)
840         return res
841
842     def append_comment(self, comment: Leaf) -> bool:
843         """Not implemented in this class. Raises `NotImplementedError`."""
844         raise NotImplementedError("Unformatted lines don't store comments separately.")
845
846     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
847         """Does nothing and returns False."""
848         return False
849
850     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
851         """Does nothing and returns False."""
852         return False
853
854
855 @dataclass
856 class EmptyLineTracker:
857     """Provides a stateful method that returns the number of potential extra
858     empty lines needed before and after the currently processed line.
859
860     Note: this tracker works on lines that haven't been split yet.  It assumes
861     the prefix of the first leaf consists of optional newlines.  Those newlines
862     are consumed by `maybe_empty_lines()` and included in the computation.
863     """
864     previous_line: Optional[Line] = None
865     previous_after: int = 0
866     previous_defs: List[int] = Factory(list)
867
868     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
869         """Return the number of extra empty lines before and after the `current_line`.
870
871         This is for separating `def`, `async def` and `class` with extra empty
872         lines (two on module-level), as well as providing an extra empty line
873         after flow control keywords to make them more prominent.
874         """
875         if isinstance(current_line, UnformattedLines):
876             return 0, 0
877
878         before, after = self._maybe_empty_lines(current_line)
879         before -= self.previous_after
880         self.previous_after = after
881         self.previous_line = current_line
882         return before, after
883
884     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
885         max_allowed = 1
886         if current_line.depth == 0:
887             max_allowed = 2
888         if current_line.leaves:
889             # Consume the first leaf's extra newlines.
890             first_leaf = current_line.leaves[0]
891             before = first_leaf.prefix.count("\n")
892             before = min(before, max_allowed)
893             first_leaf.prefix = ""
894         else:
895             before = 0
896         depth = current_line.depth
897         while self.previous_defs and self.previous_defs[-1] >= depth:
898             self.previous_defs.pop()
899             before = 1 if depth else 2
900         is_decorator = current_line.is_decorator
901         if is_decorator or current_line.is_def or current_line.is_class:
902             if not is_decorator:
903                 self.previous_defs.append(depth)
904             if self.previous_line is None:
905                 # Don't insert empty lines before the first line in the file.
906                 return 0, 0
907
908             if self.previous_line and self.previous_line.is_decorator:
909                 # Don't insert empty lines between decorators.
910                 return 0, 0
911
912             newlines = 2
913             if current_line.depth:
914                 newlines -= 1
915             return newlines, 0
916
917         if current_line.is_flow_control:
918             return before, 1
919
920         if (
921             self.previous_line
922             and self.previous_line.is_import
923             and not current_line.is_import
924             and depth == self.previous_line.depth
925         ):
926             return (before or 1), 0
927
928         if (
929             self.previous_line
930             and self.previous_line.is_yield
931             and (not current_line.is_yield or depth != self.previous_line.depth)
932         ):
933             return (before or 1), 0
934
935         return before, 0
936
937
938 @dataclass
939 class LineGenerator(Visitor[Line]):
940     """Generates reformatted Line objects.  Empty lines are not emitted.
941
942     Note: destroys the tree it's visiting by mutating prefixes of its leaves
943     in ways that will no longer stringify to valid Python code on the tree.
944     """
945     current_line: Line = Factory(Line)
946
947     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
948         """Generate a line.
949
950         If the line is empty, only emit if it makes sense.
951         If the line is too long, split it first and then generate.
952
953         If any lines were generated, set up a new current_line.
954         """
955         if not self.current_line:
956             if self.current_line.__class__ == type:
957                 self.current_line.depth += indent
958             else:
959                 self.current_line = type(depth=self.current_line.depth + indent)
960             return  # Line is empty, don't emit. Creating a new one unnecessary.
961
962         complete_line = self.current_line
963         self.current_line = type(depth=complete_line.depth + indent)
964         yield complete_line
965
966     def visit(self, node: LN) -> Iterator[Line]:
967         """Main method to visit `node` and its children.
968
969         Yields :class:`Line` objects.
970         """
971         if isinstance(self.current_line, UnformattedLines):
972             # File contained `# fmt: off`
973             yield from self.visit_unformatted(node)
974
975         else:
976             yield from super().visit(node)
977
978     def visit_default(self, node: LN) -> Iterator[Line]:
979         """Default `visit_*()` implementation. Recurses to children of `node`."""
980         if isinstance(node, Leaf):
981             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
982             try:
983                 for comment in generate_comments(node):
984                     if any_open_brackets:
985                         # any comment within brackets is subject to splitting
986                         self.current_line.append(comment)
987                     elif comment.type == token.COMMENT:
988                         # regular trailing comment
989                         self.current_line.append(comment)
990                         yield from self.line()
991
992                     else:
993                         # regular standalone comment
994                         yield from self.line()
995
996                         self.current_line.append(comment)
997                         yield from self.line()
998
999             except FormatOff as f_off:
1000                 f_off.trim_prefix(node)
1001                 yield from self.line(type=UnformattedLines)
1002                 yield from self.visit(node)
1003
1004             except FormatOn as f_on:
1005                 # This only happens here if somebody says "fmt: on" multiple
1006                 # times in a row.
1007                 f_on.trim_prefix(node)
1008                 yield from self.visit_default(node)
1009
1010             else:
1011                 normalize_prefix(node, inside_brackets=any_open_brackets)
1012                 if node.type == token.STRING:
1013                     normalize_string_quotes(node)
1014                 if node.type not in WHITESPACE:
1015                     self.current_line.append(node)
1016         yield from super().visit_default(node)
1017
1018     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1019         """Increase indentation level, maybe yield a line."""
1020         # In blib2to3 INDENT never holds comments.
1021         yield from self.line(+1)
1022         yield from self.visit_default(node)
1023
1024     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1025         """Decrease indentation level, maybe yield a line."""
1026         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1027         yield from self.line(-1)
1028
1029     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1030         """Visit a statement.
1031
1032         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1033         `def`, `with`, and `class`.
1034
1035         The relevant Python language `keywords` for a given statement will be NAME
1036         leaves within it. This methods puts those on a separate line.
1037         """
1038         for child in node.children:
1039             if child.type == token.NAME and child.value in keywords:  # type: ignore
1040                 yield from self.line()
1041
1042             yield from self.visit(child)
1043
1044     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1045         """Visit a statement without nested statements."""
1046         is_suite_like = node.parent and node.parent.type in STATEMENT
1047         if is_suite_like:
1048             yield from self.line(+1)
1049             yield from self.visit_default(node)
1050             yield from self.line(-1)
1051
1052         else:
1053             yield from self.line()
1054             yield from self.visit_default(node)
1055
1056     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1057         """Visit `async def`, `async for`, `async with`."""
1058         yield from self.line()
1059
1060         children = iter(node.children)
1061         for child in children:
1062             yield from self.visit(child)
1063
1064             if child.type == token.ASYNC:
1065                 break
1066
1067         internal_stmt = next(children)
1068         for child in internal_stmt.children:
1069             yield from self.visit(child)
1070
1071     def visit_decorators(self, node: Node) -> Iterator[Line]:
1072         """Visit decorators."""
1073         for child in node.children:
1074             yield from self.line()
1075             yield from self.visit(child)
1076
1077     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1078         """Remove a semicolon and put the other statement on a separate line."""
1079         yield from self.line()
1080
1081     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1082         """End of file. Process outstanding comments and end with a newline."""
1083         yield from self.visit_default(leaf)
1084         yield from self.line()
1085
1086     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1087         """Used when file contained a `# fmt: off`."""
1088         if isinstance(node, Node):
1089             for child in node.children:
1090                 yield from self.visit(child)
1091
1092         else:
1093             try:
1094                 self.current_line.append(node)
1095             except FormatOn as f_on:
1096                 f_on.trim_prefix(node)
1097                 yield from self.line()
1098                 yield from self.visit(node)
1099
1100     def __attrs_post_init__(self) -> None:
1101         """You are in a twisty little maze of passages."""
1102         v = self.visit_stmt
1103         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
1104         self.visit_while_stmt = partial(v, keywords={"while", "else"})
1105         self.visit_for_stmt = partial(v, keywords={"for", "else"})
1106         self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
1107         self.visit_except_clause = partial(v, keywords={"except"})
1108         self.visit_funcdef = partial(v, keywords={"def"})
1109         self.visit_with_stmt = partial(v, keywords={"with"})
1110         self.visit_classdef = partial(v, keywords={"class"})
1111         self.visit_async_funcdef = self.visit_async_stmt
1112         self.visit_decorated = self.visit_decorators
1113
1114
1115 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1116 OPENING_BRACKETS = set(BRACKET.keys())
1117 CLOSING_BRACKETS = set(BRACKET.values())
1118 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1119 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1120
1121
1122 def whitespace(leaf: Leaf) -> str:  # noqa C901
1123     """Return whitespace prefix if needed for the given `leaf`."""
1124     NO = ""
1125     SPACE = " "
1126     DOUBLESPACE = "  "
1127     t = leaf.type
1128     p = leaf.parent
1129     v = leaf.value
1130     if t in ALWAYS_NO_SPACE:
1131         return NO
1132
1133     if t == token.COMMENT:
1134         return DOUBLESPACE
1135
1136     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1137     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1138         return NO
1139
1140     prev = leaf.prev_sibling
1141     if not prev:
1142         prevp = preceding_leaf(p)
1143         if not prevp or prevp.type in OPENING_BRACKETS:
1144             return NO
1145
1146         if t == token.COLON:
1147             return SPACE if prevp.type == token.COMMA else NO
1148
1149         if prevp.type == token.EQUAL:
1150             if prevp.parent:
1151                 if prevp.parent.type in {
1152                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1153                 }:
1154                     return NO
1155
1156                 elif prevp.parent.type == syms.typedargslist:
1157                     # A bit hacky: if the equal sign has whitespace, it means we
1158                     # previously found it's a typed argument.  So, we're using
1159                     # that, too.
1160                     return prevp.prefix
1161
1162         elif prevp.type == token.DOUBLESTAR:
1163             if prevp.parent and prevp.parent.type in {
1164                 syms.arglist,
1165                 syms.argument,
1166                 syms.dictsetmaker,
1167                 syms.parameters,
1168                 syms.typedargslist,
1169                 syms.varargslist,
1170             }:
1171                 return NO
1172
1173         elif prevp.type == token.COLON:
1174             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1175                 return NO
1176
1177         elif (
1178             prevp.parent
1179             and prevp.parent.type in {syms.factor, syms.star_expr}
1180             and prevp.type in MATH_OPERATORS
1181         ):
1182             return NO
1183
1184         elif (
1185             prevp.type == token.RIGHTSHIFT
1186             and prevp.parent
1187             and prevp.parent.type == syms.shift_expr
1188             and prevp.prev_sibling
1189             and prevp.prev_sibling.type == token.NAME
1190             and prevp.prev_sibling.value == "print"  # type: ignore
1191         ):
1192             # Python 2 print chevron
1193             return NO
1194
1195     elif prev.type in OPENING_BRACKETS:
1196         return NO
1197
1198     if p.type in {syms.parameters, syms.arglist}:
1199         # untyped function signatures or calls
1200         if t == token.RPAR:
1201             return NO
1202
1203         if not prev or prev.type != token.COMMA:
1204             return NO
1205
1206     elif p.type == syms.varargslist:
1207         # lambdas
1208         if t == token.RPAR:
1209             return NO
1210
1211         if prev and prev.type != token.COMMA:
1212             return NO
1213
1214     elif p.type == syms.typedargslist:
1215         # typed function signatures
1216         if not prev:
1217             return NO
1218
1219         if t == token.EQUAL:
1220             if prev.type != syms.tname:
1221                 return NO
1222
1223         elif prev.type == token.EQUAL:
1224             # A bit hacky: if the equal sign has whitespace, it means we
1225             # previously found it's a typed argument.  So, we're using that, too.
1226             return prev.prefix
1227
1228         elif prev.type != token.COMMA:
1229             return NO
1230
1231     elif p.type == syms.tname:
1232         # type names
1233         if not prev:
1234             prevp = preceding_leaf(p)
1235             if not prevp or prevp.type != token.COMMA:
1236                 return NO
1237
1238     elif p.type == syms.trailer:
1239         # attributes and calls
1240         if t == token.LPAR or t == token.RPAR:
1241             return NO
1242
1243         if not prev:
1244             if t == token.DOT:
1245                 prevp = preceding_leaf(p)
1246                 if not prevp or prevp.type != token.NUMBER:
1247                     return NO
1248
1249             elif t == token.LSQB:
1250                 return NO
1251
1252         elif prev.type != token.COMMA:
1253             return NO
1254
1255     elif p.type == syms.argument:
1256         # single argument
1257         if t == token.EQUAL:
1258             return NO
1259
1260         if not prev:
1261             prevp = preceding_leaf(p)
1262             if not prevp or prevp.type == token.LPAR:
1263                 return NO
1264
1265         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1266             return NO
1267
1268     elif p.type == syms.decorator:
1269         # decorators
1270         return NO
1271
1272     elif p.type == syms.dotted_name:
1273         if prev:
1274             return NO
1275
1276         prevp = preceding_leaf(p)
1277         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1278             return NO
1279
1280     elif p.type == syms.classdef:
1281         if t == token.LPAR:
1282             return NO
1283
1284         if prev and prev.type == token.LPAR:
1285             return NO
1286
1287     elif p.type == syms.subscript:
1288         # indexing
1289         if not prev:
1290             assert p.parent is not None, "subscripts are always parented"
1291             if p.parent.type == syms.subscriptlist:
1292                 return SPACE
1293
1294             return NO
1295
1296         else:
1297             return NO
1298
1299     elif p.type == syms.atom:
1300         if prev and t == token.DOT:
1301             # dots, but not the first one.
1302             return NO
1303
1304     elif (
1305         p.type == syms.listmaker
1306         or p.type == syms.testlist_gexp
1307         or p.type == syms.subscriptlist
1308     ):
1309         # list interior, including unpacking
1310         if not prev:
1311             return NO
1312
1313     elif p.type == syms.dictsetmaker:
1314         # dict and set interior, including unpacking
1315         if not prev:
1316             return NO
1317
1318         if prev.type == token.DOUBLESTAR:
1319             return NO
1320
1321     elif p.type in {syms.factor, syms.star_expr}:
1322         # unary ops
1323         if not prev:
1324             prevp = preceding_leaf(p)
1325             if not prevp or prevp.type in OPENING_BRACKETS:
1326                 return NO
1327
1328             prevp_parent = prevp.parent
1329             assert prevp_parent is not None
1330             if prevp.type == token.COLON and prevp_parent.type in {
1331                 syms.subscript, syms.sliceop
1332             }:
1333                 return NO
1334
1335             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1336                 return NO
1337
1338         elif t == token.NAME or t == token.NUMBER:
1339             return NO
1340
1341     elif p.type == syms.import_from:
1342         if t == token.DOT:
1343             if prev and prev.type == token.DOT:
1344                 return NO
1345
1346         elif t == token.NAME:
1347             if v == "import":
1348                 return SPACE
1349
1350             if prev and prev.type == token.DOT:
1351                 return NO
1352
1353     elif p.type == syms.sliceop:
1354         return NO
1355
1356     return SPACE
1357
1358
1359 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1360     """Return the first leaf that precedes `node`, if any."""
1361     while node:
1362         res = node.prev_sibling
1363         if res:
1364             if isinstance(res, Leaf):
1365                 return res
1366
1367             try:
1368                 return list(res.leaves())[-1]
1369
1370             except IndexError:
1371                 return None
1372
1373         node = node.parent
1374     return None
1375
1376
1377 def is_delimiter(leaf: Leaf) -> int:
1378     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1379
1380     Higher numbers are higher priority.
1381     """
1382     if leaf.type == token.COMMA:
1383         return COMMA_PRIORITY
1384
1385     if leaf.type in COMPARATORS:
1386         return COMPARATOR_PRIORITY
1387
1388     if (
1389         leaf.type in MATH_OPERATORS
1390         and leaf.parent
1391         and leaf.parent.type not in {syms.factor, syms.star_expr}
1392     ):
1393         return MATH_PRIORITY
1394
1395     return 0
1396
1397
1398 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1399     """Clean the prefix of the `leaf` and generate comments from it, if any.
1400
1401     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1402     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1403     move because it does away with modifying the grammar to include all the
1404     possible places in which comments can be placed.
1405
1406     The sad consequence for us though is that comments don't "belong" anywhere.
1407     This is why this function generates simple parentless Leaf objects for
1408     comments.  We simply don't know what the correct parent should be.
1409
1410     No matter though, we can live without this.  We really only need to
1411     differentiate between inline and standalone comments.  The latter don't
1412     share the line with any code.
1413
1414     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1415     are emitted with a fake STANDALONE_COMMENT token identifier.
1416     """
1417     p = leaf.prefix
1418     if not p:
1419         return
1420
1421     if "#" not in p:
1422         return
1423
1424     consumed = 0
1425     nlines = 0
1426     for index, line in enumerate(p.split("\n")):
1427         consumed += len(line) + 1  # adding the length of the split '\n'
1428         line = line.lstrip()
1429         if not line:
1430             nlines += 1
1431         if not line.startswith("#"):
1432             continue
1433
1434         if index == 0 and leaf.type != token.ENDMARKER:
1435             comment_type = token.COMMENT  # simple trailing comment
1436         else:
1437             comment_type = STANDALONE_COMMENT
1438         comment = make_comment(line)
1439         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1440
1441         if comment in {"# fmt: on", "# yapf: enable"}:
1442             raise FormatOn(consumed)
1443
1444         if comment in {"# fmt: off", "# yapf: disable"}:
1445             raise FormatOff(consumed)
1446
1447         nlines = 0
1448
1449
1450 def make_comment(content: str) -> str:
1451     """Return a consistently formatted comment from the given `content` string.
1452
1453     All comments (except for "##", "#!", "#:") should have a single space between
1454     the hash sign and the content.
1455
1456     If `content` didn't start with a hash sign, one is provided.
1457     """
1458     content = content.rstrip()
1459     if not content:
1460         return "#"
1461
1462     if content[0] == "#":
1463         content = content[1:]
1464     if content and content[0] not in " !:#":
1465         content = " " + content
1466     return "#" + content
1467
1468
1469 def split_line(
1470     line: Line, line_length: int, inner: bool = False, py36: bool = False
1471 ) -> Iterator[Line]:
1472     """Split a `line` into potentially many lines.
1473
1474     They should fit in the allotted `line_length` but might not be able to.
1475     `inner` signifies that there were a pair of brackets somewhere around the
1476     current `line`, possibly transitively. This means we can fallback to splitting
1477     by delimiters if the LHS/RHS don't yield any results.
1478
1479     If `py36` is True, splitting may generate syntax that is only compatible
1480     with Python 3.6 and later.
1481     """
1482     if isinstance(line, UnformattedLines) or line.is_comment:
1483         yield line
1484         return
1485
1486     line_str = str(line).strip("\n")
1487     if (
1488         len(line_str) <= line_length
1489         and "\n" not in line_str  # multiline strings
1490         and not line.contains_standalone_comments
1491     ):
1492         yield line
1493         return
1494
1495     split_funcs: List[SplitFunc]
1496     if line.is_def:
1497         split_funcs = [left_hand_split]
1498     elif line.inside_brackets:
1499         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1500     else:
1501         split_funcs = [right_hand_split]
1502     for split_func in split_funcs:
1503         # We are accumulating lines in `result` because we might want to abort
1504         # mission and return the original line in the end, or attempt a different
1505         # split altogether.
1506         result: List[Line] = []
1507         try:
1508             for l in split_func(line, py36):
1509                 if str(l).strip("\n") == line_str:
1510                     raise CannotSplit("Split function returned an unchanged result")
1511
1512                 result.extend(
1513                     split_line(l, line_length=line_length, inner=True, py36=py36)
1514                 )
1515         except CannotSplit as cs:
1516             continue
1517
1518         else:
1519             yield from result
1520             break
1521
1522     else:
1523         yield line
1524
1525
1526 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1527     """Split line into many lines, starting with the first matching bracket pair.
1528
1529     Note: this usually looks weird, only use this for function definitions.
1530     Prefer RHS otherwise.
1531     """
1532     head = Line(depth=line.depth)
1533     body = Line(depth=line.depth + 1, inside_brackets=True)
1534     tail = Line(depth=line.depth)
1535     tail_leaves: List[Leaf] = []
1536     body_leaves: List[Leaf] = []
1537     head_leaves: List[Leaf] = []
1538     current_leaves = head_leaves
1539     matching_bracket = None
1540     for leaf in line.leaves:
1541         if (
1542             current_leaves is body_leaves
1543             and leaf.type in CLOSING_BRACKETS
1544             and leaf.opening_bracket is matching_bracket
1545         ):
1546             current_leaves = tail_leaves if body_leaves else head_leaves
1547         current_leaves.append(leaf)
1548         if current_leaves is head_leaves:
1549             if leaf.type in OPENING_BRACKETS:
1550                 matching_bracket = leaf
1551                 current_leaves = body_leaves
1552     # Since body is a new indent level, remove spurious leading whitespace.
1553     if body_leaves:
1554         normalize_prefix(body_leaves[0], inside_brackets=True)
1555     # Build the new lines.
1556     for result, leaves in (
1557         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1558     ):
1559         for leaf in leaves:
1560             result.append(leaf, preformatted=True)
1561             for comment_after in line.comments_after(leaf):
1562                 result.append(comment_after, preformatted=True)
1563     bracket_split_succeeded_or_raise(head, body, tail)
1564     for result in (head, body, tail):
1565         if result:
1566             yield result
1567
1568
1569 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1570     """Split line into many lines, starting with the last matching bracket pair."""
1571     head = Line(depth=line.depth)
1572     body = Line(depth=line.depth + 1, inside_brackets=True)
1573     tail = Line(depth=line.depth)
1574     tail_leaves: List[Leaf] = []
1575     body_leaves: List[Leaf] = []
1576     head_leaves: List[Leaf] = []
1577     current_leaves = tail_leaves
1578     opening_bracket = None
1579     for leaf in reversed(line.leaves):
1580         if current_leaves is body_leaves:
1581             if leaf is opening_bracket:
1582                 current_leaves = head_leaves if body_leaves else tail_leaves
1583         current_leaves.append(leaf)
1584         if current_leaves is tail_leaves:
1585             if leaf.type in CLOSING_BRACKETS:
1586                 opening_bracket = leaf.opening_bracket
1587                 current_leaves = body_leaves
1588     tail_leaves.reverse()
1589     body_leaves.reverse()
1590     head_leaves.reverse()
1591     # Since body is a new indent level, remove spurious leading whitespace.
1592     if body_leaves:
1593         normalize_prefix(body_leaves[0], inside_brackets=True)
1594     # Build the new lines.
1595     for result, leaves in (
1596         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1597     ):
1598         for leaf in leaves:
1599             result.append(leaf, preformatted=True)
1600             for comment_after in line.comments_after(leaf):
1601                 result.append(comment_after, preformatted=True)
1602     bracket_split_succeeded_or_raise(head, body, tail)
1603     for result in (head, body, tail):
1604         if result:
1605             yield result
1606
1607
1608 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1609     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1610
1611     Do nothing otherwise.
1612
1613     A left- or right-hand split is based on a pair of brackets. Content before
1614     (and including) the opening bracket is left on one line, content inside the
1615     brackets is put on a separate line, and finally content starting with and
1616     following the closing bracket is put on a separate line.
1617
1618     Those are called `head`, `body`, and `tail`, respectively. If the split
1619     produced the same line (all content in `head`) or ended up with an empty `body`
1620     and the `tail` is just the closing bracket, then it's considered failed.
1621     """
1622     tail_len = len(str(tail).strip())
1623     if not body:
1624         if tail_len == 0:
1625             raise CannotSplit("Splitting brackets produced the same line")
1626
1627         elif tail_len < 3:
1628             raise CannotSplit(
1629                 f"Splitting brackets on an empty body to save "
1630                 f"{tail_len} characters is not worth it"
1631             )
1632
1633
1634 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1635     """Normalize prefix of the first leaf in every line returned by `split_func`.
1636
1637     This is a decorator over relevant split functions.
1638     """
1639
1640     @wraps(split_func)
1641     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1642         for l in split_func(line, py36):
1643             normalize_prefix(l.leaves[0], inside_brackets=True)
1644             yield l
1645
1646     return split_wrapper
1647
1648
1649 @dont_increase_indentation
1650 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1651     """Split according to delimiters of the highest priority.
1652
1653     If `py36` is True, the split will add trailing commas also in function
1654     signatures that contain `*` and `**`.
1655     """
1656     try:
1657         last_leaf = line.leaves[-1]
1658     except IndexError:
1659         raise CannotSplit("Line empty")
1660
1661     delimiters = line.bracket_tracker.delimiters
1662     try:
1663         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1664             exclude={id(last_leaf)}
1665         )
1666     except ValueError:
1667         raise CannotSplit("No delimiters found")
1668
1669     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1670     lowest_depth = sys.maxsize
1671     trailing_comma_safe = True
1672
1673     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1674         """Append `leaf` to current line or to new line if appending impossible."""
1675         nonlocal current_line
1676         try:
1677             current_line.append_safe(leaf, preformatted=True)
1678         except ValueError as ve:
1679             yield current_line
1680
1681             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1682             current_line.append(leaf)
1683
1684     for leaf in line.leaves:
1685         yield from append_to_line(leaf)
1686
1687         for comment_after in line.comments_after(leaf):
1688             yield from append_to_line(comment_after)
1689
1690         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1691         if (
1692             leaf.bracket_depth == lowest_depth
1693             and leaf.type == token.STAR
1694             or leaf.type == token.DOUBLESTAR
1695         ):
1696             trailing_comma_safe = trailing_comma_safe and py36
1697         leaf_priority = delimiters.get(id(leaf))
1698         if leaf_priority == delimiter_priority:
1699             yield current_line
1700
1701             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1702     if current_line:
1703         if (
1704             delimiter_priority == COMMA_PRIORITY
1705             and current_line.leaves[-1].type != token.COMMA
1706             and trailing_comma_safe
1707         ):
1708             current_line.append(Leaf(token.COMMA, ","))
1709         yield current_line
1710
1711
1712 @dont_increase_indentation
1713 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1714     """Split standalone comments from the rest of the line."""
1715     for leaf in line.leaves:
1716         if leaf.type == STANDALONE_COMMENT:
1717             if leaf.bracket_depth == 0:
1718                 break
1719
1720     else:
1721         raise CannotSplit("Line does not have any standalone comments")
1722
1723     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1724
1725     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1726         """Append `leaf` to current line or to new line if appending impossible."""
1727         nonlocal current_line
1728         try:
1729             current_line.append_safe(leaf, preformatted=True)
1730         except ValueError as ve:
1731             yield current_line
1732
1733             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1734             current_line.append(leaf)
1735
1736     for leaf in line.leaves:
1737         yield from append_to_line(leaf)
1738
1739         for comment_after in line.comments_after(leaf):
1740             yield from append_to_line(comment_after)
1741
1742     if current_line:
1743         yield current_line
1744
1745
1746 def is_import(leaf: Leaf) -> bool:
1747     """Return True if the given leaf starts an import statement."""
1748     p = leaf.parent
1749     t = leaf.type
1750     v = leaf.value
1751     return bool(
1752         t == token.NAME
1753         and (
1754             (v == "import" and p and p.type == syms.import_name)
1755             or (v == "from" and p and p.type == syms.import_from)
1756         )
1757     )
1758
1759
1760 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1761     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1762     else.
1763
1764     Note: don't use backslashes for formatting or you'll lose your voting rights.
1765     """
1766     if not inside_brackets:
1767         spl = leaf.prefix.split("#")
1768         if "\\" not in spl[0]:
1769             nl_count = spl[-1].count("\n")
1770             if len(spl) > 1:
1771                 nl_count -= 1
1772             leaf.prefix = "\n" * nl_count
1773             return
1774
1775     leaf.prefix = ""
1776
1777
1778 def normalize_string_quotes(leaf: Leaf) -> None:
1779     value = leaf.value.lstrip("furbFURB")
1780     if value[:3] == '"""':
1781         return
1782
1783     elif value[:3] == "'''":
1784         orig_quote = "'''"
1785         new_quote = '"""'
1786     elif value[0] == '"':
1787         orig_quote = '"'
1788         new_quote = "'"
1789     else:
1790         orig_quote = "'"
1791         new_quote = '"'
1792     first_quote_pos = leaf.value.find(orig_quote)
1793     if first_quote_pos == -1:
1794         return  # There's an internal error
1795
1796     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1797     new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
1798         new_quote, f"\\{new_quote}"
1799     )
1800     if new_quote == '"""' and new_body[-1] == '"':
1801         # edge case:
1802         new_body = new_body[:-1] + '\\"'
1803     orig_escape_count = body.count("\\")
1804     new_escape_count = new_body.count("\\")
1805     if new_escape_count > orig_escape_count:
1806         return  # Do not introduce more escaping
1807
1808     if new_escape_count == orig_escape_count and orig_quote == '"':
1809         return  # Prefer double quotes
1810
1811     prefix = leaf.value[:first_quote_pos]
1812     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
1813
1814
1815 def is_python36(node: Node) -> bool:
1816     """Return True if the current file is using Python 3.6+ features.
1817
1818     Currently looking for:
1819     - f-strings; and
1820     - trailing commas after * or ** in function signatures.
1821     """
1822     for n in node.pre_order():
1823         if n.type == token.STRING:
1824             value_head = n.value[:2]  # type: ignore
1825             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1826                 return True
1827
1828         elif (
1829             n.type == syms.typedargslist
1830             and n.children
1831             and n.children[-1].type == token.COMMA
1832         ):
1833             for ch in n.children:
1834                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1835                     return True
1836
1837     return False
1838
1839
1840 PYTHON_EXTENSIONS = {".py"}
1841 BLACKLISTED_DIRECTORIES = {
1842     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
1843 }
1844
1845
1846 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1847     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1848     and have one of the PYTHON_EXTENSIONS.
1849     """
1850     for child in path.iterdir():
1851         if child.is_dir():
1852             if child.name in BLACKLISTED_DIRECTORIES:
1853                 continue
1854
1855             yield from gen_python_files_in_dir(child)
1856
1857         elif child.suffix in PYTHON_EXTENSIONS:
1858             yield child
1859
1860
1861 @dataclass
1862 class Report:
1863     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1864     check: bool = False
1865     change_count: int = 0
1866     same_count: int = 0
1867     failure_count: int = 0
1868
1869     def done(self, src: Path, changed: bool) -> None:
1870         """Increment the counter for successful reformatting. Write out a message."""
1871         if changed:
1872             reformatted = "would reformat" if self.check else "reformatted"
1873             out(f"{reformatted} {src}")
1874             self.change_count += 1
1875         else:
1876             out(f"{src} already well formatted, good job.", bold=False)
1877             self.same_count += 1
1878
1879     def failed(self, src: Path, message: str) -> None:
1880         """Increment the counter for failed reformatting. Write out a message."""
1881         err(f"error: cannot format {src}: {message}")
1882         self.failure_count += 1
1883
1884     @property
1885     def return_code(self) -> int:
1886         """Return the exit code that the app should use.
1887
1888         This considers the current state of changed files and failures:
1889         - if there were any failures, return 123;
1890         - if any files were changed and --check is being used, return 1;
1891         - otherwise return 0.
1892         """
1893         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1894         # 126 we have special returncodes reserved by the shell.
1895         if self.failure_count:
1896             return 123
1897
1898         elif self.change_count and self.check:
1899             return 1
1900
1901         return 0
1902
1903     def __str__(self) -> str:
1904         """Render a color report of the current state.
1905
1906         Use `click.unstyle` to remove colors.
1907         """
1908         if self.check:
1909             reformatted = "would be reformatted"
1910             unchanged = "would be left unchanged"
1911             failed = "would fail to reformat"
1912         else:
1913             reformatted = "reformatted"
1914             unchanged = "left unchanged"
1915             failed = "failed to reformat"
1916         report = []
1917         if self.change_count:
1918             s = "s" if self.change_count > 1 else ""
1919             report.append(
1920                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
1921             )
1922         if self.same_count:
1923             s = "s" if self.same_count > 1 else ""
1924             report.append(f"{self.same_count} file{s} {unchanged}")
1925         if self.failure_count:
1926             s = "s" if self.failure_count > 1 else ""
1927             report.append(
1928                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
1929             )
1930         return ", ".join(report) + "."
1931
1932
1933 def assert_equivalent(src: str, dst: str) -> None:
1934     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1935
1936     import ast
1937     import traceback
1938
1939     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1940         """Simple visitor generating strings to compare ASTs by content."""
1941         yield f"{'  ' * depth}{node.__class__.__name__}("
1942
1943         for field in sorted(node._fields):
1944             try:
1945                 value = getattr(node, field)
1946             except AttributeError:
1947                 continue
1948
1949             yield f"{'  ' * (depth+1)}{field}="
1950
1951             if isinstance(value, list):
1952                 for item in value:
1953                     if isinstance(item, ast.AST):
1954                         yield from _v(item, depth + 2)
1955
1956             elif isinstance(value, ast.AST):
1957                 yield from _v(value, depth + 2)
1958
1959             else:
1960                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
1961
1962         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
1963
1964     try:
1965         src_ast = ast.parse(src)
1966     except Exception as exc:
1967         major, minor = sys.version_info[:2]
1968         raise AssertionError(
1969             f"cannot use --safe with this file; failed to parse source file "
1970             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
1971             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
1972         )
1973
1974     try:
1975         dst_ast = ast.parse(dst)
1976     except Exception as exc:
1977         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
1978         raise AssertionError(
1979             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
1980             f"Please report a bug on https://github.com/ambv/black/issues.  "
1981             f"This invalid output might be helpful: {log}"
1982         ) from None
1983
1984     src_ast_str = "\n".join(_v(src_ast))
1985     dst_ast_str = "\n".join(_v(dst_ast))
1986     if src_ast_str != dst_ast_str:
1987         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
1988         raise AssertionError(
1989             f"INTERNAL ERROR: Black produced code that is not equivalent to "
1990             f"the source.  "
1991             f"Please report a bug on https://github.com/ambv/black/issues.  "
1992             f"This diff might be helpful: {log}"
1993         ) from None
1994
1995
1996 def assert_stable(src: str, dst: str, line_length: int) -> None:
1997     """Raise AssertionError if `dst` reformats differently the second time."""
1998     newdst = format_str(dst, line_length=line_length)
1999     if dst != newdst:
2000         log = dump_to_file(
2001             diff(src, dst, "source", "first pass"),
2002             diff(dst, newdst, "first pass", "second pass"),
2003         )
2004         raise AssertionError(
2005             f"INTERNAL ERROR: Black produced different code on the second pass "
2006             f"of the formatter.  "
2007             f"Please report a bug on https://github.com/ambv/black/issues.  "
2008             f"This diff might be helpful: {log}"
2009         ) from None
2010
2011
2012 def dump_to_file(*output: str) -> str:
2013     """Dump `output` to a temporary file. Return path to the file."""
2014     import tempfile
2015
2016     with tempfile.NamedTemporaryFile(
2017         mode="w", prefix="blk_", suffix=".log", delete=False
2018     ) as f:
2019         for lines in output:
2020             f.write(lines)
2021             f.write("\n")
2022     return f.name
2023
2024
2025 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2026     """Return a unified diff string between strings `a` and `b`."""
2027     import difflib
2028
2029     a_lines = [line + "\n" for line in a.split("\n")]
2030     b_lines = [line + "\n" for line in b.split("\n")]
2031     return "".join(
2032         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2033     )
2034
2035
2036 def cancel(tasks: List[asyncio.Task]) -> None:
2037     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2038     err("Aborted!")
2039     for task in tasks:
2040         task.cancel()
2041
2042
2043 def shutdown(loop: BaseEventLoop) -> None:
2044     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2045     try:
2046         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2047         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2048         if not to_cancel:
2049             return
2050
2051         for task in to_cancel:
2052             task.cancel()
2053         loop.run_until_complete(
2054             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2055         )
2056     finally:
2057         # `concurrent.futures.Future` objects cannot be cancelled once they
2058         # are already running. There might be some when the `shutdown()` happened.
2059         # Silence their logger's spew about the event loop being closed.
2060         cf_logger = logging.getLogger("concurrent.futures")
2061         cf_logger.setLevel(logging.CRITICAL)
2062         loop.close()
2063
2064
2065 if __name__ == "__main__":
2066     main()