]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

46754566fd4aaac239b93c0250fd9c12da2c9d39
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a4"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 TEST_DESCENDANTS = {
585     syms.test,
586     syms.lambdef,
587     syms.or_test,
588     syms.and_test,
589     syms.not_test,
590     syms.comparison,
591     syms.star_expr,
592     syms.expr,
593     syms.xor_expr,
594     syms.and_expr,
595     syms.shift_expr,
596     syms.arith_expr,
597     syms.trailer,
598     syms.term,
599     syms.power,
600 }
601 COMPREHENSION_PRIORITY = 20
602 COMMA_PRIORITY = 10
603 TERNARY_PRIORITY = 7
604 LOGIC_PRIORITY = 5
605 STRING_PRIORITY = 4
606 COMPARATOR_PRIORITY = 3
607 MATH_PRIORITY = 1
608
609
610 @dataclass
611 class BracketTracker:
612     """Keeps track of brackets on a line."""
613
614     depth: int = 0
615     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
616     delimiters: Dict[LeafID, Priority] = Factory(dict)
617     previous: Optional[Leaf] = None
618     _for_loop_variable: bool = False
619     _lambda_arguments: bool = False
620
621     def mark(self, leaf: Leaf) -> None:
622         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
623
624         All leaves receive an int `bracket_depth` field that stores how deep
625         within brackets a given leaf is. 0 means there are no enclosing brackets
626         that started on this line.
627
628         If a leaf is itself a closing bracket, it receives an `opening_bracket`
629         field that it forms a pair with. This is a one-directional link to
630         avoid reference cycles.
631
632         If a leaf is a delimiter (a token on which Black can split the line if
633         needed) and it's on depth 0, its `id()` is stored in the tracker's
634         `delimiters` field.
635         """
636         if leaf.type == token.COMMENT:
637             return
638
639         self.maybe_decrement_after_for_loop_variable(leaf)
640         self.maybe_decrement_after_lambda_arguments(leaf)
641         if leaf.type in CLOSING_BRACKETS:
642             self.depth -= 1
643             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
644             leaf.opening_bracket = opening_bracket
645         leaf.bracket_depth = self.depth
646         if self.depth == 0:
647             delim = is_split_before_delimiter(leaf, self.previous)
648             if delim and self.previous is not None:
649                 self.delimiters[id(self.previous)] = delim
650             else:
651                 delim = is_split_after_delimiter(leaf, self.previous)
652                 if delim:
653                     self.delimiters[id(leaf)] = delim
654         if leaf.type in OPENING_BRACKETS:
655             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
656             self.depth += 1
657         self.previous = leaf
658         self.maybe_increment_lambda_arguments(leaf)
659         self.maybe_increment_for_loop_variable(leaf)
660
661     def any_open_brackets(self) -> bool:
662         """Return True if there is an yet unmatched open bracket on the line."""
663         return bool(self.bracket_match)
664
665     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
666         """Return the highest priority of a delimiter found on the line.
667
668         Values are consistent with what `is_split_*_delimiter()` return.
669         Raises ValueError on no delimiters.
670         """
671         return max(v for k, v in self.delimiters.items() if k not in exclude)
672
673     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
674         """In a for loop, or comprehension, the variables are often unpacks.
675
676         To avoid splitting on the comma in this situation, increase the depth of
677         tokens between `for` and `in`.
678         """
679         if leaf.type == token.NAME and leaf.value == "for":
680             self.depth += 1
681             self._for_loop_variable = True
682             return True
683
684         return False
685
686     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
687         """See `maybe_increment_for_loop_variable` above for explanation."""
688         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
689             self.depth -= 1
690             self._for_loop_variable = False
691             return True
692
693         return False
694
695     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
696         """In a lambda expression, there might be more than one argument.
697
698         To avoid splitting on the comma in this situation, increase the depth of
699         tokens between `lambda` and `:`.
700         """
701         if leaf.type == token.NAME and leaf.value == "lambda":
702             self.depth += 1
703             self._lambda_arguments = True
704             return True
705
706         return False
707
708     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
709         """See `maybe_increment_lambda_arguments` above for explanation."""
710         if self._lambda_arguments and leaf.type == token.COLON:
711             self.depth -= 1
712             self._lambda_arguments = False
713             return True
714
715         return False
716
717     def get_open_lsqb(self) -> Optional[Leaf]:
718         """Return the most recent opening square bracket (if any)."""
719         return self.bracket_match.get((self.depth - 1, token.RSQB))
720
721
722 @dataclass
723 class Line:
724     """Holds leaves and comments. Can be printed with `str(line)`."""
725
726     depth: int = 0
727     leaves: List[Leaf] = Factory(list)
728     comments: List[Tuple[Index, Leaf]] = Factory(list)
729     bracket_tracker: BracketTracker = Factory(BracketTracker)
730     inside_brackets: bool = False
731
732     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
733         """Add a new `leaf` to the end of the line.
734
735         Unless `preformatted` is True, the `leaf` will receive a new consistent
736         whitespace prefix and metadata applied by :class:`BracketTracker`.
737         Trailing commas are maybe removed, unpacked for loop variables are
738         demoted from being delimiters.
739
740         Inline comments are put aside.
741         """
742         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
743         if not has_value:
744             return
745
746         if token.COLON == leaf.type and self.is_class_paren_empty:
747             del self.leaves[-2:]
748         if self.leaves and not preformatted:
749             # Note: at this point leaf.prefix should be empty except for
750             # imports, for which we only preserve newlines.
751             leaf.prefix += whitespace(
752                 leaf, complex_subscript=self.is_complex_subscript(leaf)
753             )
754         if self.inside_brackets or not preformatted:
755             self.bracket_tracker.mark(leaf)
756             self.maybe_remove_trailing_comma(leaf)
757         if not self.append_comment(leaf):
758             self.leaves.append(leaf)
759
760     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
761         """Like :func:`append()` but disallow invalid standalone comment structure.
762
763         Raises ValueError when any `leaf` is appended after a standalone comment
764         or when a standalone comment is not the first leaf on the line.
765         """
766         if self.bracket_tracker.depth == 0:
767             if self.is_comment:
768                 raise ValueError("cannot append to standalone comments")
769
770             if self.leaves and leaf.type == STANDALONE_COMMENT:
771                 raise ValueError(
772                     "cannot append standalone comments to a populated line"
773                 )
774
775         self.append(leaf, preformatted=preformatted)
776
777     @property
778     def is_comment(self) -> bool:
779         """Is this line a standalone comment?"""
780         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
781
782     @property
783     def is_decorator(self) -> bool:
784         """Is this line a decorator?"""
785         return bool(self) and self.leaves[0].type == token.AT
786
787     @property
788     def is_import(self) -> bool:
789         """Is this an import line?"""
790         return bool(self) and is_import(self.leaves[0])
791
792     @property
793     def is_class(self) -> bool:
794         """Is this line a class definition?"""
795         return (
796             bool(self)
797             and self.leaves[0].type == token.NAME
798             and self.leaves[0].value == "class"
799         )
800
801     @property
802     def is_def(self) -> bool:
803         """Is this a function definition? (Also returns True for async defs.)"""
804         try:
805             first_leaf = self.leaves[0]
806         except IndexError:
807             return False
808
809         try:
810             second_leaf: Optional[Leaf] = self.leaves[1]
811         except IndexError:
812             second_leaf = None
813         return (
814             (first_leaf.type == token.NAME and first_leaf.value == "def")
815             or (
816                 first_leaf.type == token.ASYNC
817                 and second_leaf is not None
818                 and second_leaf.type == token.NAME
819                 and second_leaf.value == "def"
820             )
821         )
822
823     @property
824     def is_flow_control(self) -> bool:
825         """Is this line a flow control statement?
826
827         Those are `return`, `raise`, `break`, and `continue`.
828         """
829         return (
830             bool(self)
831             and self.leaves[0].type == token.NAME
832             and self.leaves[0].value in FLOW_CONTROL
833         )
834
835     @property
836     def is_yield(self) -> bool:
837         """Is this line a yield statement?"""
838         return (
839             bool(self)
840             and self.leaves[0].type == token.NAME
841             and self.leaves[0].value == "yield"
842         )
843
844     @property
845     def is_class_paren_empty(self) -> bool:
846         """Is this a class with no base classes but using parentheses?
847
848         Those are unnecessary and should be removed.
849         """
850         return (
851             bool(self)
852             and len(self.leaves) == 4
853             and self.is_class
854             and self.leaves[2].type == token.LPAR
855             and self.leaves[2].value == "("
856             and self.leaves[3].type == token.RPAR
857             and self.leaves[3].value == ")"
858         )
859
860     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
861         """If so, needs to be split before emitting."""
862         for leaf in self.leaves:
863             if leaf.type == STANDALONE_COMMENT:
864                 if leaf.bracket_depth <= depth_limit:
865                     return True
866
867         return False
868
869     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
870         """Remove trailing comma if there is one and it's safe."""
871         if not (
872             self.leaves
873             and self.leaves[-1].type == token.COMMA
874             and closing.type in CLOSING_BRACKETS
875         ):
876             return False
877
878         if closing.type == token.RBRACE:
879             self.remove_trailing_comma()
880             return True
881
882         if closing.type == token.RSQB:
883             comma = self.leaves[-1]
884             if comma.parent and comma.parent.type == syms.listmaker:
885                 self.remove_trailing_comma()
886                 return True
887
888         # For parens let's check if it's safe to remove the comma.  If the
889         # trailing one is the only one, we might mistakenly change a tuple
890         # into a different type by removing the comma.
891         depth = closing.bracket_depth + 1
892         commas = 0
893         opening = closing.opening_bracket
894         for _opening_index, leaf in enumerate(self.leaves):
895             if leaf is opening:
896                 break
897
898         else:
899             return False
900
901         for leaf in self.leaves[_opening_index + 1 :]:
902             if leaf is closing:
903                 break
904
905             bracket_depth = leaf.bracket_depth
906             if bracket_depth == depth and leaf.type == token.COMMA:
907                 commas += 1
908                 if leaf.parent and leaf.parent.type == syms.arglist:
909                     commas += 1
910                     break
911
912         if commas > 1:
913             self.remove_trailing_comma()
914             return True
915
916         return False
917
918     def append_comment(self, comment: Leaf) -> bool:
919         """Add an inline or standalone comment to the line."""
920         if (
921             comment.type == STANDALONE_COMMENT
922             and self.bracket_tracker.any_open_brackets()
923         ):
924             comment.prefix = ""
925             return False
926
927         if comment.type != token.COMMENT:
928             return False
929
930         after = len(self.leaves) - 1
931         if after == -1:
932             comment.type = STANDALONE_COMMENT
933             comment.prefix = ""
934             return False
935
936         else:
937             self.comments.append((after, comment))
938             return True
939
940     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
941         """Generate comments that should appear directly after `leaf`."""
942         for _leaf_index, _leaf in enumerate(self.leaves):
943             if leaf is _leaf:
944                 break
945
946         else:
947             return
948
949         for index, comment_after in self.comments:
950             if _leaf_index == index:
951                 yield comment_after
952
953     def remove_trailing_comma(self) -> None:
954         """Remove the trailing comma and moves the comments attached to it."""
955         comma_index = len(self.leaves) - 1
956         for i in range(len(self.comments)):
957             comment_index, comment = self.comments[i]
958             if comment_index == comma_index:
959                 self.comments[i] = (comma_index - 1, comment)
960         self.leaves.pop()
961
962     def is_complex_subscript(self, leaf: Leaf) -> bool:
963         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
964         open_lsqb = (
965             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
966         )
967         if open_lsqb is None:
968             return False
969
970         subscript_start = open_lsqb.next_sibling
971         if (
972             isinstance(subscript_start, Node)
973             and subscript_start.type == syms.subscriptlist
974         ):
975             subscript_start = child_towards(subscript_start, leaf)
976         return subscript_start is not None and any(
977             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
978         )
979
980     def __str__(self) -> str:
981         """Render the line."""
982         if not self:
983             return "\n"
984
985         indent = "    " * self.depth
986         leaves = iter(self.leaves)
987         first = next(leaves)
988         res = f"{first.prefix}{indent}{first.value}"
989         for leaf in leaves:
990             res += str(leaf)
991         for _, comment in self.comments:
992             res += str(comment)
993         return res + "\n"
994
995     def __bool__(self) -> bool:
996         """Return True if the line has leaves or comments."""
997         return bool(self.leaves or self.comments)
998
999
1000 class UnformattedLines(Line):
1001     """Just like :class:`Line` but stores lines which aren't reformatted."""
1002
1003     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1004         """Just add a new `leaf` to the end of the lines.
1005
1006         The `preformatted` argument is ignored.
1007
1008         Keeps track of indentation `depth`, which is useful when the user
1009         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1010         """
1011         try:
1012             list(generate_comments(leaf))
1013         except FormatOn as f_on:
1014             self.leaves.append(f_on.leaf_from_consumed(leaf))
1015             raise
1016
1017         self.leaves.append(leaf)
1018         if leaf.type == token.INDENT:
1019             self.depth += 1
1020         elif leaf.type == token.DEDENT:
1021             self.depth -= 1
1022
1023     def __str__(self) -> str:
1024         """Render unformatted lines from leaves which were added with `append()`.
1025
1026         `depth` is not used for indentation in this case.
1027         """
1028         if not self:
1029             return "\n"
1030
1031         res = ""
1032         for leaf in self.leaves:
1033             res += str(leaf)
1034         return res
1035
1036     def append_comment(self, comment: Leaf) -> bool:
1037         """Not implemented in this class. Raises `NotImplementedError`."""
1038         raise NotImplementedError("Unformatted lines don't store comments separately.")
1039
1040     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1041         """Does nothing and returns False."""
1042         return False
1043
1044     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1045         """Does nothing and returns False."""
1046         return False
1047
1048
1049 @dataclass
1050 class EmptyLineTracker:
1051     """Provides a stateful method that returns the number of potential extra
1052     empty lines needed before and after the currently processed line.
1053
1054     Note: this tracker works on lines that haven't been split yet.  It assumes
1055     the prefix of the first leaf consists of optional newlines.  Those newlines
1056     are consumed by `maybe_empty_lines()` and included in the computation.
1057     """
1058     previous_line: Optional[Line] = None
1059     previous_after: int = 0
1060     previous_defs: List[int] = Factory(list)
1061
1062     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1063         """Return the number of extra empty lines before and after the `current_line`.
1064
1065         This is for separating `def`, `async def` and `class` with extra empty
1066         lines (two on module-level), as well as providing an extra empty line
1067         after flow control keywords to make them more prominent.
1068         """
1069         if isinstance(current_line, UnformattedLines):
1070             return 0, 0
1071
1072         before, after = self._maybe_empty_lines(current_line)
1073         before -= self.previous_after
1074         self.previous_after = after
1075         self.previous_line = current_line
1076         return before, after
1077
1078     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1079         max_allowed = 1
1080         if current_line.depth == 0:
1081             max_allowed = 2
1082         if current_line.leaves:
1083             # Consume the first leaf's extra newlines.
1084             first_leaf = current_line.leaves[0]
1085             before = first_leaf.prefix.count("\n")
1086             before = min(before, max_allowed)
1087             first_leaf.prefix = ""
1088         else:
1089             before = 0
1090         depth = current_line.depth
1091         while self.previous_defs and self.previous_defs[-1] >= depth:
1092             self.previous_defs.pop()
1093             before = 1 if depth else 2
1094         is_decorator = current_line.is_decorator
1095         if is_decorator or current_line.is_def or current_line.is_class:
1096             if not is_decorator:
1097                 self.previous_defs.append(depth)
1098             if self.previous_line is None:
1099                 # Don't insert empty lines before the first line in the file.
1100                 return 0, 0
1101
1102             if self.previous_line.is_decorator:
1103                 return 0, 0
1104
1105             if (
1106                 self.previous_line.is_comment
1107                 and self.previous_line.depth == current_line.depth
1108                 and before == 0
1109             ):
1110                 return 0, 0
1111
1112             newlines = 2
1113             if current_line.depth:
1114                 newlines -= 1
1115             return newlines, 0
1116
1117         if (
1118             self.previous_line
1119             and self.previous_line.is_import
1120             and not current_line.is_import
1121             and depth == self.previous_line.depth
1122         ):
1123             return (before or 1), 0
1124
1125         return before, 0
1126
1127
1128 @dataclass
1129 class LineGenerator(Visitor[Line]):
1130     """Generates reformatted Line objects.  Empty lines are not emitted.
1131
1132     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1133     in ways that will no longer stringify to valid Python code on the tree.
1134     """
1135     current_line: Line = Factory(Line)
1136
1137     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1138         """Generate a line.
1139
1140         If the line is empty, only emit if it makes sense.
1141         If the line is too long, split it first and then generate.
1142
1143         If any lines were generated, set up a new current_line.
1144         """
1145         if not self.current_line:
1146             if self.current_line.__class__ == type:
1147                 self.current_line.depth += indent
1148             else:
1149                 self.current_line = type(depth=self.current_line.depth + indent)
1150             return  # Line is empty, don't emit. Creating a new one unnecessary.
1151
1152         complete_line = self.current_line
1153         self.current_line = type(depth=complete_line.depth + indent)
1154         yield complete_line
1155
1156     def visit(self, node: LN) -> Iterator[Line]:
1157         """Main method to visit `node` and its children.
1158
1159         Yields :class:`Line` objects.
1160         """
1161         if isinstance(self.current_line, UnformattedLines):
1162             # File contained `# fmt: off`
1163             yield from self.visit_unformatted(node)
1164
1165         else:
1166             yield from super().visit(node)
1167
1168     def visit_default(self, node: LN) -> Iterator[Line]:
1169         """Default `visit_*()` implementation. Recurses to children of `node`."""
1170         if isinstance(node, Leaf):
1171             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1172             try:
1173                 for comment in generate_comments(node):
1174                     if any_open_brackets:
1175                         # any comment within brackets is subject to splitting
1176                         self.current_line.append(comment)
1177                     elif comment.type == token.COMMENT:
1178                         # regular trailing comment
1179                         self.current_line.append(comment)
1180                         yield from self.line()
1181
1182                     else:
1183                         # regular standalone comment
1184                         yield from self.line()
1185
1186                         self.current_line.append(comment)
1187                         yield from self.line()
1188
1189             except FormatOff as f_off:
1190                 f_off.trim_prefix(node)
1191                 yield from self.line(type=UnformattedLines)
1192                 yield from self.visit(node)
1193
1194             except FormatOn as f_on:
1195                 # This only happens here if somebody says "fmt: on" multiple
1196                 # times in a row.
1197                 f_on.trim_prefix(node)
1198                 yield from self.visit_default(node)
1199
1200             else:
1201                 normalize_prefix(node, inside_brackets=any_open_brackets)
1202                 if node.type == token.STRING:
1203                     normalize_string_quotes(node)
1204                 if node.type not in WHITESPACE:
1205                     self.current_line.append(node)
1206         yield from super().visit_default(node)
1207
1208     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1209         """Increase indentation level, maybe yield a line."""
1210         # In blib2to3 INDENT never holds comments.
1211         yield from self.line(+1)
1212         yield from self.visit_default(node)
1213
1214     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1215         """Decrease indentation level, maybe yield a line."""
1216         # The current line might still wait for trailing comments.  At DEDENT time
1217         # there won't be any (they would be prefixes on the preceding NEWLINE).
1218         # Emit the line then.
1219         yield from self.line()
1220
1221         # While DEDENT has no value, its prefix may contain standalone comments
1222         # that belong to the current indentation level.  Get 'em.
1223         yield from self.visit_default(node)
1224
1225         # Finally, emit the dedent.
1226         yield from self.line(-1)
1227
1228     def visit_stmt(
1229         self, node: Node, keywords: Set[str], parens: Set[str]
1230     ) -> Iterator[Line]:
1231         """Visit a statement.
1232
1233         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1234         `def`, `with`, `class`, and `assert`.
1235
1236         The relevant Python language `keywords` for a given statement will be
1237         NAME leaves within it. This methods puts those on a separate line.
1238
1239         `parens` holds pairs of nodes where invisible parentheses should be put.
1240         Keys hold nodes after which opening parentheses should be put, values
1241         hold nodes before which closing parentheses should be put.
1242         """
1243         normalize_invisible_parens(node, parens_after=parens)
1244         for child in node.children:
1245             if child.type == token.NAME and child.value in keywords:  # type: ignore
1246                 yield from self.line()
1247
1248             yield from self.visit(child)
1249
1250     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1251         """Visit a statement without nested statements."""
1252         is_suite_like = node.parent and node.parent.type in STATEMENT
1253         if is_suite_like:
1254             yield from self.line(+1)
1255             yield from self.visit_default(node)
1256             yield from self.line(-1)
1257
1258         else:
1259             yield from self.line()
1260             yield from self.visit_default(node)
1261
1262     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1263         """Visit `async def`, `async for`, `async with`."""
1264         yield from self.line()
1265
1266         children = iter(node.children)
1267         for child in children:
1268             yield from self.visit(child)
1269
1270             if child.type == token.ASYNC:
1271                 break
1272
1273         internal_stmt = next(children)
1274         for child in internal_stmt.children:
1275             yield from self.visit(child)
1276
1277     def visit_decorators(self, node: Node) -> Iterator[Line]:
1278         """Visit decorators."""
1279         for child in node.children:
1280             yield from self.line()
1281             yield from self.visit(child)
1282
1283     def visit_import_from(self, node: Node) -> Iterator[Line]:
1284         """Visit import_from and maybe put invisible parentheses.
1285
1286         This is separate from `visit_stmt` because import statements don't
1287         support arbitrary atoms and thus handling of parentheses is custom.
1288         """
1289         check_lpar = False
1290         for index, child in enumerate(node.children):
1291             if check_lpar:
1292                 if child.type == token.LPAR:
1293                     # make parentheses invisible
1294                     child.value = ""  # type: ignore
1295                     node.children[-1].value = ""  # type: ignore
1296                 else:
1297                     # insert invisible parentheses
1298                     node.insert_child(index, Leaf(token.LPAR, ""))
1299                     node.append_child(Leaf(token.RPAR, ""))
1300                 break
1301
1302             check_lpar = (
1303                 child.type == token.NAME and child.value == "import"  # type: ignore
1304             )
1305
1306         for child in node.children:
1307             yield from self.visit(child)
1308
1309     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1310         """Remove a semicolon and put the other statement on a separate line."""
1311         yield from self.line()
1312
1313     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1314         """End of file. Process outstanding comments and end with a newline."""
1315         yield from self.visit_default(leaf)
1316         yield from self.line()
1317
1318     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1319         """Used when file contained a `# fmt: off`."""
1320         if isinstance(node, Node):
1321             for child in node.children:
1322                 yield from self.visit(child)
1323
1324         else:
1325             try:
1326                 self.current_line.append(node)
1327             except FormatOn as f_on:
1328                 f_on.trim_prefix(node)
1329                 yield from self.line()
1330                 yield from self.visit(node)
1331
1332             if node.type == token.ENDMARKER:
1333                 # somebody decided not to put a final `# fmt: on`
1334                 yield from self.line()
1335
1336     def __attrs_post_init__(self) -> None:
1337         """You are in a twisty little maze of passages."""
1338         v = self.visit_stmt
1339         Ø: Set[str] = set()
1340         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1341         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1342         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1343         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1344         self.visit_try_stmt = partial(
1345             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1346         )
1347         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1348         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1349         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1350         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1351         self.visit_async_funcdef = self.visit_async_stmt
1352         self.visit_decorated = self.visit_decorators
1353
1354
1355 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1356 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1357 OPENING_BRACKETS = set(BRACKET.keys())
1358 CLOSING_BRACKETS = set(BRACKET.values())
1359 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1360 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1361
1362
1363 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1364     """Return whitespace prefix if needed for the given `leaf`.
1365
1366     `complex_subscript` signals whether the given leaf is part of a subscription
1367     which has non-trivial arguments, like arithmetic expressions or function calls.
1368     """
1369     NO = ""
1370     SPACE = " "
1371     DOUBLESPACE = "  "
1372     t = leaf.type
1373     p = leaf.parent
1374     v = leaf.value
1375     if t in ALWAYS_NO_SPACE:
1376         return NO
1377
1378     if t == token.COMMENT:
1379         return DOUBLESPACE
1380
1381     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1382     if (
1383         t == token.COLON
1384         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1385     ):
1386         return NO
1387
1388     prev = leaf.prev_sibling
1389     if not prev:
1390         prevp = preceding_leaf(p)
1391         if not prevp or prevp.type in OPENING_BRACKETS:
1392             return NO
1393
1394         if t == token.COLON:
1395             if prevp.type == token.COLON:
1396                 return NO
1397
1398             elif prevp.type != token.COMMA and not complex_subscript:
1399                 return NO
1400
1401             return SPACE
1402
1403         if prevp.type == token.EQUAL:
1404             if prevp.parent:
1405                 if prevp.parent.type in {
1406                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1407                 }:
1408                     return NO
1409
1410                 elif prevp.parent.type == syms.typedargslist:
1411                     # A bit hacky: if the equal sign has whitespace, it means we
1412                     # previously found it's a typed argument.  So, we're using
1413                     # that, too.
1414                     return prevp.prefix
1415
1416         elif prevp.type in STARS:
1417             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1418                 return NO
1419
1420         elif prevp.type == token.COLON:
1421             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1422                 return SPACE if complex_subscript else NO
1423
1424         elif (
1425             prevp.parent
1426             and prevp.parent.type == syms.factor
1427             and prevp.type in MATH_OPERATORS
1428         ):
1429             return NO
1430
1431         elif (
1432             prevp.type == token.RIGHTSHIFT
1433             and prevp.parent
1434             and prevp.parent.type == syms.shift_expr
1435             and prevp.prev_sibling
1436             and prevp.prev_sibling.type == token.NAME
1437             and prevp.prev_sibling.value == "print"  # type: ignore
1438         ):
1439             # Python 2 print chevron
1440             return NO
1441
1442     elif prev.type in OPENING_BRACKETS:
1443         return NO
1444
1445     if p.type in {syms.parameters, syms.arglist}:
1446         # untyped function signatures or calls
1447         if not prev or prev.type != token.COMMA:
1448             return NO
1449
1450     elif p.type == syms.varargslist:
1451         # lambdas
1452         if prev and prev.type != token.COMMA:
1453             return NO
1454
1455     elif p.type == syms.typedargslist:
1456         # typed function signatures
1457         if not prev:
1458             return NO
1459
1460         if t == token.EQUAL:
1461             if prev.type != syms.tname:
1462                 return NO
1463
1464         elif prev.type == token.EQUAL:
1465             # A bit hacky: if the equal sign has whitespace, it means we
1466             # previously found it's a typed argument.  So, we're using that, too.
1467             return prev.prefix
1468
1469         elif prev.type != token.COMMA:
1470             return NO
1471
1472     elif p.type == syms.tname:
1473         # type names
1474         if not prev:
1475             prevp = preceding_leaf(p)
1476             if not prevp or prevp.type != token.COMMA:
1477                 return NO
1478
1479     elif p.type == syms.trailer:
1480         # attributes and calls
1481         if t == token.LPAR or t == token.RPAR:
1482             return NO
1483
1484         if not prev:
1485             if t == token.DOT:
1486                 prevp = preceding_leaf(p)
1487                 if not prevp or prevp.type != token.NUMBER:
1488                     return NO
1489
1490             elif t == token.LSQB:
1491                 return NO
1492
1493         elif prev.type != token.COMMA:
1494             return NO
1495
1496     elif p.type == syms.argument:
1497         # single argument
1498         if t == token.EQUAL:
1499             return NO
1500
1501         if not prev:
1502             prevp = preceding_leaf(p)
1503             if not prevp or prevp.type == token.LPAR:
1504                 return NO
1505
1506         elif prev.type in {token.EQUAL} | STARS:
1507             return NO
1508
1509     elif p.type == syms.decorator:
1510         # decorators
1511         return NO
1512
1513     elif p.type == syms.dotted_name:
1514         if prev:
1515             return NO
1516
1517         prevp = preceding_leaf(p)
1518         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1519             return NO
1520
1521     elif p.type == syms.classdef:
1522         if t == token.LPAR:
1523             return NO
1524
1525         if prev and prev.type == token.LPAR:
1526             return NO
1527
1528     elif p.type in {syms.subscript, syms.sliceop}:
1529         # indexing
1530         if not prev:
1531             assert p.parent is not None, "subscripts are always parented"
1532             if p.parent.type == syms.subscriptlist:
1533                 return SPACE
1534
1535             return NO
1536
1537         elif not complex_subscript:
1538             return NO
1539
1540     elif p.type == syms.atom:
1541         if prev and t == token.DOT:
1542             # dots, but not the first one.
1543             return NO
1544
1545     elif p.type == syms.dictsetmaker:
1546         # dict unpacking
1547         if prev and prev.type == token.DOUBLESTAR:
1548             return NO
1549
1550     elif p.type in {syms.factor, syms.star_expr}:
1551         # unary ops
1552         if not prev:
1553             prevp = preceding_leaf(p)
1554             if not prevp or prevp.type in OPENING_BRACKETS:
1555                 return NO
1556
1557             prevp_parent = prevp.parent
1558             assert prevp_parent is not None
1559             if (
1560                 prevp.type == token.COLON
1561                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1562             ):
1563                 return NO
1564
1565             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1566                 return NO
1567
1568         elif t == token.NAME or t == token.NUMBER:
1569             return NO
1570
1571     elif p.type == syms.import_from:
1572         if t == token.DOT:
1573             if prev and prev.type == token.DOT:
1574                 return NO
1575
1576         elif t == token.NAME:
1577             if v == "import":
1578                 return SPACE
1579
1580             if prev and prev.type == token.DOT:
1581                 return NO
1582
1583     elif p.type == syms.sliceop:
1584         return NO
1585
1586     return SPACE
1587
1588
1589 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1590     """Return the first leaf that precedes `node`, if any."""
1591     while node:
1592         res = node.prev_sibling
1593         if res:
1594             if isinstance(res, Leaf):
1595                 return res
1596
1597             try:
1598                 return list(res.leaves())[-1]
1599
1600             except IndexError:
1601                 return None
1602
1603         node = node.parent
1604     return None
1605
1606
1607 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1608     """Return the child of `ancestor` that contains `descendant`."""
1609     node: Optional[LN] = descendant
1610     while node and node.parent != ancestor:
1611         node = node.parent
1612     return node
1613
1614
1615 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1616     """Return the priority of the `leaf` delimiter, given a line break after it.
1617
1618     The delimiter priorities returned here are from those delimiters that would
1619     cause a line break after themselves.
1620
1621     Higher numbers are higher priority.
1622     """
1623     if leaf.type == token.COMMA:
1624         return COMMA_PRIORITY
1625
1626     return 0
1627
1628
1629 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1630     """Return the priority of the `leaf` delimiter, given a line before after it.
1631
1632     The delimiter priorities returned here are from those delimiters that would
1633     cause a line break before themselves.
1634
1635     Higher numbers are higher priority.
1636     """
1637     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1638         # * and ** might also be MATH_OPERATORS but in this case they are not.
1639         # Don't treat them as a delimiter.
1640         return 0
1641
1642     if (
1643         leaf.type in MATH_OPERATORS
1644         and leaf.parent
1645         and leaf.parent.type not in {syms.factor, syms.star_expr}
1646     ):
1647         return MATH_PRIORITY
1648
1649     if leaf.type in COMPARATORS:
1650         return COMPARATOR_PRIORITY
1651
1652     if (
1653         leaf.type == token.STRING
1654         and previous is not None
1655         and previous.type == token.STRING
1656     ):
1657         return STRING_PRIORITY
1658
1659     if (
1660         leaf.type == token.NAME
1661         and leaf.value == "for"
1662         and leaf.parent
1663         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1664     ):
1665         return COMPREHENSION_PRIORITY
1666
1667     if (
1668         leaf.type == token.NAME
1669         and leaf.value == "if"
1670         and leaf.parent
1671         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1672     ):
1673         return COMPREHENSION_PRIORITY
1674
1675     if (
1676         leaf.type == token.NAME
1677         and leaf.value in {"if", "else"}
1678         and leaf.parent
1679         and leaf.parent.type == syms.test
1680     ):
1681         return TERNARY_PRIORITY
1682
1683     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1684         return LOGIC_PRIORITY
1685
1686     return 0
1687
1688
1689 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1690     """Clean the prefix of the `leaf` and generate comments from it, if any.
1691
1692     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1693     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1694     move because it does away with modifying the grammar to include all the
1695     possible places in which comments can be placed.
1696
1697     The sad consequence for us though is that comments don't "belong" anywhere.
1698     This is why this function generates simple parentless Leaf objects for
1699     comments.  We simply don't know what the correct parent should be.
1700
1701     No matter though, we can live without this.  We really only need to
1702     differentiate between inline and standalone comments.  The latter don't
1703     share the line with any code.
1704
1705     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1706     are emitted with a fake STANDALONE_COMMENT token identifier.
1707     """
1708     p = leaf.prefix
1709     if not p:
1710         return
1711
1712     if "#" not in p:
1713         return
1714
1715     consumed = 0
1716     nlines = 0
1717     for index, line in enumerate(p.split("\n")):
1718         consumed += len(line) + 1  # adding the length of the split '\n'
1719         line = line.lstrip()
1720         if not line:
1721             nlines += 1
1722         if not line.startswith("#"):
1723             continue
1724
1725         if index == 0 and leaf.type != token.ENDMARKER:
1726             comment_type = token.COMMENT  # simple trailing comment
1727         else:
1728             comment_type = STANDALONE_COMMENT
1729         comment = make_comment(line)
1730         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1731
1732         if comment in {"# fmt: on", "# yapf: enable"}:
1733             raise FormatOn(consumed)
1734
1735         if comment in {"# fmt: off", "# yapf: disable"}:
1736             if comment_type == STANDALONE_COMMENT:
1737                 raise FormatOff(consumed)
1738
1739             prev = preceding_leaf(leaf)
1740             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1741                 raise FormatOff(consumed)
1742
1743         nlines = 0
1744
1745
1746 def make_comment(content: str) -> str:
1747     """Return a consistently formatted comment from the given `content` string.
1748
1749     All comments (except for "##", "#!", "#:") should have a single space between
1750     the hash sign and the content.
1751
1752     If `content` didn't start with a hash sign, one is provided.
1753     """
1754     content = content.rstrip()
1755     if not content:
1756         return "#"
1757
1758     if content[0] == "#":
1759         content = content[1:]
1760     if content and content[0] not in " !:#":
1761         content = " " + content
1762     return "#" + content
1763
1764
1765 def split_line(
1766     line: Line, line_length: int, inner: bool = False, py36: bool = False
1767 ) -> Iterator[Line]:
1768     """Split a `line` into potentially many lines.
1769
1770     They should fit in the allotted `line_length` but might not be able to.
1771     `inner` signifies that there were a pair of brackets somewhere around the
1772     current `line`, possibly transitively. This means we can fallback to splitting
1773     by delimiters if the LHS/RHS don't yield any results.
1774
1775     If `py36` is True, splitting may generate syntax that is only compatible
1776     with Python 3.6 and later.
1777     """
1778     if isinstance(line, UnformattedLines) or line.is_comment:
1779         yield line
1780         return
1781
1782     line_str = str(line).strip("\n")
1783     if (
1784         len(line_str) <= line_length
1785         and "\n" not in line_str  # multiline strings
1786         and not line.contains_standalone_comments()
1787     ):
1788         yield line
1789         return
1790
1791     split_funcs: List[SplitFunc]
1792     if line.is_def:
1793         split_funcs = [left_hand_split]
1794     elif line.is_import:
1795         split_funcs = [explode_split]
1796     elif line.inside_brackets:
1797         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1798     else:
1799         split_funcs = [right_hand_split]
1800     for split_func in split_funcs:
1801         # We are accumulating lines in `result` because we might want to abort
1802         # mission and return the original line in the end, or attempt a different
1803         # split altogether.
1804         result: List[Line] = []
1805         try:
1806             for l in split_func(line, py36):
1807                 if str(l).strip("\n") == line_str:
1808                     raise CannotSplit("Split function returned an unchanged result")
1809
1810                 result.extend(
1811                     split_line(l, line_length=line_length, inner=True, py36=py36)
1812                 )
1813         except CannotSplit as cs:
1814             continue
1815
1816         else:
1817             yield from result
1818             break
1819
1820     else:
1821         yield line
1822
1823
1824 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1825     """Split line into many lines, starting with the first matching bracket pair.
1826
1827     Note: this usually looks weird, only use this for function definitions.
1828     Prefer RHS otherwise.
1829     """
1830     head = Line(depth=line.depth)
1831     body = Line(depth=line.depth + 1, inside_brackets=True)
1832     tail = Line(depth=line.depth)
1833     tail_leaves: List[Leaf] = []
1834     body_leaves: List[Leaf] = []
1835     head_leaves: List[Leaf] = []
1836     current_leaves = head_leaves
1837     matching_bracket = None
1838     for leaf in line.leaves:
1839         if (
1840             current_leaves is body_leaves
1841             and leaf.type in CLOSING_BRACKETS
1842             and leaf.opening_bracket is matching_bracket
1843         ):
1844             current_leaves = tail_leaves if body_leaves else head_leaves
1845         current_leaves.append(leaf)
1846         if current_leaves is head_leaves:
1847             if leaf.type in OPENING_BRACKETS:
1848                 matching_bracket = leaf
1849                 current_leaves = body_leaves
1850     # Since body is a new indent level, remove spurious leading whitespace.
1851     if body_leaves:
1852         normalize_prefix(body_leaves[0], inside_brackets=True)
1853     # Build the new lines.
1854     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1855         for leaf in leaves:
1856             result.append(leaf, preformatted=True)
1857             for comment_after in line.comments_after(leaf):
1858                 result.append(comment_after, preformatted=True)
1859     bracket_split_succeeded_or_raise(head, body, tail)
1860     for result in (head, body, tail):
1861         if result:
1862             yield result
1863
1864
1865 def right_hand_split(
1866     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1867 ) -> Iterator[Line]:
1868     """Split line into many lines, starting with the last matching bracket pair."""
1869     head = Line(depth=line.depth)
1870     body = Line(depth=line.depth + 1, inside_brackets=True)
1871     tail = Line(depth=line.depth)
1872     tail_leaves: List[Leaf] = []
1873     body_leaves: List[Leaf] = []
1874     head_leaves: List[Leaf] = []
1875     current_leaves = tail_leaves
1876     opening_bracket = None
1877     closing_bracket = None
1878     for leaf in reversed(line.leaves):
1879         if current_leaves is body_leaves:
1880             if leaf is opening_bracket:
1881                 current_leaves = head_leaves if body_leaves else tail_leaves
1882         current_leaves.append(leaf)
1883         if current_leaves is tail_leaves:
1884             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1885                 opening_bracket = leaf.opening_bracket
1886                 closing_bracket = leaf
1887                 current_leaves = body_leaves
1888     tail_leaves.reverse()
1889     body_leaves.reverse()
1890     head_leaves.reverse()
1891     # Since body is a new indent level, remove spurious leading whitespace.
1892     if body_leaves:
1893         normalize_prefix(body_leaves[0], inside_brackets=True)
1894     elif not head_leaves:
1895         # No `head` and no `body` means the split failed. `tail` has all content.
1896         raise CannotSplit("No brackets found")
1897
1898     # Build the new lines.
1899     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1900         for leaf in leaves:
1901             result.append(leaf, preformatted=True)
1902             for comment_after in line.comments_after(leaf):
1903                 result.append(comment_after, preformatted=True)
1904     bracket_split_succeeded_or_raise(head, body, tail)
1905     assert opening_bracket and closing_bracket
1906     if (
1907         opening_bracket.type == token.LPAR
1908         and not opening_bracket.value
1909         and closing_bracket.type == token.RPAR
1910         and not closing_bracket.value
1911     ):
1912         # These parens were optional. If there aren't any delimiters or standalone
1913         # comments in the body, they were unnecessary and another split without
1914         # them should be attempted.
1915         if not (
1916             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1917         ):
1918             omit = {id(closing_bracket), *omit}
1919             yield from right_hand_split(line, py36=py36, omit=omit)
1920             return
1921
1922     ensure_visible(opening_bracket)
1923     ensure_visible(closing_bracket)
1924     for result in (head, body, tail):
1925         if result:
1926             yield result
1927
1928
1929 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1930     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1931
1932     Do nothing otherwise.
1933
1934     A left- or right-hand split is based on a pair of brackets. Content before
1935     (and including) the opening bracket is left on one line, content inside the
1936     brackets is put on a separate line, and finally content starting with and
1937     following the closing bracket is put on a separate line.
1938
1939     Those are called `head`, `body`, and `tail`, respectively. If the split
1940     produced the same line (all content in `head`) or ended up with an empty `body`
1941     and the `tail` is just the closing bracket, then it's considered failed.
1942     """
1943     tail_len = len(str(tail).strip())
1944     if not body:
1945         if tail_len == 0:
1946             raise CannotSplit("Splitting brackets produced the same line")
1947
1948         elif tail_len < 3:
1949             raise CannotSplit(
1950                 f"Splitting brackets on an empty body to save "
1951                 f"{tail_len} characters is not worth it"
1952             )
1953
1954
1955 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1956     """Normalize prefix of the first leaf in every line returned by `split_func`.
1957
1958     This is a decorator over relevant split functions.
1959     """
1960
1961     @wraps(split_func)
1962     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1963         for l in split_func(line, py36):
1964             normalize_prefix(l.leaves[0], inside_brackets=True)
1965             yield l
1966
1967     return split_wrapper
1968
1969
1970 @dont_increase_indentation
1971 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1972     """Split according to delimiters of the highest priority.
1973
1974     If `py36` is True, the split will add trailing commas also in function
1975     signatures that contain `*` and `**`.
1976     """
1977     try:
1978         last_leaf = line.leaves[-1]
1979     except IndexError:
1980         raise CannotSplit("Line empty")
1981
1982     delimiters = line.bracket_tracker.delimiters
1983     try:
1984         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1985             exclude={id(last_leaf)}
1986         )
1987     except ValueError:
1988         raise CannotSplit("No delimiters found")
1989
1990     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1991     lowest_depth = sys.maxsize
1992     trailing_comma_safe = True
1993
1994     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1995         """Append `leaf` to current line or to new line if appending impossible."""
1996         nonlocal current_line
1997         try:
1998             current_line.append_safe(leaf, preformatted=True)
1999         except ValueError as ve:
2000             yield current_line
2001
2002             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2003             current_line.append(leaf)
2004
2005     for leaf in line.leaves:
2006         yield from append_to_line(leaf)
2007
2008         for comment_after in line.comments_after(leaf):
2009             yield from append_to_line(comment_after)
2010
2011         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2012         if (
2013             leaf.bracket_depth == lowest_depth
2014             and is_vararg(leaf, within=VARARGS_PARENTS)
2015         ):
2016             trailing_comma_safe = trailing_comma_safe and py36
2017         leaf_priority = delimiters.get(id(leaf))
2018         if leaf_priority == delimiter_priority:
2019             yield current_line
2020
2021             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2022     if current_line:
2023         if (
2024             trailing_comma_safe
2025             and delimiter_priority == COMMA_PRIORITY
2026             and current_line.leaves[-1].type != token.COMMA
2027             and current_line.leaves[-1].type != STANDALONE_COMMENT
2028         ):
2029             current_line.append(Leaf(token.COMMA, ","))
2030         yield current_line
2031
2032
2033 @dont_increase_indentation
2034 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2035     """Split standalone comments from the rest of the line."""
2036     if not line.contains_standalone_comments(0):
2037         raise CannotSplit("Line does not have any standalone comments")
2038
2039     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2040
2041     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2042         """Append `leaf` to current line or to new line if appending impossible."""
2043         nonlocal current_line
2044         try:
2045             current_line.append_safe(leaf, preformatted=True)
2046         except ValueError as ve:
2047             yield current_line
2048
2049             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2050             current_line.append(leaf)
2051
2052     for leaf in line.leaves:
2053         yield from append_to_line(leaf)
2054
2055         for comment_after in line.comments_after(leaf):
2056             yield from append_to_line(comment_after)
2057
2058     if current_line:
2059         yield current_line
2060
2061
2062 def explode_split(
2063     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2064 ) -> Iterator[Line]:
2065     """Split by rightmost bracket and immediately split contents by a delimiter."""
2066     new_lines = list(right_hand_split(line, py36, omit))
2067     if len(new_lines) != 3:
2068         yield from new_lines
2069         return
2070
2071     yield new_lines[0]
2072
2073     try:
2074         yield from delimiter_split(new_lines[1], py36)
2075
2076     except CannotSplit:
2077         yield new_lines[1]
2078
2079     yield new_lines[2]
2080
2081
2082 def is_import(leaf: Leaf) -> bool:
2083     """Return True if the given leaf starts an import statement."""
2084     p = leaf.parent
2085     t = leaf.type
2086     v = leaf.value
2087     return bool(
2088         t == token.NAME
2089         and (
2090             (v == "import" and p and p.type == syms.import_name)
2091             or (v == "from" and p and p.type == syms.import_from)
2092         )
2093     )
2094
2095
2096 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2097     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2098     else.
2099
2100     Note: don't use backslashes for formatting or you'll lose your voting rights.
2101     """
2102     if not inside_brackets:
2103         spl = leaf.prefix.split("#")
2104         if "\\" not in spl[0]:
2105             nl_count = spl[-1].count("\n")
2106             if len(spl) > 1:
2107                 nl_count -= 1
2108             leaf.prefix = "\n" * nl_count
2109             return
2110
2111     leaf.prefix = ""
2112
2113
2114 def normalize_string_quotes(leaf: Leaf) -> None:
2115     """Prefer double quotes but only if it doesn't cause more escaping.
2116
2117     Adds or removes backslashes as appropriate. Doesn't parse and fix
2118     strings nested in f-strings (yet).
2119
2120     Note: Mutates its argument.
2121     """
2122     value = leaf.value.lstrip("furbFURB")
2123     if value[:3] == '"""':
2124         return
2125
2126     elif value[:3] == "'''":
2127         orig_quote = "'''"
2128         new_quote = '"""'
2129     elif value[0] == '"':
2130         orig_quote = '"'
2131         new_quote = "'"
2132     else:
2133         orig_quote = "'"
2134         new_quote = '"'
2135     first_quote_pos = leaf.value.find(orig_quote)
2136     if first_quote_pos == -1:
2137         return  # There's an internal error
2138
2139     prefix = leaf.value[:first_quote_pos]
2140     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2141     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2142     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2143     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2144     if "r" in prefix.casefold():
2145         if unescaped_new_quote.search(body):
2146             # There's at least one unescaped new_quote in this raw string
2147             # so converting is impossible
2148             return
2149
2150         # Do not introduce or remove backslashes in raw strings
2151         new_body = body
2152     else:
2153         # remove unnecessary quotes
2154         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2155         if body != new_body:
2156             # Consider the string without unnecessary quotes as the original
2157             body = new_body
2158             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2159         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2160         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2161     if new_quote == '"""' and new_body[-1] == '"':
2162         # edge case:
2163         new_body = new_body[:-1] + '\\"'
2164     orig_escape_count = body.count("\\")
2165     new_escape_count = new_body.count("\\")
2166     if new_escape_count > orig_escape_count:
2167         return  # Do not introduce more escaping
2168
2169     if new_escape_count == orig_escape_count and orig_quote == '"':
2170         return  # Prefer double quotes
2171
2172     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2173
2174
2175 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2176     """Make existing optional parentheses invisible or create new ones.
2177
2178     Standardizes on visible parentheses for single-element tuples, and keeps
2179     existing visible parentheses for other tuples and generator expressions.
2180     """
2181     check_lpar = False
2182     for child in list(node.children):
2183         if check_lpar:
2184             if child.type == syms.atom:
2185                 if not (
2186                     is_empty_tuple(child)
2187                     or is_one_tuple(child)
2188                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2189                 ):
2190                     first = child.children[0]
2191                     last = child.children[-1]
2192                     if first.type == token.LPAR and last.type == token.RPAR:
2193                         # make parentheses invisible
2194                         first.value = ""  # type: ignore
2195                         last.value = ""  # type: ignore
2196             elif is_one_tuple(child):
2197                 # wrap child in visible parentheses
2198                 lpar = Leaf(token.LPAR, "(")
2199                 rpar = Leaf(token.RPAR, ")")
2200                 index = child.remove() or 0
2201                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2202             else:
2203                 # wrap child in invisible parentheses
2204                 lpar = Leaf(token.LPAR, "")
2205                 rpar = Leaf(token.RPAR, "")
2206                 index = child.remove() or 0
2207                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2208
2209         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2210
2211
2212 def is_empty_tuple(node: LN) -> bool:
2213     """Return True if `node` holds an empty tuple."""
2214     return (
2215         node.type == syms.atom
2216         and len(node.children) == 2
2217         and node.children[0].type == token.LPAR
2218         and node.children[1].type == token.RPAR
2219     )
2220
2221
2222 def is_one_tuple(node: LN) -> bool:
2223     """Return True if `node` holds a tuple with one element, with or without parens."""
2224     if node.type == syms.atom:
2225         if len(node.children) != 3:
2226             return False
2227
2228         lpar, gexp, rpar = node.children
2229         if not (
2230             lpar.type == token.LPAR
2231             and gexp.type == syms.testlist_gexp
2232             and rpar.type == token.RPAR
2233         ):
2234             return False
2235
2236         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2237
2238     return (
2239         node.type in IMPLICIT_TUPLE
2240         and len(node.children) == 2
2241         and node.children[1].type == token.COMMA
2242     )
2243
2244
2245 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2246     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2247
2248     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2249     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2250     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2251     generalizations (PEP 448).
2252     """
2253     if leaf.type not in STARS or not leaf.parent:
2254         return False
2255
2256     p = leaf.parent
2257     if p.type == syms.star_expr:
2258         # Star expressions are also used as assignment targets in extended
2259         # iterable unpacking (PEP 3132).  See what its parent is instead.
2260         if not p.parent:
2261             return False
2262
2263         p = p.parent
2264
2265     return p.type in within
2266
2267
2268 def max_delimiter_priority_in_atom(node: LN) -> int:
2269     """Return maximum delimiter priority inside `node`.
2270
2271     This is specific to atoms with contents contained in a pair of parentheses.
2272     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2273     """
2274     if node.type != syms.atom:
2275         return 0
2276
2277     first = node.children[0]
2278     last = node.children[-1]
2279     if not (first.type == token.LPAR and last.type == token.RPAR):
2280         return 0
2281
2282     bt = BracketTracker()
2283     for c in node.children[1:-1]:
2284         if isinstance(c, Leaf):
2285             bt.mark(c)
2286         else:
2287             for leaf in c.leaves():
2288                 bt.mark(leaf)
2289     try:
2290         return bt.max_delimiter_priority()
2291
2292     except ValueError:
2293         return 0
2294
2295
2296 def ensure_visible(leaf: Leaf) -> None:
2297     """Make sure parentheses are visible.
2298
2299     They could be invisible as part of some statements (see
2300     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2301     """
2302     if leaf.type == token.LPAR:
2303         leaf.value = "("
2304     elif leaf.type == token.RPAR:
2305         leaf.value = ")"
2306
2307
2308 def is_python36(node: Node) -> bool:
2309     """Return True if the current file is using Python 3.6+ features.
2310
2311     Currently looking for:
2312     - f-strings; and
2313     - trailing commas after * or ** in function signatures.
2314     """
2315     for n in node.pre_order():
2316         if n.type == token.STRING:
2317             value_head = n.value[:2]  # type: ignore
2318             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2319                 return True
2320
2321         elif (
2322             n.type == syms.typedargslist
2323             and n.children
2324             and n.children[-1].type == token.COMMA
2325         ):
2326             for ch in n.children:
2327                 if ch.type in STARS:
2328                     return True
2329
2330     return False
2331
2332
2333 PYTHON_EXTENSIONS = {".py"}
2334 BLACKLISTED_DIRECTORIES = {
2335     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2336 }
2337
2338
2339 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2340     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2341     and have one of the PYTHON_EXTENSIONS.
2342     """
2343     for child in path.iterdir():
2344         if child.is_dir():
2345             if child.name in BLACKLISTED_DIRECTORIES:
2346                 continue
2347
2348             yield from gen_python_files_in_dir(child)
2349
2350         elif child.suffix in PYTHON_EXTENSIONS:
2351             yield child
2352
2353
2354 @dataclass
2355 class Report:
2356     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2357     check: bool = False
2358     quiet: bool = False
2359     change_count: int = 0
2360     same_count: int = 0
2361     failure_count: int = 0
2362
2363     def done(self, src: Path, changed: Changed) -> None:
2364         """Increment the counter for successful reformatting. Write out a message."""
2365         if changed is Changed.YES:
2366             reformatted = "would reformat" if self.check else "reformatted"
2367             if not self.quiet:
2368                 out(f"{reformatted} {src}")
2369             self.change_count += 1
2370         else:
2371             if not self.quiet:
2372                 if changed is Changed.NO:
2373                     msg = f"{src} already well formatted, good job."
2374                 else:
2375                     msg = f"{src} wasn't modified on disk since last run."
2376                 out(msg, bold=False)
2377             self.same_count += 1
2378
2379     def failed(self, src: Path, message: str) -> None:
2380         """Increment the counter for failed reformatting. Write out a message."""
2381         err(f"error: cannot format {src}: {message}")
2382         self.failure_count += 1
2383
2384     @property
2385     def return_code(self) -> int:
2386         """Return the exit code that the app should use.
2387
2388         This considers the current state of changed files and failures:
2389         - if there were any failures, return 123;
2390         - if any files were changed and --check is being used, return 1;
2391         - otherwise return 0.
2392         """
2393         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2394         # 126 we have special returncodes reserved by the shell.
2395         if self.failure_count:
2396             return 123
2397
2398         elif self.change_count and self.check:
2399             return 1
2400
2401         return 0
2402
2403     def __str__(self) -> str:
2404         """Render a color report of the current state.
2405
2406         Use `click.unstyle` to remove colors.
2407         """
2408         if self.check:
2409             reformatted = "would be reformatted"
2410             unchanged = "would be left unchanged"
2411             failed = "would fail to reformat"
2412         else:
2413             reformatted = "reformatted"
2414             unchanged = "left unchanged"
2415             failed = "failed to reformat"
2416         report = []
2417         if self.change_count:
2418             s = "s" if self.change_count > 1 else ""
2419             report.append(
2420                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2421             )
2422         if self.same_count:
2423             s = "s" if self.same_count > 1 else ""
2424             report.append(f"{self.same_count} file{s} {unchanged}")
2425         if self.failure_count:
2426             s = "s" if self.failure_count > 1 else ""
2427             report.append(
2428                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2429             )
2430         return ", ".join(report) + "."
2431
2432
2433 def assert_equivalent(src: str, dst: str) -> None:
2434     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2435
2436     import ast
2437     import traceback
2438
2439     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2440         """Simple visitor generating strings to compare ASTs by content."""
2441         yield f"{'  ' * depth}{node.__class__.__name__}("
2442
2443         for field in sorted(node._fields):
2444             try:
2445                 value = getattr(node, field)
2446             except AttributeError:
2447                 continue
2448
2449             yield f"{'  ' * (depth+1)}{field}="
2450
2451             if isinstance(value, list):
2452                 for item in value:
2453                     if isinstance(item, ast.AST):
2454                         yield from _v(item, depth + 2)
2455
2456             elif isinstance(value, ast.AST):
2457                 yield from _v(value, depth + 2)
2458
2459             else:
2460                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2461
2462         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2463
2464     try:
2465         src_ast = ast.parse(src)
2466     except Exception as exc:
2467         major, minor = sys.version_info[:2]
2468         raise AssertionError(
2469             f"cannot use --safe with this file; failed to parse source file "
2470             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2471             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2472         )
2473
2474     try:
2475         dst_ast = ast.parse(dst)
2476     except Exception as exc:
2477         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2478         raise AssertionError(
2479             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2480             f"Please report a bug on https://github.com/ambv/black/issues.  "
2481             f"This invalid output might be helpful: {log}"
2482         ) from None
2483
2484     src_ast_str = "\n".join(_v(src_ast))
2485     dst_ast_str = "\n".join(_v(dst_ast))
2486     if src_ast_str != dst_ast_str:
2487         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2488         raise AssertionError(
2489             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2490             f"the source.  "
2491             f"Please report a bug on https://github.com/ambv/black/issues.  "
2492             f"This diff might be helpful: {log}"
2493         ) from None
2494
2495
2496 def assert_stable(src: str, dst: str, line_length: int) -> None:
2497     """Raise AssertionError if `dst` reformats differently the second time."""
2498     newdst = format_str(dst, line_length=line_length)
2499     if dst != newdst:
2500         log = dump_to_file(
2501             diff(src, dst, "source", "first pass"),
2502             diff(dst, newdst, "first pass", "second pass"),
2503         )
2504         raise AssertionError(
2505             f"INTERNAL ERROR: Black produced different code on the second pass "
2506             f"of the formatter.  "
2507             f"Please report a bug on https://github.com/ambv/black/issues.  "
2508             f"This diff might be helpful: {log}"
2509         ) from None
2510
2511
2512 def dump_to_file(*output: str) -> str:
2513     """Dump `output` to a temporary file. Return path to the file."""
2514     import tempfile
2515
2516     with tempfile.NamedTemporaryFile(
2517         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2518     ) as f:
2519         for lines in output:
2520             f.write(lines)
2521             if lines and lines[-1] != "\n":
2522                 f.write("\n")
2523     return f.name
2524
2525
2526 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2527     """Return a unified diff string between strings `a` and `b`."""
2528     import difflib
2529
2530     a_lines = [line + "\n" for line in a.split("\n")]
2531     b_lines = [line + "\n" for line in b.split("\n")]
2532     return "".join(
2533         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2534     )
2535
2536
2537 def cancel(tasks: List[asyncio.Task]) -> None:
2538     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2539     err("Aborted!")
2540     for task in tasks:
2541         task.cancel()
2542
2543
2544 def shutdown(loop: BaseEventLoop) -> None:
2545     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2546     try:
2547         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2548         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2549         if not to_cancel:
2550             return
2551
2552         for task in to_cancel:
2553             task.cancel()
2554         loop.run_until_complete(
2555             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2556         )
2557     finally:
2558         # `concurrent.futures.Future` objects cannot be cancelled once they
2559         # are already running. There might be some when the `shutdown()` happened.
2560         # Silence their logger's spew about the event loop being closed.
2561         cf_logger = logging.getLogger("concurrent.futures")
2562         cf_logger.setLevel(logging.CRITICAL)
2563         loop.close()
2564
2565
2566 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2567     """Replace `regex` with `replacement` twice on `original`.
2568
2569     This is used by string normalization to perform replaces on
2570     overlapping matches.
2571     """
2572     return regex.sub(replacement, regex.sub(replacement, original))
2573
2574
2575 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2576
2577
2578 def get_cache_file(line_length: int) -> Path:
2579     return CACHE_DIR / f"cache.{line_length}.pickle"
2580
2581
2582 def read_cache(line_length: int) -> Cache:
2583     """Read the cache if it exists and is well formed.
2584
2585     If it is not well formed, the call to write_cache later should resolve the issue.
2586     """
2587     cache_file = get_cache_file(line_length)
2588     if not cache_file.exists():
2589         return {}
2590
2591     with cache_file.open("rb") as fobj:
2592         try:
2593             cache: Cache = pickle.load(fobj)
2594         except pickle.UnpicklingError:
2595             return {}
2596
2597     return cache
2598
2599
2600 def get_cache_info(path: Path) -> CacheInfo:
2601     """Return the information used to check if a file is already formatted or not."""
2602     stat = path.stat()
2603     return stat.st_mtime, stat.st_size
2604
2605
2606 def filter_cached(
2607     cache: Cache, sources: Iterable[Path]
2608 ) -> Tuple[List[Path], List[Path]]:
2609     """Split a list of paths into two.
2610
2611     The first list contains paths of files that modified on disk or are not in the
2612     cache. The other list contains paths to non-modified files.
2613     """
2614     todo, done = [], []
2615     for src in sources:
2616         src = src.resolve()
2617         if cache.get(src) != get_cache_info(src):
2618             todo.append(src)
2619         else:
2620             done.append(src)
2621     return todo, done
2622
2623
2624 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2625     """Update the cache file."""
2626     cache_file = get_cache_file(line_length)
2627     try:
2628         if not CACHE_DIR.exists():
2629             CACHE_DIR.mkdir(parents=True)
2630         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2631         with cache_file.open("wb") as fobj:
2632             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2633     except OSError:
2634         pass
2635
2636
2637 if __name__ == "__main__":
2638     main()