]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix docstrings of visit_stmt and normalize_invisible_parens
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a6"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.VBAR,
557     token.CIRCUMFLEX,
558     token.AMPER,
559     token.LEFTSHIFT,
560     token.RIGHTSHIFT,
561     token.PLUS,
562     token.MINUS,
563     token.STAR,
564     token.SLASH,
565     token.DOUBLESLASH,
566     token.PERCENT,
567     token.AT,
568     token.TILDE,
569     token.DOUBLESTAR,
570 }
571 STARS = {token.STAR, token.DOUBLESTAR}
572 VARARGS_PARENTS = {
573     syms.arglist,
574     syms.argument,  # double star in arglist
575     syms.trailer,  # single argument to call
576     syms.typedargslist,
577     syms.varargslist,  # lambdas
578 }
579 UNPACKING_PARENTS = {
580     syms.atom,  # single element of a list or set literal
581     syms.dictsetmaker,
582     syms.listmaker,
583     syms.testlist_gexp,
584 }
585 TEST_DESCENDANTS = {
586     syms.test,
587     syms.lambdef,
588     syms.or_test,
589     syms.and_test,
590     syms.not_test,
591     syms.comparison,
592     syms.star_expr,
593     syms.expr,
594     syms.xor_expr,
595     syms.and_expr,
596     syms.shift_expr,
597     syms.arith_expr,
598     syms.trailer,
599     syms.term,
600     syms.power,
601 }
602 COMPREHENSION_PRIORITY = 20
603 COMMA_PRIORITY = 18
604 TERNARY_PRIORITY = 16
605 LOGIC_PRIORITY = 14
606 STRING_PRIORITY = 12
607 COMPARATOR_PRIORITY = 10
608 MATH_PRIORITIES = {
609     token.VBAR: 8,
610     token.CIRCUMFLEX: 7,
611     token.AMPER: 6,
612     token.LEFTSHIFT: 5,
613     token.RIGHTSHIFT: 5,
614     token.PLUS: 4,
615     token.MINUS: 4,
616     token.STAR: 3,
617     token.SLASH: 3,
618     token.DOUBLESLASH: 3,
619     token.PERCENT: 3,
620     token.AT: 3,
621     token.TILDE: 2,
622     token.DOUBLESTAR: 1,
623 }
624
625
626 @dataclass
627 class BracketTracker:
628     """Keeps track of brackets on a line."""
629
630     depth: int = 0
631     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
632     delimiters: Dict[LeafID, Priority] = Factory(dict)
633     previous: Optional[Leaf] = None
634     _for_loop_variable: bool = False
635     _lambda_arguments: bool = False
636
637     def mark(self, leaf: Leaf) -> None:
638         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
639
640         All leaves receive an int `bracket_depth` field that stores how deep
641         within brackets a given leaf is. 0 means there are no enclosing brackets
642         that started on this line.
643
644         If a leaf is itself a closing bracket, it receives an `opening_bracket`
645         field that it forms a pair with. This is a one-directional link to
646         avoid reference cycles.
647
648         If a leaf is a delimiter (a token on which Black can split the line if
649         needed) and it's on depth 0, its `id()` is stored in the tracker's
650         `delimiters` field.
651         """
652         if leaf.type == token.COMMENT:
653             return
654
655         self.maybe_decrement_after_for_loop_variable(leaf)
656         self.maybe_decrement_after_lambda_arguments(leaf)
657         if leaf.type in CLOSING_BRACKETS:
658             self.depth -= 1
659             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
660             leaf.opening_bracket = opening_bracket
661         leaf.bracket_depth = self.depth
662         if self.depth == 0:
663             delim = is_split_before_delimiter(leaf, self.previous)
664             if delim and self.previous is not None:
665                 self.delimiters[id(self.previous)] = delim
666             else:
667                 delim = is_split_after_delimiter(leaf, self.previous)
668                 if delim:
669                     self.delimiters[id(leaf)] = delim
670         if leaf.type in OPENING_BRACKETS:
671             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
672             self.depth += 1
673         self.previous = leaf
674         self.maybe_increment_lambda_arguments(leaf)
675         self.maybe_increment_for_loop_variable(leaf)
676
677     def any_open_brackets(self) -> bool:
678         """Return True if there is an yet unmatched open bracket on the line."""
679         return bool(self.bracket_match)
680
681     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
682         """Return the highest priority of a delimiter found on the line.
683
684         Values are consistent with what `is_split_*_delimiter()` return.
685         Raises ValueError on no delimiters.
686         """
687         return max(v for k, v in self.delimiters.items() if k not in exclude)
688
689     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
690         """In a for loop, or comprehension, the variables are often unpacks.
691
692         To avoid splitting on the comma in this situation, increase the depth of
693         tokens between `for` and `in`.
694         """
695         if leaf.type == token.NAME and leaf.value == "for":
696             self.depth += 1
697             self._for_loop_variable = True
698             return True
699
700         return False
701
702     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
703         """See `maybe_increment_for_loop_variable` above for explanation."""
704         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
705             self.depth -= 1
706             self._for_loop_variable = False
707             return True
708
709         return False
710
711     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
712         """In a lambda expression, there might be more than one argument.
713
714         To avoid splitting on the comma in this situation, increase the depth of
715         tokens between `lambda` and `:`.
716         """
717         if leaf.type == token.NAME and leaf.value == "lambda":
718             self.depth += 1
719             self._lambda_arguments = True
720             return True
721
722         return False
723
724     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
725         """See `maybe_increment_lambda_arguments` above for explanation."""
726         if self._lambda_arguments and leaf.type == token.COLON:
727             self.depth -= 1
728             self._lambda_arguments = False
729             return True
730
731         return False
732
733     def get_open_lsqb(self) -> Optional[Leaf]:
734         """Return the most recent opening square bracket (if any)."""
735         return self.bracket_match.get((self.depth - 1, token.RSQB))
736
737
738 @dataclass
739 class Line:
740     """Holds leaves and comments. Can be printed with `str(line)`."""
741
742     depth: int = 0
743     leaves: List[Leaf] = Factory(list)
744     comments: List[Tuple[Index, Leaf]] = Factory(list)
745     bracket_tracker: BracketTracker = Factory(BracketTracker)
746     inside_brackets: bool = False
747
748     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
749         """Add a new `leaf` to the end of the line.
750
751         Unless `preformatted` is True, the `leaf` will receive a new consistent
752         whitespace prefix and metadata applied by :class:`BracketTracker`.
753         Trailing commas are maybe removed, unpacked for loop variables are
754         demoted from being delimiters.
755
756         Inline comments are put aside.
757         """
758         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
759         if not has_value:
760             return
761
762         if token.COLON == leaf.type and self.is_class_paren_empty:
763             del self.leaves[-2:]
764         if self.leaves and not preformatted:
765             # Note: at this point leaf.prefix should be empty except for
766             # imports, for which we only preserve newlines.
767             leaf.prefix += whitespace(
768                 leaf, complex_subscript=self.is_complex_subscript(leaf)
769             )
770         if self.inside_brackets or not preformatted:
771             self.bracket_tracker.mark(leaf)
772             self.maybe_remove_trailing_comma(leaf)
773         if not self.append_comment(leaf):
774             self.leaves.append(leaf)
775
776     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
777         """Like :func:`append()` but disallow invalid standalone comment structure.
778
779         Raises ValueError when any `leaf` is appended after a standalone comment
780         or when a standalone comment is not the first leaf on the line.
781         """
782         if self.bracket_tracker.depth == 0:
783             if self.is_comment:
784                 raise ValueError("cannot append to standalone comments")
785
786             if self.leaves and leaf.type == STANDALONE_COMMENT:
787                 raise ValueError(
788                     "cannot append standalone comments to a populated line"
789                 )
790
791         self.append(leaf, preformatted=preformatted)
792
793     @property
794     def is_comment(self) -> bool:
795         """Is this line a standalone comment?"""
796         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
797
798     @property
799     def is_decorator(self) -> bool:
800         """Is this line a decorator?"""
801         return bool(self) and self.leaves[0].type == token.AT
802
803     @property
804     def is_import(self) -> bool:
805         """Is this an import line?"""
806         return bool(self) and is_import(self.leaves[0])
807
808     @property
809     def is_class(self) -> bool:
810         """Is this line a class definition?"""
811         return (
812             bool(self)
813             and self.leaves[0].type == token.NAME
814             and self.leaves[0].value == "class"
815         )
816
817     @property
818     def is_def(self) -> bool:
819         """Is this a function definition? (Also returns True for async defs.)"""
820         try:
821             first_leaf = self.leaves[0]
822         except IndexError:
823             return False
824
825         try:
826             second_leaf: Optional[Leaf] = self.leaves[1]
827         except IndexError:
828             second_leaf = None
829         return (
830             (first_leaf.type == token.NAME and first_leaf.value == "def")
831             or (
832                 first_leaf.type == token.ASYNC
833                 and second_leaf is not None
834                 and second_leaf.type == token.NAME
835                 and second_leaf.value == "def"
836             )
837         )
838
839     @property
840     def is_flow_control(self) -> bool:
841         """Is this line a flow control statement?
842
843         Those are `return`, `raise`, `break`, and `continue`.
844         """
845         return (
846             bool(self)
847             and self.leaves[0].type == token.NAME
848             and self.leaves[0].value in FLOW_CONTROL
849         )
850
851     @property
852     def is_yield(self) -> bool:
853         """Is this line a yield statement?"""
854         return (
855             bool(self)
856             and self.leaves[0].type == token.NAME
857             and self.leaves[0].value == "yield"
858         )
859
860     @property
861     def is_class_paren_empty(self) -> bool:
862         """Is this a class with no base classes but using parentheses?
863
864         Those are unnecessary and should be removed.
865         """
866         return (
867             bool(self)
868             and len(self.leaves) == 4
869             and self.is_class
870             and self.leaves[2].type == token.LPAR
871             and self.leaves[2].value == "("
872             and self.leaves[3].type == token.RPAR
873             and self.leaves[3].value == ")"
874         )
875
876     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
877         """If so, needs to be split before emitting."""
878         for leaf in self.leaves:
879             if leaf.type == STANDALONE_COMMENT:
880                 if leaf.bracket_depth <= depth_limit:
881                     return True
882
883         return False
884
885     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
886         """Remove trailing comma if there is one and it's safe."""
887         if not (
888             self.leaves
889             and self.leaves[-1].type == token.COMMA
890             and closing.type in CLOSING_BRACKETS
891         ):
892             return False
893
894         if closing.type == token.RBRACE:
895             self.remove_trailing_comma()
896             return True
897
898         if closing.type == token.RSQB:
899             comma = self.leaves[-1]
900             if comma.parent and comma.parent.type == syms.listmaker:
901                 self.remove_trailing_comma()
902                 return True
903
904         # For parens let's check if it's safe to remove the comma.
905         # Imports are always safe.
906         if self.is_import:
907             self.remove_trailing_comma()
908             return True
909
910         # Otheriwsse, if the trailing one is the only one, we might mistakenly
911         # change a tuple into a different type by removing the comma.
912         depth = closing.bracket_depth + 1
913         commas = 0
914         opening = closing.opening_bracket
915         for _opening_index, leaf in enumerate(self.leaves):
916             if leaf is opening:
917                 break
918
919         else:
920             return False
921
922         for leaf in self.leaves[_opening_index + 1 :]:
923             if leaf is closing:
924                 break
925
926             bracket_depth = leaf.bracket_depth
927             if bracket_depth == depth and leaf.type == token.COMMA:
928                 commas += 1
929                 if leaf.parent and leaf.parent.type == syms.arglist:
930                     commas += 1
931                     break
932
933         if commas > 1:
934             self.remove_trailing_comma()
935             return True
936
937         return False
938
939     def append_comment(self, comment: Leaf) -> bool:
940         """Add an inline or standalone comment to the line."""
941         if (
942             comment.type == STANDALONE_COMMENT
943             and self.bracket_tracker.any_open_brackets()
944         ):
945             comment.prefix = ""
946             return False
947
948         if comment.type != token.COMMENT:
949             return False
950
951         after = len(self.leaves) - 1
952         if after == -1:
953             comment.type = STANDALONE_COMMENT
954             comment.prefix = ""
955             return False
956
957         else:
958             self.comments.append((after, comment))
959             return True
960
961     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
962         """Generate comments that should appear directly after `leaf`."""
963         for _leaf_index, _leaf in enumerate(self.leaves):
964             if leaf is _leaf:
965                 break
966
967         else:
968             return
969
970         for index, comment_after in self.comments:
971             if _leaf_index == index:
972                 yield comment_after
973
974     def remove_trailing_comma(self) -> None:
975         """Remove the trailing comma and moves the comments attached to it."""
976         comma_index = len(self.leaves) - 1
977         for i in range(len(self.comments)):
978             comment_index, comment = self.comments[i]
979             if comment_index == comma_index:
980                 self.comments[i] = (comma_index - 1, comment)
981         self.leaves.pop()
982
983     def is_complex_subscript(self, leaf: Leaf) -> bool:
984         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
985         open_lsqb = (
986             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
987         )
988         if open_lsqb is None:
989             return False
990
991         subscript_start = open_lsqb.next_sibling
992         if (
993             isinstance(subscript_start, Node)
994             and subscript_start.type == syms.subscriptlist
995         ):
996             subscript_start = child_towards(subscript_start, leaf)
997         return subscript_start is not None and any(
998             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
999         )
1000
1001     def __str__(self) -> str:
1002         """Render the line."""
1003         if not self:
1004             return "\n"
1005
1006         indent = "    " * self.depth
1007         leaves = iter(self.leaves)
1008         first = next(leaves)
1009         res = f"{first.prefix}{indent}{first.value}"
1010         for leaf in leaves:
1011             res += str(leaf)
1012         for _, comment in self.comments:
1013             res += str(comment)
1014         return res + "\n"
1015
1016     def __bool__(self) -> bool:
1017         """Return True if the line has leaves or comments."""
1018         return bool(self.leaves or self.comments)
1019
1020
1021 class UnformattedLines(Line):
1022     """Just like :class:`Line` but stores lines which aren't reformatted."""
1023
1024     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1025         """Just add a new `leaf` to the end of the lines.
1026
1027         The `preformatted` argument is ignored.
1028
1029         Keeps track of indentation `depth`, which is useful when the user
1030         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1031         """
1032         try:
1033             list(generate_comments(leaf))
1034         except FormatOn as f_on:
1035             self.leaves.append(f_on.leaf_from_consumed(leaf))
1036             raise
1037
1038         self.leaves.append(leaf)
1039         if leaf.type == token.INDENT:
1040             self.depth += 1
1041         elif leaf.type == token.DEDENT:
1042             self.depth -= 1
1043
1044     def __str__(self) -> str:
1045         """Render unformatted lines from leaves which were added with `append()`.
1046
1047         `depth` is not used for indentation in this case.
1048         """
1049         if not self:
1050             return "\n"
1051
1052         res = ""
1053         for leaf in self.leaves:
1054             res += str(leaf)
1055         return res
1056
1057     def append_comment(self, comment: Leaf) -> bool:
1058         """Not implemented in this class. Raises `NotImplementedError`."""
1059         raise NotImplementedError("Unformatted lines don't store comments separately.")
1060
1061     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1062         """Does nothing and returns False."""
1063         return False
1064
1065     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1066         """Does nothing and returns False."""
1067         return False
1068
1069
1070 @dataclass
1071 class EmptyLineTracker:
1072     """Provides a stateful method that returns the number of potential extra
1073     empty lines needed before and after the currently processed line.
1074
1075     Note: this tracker works on lines that haven't been split yet.  It assumes
1076     the prefix of the first leaf consists of optional newlines.  Those newlines
1077     are consumed by `maybe_empty_lines()` and included in the computation.
1078     """
1079     previous_line: Optional[Line] = None
1080     previous_after: int = 0
1081     previous_defs: List[int] = Factory(list)
1082
1083     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1084         """Return the number of extra empty lines before and after the `current_line`.
1085
1086         This is for separating `def`, `async def` and `class` with extra empty
1087         lines (two on module-level), as well as providing an extra empty line
1088         after flow control keywords to make them more prominent.
1089         """
1090         if isinstance(current_line, UnformattedLines):
1091             return 0, 0
1092
1093         before, after = self._maybe_empty_lines(current_line)
1094         before -= self.previous_after
1095         self.previous_after = after
1096         self.previous_line = current_line
1097         return before, after
1098
1099     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1100         max_allowed = 1
1101         if current_line.depth == 0:
1102             max_allowed = 2
1103         if current_line.leaves:
1104             # Consume the first leaf's extra newlines.
1105             first_leaf = current_line.leaves[0]
1106             before = first_leaf.prefix.count("\n")
1107             before = min(before, max_allowed)
1108             first_leaf.prefix = ""
1109         else:
1110             before = 0
1111         depth = current_line.depth
1112         while self.previous_defs and self.previous_defs[-1] >= depth:
1113             self.previous_defs.pop()
1114             before = 1 if depth else 2
1115         is_decorator = current_line.is_decorator
1116         if is_decorator or current_line.is_def or current_line.is_class:
1117             if not is_decorator:
1118                 self.previous_defs.append(depth)
1119             if self.previous_line is None:
1120                 # Don't insert empty lines before the first line in the file.
1121                 return 0, 0
1122
1123             if self.previous_line.is_decorator:
1124                 return 0, 0
1125
1126             if (
1127                 self.previous_line.is_comment
1128                 and self.previous_line.depth == current_line.depth
1129                 and before == 0
1130             ):
1131                 return 0, 0
1132
1133             newlines = 2
1134             if current_line.depth:
1135                 newlines -= 1
1136             return newlines, 0
1137
1138         if (
1139             self.previous_line
1140             and self.previous_line.is_import
1141             and not current_line.is_import
1142             and depth == self.previous_line.depth
1143         ):
1144             return (before or 1), 0
1145
1146         return before, 0
1147
1148
1149 @dataclass
1150 class LineGenerator(Visitor[Line]):
1151     """Generates reformatted Line objects.  Empty lines are not emitted.
1152
1153     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1154     in ways that will no longer stringify to valid Python code on the tree.
1155     """
1156     current_line: Line = Factory(Line)
1157
1158     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1159         """Generate a line.
1160
1161         If the line is empty, only emit if it makes sense.
1162         If the line is too long, split it first and then generate.
1163
1164         If any lines were generated, set up a new current_line.
1165         """
1166         if not self.current_line:
1167             if self.current_line.__class__ == type:
1168                 self.current_line.depth += indent
1169             else:
1170                 self.current_line = type(depth=self.current_line.depth + indent)
1171             return  # Line is empty, don't emit. Creating a new one unnecessary.
1172
1173         complete_line = self.current_line
1174         self.current_line = type(depth=complete_line.depth + indent)
1175         yield complete_line
1176
1177     def visit(self, node: LN) -> Iterator[Line]:
1178         """Main method to visit `node` and its children.
1179
1180         Yields :class:`Line` objects.
1181         """
1182         if isinstance(self.current_line, UnformattedLines):
1183             # File contained `# fmt: off`
1184             yield from self.visit_unformatted(node)
1185
1186         else:
1187             yield from super().visit(node)
1188
1189     def visit_default(self, node: LN) -> Iterator[Line]:
1190         """Default `visit_*()` implementation. Recurses to children of `node`."""
1191         if isinstance(node, Leaf):
1192             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1193             try:
1194                 for comment in generate_comments(node):
1195                     if any_open_brackets:
1196                         # any comment within brackets is subject to splitting
1197                         self.current_line.append(comment)
1198                     elif comment.type == token.COMMENT:
1199                         # regular trailing comment
1200                         self.current_line.append(comment)
1201                         yield from self.line()
1202
1203                     else:
1204                         # regular standalone comment
1205                         yield from self.line()
1206
1207                         self.current_line.append(comment)
1208                         yield from self.line()
1209
1210             except FormatOff as f_off:
1211                 f_off.trim_prefix(node)
1212                 yield from self.line(type=UnformattedLines)
1213                 yield from self.visit(node)
1214
1215             except FormatOn as f_on:
1216                 # This only happens here if somebody says "fmt: on" multiple
1217                 # times in a row.
1218                 f_on.trim_prefix(node)
1219                 yield from self.visit_default(node)
1220
1221             else:
1222                 normalize_prefix(node, inside_brackets=any_open_brackets)
1223                 if node.type == token.STRING:
1224                     normalize_string_quotes(node)
1225                 if node.type not in WHITESPACE:
1226                     self.current_line.append(node)
1227         yield from super().visit_default(node)
1228
1229     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1230         """Increase indentation level, maybe yield a line."""
1231         # In blib2to3 INDENT never holds comments.
1232         yield from self.line(+1)
1233         yield from self.visit_default(node)
1234
1235     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1236         """Decrease indentation level, maybe yield a line."""
1237         # The current line might still wait for trailing comments.  At DEDENT time
1238         # there won't be any (they would be prefixes on the preceding NEWLINE).
1239         # Emit the line then.
1240         yield from self.line()
1241
1242         # While DEDENT has no value, its prefix may contain standalone comments
1243         # that belong to the current indentation level.  Get 'em.
1244         yield from self.visit_default(node)
1245
1246         # Finally, emit the dedent.
1247         yield from self.line(-1)
1248
1249     def visit_stmt(
1250         self, node: Node, keywords: Set[str], parens: Set[str]
1251     ) -> Iterator[Line]:
1252         """Visit a statement.
1253
1254         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1255         `def`, `with`, `class`, and `assert`.
1256
1257         The relevant Python language `keywords` for a given statement will be
1258         NAME leaves within it. This methods puts those on a separate line.
1259
1260         `parens` holds a set of string leaf values immeditely after which
1261         invisible parens should be put.
1262         """
1263         normalize_invisible_parens(node, parens_after=parens)
1264         for child in node.children:
1265             if child.type == token.NAME and child.value in keywords:  # type: ignore
1266                 yield from self.line()
1267
1268             yield from self.visit(child)
1269
1270     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1271         """Visit a statement without nested statements."""
1272         is_suite_like = node.parent and node.parent.type in STATEMENT
1273         if is_suite_like:
1274             yield from self.line(+1)
1275             yield from self.visit_default(node)
1276             yield from self.line(-1)
1277
1278         else:
1279             yield from self.line()
1280             yield from self.visit_default(node)
1281
1282     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1283         """Visit `async def`, `async for`, `async with`."""
1284         yield from self.line()
1285
1286         children = iter(node.children)
1287         for child in children:
1288             yield from self.visit(child)
1289
1290             if child.type == token.ASYNC:
1291                 break
1292
1293         internal_stmt = next(children)
1294         for child in internal_stmt.children:
1295             yield from self.visit(child)
1296
1297     def visit_decorators(self, node: Node) -> Iterator[Line]:
1298         """Visit decorators."""
1299         for child in node.children:
1300             yield from self.line()
1301             yield from self.visit(child)
1302
1303     def visit_import_from(self, node: Node) -> Iterator[Line]:
1304         """Visit import_from and maybe put invisible parentheses.
1305
1306         This is separate from `visit_stmt` because import statements don't
1307         support arbitrary atoms and thus handling of parentheses is custom.
1308         """
1309         check_lpar = False
1310         for index, child in enumerate(node.children):
1311             if check_lpar:
1312                 if child.type == token.LPAR:
1313                     # make parentheses invisible
1314                     child.value = ""  # type: ignore
1315                     node.children[-1].value = ""  # type: ignore
1316                 else:
1317                     # insert invisible parentheses
1318                     node.insert_child(index, Leaf(token.LPAR, ""))
1319                     node.append_child(Leaf(token.RPAR, ""))
1320                 break
1321
1322             check_lpar = (
1323                 child.type == token.NAME and child.value == "import"  # type: ignore
1324             )
1325
1326         for child in node.children:
1327             yield from self.visit(child)
1328
1329     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1330         """Remove a semicolon and put the other statement on a separate line."""
1331         yield from self.line()
1332
1333     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1334         """End of file. Process outstanding comments and end with a newline."""
1335         yield from self.visit_default(leaf)
1336         yield from self.line()
1337
1338     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1339         """Used when file contained a `# fmt: off`."""
1340         if isinstance(node, Node):
1341             for child in node.children:
1342                 yield from self.visit(child)
1343
1344         else:
1345             try:
1346                 self.current_line.append(node)
1347             except FormatOn as f_on:
1348                 f_on.trim_prefix(node)
1349                 yield from self.line()
1350                 yield from self.visit(node)
1351
1352             if node.type == token.ENDMARKER:
1353                 # somebody decided not to put a final `# fmt: on`
1354                 yield from self.line()
1355
1356     def __attrs_post_init__(self) -> None:
1357         """You are in a twisty little maze of passages."""
1358         v = self.visit_stmt
1359         Ø: Set[str] = set()
1360         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1361         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1362         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1363         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1364         self.visit_try_stmt = partial(
1365             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1366         )
1367         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1368         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1369         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1370         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1371         self.visit_async_funcdef = self.visit_async_stmt
1372         self.visit_decorated = self.visit_decorators
1373
1374
1375 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1376 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1377 OPENING_BRACKETS = set(BRACKET.keys())
1378 CLOSING_BRACKETS = set(BRACKET.values())
1379 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1380 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1381
1382
1383 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1384     """Return whitespace prefix if needed for the given `leaf`.
1385
1386     `complex_subscript` signals whether the given leaf is part of a subscription
1387     which has non-trivial arguments, like arithmetic expressions or function calls.
1388     """
1389     NO = ""
1390     SPACE = " "
1391     DOUBLESPACE = "  "
1392     t = leaf.type
1393     p = leaf.parent
1394     v = leaf.value
1395     if t in ALWAYS_NO_SPACE:
1396         return NO
1397
1398     if t == token.COMMENT:
1399         return DOUBLESPACE
1400
1401     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1402     if (
1403         t == token.COLON
1404         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1405     ):
1406         return NO
1407
1408     prev = leaf.prev_sibling
1409     if not prev:
1410         prevp = preceding_leaf(p)
1411         if not prevp or prevp.type in OPENING_BRACKETS:
1412             return NO
1413
1414         if t == token.COLON:
1415             if prevp.type == token.COLON:
1416                 return NO
1417
1418             elif prevp.type != token.COMMA and not complex_subscript:
1419                 return NO
1420
1421             return SPACE
1422
1423         if prevp.type == token.EQUAL:
1424             if prevp.parent:
1425                 if prevp.parent.type in {
1426                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1427                 }:
1428                     return NO
1429
1430                 elif prevp.parent.type == syms.typedargslist:
1431                     # A bit hacky: if the equal sign has whitespace, it means we
1432                     # previously found it's a typed argument.  So, we're using
1433                     # that, too.
1434                     return prevp.prefix
1435
1436         elif prevp.type in STARS:
1437             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1438                 return NO
1439
1440         elif prevp.type == token.COLON:
1441             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1442                 return SPACE if complex_subscript else NO
1443
1444         elif (
1445             prevp.parent
1446             and prevp.parent.type == syms.factor
1447             and prevp.type in MATH_OPERATORS
1448         ):
1449             return NO
1450
1451         elif (
1452             prevp.type == token.RIGHTSHIFT
1453             and prevp.parent
1454             and prevp.parent.type == syms.shift_expr
1455             and prevp.prev_sibling
1456             and prevp.prev_sibling.type == token.NAME
1457             and prevp.prev_sibling.value == "print"  # type: ignore
1458         ):
1459             # Python 2 print chevron
1460             return NO
1461
1462     elif prev.type in OPENING_BRACKETS:
1463         return NO
1464
1465     if p.type in {syms.parameters, syms.arglist}:
1466         # untyped function signatures or calls
1467         if not prev or prev.type != token.COMMA:
1468             return NO
1469
1470     elif p.type == syms.varargslist:
1471         # lambdas
1472         if prev and prev.type != token.COMMA:
1473             return NO
1474
1475     elif p.type == syms.typedargslist:
1476         # typed function signatures
1477         if not prev:
1478             return NO
1479
1480         if t == token.EQUAL:
1481             if prev.type != syms.tname:
1482                 return NO
1483
1484         elif prev.type == token.EQUAL:
1485             # A bit hacky: if the equal sign has whitespace, it means we
1486             # previously found it's a typed argument.  So, we're using that, too.
1487             return prev.prefix
1488
1489         elif prev.type != token.COMMA:
1490             return NO
1491
1492     elif p.type == syms.tname:
1493         # type names
1494         if not prev:
1495             prevp = preceding_leaf(p)
1496             if not prevp or prevp.type != token.COMMA:
1497                 return NO
1498
1499     elif p.type == syms.trailer:
1500         # attributes and calls
1501         if t == token.LPAR or t == token.RPAR:
1502             return NO
1503
1504         if not prev:
1505             if t == token.DOT:
1506                 prevp = preceding_leaf(p)
1507                 if not prevp or prevp.type != token.NUMBER:
1508                     return NO
1509
1510             elif t == token.LSQB:
1511                 return NO
1512
1513         elif prev.type != token.COMMA:
1514             return NO
1515
1516     elif p.type == syms.argument:
1517         # single argument
1518         if t == token.EQUAL:
1519             return NO
1520
1521         if not prev:
1522             prevp = preceding_leaf(p)
1523             if not prevp or prevp.type == token.LPAR:
1524                 return NO
1525
1526         elif prev.type in {token.EQUAL} | STARS:
1527             return NO
1528
1529     elif p.type == syms.decorator:
1530         # decorators
1531         return NO
1532
1533     elif p.type == syms.dotted_name:
1534         if prev:
1535             return NO
1536
1537         prevp = preceding_leaf(p)
1538         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1539             return NO
1540
1541     elif p.type == syms.classdef:
1542         if t == token.LPAR:
1543             return NO
1544
1545         if prev and prev.type == token.LPAR:
1546             return NO
1547
1548     elif p.type in {syms.subscript, syms.sliceop}:
1549         # indexing
1550         if not prev:
1551             assert p.parent is not None, "subscripts are always parented"
1552             if p.parent.type == syms.subscriptlist:
1553                 return SPACE
1554
1555             return NO
1556
1557         elif not complex_subscript:
1558             return NO
1559
1560     elif p.type == syms.atom:
1561         if prev and t == token.DOT:
1562             # dots, but not the first one.
1563             return NO
1564
1565     elif p.type == syms.dictsetmaker:
1566         # dict unpacking
1567         if prev and prev.type == token.DOUBLESTAR:
1568             return NO
1569
1570     elif p.type in {syms.factor, syms.star_expr}:
1571         # unary ops
1572         if not prev:
1573             prevp = preceding_leaf(p)
1574             if not prevp or prevp.type in OPENING_BRACKETS:
1575                 return NO
1576
1577             prevp_parent = prevp.parent
1578             assert prevp_parent is not None
1579             if (
1580                 prevp.type == token.COLON
1581                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1582             ):
1583                 return NO
1584
1585             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1586                 return NO
1587
1588         elif t == token.NAME or t == token.NUMBER:
1589             return NO
1590
1591     elif p.type == syms.import_from:
1592         if t == token.DOT:
1593             if prev and prev.type == token.DOT:
1594                 return NO
1595
1596         elif t == token.NAME:
1597             if v == "import":
1598                 return SPACE
1599
1600             if prev and prev.type == token.DOT:
1601                 return NO
1602
1603     elif p.type == syms.sliceop:
1604         return NO
1605
1606     return SPACE
1607
1608
1609 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1610     """Return the first leaf that precedes `node`, if any."""
1611     while node:
1612         res = node.prev_sibling
1613         if res:
1614             if isinstance(res, Leaf):
1615                 return res
1616
1617             try:
1618                 return list(res.leaves())[-1]
1619
1620             except IndexError:
1621                 return None
1622
1623         node = node.parent
1624     return None
1625
1626
1627 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1628     """Return the child of `ancestor` that contains `descendant`."""
1629     node: Optional[LN] = descendant
1630     while node and node.parent != ancestor:
1631         node = node.parent
1632     return node
1633
1634
1635 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1636     """Return the priority of the `leaf` delimiter, given a line break after it.
1637
1638     The delimiter priorities returned here are from those delimiters that would
1639     cause a line break after themselves.
1640
1641     Higher numbers are higher priority.
1642     """
1643     if leaf.type == token.COMMA:
1644         return COMMA_PRIORITY
1645
1646     return 0
1647
1648
1649 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1650     """Return the priority of the `leaf` delimiter, given a line before after it.
1651
1652     The delimiter priorities returned here are from those delimiters that would
1653     cause a line break before themselves.
1654
1655     Higher numbers are higher priority.
1656     """
1657     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1658         # * and ** might also be MATH_OPERATORS but in this case they are not.
1659         # Don't treat them as a delimiter.
1660         return 0
1661
1662     if (
1663         leaf.type in MATH_OPERATORS
1664         and leaf.parent
1665         and leaf.parent.type not in {syms.factor, syms.star_expr}
1666     ):
1667         return MATH_PRIORITIES[leaf.type]
1668
1669     if leaf.type in COMPARATORS:
1670         return COMPARATOR_PRIORITY
1671
1672     if (
1673         leaf.type == token.STRING
1674         and previous is not None
1675         and previous.type == token.STRING
1676     ):
1677         return STRING_PRIORITY
1678
1679     if (
1680         leaf.type == token.NAME
1681         and leaf.value == "for"
1682         and leaf.parent
1683         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1684     ):
1685         return COMPREHENSION_PRIORITY
1686
1687     if (
1688         leaf.type == token.NAME
1689         and leaf.value == "if"
1690         and leaf.parent
1691         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1692     ):
1693         return COMPREHENSION_PRIORITY
1694
1695     if (
1696         leaf.type == token.NAME
1697         and leaf.value in {"if", "else"}
1698         and leaf.parent
1699         and leaf.parent.type == syms.test
1700     ):
1701         return TERNARY_PRIORITY
1702
1703     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1704         return LOGIC_PRIORITY
1705
1706     return 0
1707
1708
1709 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1710     """Clean the prefix of the `leaf` and generate comments from it, if any.
1711
1712     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1713     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1714     move because it does away with modifying the grammar to include all the
1715     possible places in which comments can be placed.
1716
1717     The sad consequence for us though is that comments don't "belong" anywhere.
1718     This is why this function generates simple parentless Leaf objects for
1719     comments.  We simply don't know what the correct parent should be.
1720
1721     No matter though, we can live without this.  We really only need to
1722     differentiate between inline and standalone comments.  The latter don't
1723     share the line with any code.
1724
1725     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1726     are emitted with a fake STANDALONE_COMMENT token identifier.
1727     """
1728     p = leaf.prefix
1729     if not p:
1730         return
1731
1732     if "#" not in p:
1733         return
1734
1735     consumed = 0
1736     nlines = 0
1737     for index, line in enumerate(p.split("\n")):
1738         consumed += len(line) + 1  # adding the length of the split '\n'
1739         line = line.lstrip()
1740         if not line:
1741             nlines += 1
1742         if not line.startswith("#"):
1743             continue
1744
1745         if index == 0 and leaf.type != token.ENDMARKER:
1746             comment_type = token.COMMENT  # simple trailing comment
1747         else:
1748             comment_type = STANDALONE_COMMENT
1749         comment = make_comment(line)
1750         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1751
1752         if comment in {"# fmt: on", "# yapf: enable"}:
1753             raise FormatOn(consumed)
1754
1755         if comment in {"# fmt: off", "# yapf: disable"}:
1756             if comment_type == STANDALONE_COMMENT:
1757                 raise FormatOff(consumed)
1758
1759             prev = preceding_leaf(leaf)
1760             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1761                 raise FormatOff(consumed)
1762
1763         nlines = 0
1764
1765
1766 def make_comment(content: str) -> str:
1767     """Return a consistently formatted comment from the given `content` string.
1768
1769     All comments (except for "##", "#!", "#:") should have a single space between
1770     the hash sign and the content.
1771
1772     If `content` didn't start with a hash sign, one is provided.
1773     """
1774     content = content.rstrip()
1775     if not content:
1776         return "#"
1777
1778     if content[0] == "#":
1779         content = content[1:]
1780     if content and content[0] not in " !:#":
1781         content = " " + content
1782     return "#" + content
1783
1784
1785 def split_line(
1786     line: Line, line_length: int, inner: bool = False, py36: bool = False
1787 ) -> Iterator[Line]:
1788     """Split a `line` into potentially many lines.
1789
1790     They should fit in the allotted `line_length` but might not be able to.
1791     `inner` signifies that there were a pair of brackets somewhere around the
1792     current `line`, possibly transitively. This means we can fallback to splitting
1793     by delimiters if the LHS/RHS don't yield any results.
1794
1795     If `py36` is True, splitting may generate syntax that is only compatible
1796     with Python 3.6 and later.
1797     """
1798     if isinstance(line, UnformattedLines) or line.is_comment:
1799         yield line
1800         return
1801
1802     line_str = str(line).strip("\n")
1803     if (
1804         len(line_str) <= line_length
1805         and "\n" not in line_str  # multiline strings
1806         and not line.contains_standalone_comments()
1807     ):
1808         yield line
1809         return
1810
1811     split_funcs: List[SplitFunc]
1812     if line.is_def:
1813         split_funcs = [left_hand_split]
1814     elif line.is_import:
1815         split_funcs = [explode_split]
1816     elif line.inside_brackets:
1817         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1818     else:
1819         split_funcs = [right_hand_split]
1820     for split_func in split_funcs:
1821         # We are accumulating lines in `result` because we might want to abort
1822         # mission and return the original line in the end, or attempt a different
1823         # split altogether.
1824         result: List[Line] = []
1825         try:
1826             for l in split_func(line, py36):
1827                 if str(l).strip("\n") == line_str:
1828                     raise CannotSplit("Split function returned an unchanged result")
1829
1830                 result.extend(
1831                     split_line(l, line_length=line_length, inner=True, py36=py36)
1832                 )
1833         except CannotSplit as cs:
1834             continue
1835
1836         else:
1837             yield from result
1838             break
1839
1840     else:
1841         yield line
1842
1843
1844 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1845     """Split line into many lines, starting with the first matching bracket pair.
1846
1847     Note: this usually looks weird, only use this for function definitions.
1848     Prefer RHS otherwise.  This is why this function is not symmetrical with
1849     :func:`right_hand_split` which also handles optional parentheses.
1850     """
1851     head = Line(depth=line.depth)
1852     body = Line(depth=line.depth + 1, inside_brackets=True)
1853     tail = Line(depth=line.depth)
1854     tail_leaves: List[Leaf] = []
1855     body_leaves: List[Leaf] = []
1856     head_leaves: List[Leaf] = []
1857     current_leaves = head_leaves
1858     matching_bracket = None
1859     for leaf in line.leaves:
1860         if (
1861             current_leaves is body_leaves
1862             and leaf.type in CLOSING_BRACKETS
1863             and leaf.opening_bracket is matching_bracket
1864         ):
1865             current_leaves = tail_leaves if body_leaves else head_leaves
1866         current_leaves.append(leaf)
1867         if current_leaves is head_leaves:
1868             if leaf.type in OPENING_BRACKETS:
1869                 matching_bracket = leaf
1870                 current_leaves = body_leaves
1871     # Since body is a new indent level, remove spurious leading whitespace.
1872     if body_leaves:
1873         normalize_prefix(body_leaves[0], inside_brackets=True)
1874     # Build the new lines.
1875     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1876         for leaf in leaves:
1877             result.append(leaf, preformatted=True)
1878             for comment_after in line.comments_after(leaf):
1879                 result.append(comment_after, preformatted=True)
1880     bracket_split_succeeded_or_raise(head, body, tail)
1881     for result in (head, body, tail):
1882         if result:
1883             yield result
1884
1885
1886 def right_hand_split(
1887     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1888 ) -> Iterator[Line]:
1889     """Split line into many lines, starting with the last matching bracket pair.
1890
1891     If the split was by optional parentheses, attempt splitting without them, too.
1892     """
1893     head = Line(depth=line.depth)
1894     body = Line(depth=line.depth + 1, inside_brackets=True)
1895     tail = Line(depth=line.depth)
1896     tail_leaves: List[Leaf] = []
1897     body_leaves: List[Leaf] = []
1898     head_leaves: List[Leaf] = []
1899     current_leaves = tail_leaves
1900     opening_bracket = None
1901     closing_bracket = None
1902     for leaf in reversed(line.leaves):
1903         if current_leaves is body_leaves:
1904             if leaf is opening_bracket:
1905                 current_leaves = head_leaves if body_leaves else tail_leaves
1906         current_leaves.append(leaf)
1907         if current_leaves is tail_leaves:
1908             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1909                 opening_bracket = leaf.opening_bracket
1910                 closing_bracket = leaf
1911                 current_leaves = body_leaves
1912     tail_leaves.reverse()
1913     body_leaves.reverse()
1914     head_leaves.reverse()
1915     # Since body is a new indent level, remove spurious leading whitespace.
1916     if body_leaves:
1917         normalize_prefix(body_leaves[0], inside_brackets=True)
1918     elif not head_leaves:
1919         # No `head` and no `body` means the split failed. `tail` has all content.
1920         raise CannotSplit("No brackets found")
1921
1922     # Build the new lines.
1923     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1924         for leaf in leaves:
1925             result.append(leaf, preformatted=True)
1926             for comment_after in line.comments_after(leaf):
1927                 result.append(comment_after, preformatted=True)
1928     bracket_split_succeeded_or_raise(head, body, tail)
1929     assert opening_bracket and closing_bracket
1930     if (
1931         # the opening bracket is an optional paren
1932         opening_bracket.type == token.LPAR
1933         and not opening_bracket.value
1934         # the closing bracket is an optional paren
1935         and closing_bracket.type == token.RPAR
1936         and not closing_bracket.value
1937         # there are no delimiters or standalone comments in the body
1938         and not body.bracket_tracker.delimiters
1939         and not line.contains_standalone_comments(0)
1940         # and it's not an import (optional parens are the only thing we can split
1941         # on in this case; attempting a split without them is a waste of time)
1942         and not line.is_import
1943     ):
1944         omit = {id(closing_bracket), *omit}
1945         try:
1946             yield from right_hand_split(line, py36=py36, omit=omit)
1947             return
1948         except CannotSplit:
1949             pass
1950
1951     ensure_visible(opening_bracket)
1952     ensure_visible(closing_bracket)
1953     for result in (head, body, tail):
1954         if result:
1955             yield result
1956
1957
1958 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1959     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1960
1961     Do nothing otherwise.
1962
1963     A left- or right-hand split is based on a pair of brackets. Content before
1964     (and including) the opening bracket is left on one line, content inside the
1965     brackets is put on a separate line, and finally content starting with and
1966     following the closing bracket is put on a separate line.
1967
1968     Those are called `head`, `body`, and `tail`, respectively. If the split
1969     produced the same line (all content in `head`) or ended up with an empty `body`
1970     and the `tail` is just the closing bracket, then it's considered failed.
1971     """
1972     tail_len = len(str(tail).strip())
1973     if not body:
1974         if tail_len == 0:
1975             raise CannotSplit("Splitting brackets produced the same line")
1976
1977         elif tail_len < 3:
1978             raise CannotSplit(
1979                 f"Splitting brackets on an empty body to save "
1980                 f"{tail_len} characters is not worth it"
1981             )
1982
1983
1984 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1985     """Normalize prefix of the first leaf in every line returned by `split_func`.
1986
1987     This is a decorator over relevant split functions.
1988     """
1989
1990     @wraps(split_func)
1991     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1992         for l in split_func(line, py36):
1993             normalize_prefix(l.leaves[0], inside_brackets=True)
1994             yield l
1995
1996     return split_wrapper
1997
1998
1999 @dont_increase_indentation
2000 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2001     """Split according to delimiters of the highest priority.
2002
2003     If `py36` is True, the split will add trailing commas also in function
2004     signatures that contain `*` and `**`.
2005     """
2006     try:
2007         last_leaf = line.leaves[-1]
2008     except IndexError:
2009         raise CannotSplit("Line empty")
2010
2011     delimiters = line.bracket_tracker.delimiters
2012     try:
2013         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2014             exclude={id(last_leaf)}
2015         )
2016     except ValueError:
2017         raise CannotSplit("No delimiters found")
2018
2019     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2020     lowest_depth = sys.maxsize
2021     trailing_comma_safe = True
2022
2023     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2024         """Append `leaf` to current line or to new line if appending impossible."""
2025         nonlocal current_line
2026         try:
2027             current_line.append_safe(leaf, preformatted=True)
2028         except ValueError as ve:
2029             yield current_line
2030
2031             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2032             current_line.append(leaf)
2033
2034     for leaf in line.leaves:
2035         yield from append_to_line(leaf)
2036
2037         for comment_after in line.comments_after(leaf):
2038             yield from append_to_line(comment_after)
2039
2040         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2041         if (
2042             leaf.bracket_depth == lowest_depth
2043             and is_vararg(leaf, within=VARARGS_PARENTS)
2044         ):
2045             trailing_comma_safe = trailing_comma_safe and py36
2046         leaf_priority = delimiters.get(id(leaf))
2047         if leaf_priority == delimiter_priority:
2048             yield current_line
2049
2050             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2051     if current_line:
2052         if (
2053             trailing_comma_safe
2054             and delimiter_priority == COMMA_PRIORITY
2055             and current_line.leaves[-1].type != token.COMMA
2056             and current_line.leaves[-1].type != STANDALONE_COMMENT
2057         ):
2058             current_line.append(Leaf(token.COMMA, ","))
2059         yield current_line
2060
2061
2062 @dont_increase_indentation
2063 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2064     """Split standalone comments from the rest of the line."""
2065     if not line.contains_standalone_comments(0):
2066         raise CannotSplit("Line does not have any standalone comments")
2067
2068     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2069
2070     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2071         """Append `leaf` to current line or to new line if appending impossible."""
2072         nonlocal current_line
2073         try:
2074             current_line.append_safe(leaf, preformatted=True)
2075         except ValueError as ve:
2076             yield current_line
2077
2078             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2079             current_line.append(leaf)
2080
2081     for leaf in line.leaves:
2082         yield from append_to_line(leaf)
2083
2084         for comment_after in line.comments_after(leaf):
2085             yield from append_to_line(comment_after)
2086
2087     if current_line:
2088         yield current_line
2089
2090
2091 def explode_split(
2092     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2093 ) -> Iterator[Line]:
2094     """Split by rightmost bracket and immediately split contents by a delimiter."""
2095     new_lines = list(right_hand_split(line, py36, omit))
2096     if len(new_lines) != 3:
2097         yield from new_lines
2098         return
2099
2100     yield new_lines[0]
2101
2102     try:
2103         yield from delimiter_split(new_lines[1], py36)
2104
2105     except CannotSplit:
2106         yield new_lines[1]
2107
2108     yield new_lines[2]
2109
2110
2111 def is_import(leaf: Leaf) -> bool:
2112     """Return True if the given leaf starts an import statement."""
2113     p = leaf.parent
2114     t = leaf.type
2115     v = leaf.value
2116     return bool(
2117         t == token.NAME
2118         and (
2119             (v == "import" and p and p.type == syms.import_name)
2120             or (v == "from" and p and p.type == syms.import_from)
2121         )
2122     )
2123
2124
2125 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2126     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2127     else.
2128
2129     Note: don't use backslashes for formatting or you'll lose your voting rights.
2130     """
2131     if not inside_brackets:
2132         spl = leaf.prefix.split("#")
2133         if "\\" not in spl[0]:
2134             nl_count = spl[-1].count("\n")
2135             if len(spl) > 1:
2136                 nl_count -= 1
2137             leaf.prefix = "\n" * nl_count
2138             return
2139
2140     leaf.prefix = ""
2141
2142
2143 def normalize_string_quotes(leaf: Leaf) -> None:
2144     """Prefer double quotes but only if it doesn't cause more escaping.
2145
2146     Adds or removes backslashes as appropriate. Doesn't parse and fix
2147     strings nested in f-strings (yet).
2148
2149     Note: Mutates its argument.
2150     """
2151     value = leaf.value.lstrip("furbFURB")
2152     if value[:3] == '"""':
2153         return
2154
2155     elif value[:3] == "'''":
2156         orig_quote = "'''"
2157         new_quote = '"""'
2158     elif value[0] == '"':
2159         orig_quote = '"'
2160         new_quote = "'"
2161     else:
2162         orig_quote = "'"
2163         new_quote = '"'
2164     first_quote_pos = leaf.value.find(orig_quote)
2165     if first_quote_pos == -1:
2166         return  # There's an internal error
2167
2168     prefix = leaf.value[:first_quote_pos]
2169     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2170     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2171     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2172     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2173     if "r" in prefix.casefold():
2174         if unescaped_new_quote.search(body):
2175             # There's at least one unescaped new_quote in this raw string
2176             # so converting is impossible
2177             return
2178
2179         # Do not introduce or remove backslashes in raw strings
2180         new_body = body
2181     else:
2182         # remove unnecessary quotes
2183         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2184         if body != new_body:
2185             # Consider the string without unnecessary quotes as the original
2186             body = new_body
2187             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2188         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2189         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2190     if new_quote == '"""' and new_body[-1] == '"':
2191         # edge case:
2192         new_body = new_body[:-1] + '\\"'
2193     orig_escape_count = body.count("\\")
2194     new_escape_count = new_body.count("\\")
2195     if new_escape_count > orig_escape_count:
2196         return  # Do not introduce more escaping
2197
2198     if new_escape_count == orig_escape_count and orig_quote == '"':
2199         return  # Prefer double quotes
2200
2201     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2202
2203
2204 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2205     """Make existing optional parentheses invisible or create new ones.
2206
2207     `parens_after` is a set of string leaf values immeditely after which parens
2208     should be put.
2209
2210     Standardizes on visible parentheses for single-element tuples, and keeps
2211     existing visible parentheses for other tuples and generator expressions.
2212     """
2213     check_lpar = False
2214     for child in list(node.children):
2215         if check_lpar:
2216             if child.type == syms.atom:
2217                 maybe_make_parens_invisible_in_atom(child)
2218             elif is_one_tuple(child):
2219                 # wrap child in visible parentheses
2220                 lpar = Leaf(token.LPAR, "(")
2221                 rpar = Leaf(token.RPAR, ")")
2222                 index = child.remove() or 0
2223                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2224             else:
2225                 # wrap child in invisible parentheses
2226                 lpar = Leaf(token.LPAR, "")
2227                 rpar = Leaf(token.RPAR, "")
2228                 index = child.remove() or 0
2229                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2230
2231         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2232
2233
2234 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2235     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2236     if (
2237         node.type != syms.atom
2238         or is_empty_tuple(node)
2239         or is_one_tuple(node)
2240         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2241     ):
2242         return False
2243
2244     first = node.children[0]
2245     last = node.children[-1]
2246     if first.type == token.LPAR and last.type == token.RPAR:
2247         # make parentheses invisible
2248         first.value = ""  # type: ignore
2249         last.value = ""  # type: ignore
2250         if len(node.children) > 1:
2251             maybe_make_parens_invisible_in_atom(node.children[1])
2252         return True
2253
2254     return False
2255
2256
2257 def is_empty_tuple(node: LN) -> bool:
2258     """Return True if `node` holds an empty tuple."""
2259     return (
2260         node.type == syms.atom
2261         and len(node.children) == 2
2262         and node.children[0].type == token.LPAR
2263         and node.children[1].type == token.RPAR
2264     )
2265
2266
2267 def is_one_tuple(node: LN) -> bool:
2268     """Return True if `node` holds a tuple with one element, with or without parens."""
2269     if node.type == syms.atom:
2270         if len(node.children) != 3:
2271             return False
2272
2273         lpar, gexp, rpar = node.children
2274         if not (
2275             lpar.type == token.LPAR
2276             and gexp.type == syms.testlist_gexp
2277             and rpar.type == token.RPAR
2278         ):
2279             return False
2280
2281         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2282
2283     return (
2284         node.type in IMPLICIT_TUPLE
2285         and len(node.children) == 2
2286         and node.children[1].type == token.COMMA
2287     )
2288
2289
2290 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2291     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2292
2293     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2294     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2295     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2296     generalizations (PEP 448).
2297     """
2298     if leaf.type not in STARS or not leaf.parent:
2299         return False
2300
2301     p = leaf.parent
2302     if p.type == syms.star_expr:
2303         # Star expressions are also used as assignment targets in extended
2304         # iterable unpacking (PEP 3132).  See what its parent is instead.
2305         if not p.parent:
2306             return False
2307
2308         p = p.parent
2309
2310     return p.type in within
2311
2312
2313 def max_delimiter_priority_in_atom(node: LN) -> int:
2314     """Return maximum delimiter priority inside `node`.
2315
2316     This is specific to atoms with contents contained in a pair of parentheses.
2317     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2318     """
2319     if node.type != syms.atom:
2320         return 0
2321
2322     first = node.children[0]
2323     last = node.children[-1]
2324     if not (first.type == token.LPAR and last.type == token.RPAR):
2325         return 0
2326
2327     bt = BracketTracker()
2328     for c in node.children[1:-1]:
2329         if isinstance(c, Leaf):
2330             bt.mark(c)
2331         else:
2332             for leaf in c.leaves():
2333                 bt.mark(leaf)
2334     try:
2335         return bt.max_delimiter_priority()
2336
2337     except ValueError:
2338         return 0
2339
2340
2341 def ensure_visible(leaf: Leaf) -> None:
2342     """Make sure parentheses are visible.
2343
2344     They could be invisible as part of some statements (see
2345     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2346     """
2347     if leaf.type == token.LPAR:
2348         leaf.value = "("
2349     elif leaf.type == token.RPAR:
2350         leaf.value = ")"
2351
2352
2353 def is_python36(node: Node) -> bool:
2354     """Return True if the current file is using Python 3.6+ features.
2355
2356     Currently looking for:
2357     - f-strings; and
2358     - trailing commas after * or ** in function signatures and calls.
2359     """
2360     for n in node.pre_order():
2361         if n.type == token.STRING:
2362             value_head = n.value[:2]  # type: ignore
2363             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2364                 return True
2365
2366         elif (
2367             n.type in {syms.typedargslist, syms.arglist}
2368             and n.children
2369             and n.children[-1].type == token.COMMA
2370         ):
2371             for ch in n.children:
2372                 if ch.type in STARS:
2373                     return True
2374
2375                 if ch.type == syms.argument:
2376                     for argch in ch.children:
2377                         if argch.type in STARS:
2378                             return True
2379
2380     return False
2381
2382
2383 PYTHON_EXTENSIONS = {".py"}
2384 BLACKLISTED_DIRECTORIES = {
2385     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2386 }
2387
2388
2389 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2390     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2391     and have one of the PYTHON_EXTENSIONS.
2392     """
2393     for child in path.iterdir():
2394         if child.is_dir():
2395             if child.name in BLACKLISTED_DIRECTORIES:
2396                 continue
2397
2398             yield from gen_python_files_in_dir(child)
2399
2400         elif child.suffix in PYTHON_EXTENSIONS:
2401             yield child
2402
2403
2404 @dataclass
2405 class Report:
2406     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2407     check: bool = False
2408     quiet: bool = False
2409     change_count: int = 0
2410     same_count: int = 0
2411     failure_count: int = 0
2412
2413     def done(self, src: Path, changed: Changed) -> None:
2414         """Increment the counter for successful reformatting. Write out a message."""
2415         if changed is Changed.YES:
2416             reformatted = "would reformat" if self.check else "reformatted"
2417             if not self.quiet:
2418                 out(f"{reformatted} {src}")
2419             self.change_count += 1
2420         else:
2421             if not self.quiet:
2422                 if changed is Changed.NO:
2423                     msg = f"{src} already well formatted, good job."
2424                 else:
2425                     msg = f"{src} wasn't modified on disk since last run."
2426                 out(msg, bold=False)
2427             self.same_count += 1
2428
2429     def failed(self, src: Path, message: str) -> None:
2430         """Increment the counter for failed reformatting. Write out a message."""
2431         err(f"error: cannot format {src}: {message}")
2432         self.failure_count += 1
2433
2434     @property
2435     def return_code(self) -> int:
2436         """Return the exit code that the app should use.
2437
2438         This considers the current state of changed files and failures:
2439         - if there were any failures, return 123;
2440         - if any files were changed and --check is being used, return 1;
2441         - otherwise return 0.
2442         """
2443         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2444         # 126 we have special returncodes reserved by the shell.
2445         if self.failure_count:
2446             return 123
2447
2448         elif self.change_count and self.check:
2449             return 1
2450
2451         return 0
2452
2453     def __str__(self) -> str:
2454         """Render a color report of the current state.
2455
2456         Use `click.unstyle` to remove colors.
2457         """
2458         if self.check:
2459             reformatted = "would be reformatted"
2460             unchanged = "would be left unchanged"
2461             failed = "would fail to reformat"
2462         else:
2463             reformatted = "reformatted"
2464             unchanged = "left unchanged"
2465             failed = "failed to reformat"
2466         report = []
2467         if self.change_count:
2468             s = "s" if self.change_count > 1 else ""
2469             report.append(
2470                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2471             )
2472         if self.same_count:
2473             s = "s" if self.same_count > 1 else ""
2474             report.append(f"{self.same_count} file{s} {unchanged}")
2475         if self.failure_count:
2476             s = "s" if self.failure_count > 1 else ""
2477             report.append(
2478                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2479             )
2480         return ", ".join(report) + "."
2481
2482
2483 def assert_equivalent(src: str, dst: str) -> None:
2484     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2485
2486     import ast
2487     import traceback
2488
2489     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2490         """Simple visitor generating strings to compare ASTs by content."""
2491         yield f"{'  ' * depth}{node.__class__.__name__}("
2492
2493         for field in sorted(node._fields):
2494             try:
2495                 value = getattr(node, field)
2496             except AttributeError:
2497                 continue
2498
2499             yield f"{'  ' * (depth+1)}{field}="
2500
2501             if isinstance(value, list):
2502                 for item in value:
2503                     if isinstance(item, ast.AST):
2504                         yield from _v(item, depth + 2)
2505
2506             elif isinstance(value, ast.AST):
2507                 yield from _v(value, depth + 2)
2508
2509             else:
2510                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2511
2512         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2513
2514     try:
2515         src_ast = ast.parse(src)
2516     except Exception as exc:
2517         major, minor = sys.version_info[:2]
2518         raise AssertionError(
2519             f"cannot use --safe with this file; failed to parse source file "
2520             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2521             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2522         )
2523
2524     try:
2525         dst_ast = ast.parse(dst)
2526     except Exception as exc:
2527         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2528         raise AssertionError(
2529             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2530             f"Please report a bug on https://github.com/ambv/black/issues.  "
2531             f"This invalid output might be helpful: {log}"
2532         ) from None
2533
2534     src_ast_str = "\n".join(_v(src_ast))
2535     dst_ast_str = "\n".join(_v(dst_ast))
2536     if src_ast_str != dst_ast_str:
2537         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2538         raise AssertionError(
2539             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2540             f"the source.  "
2541             f"Please report a bug on https://github.com/ambv/black/issues.  "
2542             f"This diff might be helpful: {log}"
2543         ) from None
2544
2545
2546 def assert_stable(src: str, dst: str, line_length: int) -> None:
2547     """Raise AssertionError if `dst` reformats differently the second time."""
2548     newdst = format_str(dst, line_length=line_length)
2549     if dst != newdst:
2550         log = dump_to_file(
2551             diff(src, dst, "source", "first pass"),
2552             diff(dst, newdst, "first pass", "second pass"),
2553         )
2554         raise AssertionError(
2555             f"INTERNAL ERROR: Black produced different code on the second pass "
2556             f"of the formatter.  "
2557             f"Please report a bug on https://github.com/ambv/black/issues.  "
2558             f"This diff might be helpful: {log}"
2559         ) from None
2560
2561
2562 def dump_to_file(*output: str) -> str:
2563     """Dump `output` to a temporary file. Return path to the file."""
2564     import tempfile
2565
2566     with tempfile.NamedTemporaryFile(
2567         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2568     ) as f:
2569         for lines in output:
2570             f.write(lines)
2571             if lines and lines[-1] != "\n":
2572                 f.write("\n")
2573     return f.name
2574
2575
2576 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2577     """Return a unified diff string between strings `a` and `b`."""
2578     import difflib
2579
2580     a_lines = [line + "\n" for line in a.split("\n")]
2581     b_lines = [line + "\n" for line in b.split("\n")]
2582     return "".join(
2583         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2584     )
2585
2586
2587 def cancel(tasks: List[asyncio.Task]) -> None:
2588     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2589     err("Aborted!")
2590     for task in tasks:
2591         task.cancel()
2592
2593
2594 def shutdown(loop: BaseEventLoop) -> None:
2595     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2596     try:
2597         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2598         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2599         if not to_cancel:
2600             return
2601
2602         for task in to_cancel:
2603             task.cancel()
2604         loop.run_until_complete(
2605             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2606         )
2607     finally:
2608         # `concurrent.futures.Future` objects cannot be cancelled once they
2609         # are already running. There might be some when the `shutdown()` happened.
2610         # Silence their logger's spew about the event loop being closed.
2611         cf_logger = logging.getLogger("concurrent.futures")
2612         cf_logger.setLevel(logging.CRITICAL)
2613         loop.close()
2614
2615
2616 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2617     """Replace `regex` with `replacement` twice on `original`.
2618
2619     This is used by string normalization to perform replaces on
2620     overlapping matches.
2621     """
2622     return regex.sub(replacement, regex.sub(replacement, original))
2623
2624
2625 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2626
2627
2628 def get_cache_file(line_length: int) -> Path:
2629     return CACHE_DIR / f"cache.{line_length}.pickle"
2630
2631
2632 def read_cache(line_length: int) -> Cache:
2633     """Read the cache if it exists and is well formed.
2634
2635     If it is not well formed, the call to write_cache later should resolve the issue.
2636     """
2637     cache_file = get_cache_file(line_length)
2638     if not cache_file.exists():
2639         return {}
2640
2641     with cache_file.open("rb") as fobj:
2642         try:
2643             cache: Cache = pickle.load(fobj)
2644         except pickle.UnpicklingError:
2645             return {}
2646
2647     return cache
2648
2649
2650 def get_cache_info(path: Path) -> CacheInfo:
2651     """Return the information used to check if a file is already formatted or not."""
2652     stat = path.stat()
2653     return stat.st_mtime, stat.st_size
2654
2655
2656 def filter_cached(
2657     cache: Cache, sources: Iterable[Path]
2658 ) -> Tuple[List[Path], List[Path]]:
2659     """Split a list of paths into two.
2660
2661     The first list contains paths of files that modified on disk or are not in the
2662     cache. The other list contains paths to non-modified files.
2663     """
2664     todo, done = [], []
2665     for src in sources:
2666         src = src.resolve()
2667         if cache.get(src) != get_cache_info(src):
2668             todo.append(src)
2669         else:
2670             done.append(src)
2671     return todo, done
2672
2673
2674 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2675     """Update the cache file."""
2676     cache_file = get_cache_file(line_length)
2677     try:
2678         if not CACHE_DIR.exists():
2679             CACHE_DIR.mkdir(parents=True)
2680         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2681         with cache_file.open("wb") as fobj:
2682             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2683     except OSError:
2684         pass
2685
2686
2687 if __name__ == "__main__":
2688     main()