]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

fdefb746de14fc28002c820e2a19454db5873298
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a4"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         ctx.exit(0)
196         return
197
198     elif len(sources) == 1:
199         reformat_one(sources[0], line_length, fast, write_back, report)
200     else:
201         loop = asyncio.get_event_loop()
202         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
203         try:
204             loop.run_until_complete(
205                 schedule_formatting(
206                     sources, line_length, fast, write_back, report, loop, executor
207                 )
208             )
209         finally:
210             shutdown(loop)
211         if not quiet:
212             out("All done! ✨ 🍰 ✨")
213             click.echo(str(report))
214     ctx.exit(report.return_code)
215
216
217 def reformat_one(
218     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
219 ) -> None:
220     """Reformat a single file under `src` without spawning child processes.
221
222     If `quiet` is True, non-error messages are not output. `line_length`,
223     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
224     """
225     try:
226         changed = Changed.NO
227         if not src.is_file() and str(src) == "-":
228             if format_stdin_to_stdout(
229                 line_length=line_length, fast=fast, write_back=write_back
230             ):
231                 changed = Changed.YES
232         else:
233             cache: Cache = {}
234             if write_back != WriteBack.DIFF:
235                 cache = read_cache(line_length)
236                 src = src.resolve()
237                 if src in cache and cache[src] == get_cache_info(src):
238                     changed = Changed.CACHED
239             if (
240                 changed is not Changed.CACHED
241                 and format_file_in_place(
242                     src, line_length=line_length, fast=fast, write_back=write_back
243                 )
244             ):
245                 changed = Changed.YES
246             if write_back == WriteBack.YES and changed is not Changed.NO:
247                 write_cache(cache, [src], line_length)
248         report.done(src, changed)
249     except Exception as exc:
250         report.failed(src, str(exc))
251
252
253 async def schedule_formatting(
254     sources: List[Path],
255     line_length: int,
256     fast: bool,
257     write_back: WriteBack,
258     report: "Report",
259     loop: BaseEventLoop,
260     executor: Executor,
261 ) -> None:
262     """Run formatting of `sources` in parallel using the provided `executor`.
263
264     (Use ProcessPoolExecutors for actual parallelism.)
265
266     `line_length`, `write_back`, and `fast` options are passed to
267     :func:`format_file_in_place`.
268     """
269     cache: Cache = {}
270     if write_back != WriteBack.DIFF:
271         cache = read_cache(line_length)
272         sources, cached = filter_cached(cache, sources)
273         for src in cached:
274             report.done(src, Changed.CACHED)
275     cancelled = []
276     formatted = []
277     if sources:
278         lock = None
279         if write_back == WriteBack.DIFF:
280             # For diff output, we need locks to ensure we don't interleave output
281             # from different processes.
282             manager = Manager()
283             lock = manager.Lock()
284         tasks = {
285             src: loop.run_in_executor(
286                 executor, format_file_in_place, src, line_length, fast, write_back, lock
287             )
288             for src in sources
289         }
290         _task_values = list(tasks.values())
291         try:
292             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
293             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
294         except NotImplementedError:
295             # There are no good alternatives for these on Windows
296             pass
297         await asyncio.wait(_task_values)
298         for src, task in tasks.items():
299             if not task.done():
300                 report.failed(src, "timed out, cancelling")
301                 task.cancel()
302                 cancelled.append(task)
303             elif task.cancelled():
304                 cancelled.append(task)
305             elif task.exception():
306                 report.failed(src, str(task.exception()))
307             else:
308                 formatted.append(src)
309                 report.done(src, Changed.YES if task.result() else Changed.NO)
310
311     if cancelled:
312         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
313     if write_back == WriteBack.YES and formatted:
314         write_cache(cache, formatted, line_length)
315
316
317 def format_file_in_place(
318     src: Path,
319     line_length: int,
320     fast: bool,
321     write_back: WriteBack = WriteBack.NO,
322     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
323 ) -> bool:
324     """Format file under `src` path. Return True if changed.
325
326     If `write_back` is True, write reformatted code back to stdout.
327     `line_length` and `fast` options are passed to :func:`format_file_contents`.
328     """
329
330     with tokenize.open(src) as src_buffer:
331         src_contents = src_buffer.read()
332     try:
333         dst_contents = format_file_contents(
334             src_contents, line_length=line_length, fast=fast
335         )
336     except NothingChanged:
337         return False
338
339     if write_back == write_back.YES:
340         with open(src, "w", encoding=src_buffer.encoding) as f:
341             f.write(dst_contents)
342     elif write_back == write_back.DIFF:
343         src_name = f"{src}  (original)"
344         dst_name = f"{src}  (formatted)"
345         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
346         if lock:
347             lock.acquire()
348         try:
349             sys.stdout.write(diff_contents)
350         finally:
351             if lock:
352                 lock.release()
353     return True
354
355
356 def format_stdin_to_stdout(
357     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
358 ) -> bool:
359     """Format file on stdin. Return True if changed.
360
361     If `write_back` is True, write reformatted code back to stdout.
362     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
363     """
364     src = sys.stdin.read()
365     dst = src
366     try:
367         dst = format_file_contents(src, line_length=line_length, fast=fast)
368         return True
369
370     except NothingChanged:
371         return False
372
373     finally:
374         if write_back == WriteBack.YES:
375             sys.stdout.write(dst)
376         elif write_back == WriteBack.DIFF:
377             src_name = "<stdin>  (original)"
378             dst_name = "<stdin>  (formatted)"
379             sys.stdout.write(diff(src, dst, src_name, dst_name))
380
381
382 def format_file_contents(
383     src_contents: str, line_length: int, fast: bool
384 ) -> FileContent:
385     """Reformat contents a file and return new contents.
386
387     If `fast` is False, additionally confirm that the reformatted code is
388     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
389     `line_length` is passed to :func:`format_str`.
390     """
391     if src_contents.strip() == "":
392         raise NothingChanged
393
394     dst_contents = format_str(src_contents, line_length=line_length)
395     if src_contents == dst_contents:
396         raise NothingChanged
397
398     if not fast:
399         assert_equivalent(src_contents, dst_contents)
400         assert_stable(src_contents, dst_contents, line_length=line_length)
401     return dst_contents
402
403
404 def format_str(src_contents: str, line_length: int) -> FileContent:
405     """Reformat a string and return new contents.
406
407     `line_length` determines how many characters per line are allowed.
408     """
409     src_node = lib2to3_parse(src_contents)
410     dst_contents = ""
411     lines = LineGenerator()
412     elt = EmptyLineTracker()
413     py36 = is_python36(src_node)
414     empty_line = Line()
415     after = 0
416     for current_line in lines.visit(src_node):
417         for _ in range(after):
418             dst_contents += str(empty_line)
419         before, after = elt.maybe_empty_lines(current_line)
420         for _ in range(before):
421             dst_contents += str(empty_line)
422         for line in split_line(current_line, line_length=line_length, py36=py36):
423             dst_contents += str(line)
424     return dst_contents
425
426
427 GRAMMARS = [
428     pygram.python_grammar_no_print_statement_no_exec_statement,
429     pygram.python_grammar_no_print_statement,
430     pygram.python_grammar,
431 ]
432
433
434 def lib2to3_parse(src_txt: str) -> Node:
435     """Given a string with source, return the lib2to3 Node."""
436     grammar = pygram.python_grammar_no_print_statement
437     if src_txt[-1] != "\n":
438         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
439         src_txt += nl
440     for grammar in GRAMMARS:
441         drv = driver.Driver(grammar, pytree.convert)
442         try:
443             result = drv.parse_string(src_txt, True)
444             break
445
446         except ParseError as pe:
447             lineno, column = pe.context[1]
448             lines = src_txt.splitlines()
449             try:
450                 faulty_line = lines[lineno - 1]
451             except IndexError:
452                 faulty_line = "<line number missing in source>"
453             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
454     else:
455         raise exc from None
456
457     if isinstance(result, Leaf):
458         result = Node(syms.file_input, [result])
459     return result
460
461
462 def lib2to3_unparse(node: Node) -> str:
463     """Given a lib2to3 node, return its string representation."""
464     code = str(node)
465     return code
466
467
468 T = TypeVar("T")
469
470
471 class Visitor(Generic[T]):
472     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
473
474     def visit(self, node: LN) -> Iterator[T]:
475         """Main method to visit `node` and its children.
476
477         It tries to find a `visit_*()` method for the given `node.type`, like
478         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
479         If no dedicated `visit_*()` method is found, chooses `visit_default()`
480         instead.
481
482         Then yields objects of type `T` from the selected visitor.
483         """
484         if node.type < 256:
485             name = token.tok_name[node.type]
486         else:
487             name = type_repr(node.type)
488         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
489
490     def visit_default(self, node: LN) -> Iterator[T]:
491         """Default `visit_*()` implementation. Recurses to children of `node`."""
492         if isinstance(node, Node):
493             for child in node.children:
494                 yield from self.visit(child)
495
496
497 @dataclass
498 class DebugVisitor(Visitor[T]):
499     tree_depth: int = 0
500
501     def visit_default(self, node: LN) -> Iterator[T]:
502         indent = " " * (2 * self.tree_depth)
503         if isinstance(node, Node):
504             _type = type_repr(node.type)
505             out(f"{indent}{_type}", fg="yellow")
506             self.tree_depth += 1
507             for child in node.children:
508                 yield from self.visit(child)
509
510             self.tree_depth -= 1
511             out(f"{indent}/{_type}", fg="yellow", bold=False)
512         else:
513             _type = token.tok_name.get(node.type, str(node.type))
514             out(f"{indent}{_type}", fg="blue", nl=False)
515             if node.prefix:
516                 # We don't have to handle prefixes for `Node` objects since
517                 # that delegates to the first child anyway.
518                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
519             out(f" {node.value!r}", fg="blue", bold=False)
520
521     @classmethod
522     def show(cls, code: str) -> None:
523         """Pretty-print the lib2to3 AST of a given string of `code`.
524
525         Convenience method for debugging.
526         """
527         v: DebugVisitor[None] = DebugVisitor()
528         list(v.visit(lib2to3_parse(code)))
529
530
531 KEYWORDS = set(keyword.kwlist)
532 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
533 FLOW_CONTROL = {"return", "raise", "break", "continue"}
534 STATEMENT = {
535     syms.if_stmt,
536     syms.while_stmt,
537     syms.for_stmt,
538     syms.try_stmt,
539     syms.except_clause,
540     syms.with_stmt,
541     syms.funcdef,
542     syms.classdef,
543 }
544 STANDALONE_COMMENT = 153
545 LOGIC_OPERATORS = {"and", "or"}
546 COMPARATORS = {
547     token.LESS,
548     token.GREATER,
549     token.EQEQUAL,
550     token.NOTEQUAL,
551     token.LESSEQUAL,
552     token.GREATEREQUAL,
553 }
554 MATH_OPERATORS = {
555     token.PLUS,
556     token.MINUS,
557     token.STAR,
558     token.SLASH,
559     token.VBAR,
560     token.AMPER,
561     token.PERCENT,
562     token.CIRCUMFLEX,
563     token.TILDE,
564     token.LEFTSHIFT,
565     token.RIGHTSHIFT,
566     token.DOUBLESTAR,
567     token.DOUBLESLASH,
568 }
569 STARS = {token.STAR, token.DOUBLESTAR}
570 VARARGS_PARENTS = {
571     syms.arglist,
572     syms.argument,  # double star in arglist
573     syms.trailer,  # single argument to call
574     syms.typedargslist,
575     syms.varargslist,  # lambdas
576 }
577 UNPACKING_PARENTS = {
578     syms.atom,  # single element of a list or set literal
579     syms.dictsetmaker,
580     syms.listmaker,
581     syms.testlist_gexp,
582 }
583 TEST_DESCENDANTS = {
584     syms.test,
585     syms.lambdef,
586     syms.or_test,
587     syms.and_test,
588     syms.not_test,
589     syms.comparison,
590     syms.star_expr,
591     syms.expr,
592     syms.xor_expr,
593     syms.and_expr,
594     syms.shift_expr,
595     syms.arith_expr,
596     syms.trailer,
597     syms.term,
598     syms.power,
599 }
600 COMPREHENSION_PRIORITY = 20
601 COMMA_PRIORITY = 10
602 TERNARY_PRIORITY = 7
603 LOGIC_PRIORITY = 5
604 STRING_PRIORITY = 4
605 COMPARATOR_PRIORITY = 3
606 MATH_PRIORITY = 1
607
608
609 @dataclass
610 class BracketTracker:
611     """Keeps track of brackets on a line."""
612
613     depth: int = 0
614     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
615     delimiters: Dict[LeafID, Priority] = Factory(dict)
616     previous: Optional[Leaf] = None
617     _for_loop_variable: bool = False
618     _lambda_arguments: bool = False
619
620     def mark(self, leaf: Leaf) -> None:
621         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
622
623         All leaves receive an int `bracket_depth` field that stores how deep
624         within brackets a given leaf is. 0 means there are no enclosing brackets
625         that started on this line.
626
627         If a leaf is itself a closing bracket, it receives an `opening_bracket`
628         field that it forms a pair with. This is a one-directional link to
629         avoid reference cycles.
630
631         If a leaf is a delimiter (a token on which Black can split the line if
632         needed) and it's on depth 0, its `id()` is stored in the tracker's
633         `delimiters` field.
634         """
635         if leaf.type == token.COMMENT:
636             return
637
638         self.maybe_decrement_after_for_loop_variable(leaf)
639         self.maybe_decrement_after_lambda_arguments(leaf)
640         if leaf.type in CLOSING_BRACKETS:
641             self.depth -= 1
642             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
643             leaf.opening_bracket = opening_bracket
644         leaf.bracket_depth = self.depth
645         if self.depth == 0:
646             delim = is_split_before_delimiter(leaf, self.previous)
647             if delim and self.previous is not None:
648                 self.delimiters[id(self.previous)] = delim
649             else:
650                 delim = is_split_after_delimiter(leaf, self.previous)
651                 if delim:
652                     self.delimiters[id(leaf)] = delim
653         if leaf.type in OPENING_BRACKETS:
654             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
655             self.depth += 1
656         self.previous = leaf
657         self.maybe_increment_lambda_arguments(leaf)
658         self.maybe_increment_for_loop_variable(leaf)
659
660     def any_open_brackets(self) -> bool:
661         """Return True if there is an yet unmatched open bracket on the line."""
662         return bool(self.bracket_match)
663
664     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
665         """Return the highest priority of a delimiter found on the line.
666
667         Values are consistent with what `is_split_*_delimiter()` return.
668         Raises ValueError on no delimiters.
669         """
670         return max(v for k, v in self.delimiters.items() if k not in exclude)
671
672     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
673         """In a for loop, or comprehension, the variables are often unpacks.
674
675         To avoid splitting on the comma in this situation, increase the depth of
676         tokens between `for` and `in`.
677         """
678         if leaf.type == token.NAME and leaf.value == "for":
679             self.depth += 1
680             self._for_loop_variable = True
681             return True
682
683         return False
684
685     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
686         """See `maybe_increment_for_loop_variable` above for explanation."""
687         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
688             self.depth -= 1
689             self._for_loop_variable = False
690             return True
691
692         return False
693
694     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
695         """In a lambda expression, there might be more than one argument.
696
697         To avoid splitting on the comma in this situation, increase the depth of
698         tokens between `lambda` and `:`.
699         """
700         if leaf.type == token.NAME and leaf.value == "lambda":
701             self.depth += 1
702             self._lambda_arguments = True
703             return True
704
705         return False
706
707     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
708         """See `maybe_increment_lambda_arguments` above for explanation."""
709         if self._lambda_arguments and leaf.type == token.COLON:
710             self.depth -= 1
711             self._lambda_arguments = False
712             return True
713
714         return False
715
716     def get_open_lsqb(self) -> Optional[Leaf]:
717         """Return the most recent opening square bracket (if any)."""
718         return self.bracket_match.get((self.depth - 1, token.RSQB))
719
720
721 @dataclass
722 class Line:
723     """Holds leaves and comments. Can be printed with `str(line)`."""
724
725     depth: int = 0
726     leaves: List[Leaf] = Factory(list)
727     comments: List[Tuple[Index, Leaf]] = Factory(list)
728     bracket_tracker: BracketTracker = Factory(BracketTracker)
729     inside_brackets: bool = False
730
731     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
732         """Add a new `leaf` to the end of the line.
733
734         Unless `preformatted` is True, the `leaf` will receive a new consistent
735         whitespace prefix and metadata applied by :class:`BracketTracker`.
736         Trailing commas are maybe removed, unpacked for loop variables are
737         demoted from being delimiters.
738
739         Inline comments are put aside.
740         """
741         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
742         if not has_value:
743             return
744
745         if self.leaves and not preformatted:
746             # Note: at this point leaf.prefix should be empty except for
747             # imports, for which we only preserve newlines.
748             leaf.prefix += whitespace(
749                 leaf, complex_subscript=self.is_complex_subscript(leaf)
750             )
751         if self.inside_brackets or not preformatted:
752             self.bracket_tracker.mark(leaf)
753             self.maybe_remove_trailing_comma(leaf)
754
755         if not self.append_comment(leaf):
756             self.leaves.append(leaf)
757
758     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
759         """Like :func:`append()` but disallow invalid standalone comment structure.
760
761         Raises ValueError when any `leaf` is appended after a standalone comment
762         or when a standalone comment is not the first leaf on the line.
763         """
764         if self.bracket_tracker.depth == 0:
765             if self.is_comment:
766                 raise ValueError("cannot append to standalone comments")
767
768             if self.leaves and leaf.type == STANDALONE_COMMENT:
769                 raise ValueError(
770                     "cannot append standalone comments to a populated line"
771                 )
772
773         self.append(leaf, preformatted=preformatted)
774
775     @property
776     def is_comment(self) -> bool:
777         """Is this line a standalone comment?"""
778         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
779
780     @property
781     def is_decorator(self) -> bool:
782         """Is this line a decorator?"""
783         return bool(self) and self.leaves[0].type == token.AT
784
785     @property
786     def is_import(self) -> bool:
787         """Is this an import line?"""
788         return bool(self) and is_import(self.leaves[0])
789
790     @property
791     def is_class(self) -> bool:
792         """Is this line a class definition?"""
793         return (
794             bool(self)
795             and self.leaves[0].type == token.NAME
796             and self.leaves[0].value == "class"
797         )
798
799     @property
800     def is_def(self) -> bool:
801         """Is this a function definition? (Also returns True for async defs.)"""
802         try:
803             first_leaf = self.leaves[0]
804         except IndexError:
805             return False
806
807         try:
808             second_leaf: Optional[Leaf] = self.leaves[1]
809         except IndexError:
810             second_leaf = None
811         return (
812             (first_leaf.type == token.NAME and first_leaf.value == "def")
813             or (
814                 first_leaf.type == token.ASYNC
815                 and second_leaf is not None
816                 and second_leaf.type == token.NAME
817                 and second_leaf.value == "def"
818             )
819         )
820
821     @property
822     def is_flow_control(self) -> bool:
823         """Is this line a flow control statement?
824
825         Those are `return`, `raise`, `break`, and `continue`.
826         """
827         return (
828             bool(self)
829             and self.leaves[0].type == token.NAME
830             and self.leaves[0].value in FLOW_CONTROL
831         )
832
833     @property
834     def is_yield(self) -> bool:
835         """Is this line a yield statement?"""
836         return (
837             bool(self)
838             and self.leaves[0].type == token.NAME
839             and self.leaves[0].value == "yield"
840         )
841
842     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
843         """If so, needs to be split before emitting."""
844         for leaf in self.leaves:
845             if leaf.type == STANDALONE_COMMENT:
846                 if leaf.bracket_depth <= depth_limit:
847                     return True
848
849         return False
850
851     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
852         """Remove trailing comma if there is one and it's safe."""
853         if not (
854             self.leaves
855             and self.leaves[-1].type == token.COMMA
856             and closing.type in CLOSING_BRACKETS
857         ):
858             return False
859
860         if closing.type == token.RBRACE:
861             self.remove_trailing_comma()
862             return True
863
864         if closing.type == token.RSQB:
865             comma = self.leaves[-1]
866             if comma.parent and comma.parent.type == syms.listmaker:
867                 self.remove_trailing_comma()
868                 return True
869
870         # For parens let's check if it's safe to remove the comma.  If the
871         # trailing one is the only one, we might mistakenly change a tuple
872         # into a different type by removing the comma.
873         depth = closing.bracket_depth + 1
874         commas = 0
875         opening = closing.opening_bracket
876         for _opening_index, leaf in enumerate(self.leaves):
877             if leaf is opening:
878                 break
879
880         else:
881             return False
882
883         for leaf in self.leaves[_opening_index + 1 :]:
884             if leaf is closing:
885                 break
886
887             bracket_depth = leaf.bracket_depth
888             if bracket_depth == depth and leaf.type == token.COMMA:
889                 commas += 1
890                 if leaf.parent and leaf.parent.type == syms.arglist:
891                     commas += 1
892                     break
893
894         if commas > 1:
895             self.remove_trailing_comma()
896             return True
897
898         return False
899
900     def append_comment(self, comment: Leaf) -> bool:
901         """Add an inline or standalone comment to the line."""
902         if (
903             comment.type == STANDALONE_COMMENT
904             and self.bracket_tracker.any_open_brackets()
905         ):
906             comment.prefix = ""
907             return False
908
909         if comment.type != token.COMMENT:
910             return False
911
912         after = len(self.leaves) - 1
913         if after == -1:
914             comment.type = STANDALONE_COMMENT
915             comment.prefix = ""
916             return False
917
918         else:
919             self.comments.append((after, comment))
920             return True
921
922     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
923         """Generate comments that should appear directly after `leaf`."""
924         for _leaf_index, _leaf in enumerate(self.leaves):
925             if leaf is _leaf:
926                 break
927
928         else:
929             return
930
931         for index, comment_after in self.comments:
932             if _leaf_index == index:
933                 yield comment_after
934
935     def remove_trailing_comma(self) -> None:
936         """Remove the trailing comma and moves the comments attached to it."""
937         comma_index = len(self.leaves) - 1
938         for i in range(len(self.comments)):
939             comment_index, comment = self.comments[i]
940             if comment_index == comma_index:
941                 self.comments[i] = (comma_index - 1, comment)
942         self.leaves.pop()
943
944     def is_complex_subscript(self, leaf: Leaf) -> bool:
945         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
946         open_lsqb = (
947             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
948         )
949         if open_lsqb is None:
950             return False
951
952         subscript_start = open_lsqb.next_sibling
953         if (
954             isinstance(subscript_start, Node)
955             and subscript_start.type == syms.subscriptlist
956         ):
957             subscript_start = child_towards(subscript_start, leaf)
958         return subscript_start is not None and any(
959             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
960         )
961
962     def __str__(self) -> str:
963         """Render the line."""
964         if not self:
965             return "\n"
966
967         indent = "    " * self.depth
968         leaves = iter(self.leaves)
969         first = next(leaves)
970         res = f"{first.prefix}{indent}{first.value}"
971         for leaf in leaves:
972             res += str(leaf)
973         for _, comment in self.comments:
974             res += str(comment)
975         return res + "\n"
976
977     def __bool__(self) -> bool:
978         """Return True if the line has leaves or comments."""
979         return bool(self.leaves or self.comments)
980
981
982 class UnformattedLines(Line):
983     """Just like :class:`Line` but stores lines which aren't reformatted."""
984
985     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
986         """Just add a new `leaf` to the end of the lines.
987
988         The `preformatted` argument is ignored.
989
990         Keeps track of indentation `depth`, which is useful when the user
991         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
992         """
993         try:
994             list(generate_comments(leaf))
995         except FormatOn as f_on:
996             self.leaves.append(f_on.leaf_from_consumed(leaf))
997             raise
998
999         self.leaves.append(leaf)
1000         if leaf.type == token.INDENT:
1001             self.depth += 1
1002         elif leaf.type == token.DEDENT:
1003             self.depth -= 1
1004
1005     def __str__(self) -> str:
1006         """Render unformatted lines from leaves which were added with `append()`.
1007
1008         `depth` is not used for indentation in this case.
1009         """
1010         if not self:
1011             return "\n"
1012
1013         res = ""
1014         for leaf in self.leaves:
1015             res += str(leaf)
1016         return res
1017
1018     def append_comment(self, comment: Leaf) -> bool:
1019         """Not implemented in this class. Raises `NotImplementedError`."""
1020         raise NotImplementedError("Unformatted lines don't store comments separately.")
1021
1022     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1023         """Does nothing and returns False."""
1024         return False
1025
1026     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1027         """Does nothing and returns False."""
1028         return False
1029
1030
1031 @dataclass
1032 class EmptyLineTracker:
1033     """Provides a stateful method that returns the number of potential extra
1034     empty lines needed before and after the currently processed line.
1035
1036     Note: this tracker works on lines that haven't been split yet.  It assumes
1037     the prefix of the first leaf consists of optional newlines.  Those newlines
1038     are consumed by `maybe_empty_lines()` and included in the computation.
1039     """
1040     previous_line: Optional[Line] = None
1041     previous_after: int = 0
1042     previous_defs: List[int] = Factory(list)
1043
1044     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1045         """Return the number of extra empty lines before and after the `current_line`.
1046
1047         This is for separating `def`, `async def` and `class` with extra empty
1048         lines (two on module-level), as well as providing an extra empty line
1049         after flow control keywords to make them more prominent.
1050         """
1051         if isinstance(current_line, UnformattedLines):
1052             return 0, 0
1053
1054         before, after = self._maybe_empty_lines(current_line)
1055         before -= self.previous_after
1056         self.previous_after = after
1057         self.previous_line = current_line
1058         return before, after
1059
1060     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1061         max_allowed = 1
1062         if current_line.depth == 0:
1063             max_allowed = 2
1064         if current_line.leaves:
1065             # Consume the first leaf's extra newlines.
1066             first_leaf = current_line.leaves[0]
1067             before = first_leaf.prefix.count("\n")
1068             before = min(before, max_allowed)
1069             first_leaf.prefix = ""
1070         else:
1071             before = 0
1072         depth = current_line.depth
1073         while self.previous_defs and self.previous_defs[-1] >= depth:
1074             self.previous_defs.pop()
1075             before = 1 if depth else 2
1076         is_decorator = current_line.is_decorator
1077         if is_decorator or current_line.is_def or current_line.is_class:
1078             if not is_decorator:
1079                 self.previous_defs.append(depth)
1080             if self.previous_line is None:
1081                 # Don't insert empty lines before the first line in the file.
1082                 return 0, 0
1083
1084             if self.previous_line.is_decorator:
1085                 return 0, 0
1086
1087             if (
1088                 self.previous_line.is_comment
1089                 and self.previous_line.depth == current_line.depth
1090                 and before == 0
1091             ):
1092                 return 0, 0
1093
1094             newlines = 2
1095             if current_line.depth:
1096                 newlines -= 1
1097             return newlines, 0
1098
1099         if (
1100             self.previous_line
1101             and self.previous_line.is_import
1102             and not current_line.is_import
1103             and depth == self.previous_line.depth
1104         ):
1105             return (before or 1), 0
1106
1107         return before, 0
1108
1109
1110 @dataclass
1111 class LineGenerator(Visitor[Line]):
1112     """Generates reformatted Line objects.  Empty lines are not emitted.
1113
1114     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1115     in ways that will no longer stringify to valid Python code on the tree.
1116     """
1117     current_line: Line = Factory(Line)
1118
1119     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1120         """Generate a line.
1121
1122         If the line is empty, only emit if it makes sense.
1123         If the line is too long, split it first and then generate.
1124
1125         If any lines were generated, set up a new current_line.
1126         """
1127         if not self.current_line:
1128             if self.current_line.__class__ == type:
1129                 self.current_line.depth += indent
1130             else:
1131                 self.current_line = type(depth=self.current_line.depth + indent)
1132             return  # Line is empty, don't emit. Creating a new one unnecessary.
1133
1134         complete_line = self.current_line
1135         self.current_line = type(depth=complete_line.depth + indent)
1136         yield complete_line
1137
1138     def visit(self, node: LN) -> Iterator[Line]:
1139         """Main method to visit `node` and its children.
1140
1141         Yields :class:`Line` objects.
1142         """
1143         if isinstance(self.current_line, UnformattedLines):
1144             # File contained `# fmt: off`
1145             yield from self.visit_unformatted(node)
1146
1147         else:
1148             yield from super().visit(node)
1149
1150     def visit_default(self, node: LN) -> Iterator[Line]:
1151         """Default `visit_*()` implementation. Recurses to children of `node`."""
1152         if isinstance(node, Leaf):
1153             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1154             try:
1155                 for comment in generate_comments(node):
1156                     if any_open_brackets:
1157                         # any comment within brackets is subject to splitting
1158                         self.current_line.append(comment)
1159                     elif comment.type == token.COMMENT:
1160                         # regular trailing comment
1161                         self.current_line.append(comment)
1162                         yield from self.line()
1163
1164                     else:
1165                         # regular standalone comment
1166                         yield from self.line()
1167
1168                         self.current_line.append(comment)
1169                         yield from self.line()
1170
1171             except FormatOff as f_off:
1172                 f_off.trim_prefix(node)
1173                 yield from self.line(type=UnformattedLines)
1174                 yield from self.visit(node)
1175
1176             except FormatOn as f_on:
1177                 # This only happens here if somebody says "fmt: on" multiple
1178                 # times in a row.
1179                 f_on.trim_prefix(node)
1180                 yield from self.visit_default(node)
1181
1182             else:
1183                 normalize_prefix(node, inside_brackets=any_open_brackets)
1184                 if node.type == token.STRING:
1185                     normalize_string_quotes(node)
1186                 if node.type not in WHITESPACE:
1187                     self.current_line.append(node)
1188         yield from super().visit_default(node)
1189
1190     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1191         """Increase indentation level, maybe yield a line."""
1192         # In blib2to3 INDENT never holds comments.
1193         yield from self.line(+1)
1194         yield from self.visit_default(node)
1195
1196     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1197         """Decrease indentation level, maybe yield a line."""
1198         # The current line might still wait for trailing comments.  At DEDENT time
1199         # there won't be any (they would be prefixes on the preceding NEWLINE).
1200         # Emit the line then.
1201         yield from self.line()
1202
1203         # While DEDENT has no value, its prefix may contain standalone comments
1204         # that belong to the current indentation level.  Get 'em.
1205         yield from self.visit_default(node)
1206
1207         # Finally, emit the dedent.
1208         yield from self.line(-1)
1209
1210     def visit_stmt(
1211         self, node: Node, keywords: Set[str], parens: Set[str]
1212     ) -> Iterator[Line]:
1213         """Visit a statement.
1214
1215         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1216         `def`, `with`, `class`, and `assert`.
1217
1218         The relevant Python language `keywords` for a given statement will be
1219         NAME leaves within it. This methods puts those on a separate line.
1220
1221         `parens` holds pairs of nodes where invisible parentheses should be put.
1222         Keys hold nodes after which opening parentheses should be put, values
1223         hold nodes before which closing parentheses should be put.
1224         """
1225         normalize_invisible_parens(node, parens_after=parens)
1226         for child in node.children:
1227             if child.type == token.NAME and child.value in keywords:  # type: ignore
1228                 yield from self.line()
1229
1230             yield from self.visit(child)
1231
1232     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1233         """Visit a statement without nested statements."""
1234         is_suite_like = node.parent and node.parent.type in STATEMENT
1235         if is_suite_like:
1236             yield from self.line(+1)
1237             yield from self.visit_default(node)
1238             yield from self.line(-1)
1239
1240         else:
1241             yield from self.line()
1242             yield from self.visit_default(node)
1243
1244     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1245         """Visit `async def`, `async for`, `async with`."""
1246         yield from self.line()
1247
1248         children = iter(node.children)
1249         for child in children:
1250             yield from self.visit(child)
1251
1252             if child.type == token.ASYNC:
1253                 break
1254
1255         internal_stmt = next(children)
1256         for child in internal_stmt.children:
1257             yield from self.visit(child)
1258
1259     def visit_decorators(self, node: Node) -> Iterator[Line]:
1260         """Visit decorators."""
1261         for child in node.children:
1262             yield from self.line()
1263             yield from self.visit(child)
1264
1265     def visit_import_from(self, node: Node) -> Iterator[Line]:
1266         """Visit import_from and maybe put invisible parentheses.
1267
1268         This is separate from `visit_stmt` because import statements don't
1269         support arbitrary atoms and thus handling of parentheses is custom.
1270         """
1271         check_lpar = False
1272         for index, child in enumerate(node.children):
1273             if check_lpar:
1274                 if child.type == token.LPAR:
1275                     # make parentheses invisible
1276                     child.value = ""  # type: ignore
1277                     node.children[-1].value = ""  # type: ignore
1278                 else:
1279                     # insert invisible parentheses
1280                     node.insert_child(index, Leaf(token.LPAR, ""))
1281                     node.append_child(Leaf(token.RPAR, ""))
1282                 break
1283
1284             check_lpar = (
1285                 child.type == token.NAME and child.value == "import"  # type: ignore
1286             )
1287
1288         for child in node.children:
1289             yield from self.visit(child)
1290
1291     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1292         """Remove a semicolon and put the other statement on a separate line."""
1293         yield from self.line()
1294
1295     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1296         """End of file. Process outstanding comments and end with a newline."""
1297         yield from self.visit_default(leaf)
1298         yield from self.line()
1299
1300     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1301         """Used when file contained a `# fmt: off`."""
1302         if isinstance(node, Node):
1303             for child in node.children:
1304                 yield from self.visit(child)
1305
1306         else:
1307             try:
1308                 self.current_line.append(node)
1309             except FormatOn as f_on:
1310                 f_on.trim_prefix(node)
1311                 yield from self.line()
1312                 yield from self.visit(node)
1313
1314             if node.type == token.ENDMARKER:
1315                 # somebody decided not to put a final `# fmt: on`
1316                 yield from self.line()
1317
1318     def __attrs_post_init__(self) -> None:
1319         """You are in a twisty little maze of passages."""
1320         v = self.visit_stmt
1321         Ø: Set[str] = set()
1322         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1323         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1324         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1325         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1326         self.visit_try_stmt = partial(
1327             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1328         )
1329         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1330         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1331         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1332         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1333         self.visit_async_funcdef = self.visit_async_stmt
1334         self.visit_decorated = self.visit_decorators
1335
1336
1337 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1338 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1339 OPENING_BRACKETS = set(BRACKET.keys())
1340 CLOSING_BRACKETS = set(BRACKET.values())
1341 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1342 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1343
1344
1345 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1346     """Return whitespace prefix if needed for the given `leaf`.
1347
1348     `complex_subscript` signals whether the given leaf is part of a subscription
1349     which has non-trivial arguments, like arithmetic expressions or function calls.
1350     """
1351     NO = ""
1352     SPACE = " "
1353     DOUBLESPACE = "  "
1354     t = leaf.type
1355     p = leaf.parent
1356     v = leaf.value
1357     if t in ALWAYS_NO_SPACE:
1358         return NO
1359
1360     if t == token.COMMENT:
1361         return DOUBLESPACE
1362
1363     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1364     if (
1365         t == token.COLON
1366         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1367     ):
1368         return NO
1369
1370     prev = leaf.prev_sibling
1371     if not prev:
1372         prevp = preceding_leaf(p)
1373         if not prevp or prevp.type in OPENING_BRACKETS:
1374             return NO
1375
1376         if t == token.COLON:
1377             if prevp.type == token.COLON:
1378                 return NO
1379
1380             elif prevp.type != token.COMMA and not complex_subscript:
1381                 return NO
1382
1383             return SPACE
1384
1385         if prevp.type == token.EQUAL:
1386             if prevp.parent:
1387                 if prevp.parent.type in {
1388                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1389                 }:
1390                     return NO
1391
1392                 elif prevp.parent.type == syms.typedargslist:
1393                     # A bit hacky: if the equal sign has whitespace, it means we
1394                     # previously found it's a typed argument.  So, we're using
1395                     # that, too.
1396                     return prevp.prefix
1397
1398         elif prevp.type in STARS:
1399             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1400                 return NO
1401
1402         elif prevp.type == token.COLON:
1403             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1404                 return SPACE if complex_subscript else NO
1405
1406         elif (
1407             prevp.parent
1408             and prevp.parent.type == syms.factor
1409             and prevp.type in MATH_OPERATORS
1410         ):
1411             return NO
1412
1413         elif (
1414             prevp.type == token.RIGHTSHIFT
1415             and prevp.parent
1416             and prevp.parent.type == syms.shift_expr
1417             and prevp.prev_sibling
1418             and prevp.prev_sibling.type == token.NAME
1419             and prevp.prev_sibling.value == "print"  # type: ignore
1420         ):
1421             # Python 2 print chevron
1422             return NO
1423
1424     elif prev.type in OPENING_BRACKETS:
1425         return NO
1426
1427     if p.type in {syms.parameters, syms.arglist}:
1428         # untyped function signatures or calls
1429         if not prev or prev.type != token.COMMA:
1430             return NO
1431
1432     elif p.type == syms.varargslist:
1433         # lambdas
1434         if prev and prev.type != token.COMMA:
1435             return NO
1436
1437     elif p.type == syms.typedargslist:
1438         # typed function signatures
1439         if not prev:
1440             return NO
1441
1442         if t == token.EQUAL:
1443             if prev.type != syms.tname:
1444                 return NO
1445
1446         elif prev.type == token.EQUAL:
1447             # A bit hacky: if the equal sign has whitespace, it means we
1448             # previously found it's a typed argument.  So, we're using that, too.
1449             return prev.prefix
1450
1451         elif prev.type != token.COMMA:
1452             return NO
1453
1454     elif p.type == syms.tname:
1455         # type names
1456         if not prev:
1457             prevp = preceding_leaf(p)
1458             if not prevp or prevp.type != token.COMMA:
1459                 return NO
1460
1461     elif p.type == syms.trailer:
1462         # attributes and calls
1463         if t == token.LPAR or t == token.RPAR:
1464             return NO
1465
1466         if not prev:
1467             if t == token.DOT:
1468                 prevp = preceding_leaf(p)
1469                 if not prevp or prevp.type != token.NUMBER:
1470                     return NO
1471
1472             elif t == token.LSQB:
1473                 return NO
1474
1475         elif prev.type != token.COMMA:
1476             return NO
1477
1478     elif p.type == syms.argument:
1479         # single argument
1480         if t == token.EQUAL:
1481             return NO
1482
1483         if not prev:
1484             prevp = preceding_leaf(p)
1485             if not prevp or prevp.type == token.LPAR:
1486                 return NO
1487
1488         elif prev.type in {token.EQUAL} | STARS:
1489             return NO
1490
1491     elif p.type == syms.decorator:
1492         # decorators
1493         return NO
1494
1495     elif p.type == syms.dotted_name:
1496         if prev:
1497             return NO
1498
1499         prevp = preceding_leaf(p)
1500         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1501             return NO
1502
1503     elif p.type == syms.classdef:
1504         if t == token.LPAR:
1505             return NO
1506
1507         if prev and prev.type == token.LPAR:
1508             return NO
1509
1510     elif p.type in {syms.subscript, syms.sliceop}:
1511         # indexing
1512         if not prev:
1513             assert p.parent is not None, "subscripts are always parented"
1514             if p.parent.type == syms.subscriptlist:
1515                 return SPACE
1516
1517             return NO
1518
1519         elif not complex_subscript:
1520             return NO
1521
1522     elif p.type == syms.atom:
1523         if prev and t == token.DOT:
1524             # dots, but not the first one.
1525             return NO
1526
1527     elif p.type == syms.dictsetmaker:
1528         # dict unpacking
1529         if prev and prev.type == token.DOUBLESTAR:
1530             return NO
1531
1532     elif p.type in {syms.factor, syms.star_expr}:
1533         # unary ops
1534         if not prev:
1535             prevp = preceding_leaf(p)
1536             if not prevp or prevp.type in OPENING_BRACKETS:
1537                 return NO
1538
1539             prevp_parent = prevp.parent
1540             assert prevp_parent is not None
1541             if (
1542                 prevp.type == token.COLON
1543                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1544             ):
1545                 return NO
1546
1547             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1548                 return NO
1549
1550         elif t == token.NAME or t == token.NUMBER:
1551             return NO
1552
1553     elif p.type == syms.import_from:
1554         if t == token.DOT:
1555             if prev and prev.type == token.DOT:
1556                 return NO
1557
1558         elif t == token.NAME:
1559             if v == "import":
1560                 return SPACE
1561
1562             if prev and prev.type == token.DOT:
1563                 return NO
1564
1565     elif p.type == syms.sliceop:
1566         return NO
1567
1568     return SPACE
1569
1570
1571 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1572     """Return the first leaf that precedes `node`, if any."""
1573     while node:
1574         res = node.prev_sibling
1575         if res:
1576             if isinstance(res, Leaf):
1577                 return res
1578
1579             try:
1580                 return list(res.leaves())[-1]
1581
1582             except IndexError:
1583                 return None
1584
1585         node = node.parent
1586     return None
1587
1588
1589 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1590     """Return the child of `ancestor` that contains `descendant`."""
1591     node: Optional[LN] = descendant
1592     while node and node.parent != ancestor:
1593         node = node.parent
1594     return node
1595
1596
1597 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1598     """Return the priority of the `leaf` delimiter, given a line break after it.
1599
1600     The delimiter priorities returned here are from those delimiters that would
1601     cause a line break after themselves.
1602
1603     Higher numbers are higher priority.
1604     """
1605     if leaf.type == token.COMMA:
1606         return COMMA_PRIORITY
1607
1608     return 0
1609
1610
1611 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1612     """Return the priority of the `leaf` delimiter, given a line before after it.
1613
1614     The delimiter priorities returned here are from those delimiters that would
1615     cause a line break before themselves.
1616
1617     Higher numbers are higher priority.
1618     """
1619     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1620         # * and ** might also be MATH_OPERATORS but in this case they are not.
1621         # Don't treat them as a delimiter.
1622         return 0
1623
1624     if (
1625         leaf.type in MATH_OPERATORS
1626         and leaf.parent
1627         and leaf.parent.type not in {syms.factor, syms.star_expr}
1628     ):
1629         return MATH_PRIORITY
1630
1631     if leaf.type in COMPARATORS:
1632         return COMPARATOR_PRIORITY
1633
1634     if (
1635         leaf.type == token.STRING
1636         and previous is not None
1637         and previous.type == token.STRING
1638     ):
1639         return STRING_PRIORITY
1640
1641     if (
1642         leaf.type == token.NAME
1643         and leaf.value == "for"
1644         and leaf.parent
1645         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1646     ):
1647         return COMPREHENSION_PRIORITY
1648
1649     if (
1650         leaf.type == token.NAME
1651         and leaf.value == "if"
1652         and leaf.parent
1653         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1654     ):
1655         return COMPREHENSION_PRIORITY
1656
1657     if (
1658         leaf.type == token.NAME
1659         and leaf.value in {"if", "else"}
1660         and leaf.parent
1661         and leaf.parent.type == syms.test
1662     ):
1663         return TERNARY_PRIORITY
1664
1665     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1666         return LOGIC_PRIORITY
1667
1668     return 0
1669
1670
1671 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1672     """Clean the prefix of the `leaf` and generate comments from it, if any.
1673
1674     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1675     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1676     move because it does away with modifying the grammar to include all the
1677     possible places in which comments can be placed.
1678
1679     The sad consequence for us though is that comments don't "belong" anywhere.
1680     This is why this function generates simple parentless Leaf objects for
1681     comments.  We simply don't know what the correct parent should be.
1682
1683     No matter though, we can live without this.  We really only need to
1684     differentiate between inline and standalone comments.  The latter don't
1685     share the line with any code.
1686
1687     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1688     are emitted with a fake STANDALONE_COMMENT token identifier.
1689     """
1690     p = leaf.prefix
1691     if not p:
1692         return
1693
1694     if "#" not in p:
1695         return
1696
1697     consumed = 0
1698     nlines = 0
1699     for index, line in enumerate(p.split("\n")):
1700         consumed += len(line) + 1  # adding the length of the split '\n'
1701         line = line.lstrip()
1702         if not line:
1703             nlines += 1
1704         if not line.startswith("#"):
1705             continue
1706
1707         if index == 0 and leaf.type != token.ENDMARKER:
1708             comment_type = token.COMMENT  # simple trailing comment
1709         else:
1710             comment_type = STANDALONE_COMMENT
1711         comment = make_comment(line)
1712         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1713
1714         if comment in {"# fmt: on", "# yapf: enable"}:
1715             raise FormatOn(consumed)
1716
1717         if comment in {"# fmt: off", "# yapf: disable"}:
1718             if comment_type == STANDALONE_COMMENT:
1719                 raise FormatOff(consumed)
1720
1721             prev = preceding_leaf(leaf)
1722             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1723                 raise FormatOff(consumed)
1724
1725         nlines = 0
1726
1727
1728 def make_comment(content: str) -> str:
1729     """Return a consistently formatted comment from the given `content` string.
1730
1731     All comments (except for "##", "#!", "#:") should have a single space between
1732     the hash sign and the content.
1733
1734     If `content` didn't start with a hash sign, one is provided.
1735     """
1736     content = content.rstrip()
1737     if not content:
1738         return "#"
1739
1740     if content[0] == "#":
1741         content = content[1:]
1742     if content and content[0] not in " !:#":
1743         content = " " + content
1744     return "#" + content
1745
1746
1747 def split_line(
1748     line: Line, line_length: int, inner: bool = False, py36: bool = False
1749 ) -> Iterator[Line]:
1750     """Split a `line` into potentially many lines.
1751
1752     They should fit in the allotted `line_length` but might not be able to.
1753     `inner` signifies that there were a pair of brackets somewhere around the
1754     current `line`, possibly transitively. This means we can fallback to splitting
1755     by delimiters if the LHS/RHS don't yield any results.
1756
1757     If `py36` is True, splitting may generate syntax that is only compatible
1758     with Python 3.6 and later.
1759     """
1760     if isinstance(line, UnformattedLines) or line.is_comment:
1761         yield line
1762         return
1763
1764     line_str = str(line).strip("\n")
1765     if (
1766         len(line_str) <= line_length
1767         and "\n" not in line_str  # multiline strings
1768         and not line.contains_standalone_comments()
1769     ):
1770         yield line
1771         return
1772
1773     split_funcs: List[SplitFunc]
1774     if line.is_def:
1775         split_funcs = [left_hand_split]
1776     elif line.is_import:
1777         split_funcs = [explode_split]
1778     elif line.inside_brackets:
1779         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1780     else:
1781         split_funcs = [right_hand_split]
1782     for split_func in split_funcs:
1783         # We are accumulating lines in `result` because we might want to abort
1784         # mission and return the original line in the end, or attempt a different
1785         # split altogether.
1786         result: List[Line] = []
1787         try:
1788             for l in split_func(line, py36):
1789                 if str(l).strip("\n") == line_str:
1790                     raise CannotSplit("Split function returned an unchanged result")
1791
1792                 result.extend(
1793                     split_line(l, line_length=line_length, inner=True, py36=py36)
1794                 )
1795         except CannotSplit as cs:
1796             continue
1797
1798         else:
1799             yield from result
1800             break
1801
1802     else:
1803         yield line
1804
1805
1806 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1807     """Split line into many lines, starting with the first matching bracket pair.
1808
1809     Note: this usually looks weird, only use this for function definitions.
1810     Prefer RHS otherwise.
1811     """
1812     head = Line(depth=line.depth)
1813     body = Line(depth=line.depth + 1, inside_brackets=True)
1814     tail = Line(depth=line.depth)
1815     tail_leaves: List[Leaf] = []
1816     body_leaves: List[Leaf] = []
1817     head_leaves: List[Leaf] = []
1818     current_leaves = head_leaves
1819     matching_bracket = None
1820     for leaf in line.leaves:
1821         if (
1822             current_leaves is body_leaves
1823             and leaf.type in CLOSING_BRACKETS
1824             and leaf.opening_bracket is matching_bracket
1825         ):
1826             current_leaves = tail_leaves if body_leaves else head_leaves
1827         current_leaves.append(leaf)
1828         if current_leaves is head_leaves:
1829             if leaf.type in OPENING_BRACKETS:
1830                 matching_bracket = leaf
1831                 current_leaves = body_leaves
1832     # Since body is a new indent level, remove spurious leading whitespace.
1833     if body_leaves:
1834         normalize_prefix(body_leaves[0], inside_brackets=True)
1835     # Build the new lines.
1836     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1837         for leaf in leaves:
1838             result.append(leaf, preformatted=True)
1839             for comment_after in line.comments_after(leaf):
1840                 result.append(comment_after, preformatted=True)
1841     bracket_split_succeeded_or_raise(head, body, tail)
1842     for result in (head, body, tail):
1843         if result:
1844             yield result
1845
1846
1847 def right_hand_split(
1848     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1849 ) -> Iterator[Line]:
1850     """Split line into many lines, starting with the last matching bracket pair."""
1851     head = Line(depth=line.depth)
1852     body = Line(depth=line.depth + 1, inside_brackets=True)
1853     tail = Line(depth=line.depth)
1854     tail_leaves: List[Leaf] = []
1855     body_leaves: List[Leaf] = []
1856     head_leaves: List[Leaf] = []
1857     current_leaves = tail_leaves
1858     opening_bracket = None
1859     closing_bracket = None
1860     for leaf in reversed(line.leaves):
1861         if current_leaves is body_leaves:
1862             if leaf is opening_bracket:
1863                 current_leaves = head_leaves if body_leaves else tail_leaves
1864         current_leaves.append(leaf)
1865         if current_leaves is tail_leaves:
1866             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1867                 opening_bracket = leaf.opening_bracket
1868                 closing_bracket = leaf
1869                 current_leaves = body_leaves
1870     tail_leaves.reverse()
1871     body_leaves.reverse()
1872     head_leaves.reverse()
1873     # Since body is a new indent level, remove spurious leading whitespace.
1874     if body_leaves:
1875         normalize_prefix(body_leaves[0], inside_brackets=True)
1876     elif not head_leaves:
1877         # No `head` and no `body` means the split failed. `tail` has all content.
1878         raise CannotSplit("No brackets found")
1879
1880     # Build the new lines.
1881     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1882         for leaf in leaves:
1883             result.append(leaf, preformatted=True)
1884             for comment_after in line.comments_after(leaf):
1885                 result.append(comment_after, preformatted=True)
1886     bracket_split_succeeded_or_raise(head, body, tail)
1887     assert opening_bracket and closing_bracket
1888     if (
1889         opening_bracket.type == token.LPAR
1890         and not opening_bracket.value
1891         and closing_bracket.type == token.RPAR
1892         and not closing_bracket.value
1893     ):
1894         # These parens were optional. If there aren't any delimiters or standalone
1895         # comments in the body, they were unnecessary and another split without
1896         # them should be attempted.
1897         if not (
1898             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1899         ):
1900             omit = {id(closing_bracket), *omit}
1901             yield from right_hand_split(line, py36=py36, omit=omit)
1902             return
1903
1904     ensure_visible(opening_bracket)
1905     ensure_visible(closing_bracket)
1906     for result in (head, body, tail):
1907         if result:
1908             yield result
1909
1910
1911 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1912     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1913
1914     Do nothing otherwise.
1915
1916     A left- or right-hand split is based on a pair of brackets. Content before
1917     (and including) the opening bracket is left on one line, content inside the
1918     brackets is put on a separate line, and finally content starting with and
1919     following the closing bracket is put on a separate line.
1920
1921     Those are called `head`, `body`, and `tail`, respectively. If the split
1922     produced the same line (all content in `head`) or ended up with an empty `body`
1923     and the `tail` is just the closing bracket, then it's considered failed.
1924     """
1925     tail_len = len(str(tail).strip())
1926     if not body:
1927         if tail_len == 0:
1928             raise CannotSplit("Splitting brackets produced the same line")
1929
1930         elif tail_len < 3:
1931             raise CannotSplit(
1932                 f"Splitting brackets on an empty body to save "
1933                 f"{tail_len} characters is not worth it"
1934             )
1935
1936
1937 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1938     """Normalize prefix of the first leaf in every line returned by `split_func`.
1939
1940     This is a decorator over relevant split functions.
1941     """
1942
1943     @wraps(split_func)
1944     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1945         for l in split_func(line, py36):
1946             normalize_prefix(l.leaves[0], inside_brackets=True)
1947             yield l
1948
1949     return split_wrapper
1950
1951
1952 @dont_increase_indentation
1953 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1954     """Split according to delimiters of the highest priority.
1955
1956     If `py36` is True, the split will add trailing commas also in function
1957     signatures that contain `*` and `**`.
1958     """
1959     try:
1960         last_leaf = line.leaves[-1]
1961     except IndexError:
1962         raise CannotSplit("Line empty")
1963
1964     delimiters = line.bracket_tracker.delimiters
1965     try:
1966         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1967             exclude={id(last_leaf)}
1968         )
1969     except ValueError:
1970         raise CannotSplit("No delimiters found")
1971
1972     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1973     lowest_depth = sys.maxsize
1974     trailing_comma_safe = True
1975
1976     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1977         """Append `leaf` to current line or to new line if appending impossible."""
1978         nonlocal current_line
1979         try:
1980             current_line.append_safe(leaf, preformatted=True)
1981         except ValueError as ve:
1982             yield current_line
1983
1984             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1985             current_line.append(leaf)
1986
1987     for leaf in line.leaves:
1988         yield from append_to_line(leaf)
1989
1990         for comment_after in line.comments_after(leaf):
1991             yield from append_to_line(comment_after)
1992
1993         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1994         if (
1995             leaf.bracket_depth == lowest_depth
1996             and is_vararg(leaf, within=VARARGS_PARENTS)
1997         ):
1998             trailing_comma_safe = trailing_comma_safe and py36
1999         leaf_priority = delimiters.get(id(leaf))
2000         if leaf_priority == delimiter_priority:
2001             yield current_line
2002
2003             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2004     if current_line:
2005         if (
2006             trailing_comma_safe
2007             and delimiter_priority == COMMA_PRIORITY
2008             and current_line.leaves[-1].type != token.COMMA
2009             and current_line.leaves[-1].type != STANDALONE_COMMENT
2010         ):
2011             current_line.append(Leaf(token.COMMA, ","))
2012         yield current_line
2013
2014
2015 @dont_increase_indentation
2016 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2017     """Split standalone comments from the rest of the line."""
2018     if not line.contains_standalone_comments(0):
2019         raise CannotSplit("Line does not have any standalone comments")
2020
2021     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2022
2023     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2024         """Append `leaf` to current line or to new line if appending impossible."""
2025         nonlocal current_line
2026         try:
2027             current_line.append_safe(leaf, preformatted=True)
2028         except ValueError as ve:
2029             yield current_line
2030
2031             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2032             current_line.append(leaf)
2033
2034     for leaf in line.leaves:
2035         yield from append_to_line(leaf)
2036
2037         for comment_after in line.comments_after(leaf):
2038             yield from append_to_line(comment_after)
2039
2040     if current_line:
2041         yield current_line
2042
2043
2044 def explode_split(
2045     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2046 ) -> Iterator[Line]:
2047     """Split by rightmost bracket and immediately split contents by a delimiter."""
2048     new_lines = list(right_hand_split(line, py36, omit))
2049     if len(new_lines) != 3:
2050         yield from new_lines
2051         return
2052
2053     yield new_lines[0]
2054
2055     try:
2056         yield from delimiter_split(new_lines[1], py36)
2057
2058     except CannotSplit:
2059         yield new_lines[1]
2060
2061     yield new_lines[2]
2062
2063
2064 def is_import(leaf: Leaf) -> bool:
2065     """Return True if the given leaf starts an import statement."""
2066     p = leaf.parent
2067     t = leaf.type
2068     v = leaf.value
2069     return bool(
2070         t == token.NAME
2071         and (
2072             (v == "import" and p and p.type == syms.import_name)
2073             or (v == "from" and p and p.type == syms.import_from)
2074         )
2075     )
2076
2077
2078 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2079     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2080     else.
2081
2082     Note: don't use backslashes for formatting or you'll lose your voting rights.
2083     """
2084     if not inside_brackets:
2085         spl = leaf.prefix.split("#")
2086         if "\\" not in spl[0]:
2087             nl_count = spl[-1].count("\n")
2088             if len(spl) > 1:
2089                 nl_count -= 1
2090             leaf.prefix = "\n" * nl_count
2091             return
2092
2093     leaf.prefix = ""
2094
2095
2096 def normalize_string_quotes(leaf: Leaf) -> None:
2097     """Prefer double quotes but only if it doesn't cause more escaping.
2098
2099     Adds or removes backslashes as appropriate. Doesn't parse and fix
2100     strings nested in f-strings (yet).
2101
2102     Note: Mutates its argument.
2103     """
2104     value = leaf.value.lstrip("furbFURB")
2105     if value[:3] == '"""':
2106         return
2107
2108     elif value[:3] == "'''":
2109         orig_quote = "'''"
2110         new_quote = '"""'
2111     elif value[0] == '"':
2112         orig_quote = '"'
2113         new_quote = "'"
2114     else:
2115         orig_quote = "'"
2116         new_quote = '"'
2117     first_quote_pos = leaf.value.find(orig_quote)
2118     if first_quote_pos == -1:
2119         return  # There's an internal error
2120
2121     prefix = leaf.value[:first_quote_pos]
2122     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2123     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2124     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2125     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2126     if "r" in prefix.casefold():
2127         if unescaped_new_quote.search(body):
2128             # There's at least one unescaped new_quote in this raw string
2129             # so converting is impossible
2130             return
2131
2132         # Do not introduce or remove backslashes in raw strings
2133         new_body = body
2134     else:
2135         # remove unnecessary quotes
2136         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2137         if body != new_body:
2138             # Consider the string without unnecessary quotes as the original
2139             body = new_body
2140             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2141         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2142         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2143     if new_quote == '"""' and new_body[-1] == '"':
2144         # edge case:
2145         new_body = new_body[:-1] + '\\"'
2146     orig_escape_count = body.count("\\")
2147     new_escape_count = new_body.count("\\")
2148     if new_escape_count > orig_escape_count:
2149         return  # Do not introduce more escaping
2150
2151     if new_escape_count == orig_escape_count and orig_quote == '"':
2152         return  # Prefer double quotes
2153
2154     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2155
2156
2157 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2158     """Make existing optional parentheses invisible or create new ones.
2159
2160     Standardizes on visible parentheses for single-element tuples, and keeps
2161     existing visible parentheses for other tuples and generator expressions.
2162     """
2163     check_lpar = False
2164     for child in list(node.children):
2165         if check_lpar:
2166             if child.type == syms.atom:
2167                 if not (
2168                     is_empty_tuple(child)
2169                     or is_one_tuple(child)
2170                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2171                 ):
2172                     first = child.children[0]
2173                     last = child.children[-1]
2174                     if first.type == token.LPAR and last.type == token.RPAR:
2175                         # make parentheses invisible
2176                         first.value = ""  # type: ignore
2177                         last.value = ""  # type: ignore
2178             elif is_one_tuple(child):
2179                 # wrap child in visible parentheses
2180                 lpar = Leaf(token.LPAR, "(")
2181                 rpar = Leaf(token.RPAR, ")")
2182                 index = child.remove() or 0
2183                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2184             else:
2185                 # wrap child in invisible parentheses
2186                 lpar = Leaf(token.LPAR, "")
2187                 rpar = Leaf(token.RPAR, "")
2188                 index = child.remove() or 0
2189                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2190
2191         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2192
2193
2194 def is_empty_tuple(node: LN) -> bool:
2195     """Return True if `node` holds an empty tuple."""
2196     return (
2197         node.type == syms.atom
2198         and len(node.children) == 2
2199         and node.children[0].type == token.LPAR
2200         and node.children[1].type == token.RPAR
2201     )
2202
2203
2204 def is_one_tuple(node: LN) -> bool:
2205     """Return True if `node` holds a tuple with one element, with or without parens."""
2206     if node.type == syms.atom:
2207         if len(node.children) != 3:
2208             return False
2209
2210         lpar, gexp, rpar = node.children
2211         if not (
2212             lpar.type == token.LPAR
2213             and gexp.type == syms.testlist_gexp
2214             and rpar.type == token.RPAR
2215         ):
2216             return False
2217
2218         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2219
2220     return (
2221         node.type in IMPLICIT_TUPLE
2222         and len(node.children) == 2
2223         and node.children[1].type == token.COMMA
2224     )
2225
2226
2227 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2228     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2229
2230     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2231     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2232     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2233     generalizations (PEP 448).
2234     """
2235     if leaf.type not in STARS or not leaf.parent:
2236         return False
2237
2238     p = leaf.parent
2239     if p.type == syms.star_expr:
2240         # Star expressions are also used as assignment targets in extended
2241         # iterable unpacking (PEP 3132).  See what its parent is instead.
2242         if not p.parent:
2243             return False
2244
2245         p = p.parent
2246
2247     return p.type in within
2248
2249
2250 def max_delimiter_priority_in_atom(node: LN) -> int:
2251     """Return maximum delimiter priority inside `node`.
2252
2253     This is specific to atoms with contents contained in a pair of parentheses.
2254     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2255     """
2256     if node.type != syms.atom:
2257         return 0
2258
2259     first = node.children[0]
2260     last = node.children[-1]
2261     if not (first.type == token.LPAR and last.type == token.RPAR):
2262         return 0
2263
2264     bt = BracketTracker()
2265     for c in node.children[1:-1]:
2266         if isinstance(c, Leaf):
2267             bt.mark(c)
2268         else:
2269             for leaf in c.leaves():
2270                 bt.mark(leaf)
2271     try:
2272         return bt.max_delimiter_priority()
2273
2274     except ValueError:
2275         return 0
2276
2277
2278 def ensure_visible(leaf: Leaf) -> None:
2279     """Make sure parentheses are visible.
2280
2281     They could be invisible as part of some statements (see
2282     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2283     """
2284     if leaf.type == token.LPAR:
2285         leaf.value = "("
2286     elif leaf.type == token.RPAR:
2287         leaf.value = ")"
2288
2289
2290 def is_python36(node: Node) -> bool:
2291     """Return True if the current file is using Python 3.6+ features.
2292
2293     Currently looking for:
2294     - f-strings; and
2295     - trailing commas after * or ** in function signatures.
2296     """
2297     for n in node.pre_order():
2298         if n.type == token.STRING:
2299             value_head = n.value[:2]  # type: ignore
2300             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2301                 return True
2302
2303         elif (
2304             n.type == syms.typedargslist
2305             and n.children
2306             and n.children[-1].type == token.COMMA
2307         ):
2308             for ch in n.children:
2309                 if ch.type in STARS:
2310                     return True
2311
2312     return False
2313
2314
2315 PYTHON_EXTENSIONS = {".py"}
2316 BLACKLISTED_DIRECTORIES = {
2317     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2318 }
2319
2320
2321 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2322     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2323     and have one of the PYTHON_EXTENSIONS.
2324     """
2325     for child in path.iterdir():
2326         if child.is_dir():
2327             if child.name in BLACKLISTED_DIRECTORIES:
2328                 continue
2329
2330             yield from gen_python_files_in_dir(child)
2331
2332         elif child.suffix in PYTHON_EXTENSIONS:
2333             yield child
2334
2335
2336 @dataclass
2337 class Report:
2338     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2339     check: bool = False
2340     quiet: bool = False
2341     change_count: int = 0
2342     same_count: int = 0
2343     failure_count: int = 0
2344
2345     def done(self, src: Path, changed: Changed) -> None:
2346         """Increment the counter for successful reformatting. Write out a message."""
2347         if changed is Changed.YES:
2348             reformatted = "would reformat" if self.check else "reformatted"
2349             if not self.quiet:
2350                 out(f"{reformatted} {src}")
2351             self.change_count += 1
2352         else:
2353             if not self.quiet:
2354                 if changed is Changed.NO:
2355                     msg = f"{src} already well formatted, good job."
2356                 else:
2357                     msg = f"{src} wasn't modified on disk since last run."
2358                 out(msg, bold=False)
2359             self.same_count += 1
2360
2361     def failed(self, src: Path, message: str) -> None:
2362         """Increment the counter for failed reformatting. Write out a message."""
2363         err(f"error: cannot format {src}: {message}")
2364         self.failure_count += 1
2365
2366     @property
2367     def return_code(self) -> int:
2368         """Return the exit code that the app should use.
2369
2370         This considers the current state of changed files and failures:
2371         - if there were any failures, return 123;
2372         - if any files were changed and --check is being used, return 1;
2373         - otherwise return 0.
2374         """
2375         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2376         # 126 we have special returncodes reserved by the shell.
2377         if self.failure_count:
2378             return 123
2379
2380         elif self.change_count and self.check:
2381             return 1
2382
2383         return 0
2384
2385     def __str__(self) -> str:
2386         """Render a color report of the current state.
2387
2388         Use `click.unstyle` to remove colors.
2389         """
2390         if self.check:
2391             reformatted = "would be reformatted"
2392             unchanged = "would be left unchanged"
2393             failed = "would fail to reformat"
2394         else:
2395             reformatted = "reformatted"
2396             unchanged = "left unchanged"
2397             failed = "failed to reformat"
2398         report = []
2399         if self.change_count:
2400             s = "s" if self.change_count > 1 else ""
2401             report.append(
2402                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2403             )
2404         if self.same_count:
2405             s = "s" if self.same_count > 1 else ""
2406             report.append(f"{self.same_count} file{s} {unchanged}")
2407         if self.failure_count:
2408             s = "s" if self.failure_count > 1 else ""
2409             report.append(
2410                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2411             )
2412         return ", ".join(report) + "."
2413
2414
2415 def assert_equivalent(src: str, dst: str) -> None:
2416     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2417
2418     import ast
2419     import traceback
2420
2421     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2422         """Simple visitor generating strings to compare ASTs by content."""
2423         yield f"{'  ' * depth}{node.__class__.__name__}("
2424
2425         for field in sorted(node._fields):
2426             try:
2427                 value = getattr(node, field)
2428             except AttributeError:
2429                 continue
2430
2431             yield f"{'  ' * (depth+1)}{field}="
2432
2433             if isinstance(value, list):
2434                 for item in value:
2435                     if isinstance(item, ast.AST):
2436                         yield from _v(item, depth + 2)
2437
2438             elif isinstance(value, ast.AST):
2439                 yield from _v(value, depth + 2)
2440
2441             else:
2442                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2443
2444         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2445
2446     try:
2447         src_ast = ast.parse(src)
2448     except Exception as exc:
2449         major, minor = sys.version_info[:2]
2450         raise AssertionError(
2451             f"cannot use --safe with this file; failed to parse source file "
2452             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2453             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2454         )
2455
2456     try:
2457         dst_ast = ast.parse(dst)
2458     except Exception as exc:
2459         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2460         raise AssertionError(
2461             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2462             f"Please report a bug on https://github.com/ambv/black/issues.  "
2463             f"This invalid output might be helpful: {log}"
2464         ) from None
2465
2466     src_ast_str = "\n".join(_v(src_ast))
2467     dst_ast_str = "\n".join(_v(dst_ast))
2468     if src_ast_str != dst_ast_str:
2469         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2470         raise AssertionError(
2471             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2472             f"the source.  "
2473             f"Please report a bug on https://github.com/ambv/black/issues.  "
2474             f"This diff might be helpful: {log}"
2475         ) from None
2476
2477
2478 def assert_stable(src: str, dst: str, line_length: int) -> None:
2479     """Raise AssertionError if `dst` reformats differently the second time."""
2480     newdst = format_str(dst, line_length=line_length)
2481     if dst != newdst:
2482         log = dump_to_file(
2483             diff(src, dst, "source", "first pass"),
2484             diff(dst, newdst, "first pass", "second pass"),
2485         )
2486         raise AssertionError(
2487             f"INTERNAL ERROR: Black produced different code on the second pass "
2488             f"of the formatter.  "
2489             f"Please report a bug on https://github.com/ambv/black/issues.  "
2490             f"This diff might be helpful: {log}"
2491         ) from None
2492
2493
2494 def dump_to_file(*output: str) -> str:
2495     """Dump `output` to a temporary file. Return path to the file."""
2496     import tempfile
2497
2498     with tempfile.NamedTemporaryFile(
2499         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2500     ) as f:
2501         for lines in output:
2502             f.write(lines)
2503             if lines and lines[-1] != "\n":
2504                 f.write("\n")
2505     return f.name
2506
2507
2508 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2509     """Return a unified diff string between strings `a` and `b`."""
2510     import difflib
2511
2512     a_lines = [line + "\n" for line in a.split("\n")]
2513     b_lines = [line + "\n" for line in b.split("\n")]
2514     return "".join(
2515         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2516     )
2517
2518
2519 def cancel(tasks: List[asyncio.Task]) -> None:
2520     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2521     err("Aborted!")
2522     for task in tasks:
2523         task.cancel()
2524
2525
2526 def shutdown(loop: BaseEventLoop) -> None:
2527     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2528     try:
2529         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2530         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2531         if not to_cancel:
2532             return
2533
2534         for task in to_cancel:
2535             task.cancel()
2536         loop.run_until_complete(
2537             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2538         )
2539     finally:
2540         # `concurrent.futures.Future` objects cannot be cancelled once they
2541         # are already running. There might be some when the `shutdown()` happened.
2542         # Silence their logger's spew about the event loop being closed.
2543         cf_logger = logging.getLogger("concurrent.futures")
2544         cf_logger.setLevel(logging.CRITICAL)
2545         loop.close()
2546
2547
2548 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2549     """Replace `regex` with `replacement` twice on `original`.
2550
2551     This is used by string normalization to perform replaces on
2552     overlapping matches.
2553     """
2554     return regex.sub(replacement, regex.sub(replacement, original))
2555
2556
2557 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2558
2559
2560 def get_cache_file(line_length: int) -> Path:
2561     return CACHE_DIR / f"cache.{line_length}.pickle"
2562
2563
2564 def read_cache(line_length: int) -> Cache:
2565     """Read the cache if it exists and is well formed.
2566
2567     If it is not well formed, the call to write_cache later should resolve the issue.
2568     """
2569     cache_file = get_cache_file(line_length)
2570     if not cache_file.exists():
2571         return {}
2572
2573     with cache_file.open("rb") as fobj:
2574         try:
2575             cache: Cache = pickle.load(fobj)
2576         except pickle.UnpicklingError:
2577             return {}
2578
2579     return cache
2580
2581
2582 def get_cache_info(path: Path) -> CacheInfo:
2583     """Return the information used to check if a file is already formatted or not."""
2584     stat = path.stat()
2585     return stat.st_mtime, stat.st_size
2586
2587
2588 def filter_cached(
2589     cache: Cache, sources: Iterable[Path]
2590 ) -> Tuple[List[Path], List[Path]]:
2591     """Split a list of paths into two.
2592
2593     The first list contains paths of files that modified on disk or are not in the
2594     cache. The other list contains paths to non-modified files.
2595     """
2596     todo, done = [], []
2597     for src in sources:
2598         src = src.resolve()
2599         if cache.get(src) != get_cache_info(src):
2600             todo.append(src)
2601         else:
2602             done.append(src)
2603     return todo, done
2604
2605
2606 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2607     """Update the cache file."""
2608     cache_file = get_cache_file(line_length)
2609     try:
2610         if not CACHE_DIR.exists():
2611             CACHE_DIR.mkdir(parents=True)
2612         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2613         with cache_file.open("wb") as fobj:
2614             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2615     except OSError:
2616         pass
2617
2618
2619 if __name__ == "__main__":
2620     main()