]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

151dc8cccfe30efd38a1555e83aaaa9a90eadb48
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a6"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 TEST_DESCENDANTS = {
585     syms.test,
586     syms.lambdef,
587     syms.or_test,
588     syms.and_test,
589     syms.not_test,
590     syms.comparison,
591     syms.star_expr,
592     syms.expr,
593     syms.xor_expr,
594     syms.and_expr,
595     syms.shift_expr,
596     syms.arith_expr,
597     syms.trailer,
598     syms.term,
599     syms.power,
600 }
601 COMPREHENSION_PRIORITY = 20
602 COMMA_PRIORITY = 10
603 TERNARY_PRIORITY = 7
604 LOGIC_PRIORITY = 5
605 STRING_PRIORITY = 4
606 COMPARATOR_PRIORITY = 3
607 MATH_PRIORITY = 1
608
609
610 @dataclass
611 class BracketTracker:
612     """Keeps track of brackets on a line."""
613
614     depth: int = 0
615     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
616     delimiters: Dict[LeafID, Priority] = Factory(dict)
617     previous: Optional[Leaf] = None
618     _for_loop_variable: bool = False
619     _lambda_arguments: bool = False
620
621     def mark(self, leaf: Leaf) -> None:
622         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
623
624         All leaves receive an int `bracket_depth` field that stores how deep
625         within brackets a given leaf is. 0 means there are no enclosing brackets
626         that started on this line.
627
628         If a leaf is itself a closing bracket, it receives an `opening_bracket`
629         field that it forms a pair with. This is a one-directional link to
630         avoid reference cycles.
631
632         If a leaf is a delimiter (a token on which Black can split the line if
633         needed) and it's on depth 0, its `id()` is stored in the tracker's
634         `delimiters` field.
635         """
636         if leaf.type == token.COMMENT:
637             return
638
639         self.maybe_decrement_after_for_loop_variable(leaf)
640         self.maybe_decrement_after_lambda_arguments(leaf)
641         if leaf.type in CLOSING_BRACKETS:
642             self.depth -= 1
643             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
644             leaf.opening_bracket = opening_bracket
645         leaf.bracket_depth = self.depth
646         if self.depth == 0:
647             delim = is_split_before_delimiter(leaf, self.previous)
648             if delim and self.previous is not None:
649                 self.delimiters[id(self.previous)] = delim
650             else:
651                 delim = is_split_after_delimiter(leaf, self.previous)
652                 if delim:
653                     self.delimiters[id(leaf)] = delim
654         if leaf.type in OPENING_BRACKETS:
655             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
656             self.depth += 1
657         self.previous = leaf
658         self.maybe_increment_lambda_arguments(leaf)
659         self.maybe_increment_for_loop_variable(leaf)
660
661     def any_open_brackets(self) -> bool:
662         """Return True if there is an yet unmatched open bracket on the line."""
663         return bool(self.bracket_match)
664
665     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
666         """Return the highest priority of a delimiter found on the line.
667
668         Values are consistent with what `is_split_*_delimiter()` return.
669         Raises ValueError on no delimiters.
670         """
671         return max(v for k, v in self.delimiters.items() if k not in exclude)
672
673     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
674         """In a for loop, or comprehension, the variables are often unpacks.
675
676         To avoid splitting on the comma in this situation, increase the depth of
677         tokens between `for` and `in`.
678         """
679         if leaf.type == token.NAME and leaf.value == "for":
680             self.depth += 1
681             self._for_loop_variable = True
682             return True
683
684         return False
685
686     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
687         """See `maybe_increment_for_loop_variable` above for explanation."""
688         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
689             self.depth -= 1
690             self._for_loop_variable = False
691             return True
692
693         return False
694
695     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
696         """In a lambda expression, there might be more than one argument.
697
698         To avoid splitting on the comma in this situation, increase the depth of
699         tokens between `lambda` and `:`.
700         """
701         if leaf.type == token.NAME and leaf.value == "lambda":
702             self.depth += 1
703             self._lambda_arguments = True
704             return True
705
706         return False
707
708     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
709         """See `maybe_increment_lambda_arguments` above for explanation."""
710         if self._lambda_arguments and leaf.type == token.COLON:
711             self.depth -= 1
712             self._lambda_arguments = False
713             return True
714
715         return False
716
717     def get_open_lsqb(self) -> Optional[Leaf]:
718         """Return the most recent opening square bracket (if any)."""
719         return self.bracket_match.get((self.depth - 1, token.RSQB))
720
721
722 @dataclass
723 class Line:
724     """Holds leaves and comments. Can be printed with `str(line)`."""
725
726     depth: int = 0
727     leaves: List[Leaf] = Factory(list)
728     comments: List[Tuple[Index, Leaf]] = Factory(list)
729     bracket_tracker: BracketTracker = Factory(BracketTracker)
730     inside_brackets: bool = False
731
732     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
733         """Add a new `leaf` to the end of the line.
734
735         Unless `preformatted` is True, the `leaf` will receive a new consistent
736         whitespace prefix and metadata applied by :class:`BracketTracker`.
737         Trailing commas are maybe removed, unpacked for loop variables are
738         demoted from being delimiters.
739
740         Inline comments are put aside.
741         """
742         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
743         if not has_value:
744             return
745
746         if token.COLON == leaf.type and self.is_class_paren_empty:
747             del self.leaves[-2:]
748         if self.leaves and not preformatted:
749             # Note: at this point leaf.prefix should be empty except for
750             # imports, for which we only preserve newlines.
751             leaf.prefix += whitespace(
752                 leaf, complex_subscript=self.is_complex_subscript(leaf)
753             )
754         if self.inside_brackets or not preformatted:
755             self.bracket_tracker.mark(leaf)
756             self.maybe_remove_trailing_comma(leaf)
757         if not self.append_comment(leaf):
758             self.leaves.append(leaf)
759
760     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
761         """Like :func:`append()` but disallow invalid standalone comment structure.
762
763         Raises ValueError when any `leaf` is appended after a standalone comment
764         or when a standalone comment is not the first leaf on the line.
765         """
766         if self.bracket_tracker.depth == 0:
767             if self.is_comment:
768                 raise ValueError("cannot append to standalone comments")
769
770             if self.leaves and leaf.type == STANDALONE_COMMENT:
771                 raise ValueError(
772                     "cannot append standalone comments to a populated line"
773                 )
774
775         self.append(leaf, preformatted=preformatted)
776
777     @property
778     def is_comment(self) -> bool:
779         """Is this line a standalone comment?"""
780         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
781
782     @property
783     def is_decorator(self) -> bool:
784         """Is this line a decorator?"""
785         return bool(self) and self.leaves[0].type == token.AT
786
787     @property
788     def is_import(self) -> bool:
789         """Is this an import line?"""
790         return bool(self) and is_import(self.leaves[0])
791
792     @property
793     def is_class(self) -> bool:
794         """Is this line a class definition?"""
795         return (
796             bool(self)
797             and self.leaves[0].type == token.NAME
798             and self.leaves[0].value == "class"
799         )
800
801     @property
802     def is_def(self) -> bool:
803         """Is this a function definition? (Also returns True for async defs.)"""
804         try:
805             first_leaf = self.leaves[0]
806         except IndexError:
807             return False
808
809         try:
810             second_leaf: Optional[Leaf] = self.leaves[1]
811         except IndexError:
812             second_leaf = None
813         return (
814             (first_leaf.type == token.NAME and first_leaf.value == "def")
815             or (
816                 first_leaf.type == token.ASYNC
817                 and second_leaf is not None
818                 and second_leaf.type == token.NAME
819                 and second_leaf.value == "def"
820             )
821         )
822
823     @property
824     def is_flow_control(self) -> bool:
825         """Is this line a flow control statement?
826
827         Those are `return`, `raise`, `break`, and `continue`.
828         """
829         return (
830             bool(self)
831             and self.leaves[0].type == token.NAME
832             and self.leaves[0].value in FLOW_CONTROL
833         )
834
835     @property
836     def is_yield(self) -> bool:
837         """Is this line a yield statement?"""
838         return (
839             bool(self)
840             and self.leaves[0].type == token.NAME
841             and self.leaves[0].value == "yield"
842         )
843
844     @property
845     def is_class_paren_empty(self) -> bool:
846         """Is this a class with no base classes but using parentheses?
847
848         Those are unnecessary and should be removed.
849         """
850         return (
851             bool(self)
852             and len(self.leaves) == 4
853             and self.is_class
854             and self.leaves[2].type == token.LPAR
855             and self.leaves[2].value == "("
856             and self.leaves[3].type == token.RPAR
857             and self.leaves[3].value == ")"
858         )
859
860     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
861         """If so, needs to be split before emitting."""
862         for leaf in self.leaves:
863             if leaf.type == STANDALONE_COMMENT:
864                 if leaf.bracket_depth <= depth_limit:
865                     return True
866
867         return False
868
869     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
870         """Remove trailing comma if there is one and it's safe."""
871         if not (
872             self.leaves
873             and self.leaves[-1].type == token.COMMA
874             and closing.type in CLOSING_BRACKETS
875         ):
876             return False
877
878         if closing.type == token.RBRACE:
879             self.remove_trailing_comma()
880             return True
881
882         if closing.type == token.RSQB:
883             comma = self.leaves[-1]
884             if comma.parent and comma.parent.type == syms.listmaker:
885                 self.remove_trailing_comma()
886                 return True
887
888         # For parens let's check if it's safe to remove the comma.
889         # Imports are always safe.
890         if self.is_import:
891             self.remove_trailing_comma()
892             return True
893
894         # Otheriwsse, if the trailing one is the only one, we might mistakenly
895         # change a tuple into a different type by removing the comma.
896         depth = closing.bracket_depth + 1
897         commas = 0
898         opening = closing.opening_bracket
899         for _opening_index, leaf in enumerate(self.leaves):
900             if leaf is opening:
901                 break
902
903         else:
904             return False
905
906         for leaf in self.leaves[_opening_index + 1 :]:
907             if leaf is closing:
908                 break
909
910             bracket_depth = leaf.bracket_depth
911             if bracket_depth == depth and leaf.type == token.COMMA:
912                 commas += 1
913                 if leaf.parent and leaf.parent.type == syms.arglist:
914                     commas += 1
915                     break
916
917         if commas > 1:
918             self.remove_trailing_comma()
919             return True
920
921         return False
922
923     def append_comment(self, comment: Leaf) -> bool:
924         """Add an inline or standalone comment to the line."""
925         if (
926             comment.type == STANDALONE_COMMENT
927             and self.bracket_tracker.any_open_brackets()
928         ):
929             comment.prefix = ""
930             return False
931
932         if comment.type != token.COMMENT:
933             return False
934
935         after = len(self.leaves) - 1
936         if after == -1:
937             comment.type = STANDALONE_COMMENT
938             comment.prefix = ""
939             return False
940
941         else:
942             self.comments.append((after, comment))
943             return True
944
945     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
946         """Generate comments that should appear directly after `leaf`."""
947         for _leaf_index, _leaf in enumerate(self.leaves):
948             if leaf is _leaf:
949                 break
950
951         else:
952             return
953
954         for index, comment_after in self.comments:
955             if _leaf_index == index:
956                 yield comment_after
957
958     def remove_trailing_comma(self) -> None:
959         """Remove the trailing comma and moves the comments attached to it."""
960         comma_index = len(self.leaves) - 1
961         for i in range(len(self.comments)):
962             comment_index, comment = self.comments[i]
963             if comment_index == comma_index:
964                 self.comments[i] = (comma_index - 1, comment)
965         self.leaves.pop()
966
967     def is_complex_subscript(self, leaf: Leaf) -> bool:
968         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
969         open_lsqb = (
970             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
971         )
972         if open_lsqb is None:
973             return False
974
975         subscript_start = open_lsqb.next_sibling
976         if (
977             isinstance(subscript_start, Node)
978             and subscript_start.type == syms.subscriptlist
979         ):
980             subscript_start = child_towards(subscript_start, leaf)
981         return subscript_start is not None and any(
982             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
983         )
984
985     def __str__(self) -> str:
986         """Render the line."""
987         if not self:
988             return "\n"
989
990         indent = "    " * self.depth
991         leaves = iter(self.leaves)
992         first = next(leaves)
993         res = f"{first.prefix}{indent}{first.value}"
994         for leaf in leaves:
995             res += str(leaf)
996         for _, comment in self.comments:
997             res += str(comment)
998         return res + "\n"
999
1000     def __bool__(self) -> bool:
1001         """Return True if the line has leaves or comments."""
1002         return bool(self.leaves or self.comments)
1003
1004
1005 class UnformattedLines(Line):
1006     """Just like :class:`Line` but stores lines which aren't reformatted."""
1007
1008     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1009         """Just add a new `leaf` to the end of the lines.
1010
1011         The `preformatted` argument is ignored.
1012
1013         Keeps track of indentation `depth`, which is useful when the user
1014         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1015         """
1016         try:
1017             list(generate_comments(leaf))
1018         except FormatOn as f_on:
1019             self.leaves.append(f_on.leaf_from_consumed(leaf))
1020             raise
1021
1022         self.leaves.append(leaf)
1023         if leaf.type == token.INDENT:
1024             self.depth += 1
1025         elif leaf.type == token.DEDENT:
1026             self.depth -= 1
1027
1028     def __str__(self) -> str:
1029         """Render unformatted lines from leaves which were added with `append()`.
1030
1031         `depth` is not used for indentation in this case.
1032         """
1033         if not self:
1034             return "\n"
1035
1036         res = ""
1037         for leaf in self.leaves:
1038             res += str(leaf)
1039         return res
1040
1041     def append_comment(self, comment: Leaf) -> bool:
1042         """Not implemented in this class. Raises `NotImplementedError`."""
1043         raise NotImplementedError("Unformatted lines don't store comments separately.")
1044
1045     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1046         """Does nothing and returns False."""
1047         return False
1048
1049     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1050         """Does nothing and returns False."""
1051         return False
1052
1053
1054 @dataclass
1055 class EmptyLineTracker:
1056     """Provides a stateful method that returns the number of potential extra
1057     empty lines needed before and after the currently processed line.
1058
1059     Note: this tracker works on lines that haven't been split yet.  It assumes
1060     the prefix of the first leaf consists of optional newlines.  Those newlines
1061     are consumed by `maybe_empty_lines()` and included in the computation.
1062     """
1063     previous_line: Optional[Line] = None
1064     previous_after: int = 0
1065     previous_defs: List[int] = Factory(list)
1066
1067     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1068         """Return the number of extra empty lines before and after the `current_line`.
1069
1070         This is for separating `def`, `async def` and `class` with extra empty
1071         lines (two on module-level), as well as providing an extra empty line
1072         after flow control keywords to make them more prominent.
1073         """
1074         if isinstance(current_line, UnformattedLines):
1075             return 0, 0
1076
1077         before, after = self._maybe_empty_lines(current_line)
1078         before -= self.previous_after
1079         self.previous_after = after
1080         self.previous_line = current_line
1081         return before, after
1082
1083     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1084         max_allowed = 1
1085         if current_line.depth == 0:
1086             max_allowed = 2
1087         if current_line.leaves:
1088             # Consume the first leaf's extra newlines.
1089             first_leaf = current_line.leaves[0]
1090             before = first_leaf.prefix.count("\n")
1091             before = min(before, max_allowed)
1092             first_leaf.prefix = ""
1093         else:
1094             before = 0
1095         depth = current_line.depth
1096         while self.previous_defs and self.previous_defs[-1] >= depth:
1097             self.previous_defs.pop()
1098             before = 1 if depth else 2
1099         is_decorator = current_line.is_decorator
1100         if is_decorator or current_line.is_def or current_line.is_class:
1101             if not is_decorator:
1102                 self.previous_defs.append(depth)
1103             if self.previous_line is None:
1104                 # Don't insert empty lines before the first line in the file.
1105                 return 0, 0
1106
1107             if self.previous_line.is_decorator:
1108                 return 0, 0
1109
1110             if (
1111                 self.previous_line.is_comment
1112                 and self.previous_line.depth == current_line.depth
1113                 and before == 0
1114             ):
1115                 return 0, 0
1116
1117             newlines = 2
1118             if current_line.depth:
1119                 newlines -= 1
1120             return newlines, 0
1121
1122         if (
1123             self.previous_line
1124             and self.previous_line.is_import
1125             and not current_line.is_import
1126             and depth == self.previous_line.depth
1127         ):
1128             return (before or 1), 0
1129
1130         return before, 0
1131
1132
1133 @dataclass
1134 class LineGenerator(Visitor[Line]):
1135     """Generates reformatted Line objects.  Empty lines are not emitted.
1136
1137     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1138     in ways that will no longer stringify to valid Python code on the tree.
1139     """
1140     current_line: Line = Factory(Line)
1141
1142     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1143         """Generate a line.
1144
1145         If the line is empty, only emit if it makes sense.
1146         If the line is too long, split it first and then generate.
1147
1148         If any lines were generated, set up a new current_line.
1149         """
1150         if not self.current_line:
1151             if self.current_line.__class__ == type:
1152                 self.current_line.depth += indent
1153             else:
1154                 self.current_line = type(depth=self.current_line.depth + indent)
1155             return  # Line is empty, don't emit. Creating a new one unnecessary.
1156
1157         complete_line = self.current_line
1158         self.current_line = type(depth=complete_line.depth + indent)
1159         yield complete_line
1160
1161     def visit(self, node: LN) -> Iterator[Line]:
1162         """Main method to visit `node` and its children.
1163
1164         Yields :class:`Line` objects.
1165         """
1166         if isinstance(self.current_line, UnformattedLines):
1167             # File contained `# fmt: off`
1168             yield from self.visit_unformatted(node)
1169
1170         else:
1171             yield from super().visit(node)
1172
1173     def visit_default(self, node: LN) -> Iterator[Line]:
1174         """Default `visit_*()` implementation. Recurses to children of `node`."""
1175         if isinstance(node, Leaf):
1176             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1177             try:
1178                 for comment in generate_comments(node):
1179                     if any_open_brackets:
1180                         # any comment within brackets is subject to splitting
1181                         self.current_line.append(comment)
1182                     elif comment.type == token.COMMENT:
1183                         # regular trailing comment
1184                         self.current_line.append(comment)
1185                         yield from self.line()
1186
1187                     else:
1188                         # regular standalone comment
1189                         yield from self.line()
1190
1191                         self.current_line.append(comment)
1192                         yield from self.line()
1193
1194             except FormatOff as f_off:
1195                 f_off.trim_prefix(node)
1196                 yield from self.line(type=UnformattedLines)
1197                 yield from self.visit(node)
1198
1199             except FormatOn as f_on:
1200                 # This only happens here if somebody says "fmt: on" multiple
1201                 # times in a row.
1202                 f_on.trim_prefix(node)
1203                 yield from self.visit_default(node)
1204
1205             else:
1206                 normalize_prefix(node, inside_brackets=any_open_brackets)
1207                 if node.type == token.STRING:
1208                     normalize_string_quotes(node)
1209                 if node.type not in WHITESPACE:
1210                     self.current_line.append(node)
1211         yield from super().visit_default(node)
1212
1213     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1214         """Increase indentation level, maybe yield a line."""
1215         # In blib2to3 INDENT never holds comments.
1216         yield from self.line(+1)
1217         yield from self.visit_default(node)
1218
1219     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1220         """Decrease indentation level, maybe yield a line."""
1221         # The current line might still wait for trailing comments.  At DEDENT time
1222         # there won't be any (they would be prefixes on the preceding NEWLINE).
1223         # Emit the line then.
1224         yield from self.line()
1225
1226         # While DEDENT has no value, its prefix may contain standalone comments
1227         # that belong to the current indentation level.  Get 'em.
1228         yield from self.visit_default(node)
1229
1230         # Finally, emit the dedent.
1231         yield from self.line(-1)
1232
1233     def visit_stmt(
1234         self, node: Node, keywords: Set[str], parens: Set[str]
1235     ) -> Iterator[Line]:
1236         """Visit a statement.
1237
1238         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1239         `def`, `with`, `class`, and `assert`.
1240
1241         The relevant Python language `keywords` for a given statement will be
1242         NAME leaves within it. This methods puts those on a separate line.
1243
1244         `parens` holds pairs of nodes where invisible parentheses should be put.
1245         Keys hold nodes after which opening parentheses should be put, values
1246         hold nodes before which closing parentheses should be put.
1247         """
1248         normalize_invisible_parens(node, parens_after=parens)
1249         for child in node.children:
1250             if child.type == token.NAME and child.value in keywords:  # type: ignore
1251                 yield from self.line()
1252
1253             yield from self.visit(child)
1254
1255     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1256         """Visit a statement without nested statements."""
1257         is_suite_like = node.parent and node.parent.type in STATEMENT
1258         if is_suite_like:
1259             yield from self.line(+1)
1260             yield from self.visit_default(node)
1261             yield from self.line(-1)
1262
1263         else:
1264             yield from self.line()
1265             yield from self.visit_default(node)
1266
1267     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1268         """Visit `async def`, `async for`, `async with`."""
1269         yield from self.line()
1270
1271         children = iter(node.children)
1272         for child in children:
1273             yield from self.visit(child)
1274
1275             if child.type == token.ASYNC:
1276                 break
1277
1278         internal_stmt = next(children)
1279         for child in internal_stmt.children:
1280             yield from self.visit(child)
1281
1282     def visit_decorators(self, node: Node) -> Iterator[Line]:
1283         """Visit decorators."""
1284         for child in node.children:
1285             yield from self.line()
1286             yield from self.visit(child)
1287
1288     def visit_import_from(self, node: Node) -> Iterator[Line]:
1289         """Visit import_from and maybe put invisible parentheses.
1290
1291         This is separate from `visit_stmt` because import statements don't
1292         support arbitrary atoms and thus handling of parentheses is custom.
1293         """
1294         check_lpar = False
1295         for index, child in enumerate(node.children):
1296             if check_lpar:
1297                 if child.type == token.LPAR:
1298                     # make parentheses invisible
1299                     child.value = ""  # type: ignore
1300                     node.children[-1].value = ""  # type: ignore
1301                 else:
1302                     # insert invisible parentheses
1303                     node.insert_child(index, Leaf(token.LPAR, ""))
1304                     node.append_child(Leaf(token.RPAR, ""))
1305                 break
1306
1307             check_lpar = (
1308                 child.type == token.NAME and child.value == "import"  # type: ignore
1309             )
1310
1311         for child in node.children:
1312             yield from self.visit(child)
1313
1314     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1315         """Remove a semicolon and put the other statement on a separate line."""
1316         yield from self.line()
1317
1318     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1319         """End of file. Process outstanding comments and end with a newline."""
1320         yield from self.visit_default(leaf)
1321         yield from self.line()
1322
1323     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1324         """Used when file contained a `# fmt: off`."""
1325         if isinstance(node, Node):
1326             for child in node.children:
1327                 yield from self.visit(child)
1328
1329         else:
1330             try:
1331                 self.current_line.append(node)
1332             except FormatOn as f_on:
1333                 f_on.trim_prefix(node)
1334                 yield from self.line()
1335                 yield from self.visit(node)
1336
1337             if node.type == token.ENDMARKER:
1338                 # somebody decided not to put a final `# fmt: on`
1339                 yield from self.line()
1340
1341     def __attrs_post_init__(self) -> None:
1342         """You are in a twisty little maze of passages."""
1343         v = self.visit_stmt
1344         Ø: Set[str] = set()
1345         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1346         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1347         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1348         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1349         self.visit_try_stmt = partial(
1350             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1351         )
1352         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1353         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1354         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1355         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1356         self.visit_async_funcdef = self.visit_async_stmt
1357         self.visit_decorated = self.visit_decorators
1358
1359
1360 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1361 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1362 OPENING_BRACKETS = set(BRACKET.keys())
1363 CLOSING_BRACKETS = set(BRACKET.values())
1364 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1365 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1366
1367
1368 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1369     """Return whitespace prefix if needed for the given `leaf`.
1370
1371     `complex_subscript` signals whether the given leaf is part of a subscription
1372     which has non-trivial arguments, like arithmetic expressions or function calls.
1373     """
1374     NO = ""
1375     SPACE = " "
1376     DOUBLESPACE = "  "
1377     t = leaf.type
1378     p = leaf.parent
1379     v = leaf.value
1380     if t in ALWAYS_NO_SPACE:
1381         return NO
1382
1383     if t == token.COMMENT:
1384         return DOUBLESPACE
1385
1386     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1387     if (
1388         t == token.COLON
1389         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1390     ):
1391         return NO
1392
1393     prev = leaf.prev_sibling
1394     if not prev:
1395         prevp = preceding_leaf(p)
1396         if not prevp or prevp.type in OPENING_BRACKETS:
1397             return NO
1398
1399         if t == token.COLON:
1400             if prevp.type == token.COLON:
1401                 return NO
1402
1403             elif prevp.type != token.COMMA and not complex_subscript:
1404                 return NO
1405
1406             return SPACE
1407
1408         if prevp.type == token.EQUAL:
1409             if prevp.parent:
1410                 if prevp.parent.type in {
1411                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1412                 }:
1413                     return NO
1414
1415                 elif prevp.parent.type == syms.typedargslist:
1416                     # A bit hacky: if the equal sign has whitespace, it means we
1417                     # previously found it's a typed argument.  So, we're using
1418                     # that, too.
1419                     return prevp.prefix
1420
1421         elif prevp.type in STARS:
1422             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1423                 return NO
1424
1425         elif prevp.type == token.COLON:
1426             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1427                 return SPACE if complex_subscript else NO
1428
1429         elif (
1430             prevp.parent
1431             and prevp.parent.type == syms.factor
1432             and prevp.type in MATH_OPERATORS
1433         ):
1434             return NO
1435
1436         elif (
1437             prevp.type == token.RIGHTSHIFT
1438             and prevp.parent
1439             and prevp.parent.type == syms.shift_expr
1440             and prevp.prev_sibling
1441             and prevp.prev_sibling.type == token.NAME
1442             and prevp.prev_sibling.value == "print"  # type: ignore
1443         ):
1444             # Python 2 print chevron
1445             return NO
1446
1447     elif prev.type in OPENING_BRACKETS:
1448         return NO
1449
1450     if p.type in {syms.parameters, syms.arglist}:
1451         # untyped function signatures or calls
1452         if not prev or prev.type != token.COMMA:
1453             return NO
1454
1455     elif p.type == syms.varargslist:
1456         # lambdas
1457         if prev and prev.type != token.COMMA:
1458             return NO
1459
1460     elif p.type == syms.typedargslist:
1461         # typed function signatures
1462         if not prev:
1463             return NO
1464
1465         if t == token.EQUAL:
1466             if prev.type != syms.tname:
1467                 return NO
1468
1469         elif prev.type == token.EQUAL:
1470             # A bit hacky: if the equal sign has whitespace, it means we
1471             # previously found it's a typed argument.  So, we're using that, too.
1472             return prev.prefix
1473
1474         elif prev.type != token.COMMA:
1475             return NO
1476
1477     elif p.type == syms.tname:
1478         # type names
1479         if not prev:
1480             prevp = preceding_leaf(p)
1481             if not prevp or prevp.type != token.COMMA:
1482                 return NO
1483
1484     elif p.type == syms.trailer:
1485         # attributes and calls
1486         if t == token.LPAR or t == token.RPAR:
1487             return NO
1488
1489         if not prev:
1490             if t == token.DOT:
1491                 prevp = preceding_leaf(p)
1492                 if not prevp or prevp.type != token.NUMBER:
1493                     return NO
1494
1495             elif t == token.LSQB:
1496                 return NO
1497
1498         elif prev.type != token.COMMA:
1499             return NO
1500
1501     elif p.type == syms.argument:
1502         # single argument
1503         if t == token.EQUAL:
1504             return NO
1505
1506         if not prev:
1507             prevp = preceding_leaf(p)
1508             if not prevp or prevp.type == token.LPAR:
1509                 return NO
1510
1511         elif prev.type in {token.EQUAL} | STARS:
1512             return NO
1513
1514     elif p.type == syms.decorator:
1515         # decorators
1516         return NO
1517
1518     elif p.type == syms.dotted_name:
1519         if prev:
1520             return NO
1521
1522         prevp = preceding_leaf(p)
1523         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1524             return NO
1525
1526     elif p.type == syms.classdef:
1527         if t == token.LPAR:
1528             return NO
1529
1530         if prev and prev.type == token.LPAR:
1531             return NO
1532
1533     elif p.type in {syms.subscript, syms.sliceop}:
1534         # indexing
1535         if not prev:
1536             assert p.parent is not None, "subscripts are always parented"
1537             if p.parent.type == syms.subscriptlist:
1538                 return SPACE
1539
1540             return NO
1541
1542         elif not complex_subscript:
1543             return NO
1544
1545     elif p.type == syms.atom:
1546         if prev and t == token.DOT:
1547             # dots, but not the first one.
1548             return NO
1549
1550     elif p.type == syms.dictsetmaker:
1551         # dict unpacking
1552         if prev and prev.type == token.DOUBLESTAR:
1553             return NO
1554
1555     elif p.type in {syms.factor, syms.star_expr}:
1556         # unary ops
1557         if not prev:
1558             prevp = preceding_leaf(p)
1559             if not prevp or prevp.type in OPENING_BRACKETS:
1560                 return NO
1561
1562             prevp_parent = prevp.parent
1563             assert prevp_parent is not None
1564             if (
1565                 prevp.type == token.COLON
1566                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1567             ):
1568                 return NO
1569
1570             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1571                 return NO
1572
1573         elif t == token.NAME or t == token.NUMBER:
1574             return NO
1575
1576     elif p.type == syms.import_from:
1577         if t == token.DOT:
1578             if prev and prev.type == token.DOT:
1579                 return NO
1580
1581         elif t == token.NAME:
1582             if v == "import":
1583                 return SPACE
1584
1585             if prev and prev.type == token.DOT:
1586                 return NO
1587
1588     elif p.type == syms.sliceop:
1589         return NO
1590
1591     return SPACE
1592
1593
1594 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1595     """Return the first leaf that precedes `node`, if any."""
1596     while node:
1597         res = node.prev_sibling
1598         if res:
1599             if isinstance(res, Leaf):
1600                 return res
1601
1602             try:
1603                 return list(res.leaves())[-1]
1604
1605             except IndexError:
1606                 return None
1607
1608         node = node.parent
1609     return None
1610
1611
1612 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1613     """Return the child of `ancestor` that contains `descendant`."""
1614     node: Optional[LN] = descendant
1615     while node and node.parent != ancestor:
1616         node = node.parent
1617     return node
1618
1619
1620 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1621     """Return the priority of the `leaf` delimiter, given a line break after it.
1622
1623     The delimiter priorities returned here are from those delimiters that would
1624     cause a line break after themselves.
1625
1626     Higher numbers are higher priority.
1627     """
1628     if leaf.type == token.COMMA:
1629         return COMMA_PRIORITY
1630
1631     return 0
1632
1633
1634 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1635     """Return the priority of the `leaf` delimiter, given a line before after it.
1636
1637     The delimiter priorities returned here are from those delimiters that would
1638     cause a line break before themselves.
1639
1640     Higher numbers are higher priority.
1641     """
1642     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1643         # * and ** might also be MATH_OPERATORS but in this case they are not.
1644         # Don't treat them as a delimiter.
1645         return 0
1646
1647     if (
1648         leaf.type in MATH_OPERATORS
1649         and leaf.parent
1650         and leaf.parent.type not in {syms.factor, syms.star_expr}
1651     ):
1652         return MATH_PRIORITY
1653
1654     if leaf.type in COMPARATORS:
1655         return COMPARATOR_PRIORITY
1656
1657     if (
1658         leaf.type == token.STRING
1659         and previous is not None
1660         and previous.type == token.STRING
1661     ):
1662         return STRING_PRIORITY
1663
1664     if (
1665         leaf.type == token.NAME
1666         and leaf.value == "for"
1667         and leaf.parent
1668         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1669     ):
1670         return COMPREHENSION_PRIORITY
1671
1672     if (
1673         leaf.type == token.NAME
1674         and leaf.value == "if"
1675         and leaf.parent
1676         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1677     ):
1678         return COMPREHENSION_PRIORITY
1679
1680     if (
1681         leaf.type == token.NAME
1682         and leaf.value in {"if", "else"}
1683         and leaf.parent
1684         and leaf.parent.type == syms.test
1685     ):
1686         return TERNARY_PRIORITY
1687
1688     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1689         return LOGIC_PRIORITY
1690
1691     return 0
1692
1693
1694 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1695     """Clean the prefix of the `leaf` and generate comments from it, if any.
1696
1697     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1698     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1699     move because it does away with modifying the grammar to include all the
1700     possible places in which comments can be placed.
1701
1702     The sad consequence for us though is that comments don't "belong" anywhere.
1703     This is why this function generates simple parentless Leaf objects for
1704     comments.  We simply don't know what the correct parent should be.
1705
1706     No matter though, we can live without this.  We really only need to
1707     differentiate between inline and standalone comments.  The latter don't
1708     share the line with any code.
1709
1710     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1711     are emitted with a fake STANDALONE_COMMENT token identifier.
1712     """
1713     p = leaf.prefix
1714     if not p:
1715         return
1716
1717     if "#" not in p:
1718         return
1719
1720     consumed = 0
1721     nlines = 0
1722     for index, line in enumerate(p.split("\n")):
1723         consumed += len(line) + 1  # adding the length of the split '\n'
1724         line = line.lstrip()
1725         if not line:
1726             nlines += 1
1727         if not line.startswith("#"):
1728             continue
1729
1730         if index == 0 and leaf.type != token.ENDMARKER:
1731             comment_type = token.COMMENT  # simple trailing comment
1732         else:
1733             comment_type = STANDALONE_COMMENT
1734         comment = make_comment(line)
1735         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1736
1737         if comment in {"# fmt: on", "# yapf: enable"}:
1738             raise FormatOn(consumed)
1739
1740         if comment in {"# fmt: off", "# yapf: disable"}:
1741             if comment_type == STANDALONE_COMMENT:
1742                 raise FormatOff(consumed)
1743
1744             prev = preceding_leaf(leaf)
1745             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1746                 raise FormatOff(consumed)
1747
1748         nlines = 0
1749
1750
1751 def make_comment(content: str) -> str:
1752     """Return a consistently formatted comment from the given `content` string.
1753
1754     All comments (except for "##", "#!", "#:") should have a single space between
1755     the hash sign and the content.
1756
1757     If `content` didn't start with a hash sign, one is provided.
1758     """
1759     content = content.rstrip()
1760     if not content:
1761         return "#"
1762
1763     if content[0] == "#":
1764         content = content[1:]
1765     if content and content[0] not in " !:#":
1766         content = " " + content
1767     return "#" + content
1768
1769
1770 def split_line(
1771     line: Line, line_length: int, inner: bool = False, py36: bool = False
1772 ) -> Iterator[Line]:
1773     """Split a `line` into potentially many lines.
1774
1775     They should fit in the allotted `line_length` but might not be able to.
1776     `inner` signifies that there were a pair of brackets somewhere around the
1777     current `line`, possibly transitively. This means we can fallback to splitting
1778     by delimiters if the LHS/RHS don't yield any results.
1779
1780     If `py36` is True, splitting may generate syntax that is only compatible
1781     with Python 3.6 and later.
1782     """
1783     if isinstance(line, UnformattedLines) or line.is_comment:
1784         yield line
1785         return
1786
1787     line_str = str(line).strip("\n")
1788     if (
1789         len(line_str) <= line_length
1790         and "\n" not in line_str  # multiline strings
1791         and not line.contains_standalone_comments()
1792     ):
1793         yield line
1794         return
1795
1796     split_funcs: List[SplitFunc]
1797     if line.is_def:
1798         split_funcs = [left_hand_split]
1799     elif line.is_import:
1800         split_funcs = [explode_split]
1801     elif line.inside_brackets:
1802         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1803     else:
1804         split_funcs = [right_hand_split]
1805     for split_func in split_funcs:
1806         # We are accumulating lines in `result` because we might want to abort
1807         # mission and return the original line in the end, or attempt a different
1808         # split altogether.
1809         result: List[Line] = []
1810         try:
1811             for l in split_func(line, py36):
1812                 if str(l).strip("\n") == line_str:
1813                     raise CannotSplit("Split function returned an unchanged result")
1814
1815                 result.extend(
1816                     split_line(l, line_length=line_length, inner=True, py36=py36)
1817                 )
1818         except CannotSplit as cs:
1819             continue
1820
1821         else:
1822             yield from result
1823             break
1824
1825     else:
1826         yield line
1827
1828
1829 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1830     """Split line into many lines, starting with the first matching bracket pair.
1831
1832     Note: this usually looks weird, only use this for function definitions.
1833     Prefer RHS otherwise.  This is why this function is not symmetrical with
1834     :func:`right_hand_split` which also handles optional parentheses.
1835     """
1836     head = Line(depth=line.depth)
1837     body = Line(depth=line.depth + 1, inside_brackets=True)
1838     tail = Line(depth=line.depth)
1839     tail_leaves: List[Leaf] = []
1840     body_leaves: List[Leaf] = []
1841     head_leaves: List[Leaf] = []
1842     current_leaves = head_leaves
1843     matching_bracket = None
1844     for leaf in line.leaves:
1845         if (
1846             current_leaves is body_leaves
1847             and leaf.type in CLOSING_BRACKETS
1848             and leaf.opening_bracket is matching_bracket
1849         ):
1850             current_leaves = tail_leaves if body_leaves else head_leaves
1851         current_leaves.append(leaf)
1852         if current_leaves is head_leaves:
1853             if leaf.type in OPENING_BRACKETS:
1854                 matching_bracket = leaf
1855                 current_leaves = body_leaves
1856     # Since body is a new indent level, remove spurious leading whitespace.
1857     if body_leaves:
1858         normalize_prefix(body_leaves[0], inside_brackets=True)
1859     # Build the new lines.
1860     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1861         for leaf in leaves:
1862             result.append(leaf, preformatted=True)
1863             for comment_after in line.comments_after(leaf):
1864                 result.append(comment_after, preformatted=True)
1865     bracket_split_succeeded_or_raise(head, body, tail)
1866     for result in (head, body, tail):
1867         if result:
1868             yield result
1869
1870
1871 def right_hand_split(
1872     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1873 ) -> Iterator[Line]:
1874     """Split line into many lines, starting with the last matching bracket pair.
1875
1876     If the split was by optional parentheses, attempt splitting without them, too.
1877     """
1878     head = Line(depth=line.depth)
1879     body = Line(depth=line.depth + 1, inside_brackets=True)
1880     tail = Line(depth=line.depth)
1881     tail_leaves: List[Leaf] = []
1882     body_leaves: List[Leaf] = []
1883     head_leaves: List[Leaf] = []
1884     current_leaves = tail_leaves
1885     opening_bracket = None
1886     closing_bracket = None
1887     for leaf in reversed(line.leaves):
1888         if current_leaves is body_leaves:
1889             if leaf is opening_bracket:
1890                 current_leaves = head_leaves if body_leaves else tail_leaves
1891         current_leaves.append(leaf)
1892         if current_leaves is tail_leaves:
1893             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1894                 opening_bracket = leaf.opening_bracket
1895                 closing_bracket = leaf
1896                 current_leaves = body_leaves
1897     tail_leaves.reverse()
1898     body_leaves.reverse()
1899     head_leaves.reverse()
1900     # Since body is a new indent level, remove spurious leading whitespace.
1901     if body_leaves:
1902         normalize_prefix(body_leaves[0], inside_brackets=True)
1903     elif not head_leaves:
1904         # No `head` and no `body` means the split failed. `tail` has all content.
1905         raise CannotSplit("No brackets found")
1906
1907     # Build the new lines.
1908     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1909         for leaf in leaves:
1910             result.append(leaf, preformatted=True)
1911             for comment_after in line.comments_after(leaf):
1912                 result.append(comment_after, preformatted=True)
1913     bracket_split_succeeded_or_raise(head, body, tail)
1914     assert opening_bracket and closing_bracket
1915     if (
1916         # the opening bracket is an optional paren
1917         opening_bracket.type == token.LPAR
1918         and not opening_bracket.value
1919         # the closing bracket is an optional paren
1920         and closing_bracket.type == token.RPAR
1921         and not closing_bracket.value
1922         # there are no delimiters or standalone comments in the body
1923         and not body.bracket_tracker.delimiters
1924         and not line.contains_standalone_comments(0)
1925         # and it's not an import (optional parens are the only thing we can split
1926         # on in this case; attempting a split without them is a waste of time)
1927         and not line.is_import
1928     ):
1929         omit = {id(closing_bracket), *omit}
1930         try:
1931             yield from right_hand_split(line, py36=py36, omit=omit)
1932             return
1933         except CannotSplit:
1934             pass
1935
1936     ensure_visible(opening_bracket)
1937     ensure_visible(closing_bracket)
1938     for result in (head, body, tail):
1939         if result:
1940             yield result
1941
1942
1943 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1944     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1945
1946     Do nothing otherwise.
1947
1948     A left- or right-hand split is based on a pair of brackets. Content before
1949     (and including) the opening bracket is left on one line, content inside the
1950     brackets is put on a separate line, and finally content starting with and
1951     following the closing bracket is put on a separate line.
1952
1953     Those are called `head`, `body`, and `tail`, respectively. If the split
1954     produced the same line (all content in `head`) or ended up with an empty `body`
1955     and the `tail` is just the closing bracket, then it's considered failed.
1956     """
1957     tail_len = len(str(tail).strip())
1958     if not body:
1959         if tail_len == 0:
1960             raise CannotSplit("Splitting brackets produced the same line")
1961
1962         elif tail_len < 3:
1963             raise CannotSplit(
1964                 f"Splitting brackets on an empty body to save "
1965                 f"{tail_len} characters is not worth it"
1966             )
1967
1968
1969 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1970     """Normalize prefix of the first leaf in every line returned by `split_func`.
1971
1972     This is a decorator over relevant split functions.
1973     """
1974
1975     @wraps(split_func)
1976     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1977         for l in split_func(line, py36):
1978             normalize_prefix(l.leaves[0], inside_brackets=True)
1979             yield l
1980
1981     return split_wrapper
1982
1983
1984 @dont_increase_indentation
1985 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1986     """Split according to delimiters of the highest priority.
1987
1988     If `py36` is True, the split will add trailing commas also in function
1989     signatures that contain `*` and `**`.
1990     """
1991     try:
1992         last_leaf = line.leaves[-1]
1993     except IndexError:
1994         raise CannotSplit("Line empty")
1995
1996     delimiters = line.bracket_tracker.delimiters
1997     try:
1998         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1999             exclude={id(last_leaf)}
2000         )
2001     except ValueError:
2002         raise CannotSplit("No delimiters found")
2003
2004     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2005     lowest_depth = sys.maxsize
2006     trailing_comma_safe = True
2007
2008     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2009         """Append `leaf` to current line or to new line if appending impossible."""
2010         nonlocal current_line
2011         try:
2012             current_line.append_safe(leaf, preformatted=True)
2013         except ValueError as ve:
2014             yield current_line
2015
2016             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2017             current_line.append(leaf)
2018
2019     for leaf in line.leaves:
2020         yield from append_to_line(leaf)
2021
2022         for comment_after in line.comments_after(leaf):
2023             yield from append_to_line(comment_after)
2024
2025         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2026         if (
2027             leaf.bracket_depth == lowest_depth
2028             and is_vararg(leaf, within=VARARGS_PARENTS)
2029         ):
2030             trailing_comma_safe = trailing_comma_safe and py36
2031         leaf_priority = delimiters.get(id(leaf))
2032         if leaf_priority == delimiter_priority:
2033             yield current_line
2034
2035             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2036     if current_line:
2037         if (
2038             trailing_comma_safe
2039             and delimiter_priority == COMMA_PRIORITY
2040             and current_line.leaves[-1].type != token.COMMA
2041             and current_line.leaves[-1].type != STANDALONE_COMMENT
2042         ):
2043             current_line.append(Leaf(token.COMMA, ","))
2044         yield current_line
2045
2046
2047 @dont_increase_indentation
2048 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2049     """Split standalone comments from the rest of the line."""
2050     if not line.contains_standalone_comments(0):
2051         raise CannotSplit("Line does not have any standalone comments")
2052
2053     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2054
2055     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2056         """Append `leaf` to current line or to new line if appending impossible."""
2057         nonlocal current_line
2058         try:
2059             current_line.append_safe(leaf, preformatted=True)
2060         except ValueError as ve:
2061             yield current_line
2062
2063             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2064             current_line.append(leaf)
2065
2066     for leaf in line.leaves:
2067         yield from append_to_line(leaf)
2068
2069         for comment_after in line.comments_after(leaf):
2070             yield from append_to_line(comment_after)
2071
2072     if current_line:
2073         yield current_line
2074
2075
2076 def explode_split(
2077     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2078 ) -> Iterator[Line]:
2079     """Split by rightmost bracket and immediately split contents by a delimiter."""
2080     new_lines = list(right_hand_split(line, py36, omit))
2081     if len(new_lines) != 3:
2082         yield from new_lines
2083         return
2084
2085     yield new_lines[0]
2086
2087     try:
2088         yield from delimiter_split(new_lines[1], py36)
2089
2090     except CannotSplit:
2091         yield new_lines[1]
2092
2093     yield new_lines[2]
2094
2095
2096 def is_import(leaf: Leaf) -> bool:
2097     """Return True if the given leaf starts an import statement."""
2098     p = leaf.parent
2099     t = leaf.type
2100     v = leaf.value
2101     return bool(
2102         t == token.NAME
2103         and (
2104             (v == "import" and p and p.type == syms.import_name)
2105             or (v == "from" and p and p.type == syms.import_from)
2106         )
2107     )
2108
2109
2110 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2111     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2112     else.
2113
2114     Note: don't use backslashes for formatting or you'll lose your voting rights.
2115     """
2116     if not inside_brackets:
2117         spl = leaf.prefix.split("#")
2118         if "\\" not in spl[0]:
2119             nl_count = spl[-1].count("\n")
2120             if len(spl) > 1:
2121                 nl_count -= 1
2122             leaf.prefix = "\n" * nl_count
2123             return
2124
2125     leaf.prefix = ""
2126
2127
2128 def normalize_string_quotes(leaf: Leaf) -> None:
2129     """Prefer double quotes but only if it doesn't cause more escaping.
2130
2131     Adds or removes backslashes as appropriate. Doesn't parse and fix
2132     strings nested in f-strings (yet).
2133
2134     Note: Mutates its argument.
2135     """
2136     value = leaf.value.lstrip("furbFURB")
2137     if value[:3] == '"""':
2138         return
2139
2140     elif value[:3] == "'''":
2141         orig_quote = "'''"
2142         new_quote = '"""'
2143     elif value[0] == '"':
2144         orig_quote = '"'
2145         new_quote = "'"
2146     else:
2147         orig_quote = "'"
2148         new_quote = '"'
2149     first_quote_pos = leaf.value.find(orig_quote)
2150     if first_quote_pos == -1:
2151         return  # There's an internal error
2152
2153     prefix = leaf.value[:first_quote_pos]
2154     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2155     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2156     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2157     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2158     if "r" in prefix.casefold():
2159         if unescaped_new_quote.search(body):
2160             # There's at least one unescaped new_quote in this raw string
2161             # so converting is impossible
2162             return
2163
2164         # Do not introduce or remove backslashes in raw strings
2165         new_body = body
2166     else:
2167         # remove unnecessary quotes
2168         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2169         if body != new_body:
2170             # Consider the string without unnecessary quotes as the original
2171             body = new_body
2172             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2173         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2174         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2175     if new_quote == '"""' and new_body[-1] == '"':
2176         # edge case:
2177         new_body = new_body[:-1] + '\\"'
2178     orig_escape_count = body.count("\\")
2179     new_escape_count = new_body.count("\\")
2180     if new_escape_count > orig_escape_count:
2181         return  # Do not introduce more escaping
2182
2183     if new_escape_count == orig_escape_count and orig_quote == '"':
2184         return  # Prefer double quotes
2185
2186     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2187
2188
2189 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2190     """Make existing optional parentheses invisible or create new ones.
2191
2192     Standardizes on visible parentheses for single-element tuples, and keeps
2193     existing visible parentheses for other tuples and generator expressions.
2194     """
2195     check_lpar = False
2196     for child in list(node.children):
2197         if check_lpar:
2198             if child.type == syms.atom:
2199                 maybe_make_parens_invisible_in_atom(child)
2200             elif is_one_tuple(child):
2201                 # wrap child in visible parentheses
2202                 lpar = Leaf(token.LPAR, "(")
2203                 rpar = Leaf(token.RPAR, ")")
2204                 index = child.remove() or 0
2205                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2206             else:
2207                 # wrap child in invisible parentheses
2208                 lpar = Leaf(token.LPAR, "")
2209                 rpar = Leaf(token.RPAR, "")
2210                 index = child.remove() or 0
2211                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2212
2213         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2214
2215
2216 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2217     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2218     if (
2219         node.type != syms.atom
2220         or is_empty_tuple(node)
2221         or is_one_tuple(node)
2222         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2223     ):
2224         return False
2225
2226     first = node.children[0]
2227     last = node.children[-1]
2228     if first.type == token.LPAR and last.type == token.RPAR:
2229         # make parentheses invisible
2230         first.value = ""  # type: ignore
2231         last.value = ""  # type: ignore
2232         if len(node.children) > 1:
2233             maybe_make_parens_invisible_in_atom(node.children[1])
2234         return True
2235
2236     return False
2237
2238
2239 def is_empty_tuple(node: LN) -> bool:
2240     """Return True if `node` holds an empty tuple."""
2241     return (
2242         node.type == syms.atom
2243         and len(node.children) == 2
2244         and node.children[0].type == token.LPAR
2245         and node.children[1].type == token.RPAR
2246     )
2247
2248
2249 def is_one_tuple(node: LN) -> bool:
2250     """Return True if `node` holds a tuple with one element, with or without parens."""
2251     if node.type == syms.atom:
2252         if len(node.children) != 3:
2253             return False
2254
2255         lpar, gexp, rpar = node.children
2256         if not (
2257             lpar.type == token.LPAR
2258             and gexp.type == syms.testlist_gexp
2259             and rpar.type == token.RPAR
2260         ):
2261             return False
2262
2263         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2264
2265     return (
2266         node.type in IMPLICIT_TUPLE
2267         and len(node.children) == 2
2268         and node.children[1].type == token.COMMA
2269     )
2270
2271
2272 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2273     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2274
2275     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2276     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2277     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2278     generalizations (PEP 448).
2279     """
2280     if leaf.type not in STARS or not leaf.parent:
2281         return False
2282
2283     p = leaf.parent
2284     if p.type == syms.star_expr:
2285         # Star expressions are also used as assignment targets in extended
2286         # iterable unpacking (PEP 3132).  See what its parent is instead.
2287         if not p.parent:
2288             return False
2289
2290         p = p.parent
2291
2292     return p.type in within
2293
2294
2295 def max_delimiter_priority_in_atom(node: LN) -> int:
2296     """Return maximum delimiter priority inside `node`.
2297
2298     This is specific to atoms with contents contained in a pair of parentheses.
2299     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2300     """
2301     if node.type != syms.atom:
2302         return 0
2303
2304     first = node.children[0]
2305     last = node.children[-1]
2306     if not (first.type == token.LPAR and last.type == token.RPAR):
2307         return 0
2308
2309     bt = BracketTracker()
2310     for c in node.children[1:-1]:
2311         if isinstance(c, Leaf):
2312             bt.mark(c)
2313         else:
2314             for leaf in c.leaves():
2315                 bt.mark(leaf)
2316     try:
2317         return bt.max_delimiter_priority()
2318
2319     except ValueError:
2320         return 0
2321
2322
2323 def ensure_visible(leaf: Leaf) -> None:
2324     """Make sure parentheses are visible.
2325
2326     They could be invisible as part of some statements (see
2327     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2328     """
2329     if leaf.type == token.LPAR:
2330         leaf.value = "("
2331     elif leaf.type == token.RPAR:
2332         leaf.value = ")"
2333
2334
2335 def is_python36(node: Node) -> bool:
2336     """Return True if the current file is using Python 3.6+ features.
2337
2338     Currently looking for:
2339     - f-strings; and
2340     - trailing commas after * or ** in function signatures and calls.
2341     """
2342     for n in node.pre_order():
2343         if n.type == token.STRING:
2344             value_head = n.value[:2]  # type: ignore
2345             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2346                 return True
2347
2348         elif (
2349             n.type in {syms.typedargslist, syms.arglist}
2350             and n.children
2351             and n.children[-1].type == token.COMMA
2352         ):
2353             for ch in n.children:
2354                 if ch.type in STARS:
2355                     return True
2356
2357                 if ch.type == syms.argument:
2358                     for argch in ch.children:
2359                         if argch.type in STARS:
2360                             return True
2361
2362     return False
2363
2364
2365 PYTHON_EXTENSIONS = {".py"}
2366 BLACKLISTED_DIRECTORIES = {
2367     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2368 }
2369
2370
2371 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2372     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2373     and have one of the PYTHON_EXTENSIONS.
2374     """
2375     for child in path.iterdir():
2376         if child.is_dir():
2377             if child.name in BLACKLISTED_DIRECTORIES:
2378                 continue
2379
2380             yield from gen_python_files_in_dir(child)
2381
2382         elif child.suffix in PYTHON_EXTENSIONS:
2383             yield child
2384
2385
2386 @dataclass
2387 class Report:
2388     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2389     check: bool = False
2390     quiet: bool = False
2391     change_count: int = 0
2392     same_count: int = 0
2393     failure_count: int = 0
2394
2395     def done(self, src: Path, changed: Changed) -> None:
2396         """Increment the counter for successful reformatting. Write out a message."""
2397         if changed is Changed.YES:
2398             reformatted = "would reformat" if self.check else "reformatted"
2399             if not self.quiet:
2400                 out(f"{reformatted} {src}")
2401             self.change_count += 1
2402         else:
2403             if not self.quiet:
2404                 if changed is Changed.NO:
2405                     msg = f"{src} already well formatted, good job."
2406                 else:
2407                     msg = f"{src} wasn't modified on disk since last run."
2408                 out(msg, bold=False)
2409             self.same_count += 1
2410
2411     def failed(self, src: Path, message: str) -> None:
2412         """Increment the counter for failed reformatting. Write out a message."""
2413         err(f"error: cannot format {src}: {message}")
2414         self.failure_count += 1
2415
2416     @property
2417     def return_code(self) -> int:
2418         """Return the exit code that the app should use.
2419
2420         This considers the current state of changed files and failures:
2421         - if there were any failures, return 123;
2422         - if any files were changed and --check is being used, return 1;
2423         - otherwise return 0.
2424         """
2425         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2426         # 126 we have special returncodes reserved by the shell.
2427         if self.failure_count:
2428             return 123
2429
2430         elif self.change_count and self.check:
2431             return 1
2432
2433         return 0
2434
2435     def __str__(self) -> str:
2436         """Render a color report of the current state.
2437
2438         Use `click.unstyle` to remove colors.
2439         """
2440         if self.check:
2441             reformatted = "would be reformatted"
2442             unchanged = "would be left unchanged"
2443             failed = "would fail to reformat"
2444         else:
2445             reformatted = "reformatted"
2446             unchanged = "left unchanged"
2447             failed = "failed to reformat"
2448         report = []
2449         if self.change_count:
2450             s = "s" if self.change_count > 1 else ""
2451             report.append(
2452                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2453             )
2454         if self.same_count:
2455             s = "s" if self.same_count > 1 else ""
2456             report.append(f"{self.same_count} file{s} {unchanged}")
2457         if self.failure_count:
2458             s = "s" if self.failure_count > 1 else ""
2459             report.append(
2460                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2461             )
2462         return ", ".join(report) + "."
2463
2464
2465 def assert_equivalent(src: str, dst: str) -> None:
2466     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2467
2468     import ast
2469     import traceback
2470
2471     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2472         """Simple visitor generating strings to compare ASTs by content."""
2473         yield f"{'  ' * depth}{node.__class__.__name__}("
2474
2475         for field in sorted(node._fields):
2476             try:
2477                 value = getattr(node, field)
2478             except AttributeError:
2479                 continue
2480
2481             yield f"{'  ' * (depth+1)}{field}="
2482
2483             if isinstance(value, list):
2484                 for item in value:
2485                     if isinstance(item, ast.AST):
2486                         yield from _v(item, depth + 2)
2487
2488             elif isinstance(value, ast.AST):
2489                 yield from _v(value, depth + 2)
2490
2491             else:
2492                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2493
2494         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2495
2496     try:
2497         src_ast = ast.parse(src)
2498     except Exception as exc:
2499         major, minor = sys.version_info[:2]
2500         raise AssertionError(
2501             f"cannot use --safe with this file; failed to parse source file "
2502             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2503             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2504         )
2505
2506     try:
2507         dst_ast = ast.parse(dst)
2508     except Exception as exc:
2509         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2510         raise AssertionError(
2511             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2512             f"Please report a bug on https://github.com/ambv/black/issues.  "
2513             f"This invalid output might be helpful: {log}"
2514         ) from None
2515
2516     src_ast_str = "\n".join(_v(src_ast))
2517     dst_ast_str = "\n".join(_v(dst_ast))
2518     if src_ast_str != dst_ast_str:
2519         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2520         raise AssertionError(
2521             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2522             f"the source.  "
2523             f"Please report a bug on https://github.com/ambv/black/issues.  "
2524             f"This diff might be helpful: {log}"
2525         ) from None
2526
2527
2528 def assert_stable(src: str, dst: str, line_length: int) -> None:
2529     """Raise AssertionError if `dst` reformats differently the second time."""
2530     newdst = format_str(dst, line_length=line_length)
2531     if dst != newdst:
2532         log = dump_to_file(
2533             diff(src, dst, "source", "first pass"),
2534             diff(dst, newdst, "first pass", "second pass"),
2535         )
2536         raise AssertionError(
2537             f"INTERNAL ERROR: Black produced different code on the second pass "
2538             f"of the formatter.  "
2539             f"Please report a bug on https://github.com/ambv/black/issues.  "
2540             f"This diff might be helpful: {log}"
2541         ) from None
2542
2543
2544 def dump_to_file(*output: str) -> str:
2545     """Dump `output` to a temporary file. Return path to the file."""
2546     import tempfile
2547
2548     with tempfile.NamedTemporaryFile(
2549         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2550     ) as f:
2551         for lines in output:
2552             f.write(lines)
2553             if lines and lines[-1] != "\n":
2554                 f.write("\n")
2555     return f.name
2556
2557
2558 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2559     """Return a unified diff string between strings `a` and `b`."""
2560     import difflib
2561
2562     a_lines = [line + "\n" for line in a.split("\n")]
2563     b_lines = [line + "\n" for line in b.split("\n")]
2564     return "".join(
2565         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2566     )
2567
2568
2569 def cancel(tasks: List[asyncio.Task]) -> None:
2570     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2571     err("Aborted!")
2572     for task in tasks:
2573         task.cancel()
2574
2575
2576 def shutdown(loop: BaseEventLoop) -> None:
2577     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2578     try:
2579         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2580         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2581         if not to_cancel:
2582             return
2583
2584         for task in to_cancel:
2585             task.cancel()
2586         loop.run_until_complete(
2587             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2588         )
2589     finally:
2590         # `concurrent.futures.Future` objects cannot be cancelled once they
2591         # are already running. There might be some when the `shutdown()` happened.
2592         # Silence their logger's spew about the event loop being closed.
2593         cf_logger = logging.getLogger("concurrent.futures")
2594         cf_logger.setLevel(logging.CRITICAL)
2595         loop.close()
2596
2597
2598 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2599     """Replace `regex` with `replacement` twice on `original`.
2600
2601     This is used by string normalization to perform replaces on
2602     overlapping matches.
2603     """
2604     return regex.sub(replacement, regex.sub(replacement, original))
2605
2606
2607 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2608
2609
2610 def get_cache_file(line_length: int) -> Path:
2611     return CACHE_DIR / f"cache.{line_length}.pickle"
2612
2613
2614 def read_cache(line_length: int) -> Cache:
2615     """Read the cache if it exists and is well formed.
2616
2617     If it is not well formed, the call to write_cache later should resolve the issue.
2618     """
2619     cache_file = get_cache_file(line_length)
2620     if not cache_file.exists():
2621         return {}
2622
2623     with cache_file.open("rb") as fobj:
2624         try:
2625             cache: Cache = pickle.load(fobj)
2626         except pickle.UnpicklingError:
2627             return {}
2628
2629     return cache
2630
2631
2632 def get_cache_info(path: Path) -> CacheInfo:
2633     """Return the information used to check if a file is already formatted or not."""
2634     stat = path.stat()
2635     return stat.st_mtime, stat.st_size
2636
2637
2638 def filter_cached(
2639     cache: Cache, sources: Iterable[Path]
2640 ) -> Tuple[List[Path], List[Path]]:
2641     """Split a list of paths into two.
2642
2643     The first list contains paths of files that modified on disk or are not in the
2644     cache. The other list contains paths to non-modified files.
2645     """
2646     todo, done = [], []
2647     for src in sources:
2648         src = src.resolve()
2649         if cache.get(src) != get_cache_info(src):
2650             todo.append(src)
2651         else:
2652             done.append(src)
2653     return todo, done
2654
2655
2656 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2657     """Update the cache file."""
2658     cache_file = get_cache_file(line_length)
2659     try:
2660         if not CACHE_DIR.exists():
2661             CACHE_DIR.mkdir(parents=True)
2662         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2663         with cache_file.open("wb") as fobj:
2664             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2665     except OSError:
2666         pass
2667
2668
2669 if __name__ == "__main__":
2670     main()