]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add instructions for running Black on save in Vim (#255)
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b0"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 @click.command()
126 @click.option(
127     "-l",
128     "--line-length",
129     type=int,
130     default=DEFAULT_LINE_LENGTH,
131     help="How many character per line to allow.",
132     show_default=True,
133 )
134 @click.option(
135     "--check",
136     is_flag=True,
137     help=(
138         "Don't write the files back, just return the status.  Return code 0 "
139         "means nothing would change.  Return code 1 means some files would be "
140         "reformatted.  Return code 123 means there was an internal error."
141     ),
142 )
143 @click.option(
144     "--diff",
145     is_flag=True,
146     help="Don't write the files back, just output a diff for each file on stdout.",
147 )
148 @click.option(
149     "--fast/--safe",
150     is_flag=True,
151     help="If --fast given, skip temporary sanity checks. [default: --safe]",
152 )
153 @click.option(
154     "-q",
155     "--quiet",
156     is_flag=True,
157     help=(
158         "Don't emit non-error messages to stderr. Errors are still emitted, "
159         "silence those with 2>/dev/null."
160     ),
161 )
162 @click.option(
163     "--pyi",
164     is_flag=True,
165     help=(
166         "Consider all input files typing stubs regardless of file extension "
167         "(useful when piping source on standard input)."
168     ),
169 )
170 @click.option(
171     "--py36",
172     is_flag=True,
173     help=(
174         "Allow using Python 3.6-only syntax on all input files.  This will put "
175         "trailing commas in function signatures and calls also after *args and "
176         "**kwargs.  [default: per-file auto-detection]"
177     ),
178 )
179 @click.version_option(version=__version__)
180 @click.argument(
181     "src",
182     nargs=-1,
183     type=click.Path(
184         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
185     ),
186 )
187 @click.pass_context
188 def main(
189     ctx: click.Context,
190     line_length: int,
191     check: bool,
192     diff: bool,
193     fast: bool,
194     pyi: bool,
195     py36: bool,
196     quiet: bool,
197     src: List[str],
198 ) -> None:
199     """The uncompromising code formatter."""
200     sources: List[Path] = []
201     for s in src:
202         p = Path(s)
203         if p.is_dir():
204             sources.extend(gen_python_files_in_dir(p))
205         elif p.is_file():
206             # if a file was explicitly given, we don't care about its extension
207             sources.append(p)
208         elif s == "-":
209             sources.append(Path("-"))
210         else:
211             err(f"invalid path: {s}")
212
213     if check and not diff:
214         write_back = WriteBack.NO
215     elif diff:
216         write_back = WriteBack.DIFF
217     else:
218         write_back = WriteBack.YES
219     report = Report(check=check, quiet=quiet)
220     if len(sources) == 0:
221         out("No paths given. Nothing to do 😴")
222         ctx.exit(0)
223         return
224
225     elif len(sources) == 1:
226         reformat_one(
227             src=sources[0],
228             line_length=line_length,
229             fast=fast,
230             pyi=pyi,
231             py36=py36,
232             write_back=write_back,
233             report=report,
234         )
235     else:
236         loop = asyncio.get_event_loop()
237         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
238         try:
239             loop.run_until_complete(
240                 schedule_formatting(
241                     sources=sources,
242                     line_length=line_length,
243                     fast=fast,
244                     pyi=pyi,
245                     py36=py36,
246                     write_back=write_back,
247                     report=report,
248                     loop=loop,
249                     executor=executor,
250                 )
251             )
252         finally:
253             shutdown(loop)
254         if not quiet:
255             out("All done! ✨ 🍰 ✨")
256             click.echo(str(report))
257     ctx.exit(report.return_code)
258
259
260 def reformat_one(
261     src: Path,
262     line_length: int,
263     fast: bool,
264     pyi: bool,
265     py36: bool,
266     write_back: WriteBack,
267     report: "Report",
268 ) -> None:
269     """Reformat a single file under `src` without spawning child processes.
270
271     If `quiet` is True, non-error messages are not output. `line_length`,
272     `write_back`, `fast` and `pyi` options are passed to
273     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
274     """
275     try:
276         changed = Changed.NO
277         if not src.is_file() and str(src) == "-":
278             if format_stdin_to_stdout(
279                 line_length=line_length,
280                 fast=fast,
281                 is_pyi=pyi,
282                 force_py36=py36,
283                 write_back=write_back,
284             ):
285                 changed = Changed.YES
286         else:
287             cache: Cache = {}
288             if write_back != WriteBack.DIFF:
289                 cache = read_cache(line_length, pyi, py36)
290                 src = src.resolve()
291                 if src in cache and cache[src] == get_cache_info(src):
292                     changed = Changed.CACHED
293             if changed is not Changed.CACHED and format_file_in_place(
294                 src,
295                 line_length=line_length,
296                 fast=fast,
297                 force_pyi=pyi,
298                 force_py36=py36,
299                 write_back=write_back,
300             ):
301                 changed = Changed.YES
302             if write_back == WriteBack.YES and changed is not Changed.NO:
303                 write_cache(cache, [src], line_length, pyi, py36)
304         report.done(src, changed)
305     except Exception as exc:
306         report.failed(src, str(exc))
307
308
309 async def schedule_formatting(
310     sources: List[Path],
311     line_length: int,
312     fast: bool,
313     pyi: bool,
314     py36: bool,
315     write_back: WriteBack,
316     report: "Report",
317     loop: BaseEventLoop,
318     executor: Executor,
319 ) -> None:
320     """Run formatting of `sources` in parallel using the provided `executor`.
321
322     (Use ProcessPoolExecutors for actual parallelism.)
323
324     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
325     :func:`format_file_in_place`.
326     """
327     cache: Cache = {}
328     if write_back != WriteBack.DIFF:
329         cache = read_cache(line_length, pyi, py36)
330         sources, cached = filter_cached(cache, sources)
331         for src in cached:
332             report.done(src, Changed.CACHED)
333     cancelled = []
334     formatted = []
335     if sources:
336         lock = None
337         if write_back == WriteBack.DIFF:
338             # For diff output, we need locks to ensure we don't interleave output
339             # from different processes.
340             manager = Manager()
341             lock = manager.Lock()
342         tasks = {
343             loop.run_in_executor(
344                 executor,
345                 format_file_in_place,
346                 src,
347                 line_length,
348                 fast,
349                 pyi,
350                 py36,
351                 write_back,
352                 lock,
353             ): src
354             for src in sorted(sources)
355         }
356         pending: Iterable[asyncio.Task] = tasks.keys()
357         try:
358             loop.add_signal_handler(signal.SIGINT, cancel, pending)
359             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
360         except NotImplementedError:
361             # There are no good alternatives for these on Windows
362             pass
363         while pending:
364             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
365             for task in done:
366                 src = tasks.pop(task)
367                 if task.cancelled():
368                     cancelled.append(task)
369                 elif task.exception():
370                     report.failed(src, str(task.exception()))
371                 else:
372                     formatted.append(src)
373                     report.done(src, Changed.YES if task.result() else Changed.NO)
374     if cancelled:
375         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
376     if write_back == WriteBack.YES and formatted:
377         write_cache(cache, formatted, line_length, pyi, py36)
378
379
380 def format_file_in_place(
381     src: Path,
382     line_length: int,
383     fast: bool,
384     force_pyi: bool = False,
385     force_py36: bool = False,
386     write_back: WriteBack = WriteBack.NO,
387     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
388 ) -> bool:
389     """Format file under `src` path. Return True if changed.
390
391     If `write_back` is True, write reformatted code back to stdout.
392     `line_length` and `fast` options are passed to :func:`format_file_contents`.
393     """
394     is_pyi = force_pyi or src.suffix == ".pyi"
395
396     with tokenize.open(src) as src_buffer:
397         src_contents = src_buffer.read()
398     try:
399         dst_contents = format_file_contents(
400             src_contents,
401             line_length=line_length,
402             fast=fast,
403             is_pyi=is_pyi,
404             force_py36=force_py36,
405         )
406     except NothingChanged:
407         return False
408
409     if write_back == write_back.YES:
410         with open(src, "w", encoding=src_buffer.encoding) as f:
411             f.write(dst_contents)
412     elif write_back == write_back.DIFF:
413         src_name = f"{src}  (original)"
414         dst_name = f"{src}  (formatted)"
415         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
416         if lock:
417             lock.acquire()
418         try:
419             sys.stdout.write(diff_contents)
420         finally:
421             if lock:
422                 lock.release()
423     return True
424
425
426 def format_stdin_to_stdout(
427     line_length: int,
428     fast: bool,
429     is_pyi: bool = False,
430     force_py36: bool = False,
431     write_back: WriteBack = WriteBack.NO,
432 ) -> bool:
433     """Format file on stdin. Return True if changed.
434
435     If `write_back` is True, write reformatted code back to stdout.
436     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
437     :func:`format_file_contents`.
438     """
439     src = sys.stdin.read()
440     dst = src
441     try:
442         dst = format_file_contents(
443             src,
444             line_length=line_length,
445             fast=fast,
446             is_pyi=is_pyi,
447             force_py36=force_py36,
448         )
449         return True
450
451     except NothingChanged:
452         return False
453
454     finally:
455         if write_back == WriteBack.YES:
456             sys.stdout.write(dst)
457         elif write_back == WriteBack.DIFF:
458             src_name = "<stdin>  (original)"
459             dst_name = "<stdin>  (formatted)"
460             sys.stdout.write(diff(src, dst, src_name, dst_name))
461
462
463 def format_file_contents(
464     src_contents: str,
465     *,
466     line_length: int,
467     fast: bool,
468     is_pyi: bool = False,
469     force_py36: bool = False,
470 ) -> FileContent:
471     """Reformat contents a file and return new contents.
472
473     If `fast` is False, additionally confirm that the reformatted code is
474     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
475     `line_length` is passed to :func:`format_str`.
476     """
477     if src_contents.strip() == "":
478         raise NothingChanged
479
480     dst_contents = format_str(
481         src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
482     )
483     if src_contents == dst_contents:
484         raise NothingChanged
485
486     if not fast:
487         assert_equivalent(src_contents, dst_contents)
488         assert_stable(
489             src_contents,
490             dst_contents,
491             line_length=line_length,
492             is_pyi=is_pyi,
493             force_py36=force_py36,
494         )
495     return dst_contents
496
497
498 def format_str(
499     src_contents: str,
500     line_length: int,
501     *,
502     is_pyi: bool = False,
503     force_py36: bool = False,
504 ) -> FileContent:
505     """Reformat a string and return new contents.
506
507     `line_length` determines how many characters per line are allowed.
508     """
509     src_node = lib2to3_parse(src_contents)
510     dst_contents = ""
511     future_imports = get_future_imports(src_node)
512     elt = EmptyLineTracker(is_pyi=is_pyi)
513     py36 = force_py36 or is_python36(src_node)
514     lines = LineGenerator(
515         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
516     )
517     empty_line = Line()
518     after = 0
519     for current_line in lines.visit(src_node):
520         for _ in range(after):
521             dst_contents += str(empty_line)
522         before, after = elt.maybe_empty_lines(current_line)
523         for _ in range(before):
524             dst_contents += str(empty_line)
525         for line in split_line(current_line, line_length=line_length, py36=py36):
526             dst_contents += str(line)
527     return dst_contents
528
529
530 GRAMMARS = [
531     pygram.python_grammar_no_print_statement_no_exec_statement,
532     pygram.python_grammar_no_print_statement,
533     pygram.python_grammar,
534 ]
535
536
537 def lib2to3_parse(src_txt: str) -> Node:
538     """Given a string with source, return the lib2to3 Node."""
539     grammar = pygram.python_grammar_no_print_statement
540     if src_txt[-1] != "\n":
541         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
542         src_txt += nl
543     for grammar in GRAMMARS:
544         drv = driver.Driver(grammar, pytree.convert)
545         try:
546             result = drv.parse_string(src_txt, True)
547             break
548
549         except ParseError as pe:
550             lineno, column = pe.context[1]
551             lines = src_txt.splitlines()
552             try:
553                 faulty_line = lines[lineno - 1]
554             except IndexError:
555                 faulty_line = "<line number missing in source>"
556             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
557     else:
558         raise exc from None
559
560     if isinstance(result, Leaf):
561         result = Node(syms.file_input, [result])
562     return result
563
564
565 def lib2to3_unparse(node: Node) -> str:
566     """Given a lib2to3 node, return its string representation."""
567     code = str(node)
568     return code
569
570
571 T = TypeVar("T")
572
573
574 class Visitor(Generic[T]):
575     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
576
577     def visit(self, node: LN) -> Iterator[T]:
578         """Main method to visit `node` and its children.
579
580         It tries to find a `visit_*()` method for the given `node.type`, like
581         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
582         If no dedicated `visit_*()` method is found, chooses `visit_default()`
583         instead.
584
585         Then yields objects of type `T` from the selected visitor.
586         """
587         if node.type < 256:
588             name = token.tok_name[node.type]
589         else:
590             name = type_repr(node.type)
591         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
592
593     def visit_default(self, node: LN) -> Iterator[T]:
594         """Default `visit_*()` implementation. Recurses to children of `node`."""
595         if isinstance(node, Node):
596             for child in node.children:
597                 yield from self.visit(child)
598
599
600 @dataclass
601 class DebugVisitor(Visitor[T]):
602     tree_depth: int = 0
603
604     def visit_default(self, node: LN) -> Iterator[T]:
605         indent = " " * (2 * self.tree_depth)
606         if isinstance(node, Node):
607             _type = type_repr(node.type)
608             out(f"{indent}{_type}", fg="yellow")
609             self.tree_depth += 1
610             for child in node.children:
611                 yield from self.visit(child)
612
613             self.tree_depth -= 1
614             out(f"{indent}/{_type}", fg="yellow", bold=False)
615         else:
616             _type = token.tok_name.get(node.type, str(node.type))
617             out(f"{indent}{_type}", fg="blue", nl=False)
618             if node.prefix:
619                 # We don't have to handle prefixes for `Node` objects since
620                 # that delegates to the first child anyway.
621                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
622             out(f" {node.value!r}", fg="blue", bold=False)
623
624     @classmethod
625     def show(cls, code: str) -> None:
626         """Pretty-print the lib2to3 AST of a given string of `code`.
627
628         Convenience method for debugging.
629         """
630         v: DebugVisitor[None] = DebugVisitor()
631         list(v.visit(lib2to3_parse(code)))
632
633
634 KEYWORDS = set(keyword.kwlist)
635 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
636 FLOW_CONTROL = {"return", "raise", "break", "continue"}
637 STATEMENT = {
638     syms.if_stmt,
639     syms.while_stmt,
640     syms.for_stmt,
641     syms.try_stmt,
642     syms.except_clause,
643     syms.with_stmt,
644     syms.funcdef,
645     syms.classdef,
646 }
647 STANDALONE_COMMENT = 153
648 LOGIC_OPERATORS = {"and", "or"}
649 COMPARATORS = {
650     token.LESS,
651     token.GREATER,
652     token.EQEQUAL,
653     token.NOTEQUAL,
654     token.LESSEQUAL,
655     token.GREATEREQUAL,
656 }
657 MATH_OPERATORS = {
658     token.VBAR,
659     token.CIRCUMFLEX,
660     token.AMPER,
661     token.LEFTSHIFT,
662     token.RIGHTSHIFT,
663     token.PLUS,
664     token.MINUS,
665     token.STAR,
666     token.SLASH,
667     token.DOUBLESLASH,
668     token.PERCENT,
669     token.AT,
670     token.TILDE,
671     token.DOUBLESTAR,
672 }
673 STARS = {token.STAR, token.DOUBLESTAR}
674 VARARGS_PARENTS = {
675     syms.arglist,
676     syms.argument,  # double star in arglist
677     syms.trailer,  # single argument to call
678     syms.typedargslist,
679     syms.varargslist,  # lambdas
680 }
681 UNPACKING_PARENTS = {
682     syms.atom,  # single element of a list or set literal
683     syms.dictsetmaker,
684     syms.listmaker,
685     syms.testlist_gexp,
686 }
687 TEST_DESCENDANTS = {
688     syms.test,
689     syms.lambdef,
690     syms.or_test,
691     syms.and_test,
692     syms.not_test,
693     syms.comparison,
694     syms.star_expr,
695     syms.expr,
696     syms.xor_expr,
697     syms.and_expr,
698     syms.shift_expr,
699     syms.arith_expr,
700     syms.trailer,
701     syms.term,
702     syms.power,
703 }
704 ASSIGNMENTS = {
705     "=",
706     "+=",
707     "-=",
708     "*=",
709     "@=",
710     "/=",
711     "%=",
712     "&=",
713     "|=",
714     "^=",
715     "<<=",
716     ">>=",
717     "**=",
718     "//=",
719 }
720 COMPREHENSION_PRIORITY = 20
721 COMMA_PRIORITY = 18
722 TERNARY_PRIORITY = 16
723 LOGIC_PRIORITY = 14
724 STRING_PRIORITY = 12
725 COMPARATOR_PRIORITY = 10
726 MATH_PRIORITIES = {
727     token.VBAR: 9,
728     token.CIRCUMFLEX: 8,
729     token.AMPER: 7,
730     token.LEFTSHIFT: 6,
731     token.RIGHTSHIFT: 6,
732     token.PLUS: 5,
733     token.MINUS: 5,
734     token.STAR: 4,
735     token.SLASH: 4,
736     token.DOUBLESLASH: 4,
737     token.PERCENT: 4,
738     token.AT: 4,
739     token.TILDE: 3,
740     token.DOUBLESTAR: 2,
741 }
742 DOT_PRIORITY = 1
743
744
745 @dataclass
746 class BracketTracker:
747     """Keeps track of brackets on a line."""
748
749     depth: int = 0
750     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
751     delimiters: Dict[LeafID, Priority] = Factory(dict)
752     previous: Optional[Leaf] = None
753     _for_loop_variable: int = 0
754     _lambda_arguments: int = 0
755
756     def mark(self, leaf: Leaf) -> None:
757         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
758
759         All leaves receive an int `bracket_depth` field that stores how deep
760         within brackets a given leaf is. 0 means there are no enclosing brackets
761         that started on this line.
762
763         If a leaf is itself a closing bracket, it receives an `opening_bracket`
764         field that it forms a pair with. This is a one-directional link to
765         avoid reference cycles.
766
767         If a leaf is a delimiter (a token on which Black can split the line if
768         needed) and it's on depth 0, its `id()` is stored in the tracker's
769         `delimiters` field.
770         """
771         if leaf.type == token.COMMENT:
772             return
773
774         self.maybe_decrement_after_for_loop_variable(leaf)
775         self.maybe_decrement_after_lambda_arguments(leaf)
776         if leaf.type in CLOSING_BRACKETS:
777             self.depth -= 1
778             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
779             leaf.opening_bracket = opening_bracket
780         leaf.bracket_depth = self.depth
781         if self.depth == 0:
782             delim = is_split_before_delimiter(leaf, self.previous)
783             if delim and self.previous is not None:
784                 self.delimiters[id(self.previous)] = delim
785             else:
786                 delim = is_split_after_delimiter(leaf, self.previous)
787                 if delim:
788                     self.delimiters[id(leaf)] = delim
789         if leaf.type in OPENING_BRACKETS:
790             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
791             self.depth += 1
792         self.previous = leaf
793         self.maybe_increment_lambda_arguments(leaf)
794         self.maybe_increment_for_loop_variable(leaf)
795
796     def any_open_brackets(self) -> bool:
797         """Return True if there is an yet unmatched open bracket on the line."""
798         return bool(self.bracket_match)
799
800     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
801         """Return the highest priority of a delimiter found on the line.
802
803         Values are consistent with what `is_split_*_delimiter()` return.
804         Raises ValueError on no delimiters.
805         """
806         return max(v for k, v in self.delimiters.items() if k not in exclude)
807
808     def delimiter_count_with_priority(self, priority: int = 0) -> int:
809         """Return the number of delimiters with the given `priority`.
810
811         If no `priority` is passed, defaults to max priority on the line.
812         """
813         if not self.delimiters:
814             return 0
815
816         priority = priority or self.max_delimiter_priority()
817         return sum(1 for p in self.delimiters.values() if p == priority)
818
819     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
820         """In a for loop, or comprehension, the variables are often unpacks.
821
822         To avoid splitting on the comma in this situation, increase the depth of
823         tokens between `for` and `in`.
824         """
825         if leaf.type == token.NAME and leaf.value == "for":
826             self.depth += 1
827             self._for_loop_variable += 1
828             return True
829
830         return False
831
832     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
833         """See `maybe_increment_for_loop_variable` above for explanation."""
834         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
835             self.depth -= 1
836             self._for_loop_variable -= 1
837             return True
838
839         return False
840
841     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
842         """In a lambda expression, there might be more than one argument.
843
844         To avoid splitting on the comma in this situation, increase the depth of
845         tokens between `lambda` and `:`.
846         """
847         if leaf.type == token.NAME and leaf.value == "lambda":
848             self.depth += 1
849             self._lambda_arguments += 1
850             return True
851
852         return False
853
854     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
855         """See `maybe_increment_lambda_arguments` above for explanation."""
856         if self._lambda_arguments and leaf.type == token.COLON:
857             self.depth -= 1
858             self._lambda_arguments -= 1
859             return True
860
861         return False
862
863     def get_open_lsqb(self) -> Optional[Leaf]:
864         """Return the most recent opening square bracket (if any)."""
865         return self.bracket_match.get((self.depth - 1, token.RSQB))
866
867
868 @dataclass
869 class Line:
870     """Holds leaves and comments. Can be printed with `str(line)`."""
871
872     depth: int = 0
873     leaves: List[Leaf] = Factory(list)
874     comments: List[Tuple[Index, Leaf]] = Factory(list)
875     bracket_tracker: BracketTracker = Factory(BracketTracker)
876     inside_brackets: bool = False
877     should_explode: bool = False
878
879     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
880         """Add a new `leaf` to the end of the line.
881
882         Unless `preformatted` is True, the `leaf` will receive a new consistent
883         whitespace prefix and metadata applied by :class:`BracketTracker`.
884         Trailing commas are maybe removed, unpacked for loop variables are
885         demoted from being delimiters.
886
887         Inline comments are put aside.
888         """
889         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
890         if not has_value:
891             return
892
893         if token.COLON == leaf.type and self.is_class_paren_empty:
894             del self.leaves[-2:]
895         if self.leaves and not preformatted:
896             # Note: at this point leaf.prefix should be empty except for
897             # imports, for which we only preserve newlines.
898             leaf.prefix += whitespace(
899                 leaf, complex_subscript=self.is_complex_subscript(leaf)
900             )
901         if self.inside_brackets or not preformatted:
902             self.bracket_tracker.mark(leaf)
903             self.maybe_remove_trailing_comma(leaf)
904         if not self.append_comment(leaf):
905             self.leaves.append(leaf)
906
907     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
908         """Like :func:`append()` but disallow invalid standalone comment structure.
909
910         Raises ValueError when any `leaf` is appended after a standalone comment
911         or when a standalone comment is not the first leaf on the line.
912         """
913         if self.bracket_tracker.depth == 0:
914             if self.is_comment:
915                 raise ValueError("cannot append to standalone comments")
916
917             if self.leaves and leaf.type == STANDALONE_COMMENT:
918                 raise ValueError(
919                     "cannot append standalone comments to a populated line"
920                 )
921
922         self.append(leaf, preformatted=preformatted)
923
924     @property
925     def is_comment(self) -> bool:
926         """Is this line a standalone comment?"""
927         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
928
929     @property
930     def is_decorator(self) -> bool:
931         """Is this line a decorator?"""
932         return bool(self) and self.leaves[0].type == token.AT
933
934     @property
935     def is_import(self) -> bool:
936         """Is this an import line?"""
937         return bool(self) and is_import(self.leaves[0])
938
939     @property
940     def is_class(self) -> bool:
941         """Is this line a class definition?"""
942         return (
943             bool(self)
944             and self.leaves[0].type == token.NAME
945             and self.leaves[0].value == "class"
946         )
947
948     @property
949     def is_stub_class(self) -> bool:
950         """Is this line a class definition with a body consisting only of "..."?"""
951         return self.is_class and self.leaves[-3:] == [
952             Leaf(token.DOT, ".") for _ in range(3)
953         ]
954
955     @property
956     def is_def(self) -> bool:
957         """Is this a function definition? (Also returns True for async defs.)"""
958         try:
959             first_leaf = self.leaves[0]
960         except IndexError:
961             return False
962
963         try:
964             second_leaf: Optional[Leaf] = self.leaves[1]
965         except IndexError:
966             second_leaf = None
967         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
968             first_leaf.type == token.ASYNC
969             and second_leaf is not None
970             and second_leaf.type == token.NAME
971             and second_leaf.value == "def"
972         )
973
974     @property
975     def is_class_paren_empty(self) -> bool:
976         """Is this a class with no base classes but using parentheses?
977
978         Those are unnecessary and should be removed.
979         """
980         return (
981             bool(self)
982             and len(self.leaves) == 4
983             and self.is_class
984             and self.leaves[2].type == token.LPAR
985             and self.leaves[2].value == "("
986             and self.leaves[3].type == token.RPAR
987             and self.leaves[3].value == ")"
988         )
989
990     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
991         """If so, needs to be split before emitting."""
992         for leaf in self.leaves:
993             if leaf.type == STANDALONE_COMMENT:
994                 if leaf.bracket_depth <= depth_limit:
995                     return True
996
997         return False
998
999     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1000         """Remove trailing comma if there is one and it's safe."""
1001         if not (
1002             self.leaves
1003             and self.leaves[-1].type == token.COMMA
1004             and closing.type in CLOSING_BRACKETS
1005         ):
1006             return False
1007
1008         if closing.type == token.RBRACE:
1009             self.remove_trailing_comma()
1010             return True
1011
1012         if closing.type == token.RSQB:
1013             comma = self.leaves[-1]
1014             if comma.parent and comma.parent.type == syms.listmaker:
1015                 self.remove_trailing_comma()
1016                 return True
1017
1018         # For parens let's check if it's safe to remove the comma.
1019         # Imports are always safe.
1020         if self.is_import:
1021             self.remove_trailing_comma()
1022             return True
1023
1024         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1025         # change a tuple into a different type by removing the comma.
1026         depth = closing.bracket_depth + 1
1027         commas = 0
1028         opening = closing.opening_bracket
1029         for _opening_index, leaf in enumerate(self.leaves):
1030             if leaf is opening:
1031                 break
1032
1033         else:
1034             return False
1035
1036         for leaf in self.leaves[_opening_index + 1 :]:
1037             if leaf is closing:
1038                 break
1039
1040             bracket_depth = leaf.bracket_depth
1041             if bracket_depth == depth and leaf.type == token.COMMA:
1042                 commas += 1
1043                 if leaf.parent and leaf.parent.type == syms.arglist:
1044                     commas += 1
1045                     break
1046
1047         if commas > 1:
1048             self.remove_trailing_comma()
1049             return True
1050
1051         return False
1052
1053     def append_comment(self, comment: Leaf) -> bool:
1054         """Add an inline or standalone comment to the line."""
1055         if (
1056             comment.type == STANDALONE_COMMENT
1057             and self.bracket_tracker.any_open_brackets()
1058         ):
1059             comment.prefix = ""
1060             return False
1061
1062         if comment.type != token.COMMENT:
1063             return False
1064
1065         after = len(self.leaves) - 1
1066         if after == -1:
1067             comment.type = STANDALONE_COMMENT
1068             comment.prefix = ""
1069             return False
1070
1071         else:
1072             self.comments.append((after, comment))
1073             return True
1074
1075     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1076         """Generate comments that should appear directly after `leaf`.
1077
1078         Provide a non-negative leaf `_index` to speed up the function.
1079         """
1080         if _index == -1:
1081             for _index, _leaf in enumerate(self.leaves):
1082                 if leaf is _leaf:
1083                     break
1084
1085             else:
1086                 return
1087
1088         for index, comment_after in self.comments:
1089             if _index == index:
1090                 yield comment_after
1091
1092     def remove_trailing_comma(self) -> None:
1093         """Remove the trailing comma and moves the comments attached to it."""
1094         comma_index = len(self.leaves) - 1
1095         for i in range(len(self.comments)):
1096             comment_index, comment = self.comments[i]
1097             if comment_index == comma_index:
1098                 self.comments[i] = (comma_index - 1, comment)
1099         self.leaves.pop()
1100
1101     def is_complex_subscript(self, leaf: Leaf) -> bool:
1102         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1103         open_lsqb = (
1104             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1105         )
1106         if open_lsqb is None:
1107             return False
1108
1109         subscript_start = open_lsqb.next_sibling
1110         if (
1111             isinstance(subscript_start, Node)
1112             and subscript_start.type == syms.subscriptlist
1113         ):
1114             subscript_start = child_towards(subscript_start, leaf)
1115         return subscript_start is not None and any(
1116             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1117         )
1118
1119     def __str__(self) -> str:
1120         """Render the line."""
1121         if not self:
1122             return "\n"
1123
1124         indent = "    " * self.depth
1125         leaves = iter(self.leaves)
1126         first = next(leaves)
1127         res = f"{first.prefix}{indent}{first.value}"
1128         for leaf in leaves:
1129             res += str(leaf)
1130         for _, comment in self.comments:
1131             res += str(comment)
1132         return res + "\n"
1133
1134     def __bool__(self) -> bool:
1135         """Return True if the line has leaves or comments."""
1136         return bool(self.leaves or self.comments)
1137
1138
1139 class UnformattedLines(Line):
1140     """Just like :class:`Line` but stores lines which aren't reformatted."""
1141
1142     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1143         """Just add a new `leaf` to the end of the lines.
1144
1145         The `preformatted` argument is ignored.
1146
1147         Keeps track of indentation `depth`, which is useful when the user
1148         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1149         """
1150         try:
1151             list(generate_comments(leaf))
1152         except FormatOn as f_on:
1153             self.leaves.append(f_on.leaf_from_consumed(leaf))
1154             raise
1155
1156         self.leaves.append(leaf)
1157         if leaf.type == token.INDENT:
1158             self.depth += 1
1159         elif leaf.type == token.DEDENT:
1160             self.depth -= 1
1161
1162     def __str__(self) -> str:
1163         """Render unformatted lines from leaves which were added with `append()`.
1164
1165         `depth` is not used for indentation in this case.
1166         """
1167         if not self:
1168             return "\n"
1169
1170         res = ""
1171         for leaf in self.leaves:
1172             res += str(leaf)
1173         return res
1174
1175     def append_comment(self, comment: Leaf) -> bool:
1176         """Not implemented in this class. Raises `NotImplementedError`."""
1177         raise NotImplementedError("Unformatted lines don't store comments separately.")
1178
1179     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1180         """Does nothing and returns False."""
1181         return False
1182
1183     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1184         """Does nothing and returns False."""
1185         return False
1186
1187
1188 @dataclass
1189 class EmptyLineTracker:
1190     """Provides a stateful method that returns the number of potential extra
1191     empty lines needed before and after the currently processed line.
1192
1193     Note: this tracker works on lines that haven't been split yet.  It assumes
1194     the prefix of the first leaf consists of optional newlines.  Those newlines
1195     are consumed by `maybe_empty_lines()` and included in the computation.
1196     """
1197     is_pyi: bool = False
1198     previous_line: Optional[Line] = None
1199     previous_after: int = 0
1200     previous_defs: List[int] = Factory(list)
1201
1202     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1203         """Return the number of extra empty lines before and after the `current_line`.
1204
1205         This is for separating `def`, `async def` and `class` with extra empty
1206         lines (two on module-level).
1207         """
1208         if isinstance(current_line, UnformattedLines):
1209             return 0, 0
1210
1211         before, after = self._maybe_empty_lines(current_line)
1212         before -= self.previous_after
1213         self.previous_after = after
1214         self.previous_line = current_line
1215         return before, after
1216
1217     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1218         max_allowed = 1
1219         if current_line.depth == 0:
1220             max_allowed = 1 if self.is_pyi else 2
1221         if current_line.leaves:
1222             # Consume the first leaf's extra newlines.
1223             first_leaf = current_line.leaves[0]
1224             before = first_leaf.prefix.count("\n")
1225             before = min(before, max_allowed)
1226             first_leaf.prefix = ""
1227         else:
1228             before = 0
1229         depth = current_line.depth
1230         while self.previous_defs and self.previous_defs[-1] >= depth:
1231             self.previous_defs.pop()
1232             if self.is_pyi:
1233                 before = 0 if depth else 1
1234             else:
1235                 before = 1 if depth else 2
1236         is_decorator = current_line.is_decorator
1237         if is_decorator or current_line.is_def or current_line.is_class:
1238             if not is_decorator:
1239                 self.previous_defs.append(depth)
1240             if self.previous_line is None:
1241                 # Don't insert empty lines before the first line in the file.
1242                 return 0, 0
1243
1244             if self.previous_line.is_decorator:
1245                 return 0, 0
1246
1247             if (
1248                 self.previous_line.is_comment
1249                 and self.previous_line.depth == current_line.depth
1250                 and before == 0
1251             ):
1252                 return 0, 0
1253
1254             if self.is_pyi:
1255                 if self.previous_line.depth > current_line.depth:
1256                     newlines = 1
1257                 elif current_line.is_class or self.previous_line.is_class:
1258                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1259                         newlines = 0
1260                     else:
1261                         newlines = 1
1262                 else:
1263                     newlines = 0
1264             else:
1265                 newlines = 2
1266             if current_line.depth and newlines:
1267                 newlines -= 1
1268             return newlines, 0
1269
1270         if (
1271             self.previous_line
1272             and self.previous_line.is_import
1273             and not current_line.is_import
1274             and depth == self.previous_line.depth
1275         ):
1276             return (before or 1), 0
1277
1278         return before, 0
1279
1280
1281 @dataclass
1282 class LineGenerator(Visitor[Line]):
1283     """Generates reformatted Line objects.  Empty lines are not emitted.
1284
1285     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1286     in ways that will no longer stringify to valid Python code on the tree.
1287     """
1288     is_pyi: bool = False
1289     current_line: Line = Factory(Line)
1290     remove_u_prefix: bool = False
1291
1292     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1293         """Generate a line.
1294
1295         If the line is empty, only emit if it makes sense.
1296         If the line is too long, split it first and then generate.
1297
1298         If any lines were generated, set up a new current_line.
1299         """
1300         if not self.current_line:
1301             if self.current_line.__class__ == type:
1302                 self.current_line.depth += indent
1303             else:
1304                 self.current_line = type(depth=self.current_line.depth + indent)
1305             return  # Line is empty, don't emit. Creating a new one unnecessary.
1306
1307         complete_line = self.current_line
1308         self.current_line = type(depth=complete_line.depth + indent)
1309         yield complete_line
1310
1311     def visit(self, node: LN) -> Iterator[Line]:
1312         """Main method to visit `node` and its children.
1313
1314         Yields :class:`Line` objects.
1315         """
1316         if isinstance(self.current_line, UnformattedLines):
1317             # File contained `# fmt: off`
1318             yield from self.visit_unformatted(node)
1319
1320         else:
1321             yield from super().visit(node)
1322
1323     def visit_default(self, node: LN) -> Iterator[Line]:
1324         """Default `visit_*()` implementation. Recurses to children of `node`."""
1325         if isinstance(node, Leaf):
1326             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1327             try:
1328                 for comment in generate_comments(node):
1329                     if any_open_brackets:
1330                         # any comment within brackets is subject to splitting
1331                         self.current_line.append(comment)
1332                     elif comment.type == token.COMMENT:
1333                         # regular trailing comment
1334                         self.current_line.append(comment)
1335                         yield from self.line()
1336
1337                     else:
1338                         # regular standalone comment
1339                         yield from self.line()
1340
1341                         self.current_line.append(comment)
1342                         yield from self.line()
1343
1344             except FormatOff as f_off:
1345                 f_off.trim_prefix(node)
1346                 yield from self.line(type=UnformattedLines)
1347                 yield from self.visit(node)
1348
1349             except FormatOn as f_on:
1350                 # This only happens here if somebody says "fmt: on" multiple
1351                 # times in a row.
1352                 f_on.trim_prefix(node)
1353                 yield from self.visit_default(node)
1354
1355             else:
1356                 normalize_prefix(node, inside_brackets=any_open_brackets)
1357                 if node.type == token.STRING:
1358                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1359                     normalize_string_quotes(node)
1360                 if node.type not in WHITESPACE:
1361                     self.current_line.append(node)
1362         yield from super().visit_default(node)
1363
1364     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1365         """Increase indentation level, maybe yield a line."""
1366         # In blib2to3 INDENT never holds comments.
1367         yield from self.line(+1)
1368         yield from self.visit_default(node)
1369
1370     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1371         """Decrease indentation level, maybe yield a line."""
1372         # The current line might still wait for trailing comments.  At DEDENT time
1373         # there won't be any (they would be prefixes on the preceding NEWLINE).
1374         # Emit the line then.
1375         yield from self.line()
1376
1377         # While DEDENT has no value, its prefix may contain standalone comments
1378         # that belong to the current indentation level.  Get 'em.
1379         yield from self.visit_default(node)
1380
1381         # Finally, emit the dedent.
1382         yield from self.line(-1)
1383
1384     def visit_stmt(
1385         self, node: Node, keywords: Set[str], parens: Set[str]
1386     ) -> Iterator[Line]:
1387         """Visit a statement.
1388
1389         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1390         `def`, `with`, `class`, `assert` and assignments.
1391
1392         The relevant Python language `keywords` for a given statement will be
1393         NAME leaves within it. This methods puts those on a separate line.
1394
1395         `parens` holds a set of string leaf values immediately after which
1396         invisible parens should be put.
1397         """
1398         normalize_invisible_parens(node, parens_after=parens)
1399         for child in node.children:
1400             if child.type == token.NAME and child.value in keywords:  # type: ignore
1401                 yield from self.line()
1402
1403             yield from self.visit(child)
1404
1405     def visit_suite(self, node: Node) -> Iterator[Line]:
1406         """Visit a suite."""
1407         if self.is_pyi and is_stub_suite(node):
1408             yield from self.visit(node.children[2])
1409         else:
1410             yield from self.visit_default(node)
1411
1412     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1413         """Visit a statement without nested statements."""
1414         is_suite_like = node.parent and node.parent.type in STATEMENT
1415         if is_suite_like:
1416             if self.is_pyi and is_stub_body(node):
1417                 yield from self.visit_default(node)
1418             else:
1419                 yield from self.line(+1)
1420                 yield from self.visit_default(node)
1421                 yield from self.line(-1)
1422
1423         else:
1424             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1425                 yield from self.line()
1426             yield from self.visit_default(node)
1427
1428     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1429         """Visit `async def`, `async for`, `async with`."""
1430         yield from self.line()
1431
1432         children = iter(node.children)
1433         for child in children:
1434             yield from self.visit(child)
1435
1436             if child.type == token.ASYNC:
1437                 break
1438
1439         internal_stmt = next(children)
1440         for child in internal_stmt.children:
1441             yield from self.visit(child)
1442
1443     def visit_decorators(self, node: Node) -> Iterator[Line]:
1444         """Visit decorators."""
1445         for child in node.children:
1446             yield from self.line()
1447             yield from self.visit(child)
1448
1449     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1450         """Remove a semicolon and put the other statement on a separate line."""
1451         yield from self.line()
1452
1453     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1454         """End of file. Process outstanding comments and end with a newline."""
1455         yield from self.visit_default(leaf)
1456         yield from self.line()
1457
1458     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1459         """Used when file contained a `# fmt: off`."""
1460         if isinstance(node, Node):
1461             for child in node.children:
1462                 yield from self.visit(child)
1463
1464         else:
1465             try:
1466                 self.current_line.append(node)
1467             except FormatOn as f_on:
1468                 f_on.trim_prefix(node)
1469                 yield from self.line()
1470                 yield from self.visit(node)
1471
1472             if node.type == token.ENDMARKER:
1473                 # somebody decided not to put a final `# fmt: on`
1474                 yield from self.line()
1475
1476     def __attrs_post_init__(self) -> None:
1477         """You are in a twisty little maze of passages."""
1478         v = self.visit_stmt
1479         Ø: Set[str] = set()
1480         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1481         self.visit_if_stmt = partial(
1482             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1483         )
1484         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1485         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1486         self.visit_try_stmt = partial(
1487             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1488         )
1489         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1490         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1491         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1492         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1493         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1494         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1495         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1496         self.visit_async_funcdef = self.visit_async_stmt
1497         self.visit_decorated = self.visit_decorators
1498
1499
1500 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1501 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1502 OPENING_BRACKETS = set(BRACKET.keys())
1503 CLOSING_BRACKETS = set(BRACKET.values())
1504 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1505 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1506
1507
1508 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1509     """Return whitespace prefix if needed for the given `leaf`.
1510
1511     `complex_subscript` signals whether the given leaf is part of a subscription
1512     which has non-trivial arguments, like arithmetic expressions or function calls.
1513     """
1514     NO = ""
1515     SPACE = " "
1516     DOUBLESPACE = "  "
1517     t = leaf.type
1518     p = leaf.parent
1519     v = leaf.value
1520     if t in ALWAYS_NO_SPACE:
1521         return NO
1522
1523     if t == token.COMMENT:
1524         return DOUBLESPACE
1525
1526     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1527     if t == token.COLON and p.type not in {
1528         syms.subscript,
1529         syms.subscriptlist,
1530         syms.sliceop,
1531     }:
1532         return NO
1533
1534     prev = leaf.prev_sibling
1535     if not prev:
1536         prevp = preceding_leaf(p)
1537         if not prevp or prevp.type in OPENING_BRACKETS:
1538             return NO
1539
1540         if t == token.COLON:
1541             if prevp.type == token.COLON:
1542                 return NO
1543
1544             elif prevp.type != token.COMMA and not complex_subscript:
1545                 return NO
1546
1547             return SPACE
1548
1549         if prevp.type == token.EQUAL:
1550             if prevp.parent:
1551                 if prevp.parent.type in {
1552                     syms.arglist,
1553                     syms.argument,
1554                     syms.parameters,
1555                     syms.varargslist,
1556                 }:
1557                     return NO
1558
1559                 elif prevp.parent.type == syms.typedargslist:
1560                     # A bit hacky: if the equal sign has whitespace, it means we
1561                     # previously found it's a typed argument.  So, we're using
1562                     # that, too.
1563                     return prevp.prefix
1564
1565         elif prevp.type in STARS:
1566             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1567                 return NO
1568
1569         elif prevp.type == token.COLON:
1570             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1571                 return SPACE if complex_subscript else NO
1572
1573         elif (
1574             prevp.parent
1575             and prevp.parent.type == syms.factor
1576             and prevp.type in MATH_OPERATORS
1577         ):
1578             return NO
1579
1580         elif (
1581             prevp.type == token.RIGHTSHIFT
1582             and prevp.parent
1583             and prevp.parent.type == syms.shift_expr
1584             and prevp.prev_sibling
1585             and prevp.prev_sibling.type == token.NAME
1586             and prevp.prev_sibling.value == "print"  # type: ignore
1587         ):
1588             # Python 2 print chevron
1589             return NO
1590
1591     elif prev.type in OPENING_BRACKETS:
1592         return NO
1593
1594     if p.type in {syms.parameters, syms.arglist}:
1595         # untyped function signatures or calls
1596         if not prev or prev.type != token.COMMA:
1597             return NO
1598
1599     elif p.type == syms.varargslist:
1600         # lambdas
1601         if prev and prev.type != token.COMMA:
1602             return NO
1603
1604     elif p.type == syms.typedargslist:
1605         # typed function signatures
1606         if not prev:
1607             return NO
1608
1609         if t == token.EQUAL:
1610             if prev.type != syms.tname:
1611                 return NO
1612
1613         elif prev.type == token.EQUAL:
1614             # A bit hacky: if the equal sign has whitespace, it means we
1615             # previously found it's a typed argument.  So, we're using that, too.
1616             return prev.prefix
1617
1618         elif prev.type != token.COMMA:
1619             return NO
1620
1621     elif p.type == syms.tname:
1622         # type names
1623         if not prev:
1624             prevp = preceding_leaf(p)
1625             if not prevp or prevp.type != token.COMMA:
1626                 return NO
1627
1628     elif p.type == syms.trailer:
1629         # attributes and calls
1630         if t == token.LPAR or t == token.RPAR:
1631             return NO
1632
1633         if not prev:
1634             if t == token.DOT:
1635                 prevp = preceding_leaf(p)
1636                 if not prevp or prevp.type != token.NUMBER:
1637                     return NO
1638
1639             elif t == token.LSQB:
1640                 return NO
1641
1642         elif prev.type != token.COMMA:
1643             return NO
1644
1645     elif p.type == syms.argument:
1646         # single argument
1647         if t == token.EQUAL:
1648             return NO
1649
1650         if not prev:
1651             prevp = preceding_leaf(p)
1652             if not prevp or prevp.type == token.LPAR:
1653                 return NO
1654
1655         elif prev.type in {token.EQUAL} | STARS:
1656             return NO
1657
1658     elif p.type == syms.decorator:
1659         # decorators
1660         return NO
1661
1662     elif p.type == syms.dotted_name:
1663         if prev:
1664             return NO
1665
1666         prevp = preceding_leaf(p)
1667         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1668             return NO
1669
1670     elif p.type == syms.classdef:
1671         if t == token.LPAR:
1672             return NO
1673
1674         if prev and prev.type == token.LPAR:
1675             return NO
1676
1677     elif p.type in {syms.subscript, syms.sliceop}:
1678         # indexing
1679         if not prev:
1680             assert p.parent is not None, "subscripts are always parented"
1681             if p.parent.type == syms.subscriptlist:
1682                 return SPACE
1683
1684             return NO
1685
1686         elif not complex_subscript:
1687             return NO
1688
1689     elif p.type == syms.atom:
1690         if prev and t == token.DOT:
1691             # dots, but not the first one.
1692             return NO
1693
1694     elif p.type == syms.dictsetmaker:
1695         # dict unpacking
1696         if prev and prev.type == token.DOUBLESTAR:
1697             return NO
1698
1699     elif p.type in {syms.factor, syms.star_expr}:
1700         # unary ops
1701         if not prev:
1702             prevp = preceding_leaf(p)
1703             if not prevp or prevp.type in OPENING_BRACKETS:
1704                 return NO
1705
1706             prevp_parent = prevp.parent
1707             assert prevp_parent is not None
1708             if prevp.type == token.COLON and prevp_parent.type in {
1709                 syms.subscript,
1710                 syms.sliceop,
1711             }:
1712                 return NO
1713
1714             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1715                 return NO
1716
1717         elif t == token.NAME or t == token.NUMBER:
1718             return NO
1719
1720     elif p.type == syms.import_from:
1721         if t == token.DOT:
1722             if prev and prev.type == token.DOT:
1723                 return NO
1724
1725         elif t == token.NAME:
1726             if v == "import":
1727                 return SPACE
1728
1729             if prev and prev.type == token.DOT:
1730                 return NO
1731
1732     elif p.type == syms.sliceop:
1733         return NO
1734
1735     return SPACE
1736
1737
1738 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1739     """Return the first leaf that precedes `node`, if any."""
1740     while node:
1741         res = node.prev_sibling
1742         if res:
1743             if isinstance(res, Leaf):
1744                 return res
1745
1746             try:
1747                 return list(res.leaves())[-1]
1748
1749             except IndexError:
1750                 return None
1751
1752         node = node.parent
1753     return None
1754
1755
1756 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1757     """Return the child of `ancestor` that contains `descendant`."""
1758     node: Optional[LN] = descendant
1759     while node and node.parent != ancestor:
1760         node = node.parent
1761     return node
1762
1763
1764 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1765     """Return the priority of the `leaf` delimiter, given a line break after it.
1766
1767     The delimiter priorities returned here are from those delimiters that would
1768     cause a line break after themselves.
1769
1770     Higher numbers are higher priority.
1771     """
1772     if leaf.type == token.COMMA:
1773         return COMMA_PRIORITY
1774
1775     return 0
1776
1777
1778 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1779     """Return the priority of the `leaf` delimiter, given a line before after it.
1780
1781     The delimiter priorities returned here are from those delimiters that would
1782     cause a line break before themselves.
1783
1784     Higher numbers are higher priority.
1785     """
1786     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1787         # * and ** might also be MATH_OPERATORS but in this case they are not.
1788         # Don't treat them as a delimiter.
1789         return 0
1790
1791     if (
1792         leaf.type == token.DOT
1793         and leaf.parent
1794         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1795         and (previous is None or previous.type in CLOSING_BRACKETS)
1796     ):
1797         return DOT_PRIORITY
1798
1799     if (
1800         leaf.type in MATH_OPERATORS
1801         and leaf.parent
1802         and leaf.parent.type not in {syms.factor, syms.star_expr}
1803     ):
1804         return MATH_PRIORITIES[leaf.type]
1805
1806     if leaf.type in COMPARATORS:
1807         return COMPARATOR_PRIORITY
1808
1809     if (
1810         leaf.type == token.STRING
1811         and previous is not None
1812         and previous.type == token.STRING
1813     ):
1814         return STRING_PRIORITY
1815
1816     if leaf.type != token.NAME:
1817         return 0
1818
1819     if (
1820         leaf.value == "for"
1821         and leaf.parent
1822         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1823     ):
1824         return COMPREHENSION_PRIORITY
1825
1826     if (
1827         leaf.value == "if"
1828         and leaf.parent
1829         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1830     ):
1831         return COMPREHENSION_PRIORITY
1832
1833     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1834         return TERNARY_PRIORITY
1835
1836     if leaf.value == "is":
1837         return COMPARATOR_PRIORITY
1838
1839     if (
1840         leaf.value == "in"
1841         and leaf.parent
1842         and leaf.parent.type in {syms.comp_op, syms.comparison}
1843         and not (
1844             previous is not None
1845             and previous.type == token.NAME
1846             and previous.value == "not"
1847         )
1848     ):
1849         return COMPARATOR_PRIORITY
1850
1851     if (
1852         leaf.value == "not"
1853         and leaf.parent
1854         and leaf.parent.type == syms.comp_op
1855         and not (
1856             previous is not None
1857             and previous.type == token.NAME
1858             and previous.value == "is"
1859         )
1860     ):
1861         return COMPARATOR_PRIORITY
1862
1863     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1864         return LOGIC_PRIORITY
1865
1866     return 0
1867
1868
1869 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1870     """Clean the prefix of the `leaf` and generate comments from it, if any.
1871
1872     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1873     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1874     move because it does away with modifying the grammar to include all the
1875     possible places in which comments can be placed.
1876
1877     The sad consequence for us though is that comments don't "belong" anywhere.
1878     This is why this function generates simple parentless Leaf objects for
1879     comments.  We simply don't know what the correct parent should be.
1880
1881     No matter though, we can live without this.  We really only need to
1882     differentiate between inline and standalone comments.  The latter don't
1883     share the line with any code.
1884
1885     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1886     are emitted with a fake STANDALONE_COMMENT token identifier.
1887     """
1888     p = leaf.prefix
1889     if not p:
1890         return
1891
1892     if "#" not in p:
1893         return
1894
1895     consumed = 0
1896     nlines = 0
1897     for index, line in enumerate(p.split("\n")):
1898         consumed += len(line) + 1  # adding the length of the split '\n'
1899         line = line.lstrip()
1900         if not line:
1901             nlines += 1
1902         if not line.startswith("#"):
1903             continue
1904
1905         if index == 0 and leaf.type != token.ENDMARKER:
1906             comment_type = token.COMMENT  # simple trailing comment
1907         else:
1908             comment_type = STANDALONE_COMMENT
1909         comment = make_comment(line)
1910         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1911
1912         if comment in {"# fmt: on", "# yapf: enable"}:
1913             raise FormatOn(consumed)
1914
1915         if comment in {"# fmt: off", "# yapf: disable"}:
1916             if comment_type == STANDALONE_COMMENT:
1917                 raise FormatOff(consumed)
1918
1919             prev = preceding_leaf(leaf)
1920             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1921                 raise FormatOff(consumed)
1922
1923         nlines = 0
1924
1925
1926 def make_comment(content: str) -> str:
1927     """Return a consistently formatted comment from the given `content` string.
1928
1929     All comments (except for "##", "#!", "#:") should have a single space between
1930     the hash sign and the content.
1931
1932     If `content` didn't start with a hash sign, one is provided.
1933     """
1934     content = content.rstrip()
1935     if not content:
1936         return "#"
1937
1938     if content[0] == "#":
1939         content = content[1:]
1940     if content and content[0] not in " !:#":
1941         content = " " + content
1942     return "#" + content
1943
1944
1945 def split_line(
1946     line: Line, line_length: int, inner: bool = False, py36: bool = False
1947 ) -> Iterator[Line]:
1948     """Split a `line` into potentially many lines.
1949
1950     They should fit in the allotted `line_length` but might not be able to.
1951     `inner` signifies that there were a pair of brackets somewhere around the
1952     current `line`, possibly transitively. This means we can fallback to splitting
1953     by delimiters if the LHS/RHS don't yield any results.
1954
1955     If `py36` is True, splitting may generate syntax that is only compatible
1956     with Python 3.6 and later.
1957     """
1958     if isinstance(line, UnformattedLines) or line.is_comment:
1959         yield line
1960         return
1961
1962     line_str = str(line).strip("\n")
1963     if not line.should_explode and is_line_short_enough(
1964         line, line_length=line_length, line_str=line_str
1965     ):
1966         yield line
1967         return
1968
1969     split_funcs: List[SplitFunc]
1970     if line.is_def:
1971         split_funcs = [left_hand_split]
1972     else:
1973
1974         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1975             for omit in generate_trailers_to_omit(line, line_length):
1976                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
1977                 if is_line_short_enough(lines[0], line_length=line_length):
1978                     yield from lines
1979                     return
1980
1981             # All splits failed, best effort split with no omits.
1982             # This mostly happens to multiline strings that are by definition
1983             # reported as not fitting a single line.
1984             yield from right_hand_split(line, py36)
1985
1986         if line.inside_brackets:
1987             split_funcs = [delimiter_split, standalone_comment_split, rhs]
1988         else:
1989             split_funcs = [rhs]
1990     for split_func in split_funcs:
1991         # We are accumulating lines in `result` because we might want to abort
1992         # mission and return the original line in the end, or attempt a different
1993         # split altogether.
1994         result: List[Line] = []
1995         try:
1996             for l in split_func(line, py36):
1997                 if str(l).strip("\n") == line_str:
1998                     raise CannotSplit("Split function returned an unchanged result")
1999
2000                 result.extend(
2001                     split_line(l, line_length=line_length, inner=True, py36=py36)
2002                 )
2003         except CannotSplit as cs:
2004             continue
2005
2006         else:
2007             yield from result
2008             break
2009
2010     else:
2011         yield line
2012
2013
2014 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2015     """Split line into many lines, starting with the first matching bracket pair.
2016
2017     Note: this usually looks weird, only use this for function definitions.
2018     Prefer RHS otherwise.  This is why this function is not symmetrical with
2019     :func:`right_hand_split` which also handles optional parentheses.
2020     """
2021     head = Line(depth=line.depth)
2022     body = Line(depth=line.depth + 1, inside_brackets=True)
2023     tail = Line(depth=line.depth)
2024     tail_leaves: List[Leaf] = []
2025     body_leaves: List[Leaf] = []
2026     head_leaves: List[Leaf] = []
2027     current_leaves = head_leaves
2028     matching_bracket = None
2029     for leaf in line.leaves:
2030         if (
2031             current_leaves is body_leaves
2032             and leaf.type in CLOSING_BRACKETS
2033             and leaf.opening_bracket is matching_bracket
2034         ):
2035             current_leaves = tail_leaves if body_leaves else head_leaves
2036         current_leaves.append(leaf)
2037         if current_leaves is head_leaves:
2038             if leaf.type in OPENING_BRACKETS:
2039                 matching_bracket = leaf
2040                 current_leaves = body_leaves
2041     # Since body is a new indent level, remove spurious leading whitespace.
2042     if body_leaves:
2043         normalize_prefix(body_leaves[0], inside_brackets=True)
2044     # Build the new lines.
2045     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2046         for leaf in leaves:
2047             result.append(leaf, preformatted=True)
2048             for comment_after in line.comments_after(leaf):
2049                 result.append(comment_after, preformatted=True)
2050     bracket_split_succeeded_or_raise(head, body, tail)
2051     for result in (head, body, tail):
2052         if result:
2053             yield result
2054
2055
2056 def right_hand_split(
2057     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2058 ) -> Iterator[Line]:
2059     """Split line into many lines, starting with the last matching bracket pair.
2060
2061     If the split was by optional parentheses, attempt splitting without them, too.
2062     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2063     this split.
2064
2065     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2066     """
2067     head = Line(depth=line.depth)
2068     body = Line(depth=line.depth + 1, inside_brackets=True)
2069     tail = Line(depth=line.depth)
2070     tail_leaves: List[Leaf] = []
2071     body_leaves: List[Leaf] = []
2072     head_leaves: List[Leaf] = []
2073     current_leaves = tail_leaves
2074     opening_bracket = None
2075     closing_bracket = None
2076     for leaf in reversed(line.leaves):
2077         if current_leaves is body_leaves:
2078             if leaf is opening_bracket:
2079                 current_leaves = head_leaves if body_leaves else tail_leaves
2080         current_leaves.append(leaf)
2081         if current_leaves is tail_leaves:
2082             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2083                 opening_bracket = leaf.opening_bracket
2084                 closing_bracket = leaf
2085                 current_leaves = body_leaves
2086     tail_leaves.reverse()
2087     body_leaves.reverse()
2088     head_leaves.reverse()
2089     # Since body is a new indent level, remove spurious leading whitespace.
2090     if body_leaves:
2091         normalize_prefix(body_leaves[0], inside_brackets=True)
2092     if not head_leaves:
2093         # No `head` means the split failed. Either `tail` has all content or
2094         # the matching `opening_bracket` wasn't available on `line` anymore.
2095         raise CannotSplit("No brackets found")
2096
2097     # Build the new lines.
2098     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2099         for leaf in leaves:
2100             result.append(leaf, preformatted=True)
2101             for comment_after in line.comments_after(leaf):
2102                 result.append(comment_after, preformatted=True)
2103     bracket_split_succeeded_or_raise(head, body, tail)
2104     assert opening_bracket and closing_bracket
2105     if (
2106         # the opening bracket is an optional paren
2107         opening_bracket.type == token.LPAR
2108         and not opening_bracket.value
2109         # the closing bracket is an optional paren
2110         and closing_bracket.type == token.RPAR
2111         and not closing_bracket.value
2112         # there are no standalone comments in the body
2113         and not line.contains_standalone_comments(0)
2114         # and it's not an import (optional parens are the only thing we can split
2115         # on in this case; attempting a split without them is a waste of time)
2116         and not line.is_import
2117     ):
2118         omit = {id(closing_bracket), *omit}
2119         if can_omit_invisible_parens(body, line_length):
2120             try:
2121                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2122                 return
2123             except CannotSplit:
2124                 pass
2125
2126     ensure_visible(opening_bracket)
2127     ensure_visible(closing_bracket)
2128     body.should_explode = should_explode(body, opening_bracket)
2129     for result in (head, body, tail):
2130         if result:
2131             yield result
2132
2133
2134 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2135     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2136
2137     Do nothing otherwise.
2138
2139     A left- or right-hand split is based on a pair of brackets. Content before
2140     (and including) the opening bracket is left on one line, content inside the
2141     brackets is put on a separate line, and finally content starting with and
2142     following the closing bracket is put on a separate line.
2143
2144     Those are called `head`, `body`, and `tail`, respectively. If the split
2145     produced the same line (all content in `head`) or ended up with an empty `body`
2146     and the `tail` is just the closing bracket, then it's considered failed.
2147     """
2148     tail_len = len(str(tail).strip())
2149     if not body:
2150         if tail_len == 0:
2151             raise CannotSplit("Splitting brackets produced the same line")
2152
2153         elif tail_len < 3:
2154             raise CannotSplit(
2155                 f"Splitting brackets on an empty body to save "
2156                 f"{tail_len} characters is not worth it"
2157             )
2158
2159
2160 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2161     """Normalize prefix of the first leaf in every line returned by `split_func`.
2162
2163     This is a decorator over relevant split functions.
2164     """
2165
2166     @wraps(split_func)
2167     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2168         for l in split_func(line, py36):
2169             normalize_prefix(l.leaves[0], inside_brackets=True)
2170             yield l
2171
2172     return split_wrapper
2173
2174
2175 @dont_increase_indentation
2176 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2177     """Split according to delimiters of the highest priority.
2178
2179     If `py36` is True, the split will add trailing commas also in function
2180     signatures that contain `*` and `**`.
2181     """
2182     try:
2183         last_leaf = line.leaves[-1]
2184     except IndexError:
2185         raise CannotSplit("Line empty")
2186
2187     bt = line.bracket_tracker
2188     try:
2189         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2190     except ValueError:
2191         raise CannotSplit("No delimiters found")
2192
2193     if delimiter_priority == DOT_PRIORITY:
2194         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2195             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2196
2197     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2198     lowest_depth = sys.maxsize
2199     trailing_comma_safe = True
2200
2201     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2202         """Append `leaf` to current line or to new line if appending impossible."""
2203         nonlocal current_line
2204         try:
2205             current_line.append_safe(leaf, preformatted=True)
2206         except ValueError as ve:
2207             yield current_line
2208
2209             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2210             current_line.append(leaf)
2211
2212     for index, leaf in enumerate(line.leaves):
2213         yield from append_to_line(leaf)
2214
2215         for comment_after in line.comments_after(leaf, index):
2216             yield from append_to_line(comment_after)
2217
2218         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2219         if leaf.bracket_depth == lowest_depth and is_vararg(
2220             leaf, within=VARARGS_PARENTS
2221         ):
2222             trailing_comma_safe = trailing_comma_safe and py36
2223         leaf_priority = bt.delimiters.get(id(leaf))
2224         if leaf_priority == delimiter_priority:
2225             yield current_line
2226
2227             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2228     if current_line:
2229         if (
2230             trailing_comma_safe
2231             and delimiter_priority == COMMA_PRIORITY
2232             and current_line.leaves[-1].type != token.COMMA
2233             and current_line.leaves[-1].type != STANDALONE_COMMENT
2234         ):
2235             current_line.append(Leaf(token.COMMA, ","))
2236         yield current_line
2237
2238
2239 @dont_increase_indentation
2240 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2241     """Split standalone comments from the rest of the line."""
2242     if not line.contains_standalone_comments(0):
2243         raise CannotSplit("Line does not have any standalone comments")
2244
2245     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2246
2247     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2248         """Append `leaf` to current line or to new line if appending impossible."""
2249         nonlocal current_line
2250         try:
2251             current_line.append_safe(leaf, preformatted=True)
2252         except ValueError as ve:
2253             yield current_line
2254
2255             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2256             current_line.append(leaf)
2257
2258     for index, leaf in enumerate(line.leaves):
2259         yield from append_to_line(leaf)
2260
2261         for comment_after in line.comments_after(leaf, index):
2262             yield from append_to_line(comment_after)
2263
2264     if current_line:
2265         yield current_line
2266
2267
2268 def is_import(leaf: Leaf) -> bool:
2269     """Return True if the given leaf starts an import statement."""
2270     p = leaf.parent
2271     t = leaf.type
2272     v = leaf.value
2273     return bool(
2274         t == token.NAME
2275         and (
2276             (v == "import" and p and p.type == syms.import_name)
2277             or (v == "from" and p and p.type == syms.import_from)
2278         )
2279     )
2280
2281
2282 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2283     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2284     else.
2285
2286     Note: don't use backslashes for formatting or you'll lose your voting rights.
2287     """
2288     if not inside_brackets:
2289         spl = leaf.prefix.split("#")
2290         if "\\" not in spl[0]:
2291             nl_count = spl[-1].count("\n")
2292             if len(spl) > 1:
2293                 nl_count -= 1
2294             leaf.prefix = "\n" * nl_count
2295             return
2296
2297     leaf.prefix = ""
2298
2299
2300 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2301     """Make all string prefixes lowercase.
2302
2303     If remove_u_prefix is given, also removes any u prefix from the string.
2304
2305     Note: Mutates its argument.
2306     """
2307     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2308     assert match is not None, f"failed to match string {leaf.value!r}"
2309     orig_prefix = match.group(1)
2310     new_prefix = orig_prefix.lower()
2311     if remove_u_prefix:
2312         new_prefix = new_prefix.replace("u", "")
2313     leaf.value = f"{new_prefix}{match.group(2)}"
2314
2315
2316 def normalize_string_quotes(leaf: Leaf) -> None:
2317     """Prefer double quotes but only if it doesn't cause more escaping.
2318
2319     Adds or removes backslashes as appropriate. Doesn't parse and fix
2320     strings nested in f-strings (yet).
2321
2322     Note: Mutates its argument.
2323     """
2324     value = leaf.value.lstrip("furbFURB")
2325     if value[:3] == '"""':
2326         return
2327
2328     elif value[:3] == "'''":
2329         orig_quote = "'''"
2330         new_quote = '"""'
2331     elif value[0] == '"':
2332         orig_quote = '"'
2333         new_quote = "'"
2334     else:
2335         orig_quote = "'"
2336         new_quote = '"'
2337     first_quote_pos = leaf.value.find(orig_quote)
2338     if first_quote_pos == -1:
2339         return  # There's an internal error
2340
2341     prefix = leaf.value[:first_quote_pos]
2342     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2343     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2344     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2345     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2346     if "r" in prefix.casefold():
2347         if unescaped_new_quote.search(body):
2348             # There's at least one unescaped new_quote in this raw string
2349             # so converting is impossible
2350             return
2351
2352         # Do not introduce or remove backslashes in raw strings
2353         new_body = body
2354     else:
2355         # remove unnecessary quotes
2356         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2357         if body != new_body:
2358             # Consider the string without unnecessary quotes as the original
2359             body = new_body
2360             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2361         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2362         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2363     if new_quote == '"""' and new_body[-1] == '"':
2364         # edge case:
2365         new_body = new_body[:-1] + '\\"'
2366     orig_escape_count = body.count("\\")
2367     new_escape_count = new_body.count("\\")
2368     if new_escape_count > orig_escape_count:
2369         return  # Do not introduce more escaping
2370
2371     if new_escape_count == orig_escape_count and orig_quote == '"':
2372         return  # Prefer double quotes
2373
2374     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2375
2376
2377 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2378     """Make existing optional parentheses invisible or create new ones.
2379
2380     `parens_after` is a set of string leaf values immeditely after which parens
2381     should be put.
2382
2383     Standardizes on visible parentheses for single-element tuples, and keeps
2384     existing visible parentheses for other tuples and generator expressions.
2385     """
2386     try:
2387         list(generate_comments(node))
2388     except FormatOff:
2389         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2390
2391     check_lpar = False
2392     for index, child in enumerate(list(node.children)):
2393         if check_lpar:
2394             if child.type == syms.atom:
2395                 maybe_make_parens_invisible_in_atom(child)
2396             elif is_one_tuple(child):
2397                 # wrap child in visible parentheses
2398                 lpar = Leaf(token.LPAR, "(")
2399                 rpar = Leaf(token.RPAR, ")")
2400                 child.remove()
2401                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2402             elif node.type == syms.import_from:
2403                 # "import from" nodes store parentheses directly as part of
2404                 # the statement
2405                 if child.type == token.LPAR:
2406                     # make parentheses invisible
2407                     child.value = ""  # type: ignore
2408                     node.children[-1].value = ""  # type: ignore
2409                 elif child.type != token.STAR:
2410                     # insert invisible parentheses
2411                     node.insert_child(index, Leaf(token.LPAR, ""))
2412                     node.append_child(Leaf(token.RPAR, ""))
2413                 break
2414
2415             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2416                 # wrap child in invisible parentheses
2417                 lpar = Leaf(token.LPAR, "")
2418                 rpar = Leaf(token.RPAR, "")
2419                 index = child.remove() or 0
2420                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2421
2422         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2423
2424
2425 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2426     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2427     if (
2428         node.type != syms.atom
2429         or is_empty_tuple(node)
2430         or is_one_tuple(node)
2431         or is_yield(node)
2432         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2433     ):
2434         return False
2435
2436     first = node.children[0]
2437     last = node.children[-1]
2438     if first.type == token.LPAR and last.type == token.RPAR:
2439         # make parentheses invisible
2440         first.value = ""  # type: ignore
2441         last.value = ""  # type: ignore
2442         if len(node.children) > 1:
2443             maybe_make_parens_invisible_in_atom(node.children[1])
2444         return True
2445
2446     return False
2447
2448
2449 def is_empty_tuple(node: LN) -> bool:
2450     """Return True if `node` holds an empty tuple."""
2451     return (
2452         node.type == syms.atom
2453         and len(node.children) == 2
2454         and node.children[0].type == token.LPAR
2455         and node.children[1].type == token.RPAR
2456     )
2457
2458
2459 def is_one_tuple(node: LN) -> bool:
2460     """Return True if `node` holds a tuple with one element, with or without parens."""
2461     if node.type == syms.atom:
2462         if len(node.children) != 3:
2463             return False
2464
2465         lpar, gexp, rpar = node.children
2466         if not (
2467             lpar.type == token.LPAR
2468             and gexp.type == syms.testlist_gexp
2469             and rpar.type == token.RPAR
2470         ):
2471             return False
2472
2473         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2474
2475     return (
2476         node.type in IMPLICIT_TUPLE
2477         and len(node.children) == 2
2478         and node.children[1].type == token.COMMA
2479     )
2480
2481
2482 def is_yield(node: LN) -> bool:
2483     """Return True if `node` holds a `yield` or `yield from` expression."""
2484     if node.type == syms.yield_expr:
2485         return True
2486
2487     if node.type == token.NAME and node.value == "yield":  # type: ignore
2488         return True
2489
2490     if node.type != syms.atom:
2491         return False
2492
2493     if len(node.children) != 3:
2494         return False
2495
2496     lpar, expr, rpar = node.children
2497     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2498         return is_yield(expr)
2499
2500     return False
2501
2502
2503 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2504     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2505
2506     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2507     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2508     extended iterable unpacking (PEP 3132) and additional unpacking
2509     generalizations (PEP 448).
2510     """
2511     if leaf.type not in STARS or not leaf.parent:
2512         return False
2513
2514     p = leaf.parent
2515     if p.type == syms.star_expr:
2516         # Star expressions are also used as assignment targets in extended
2517         # iterable unpacking (PEP 3132).  See what its parent is instead.
2518         if not p.parent:
2519             return False
2520
2521         p = p.parent
2522
2523     return p.type in within
2524
2525
2526 def is_multiline_string(leaf: Leaf) -> bool:
2527     """Return True if `leaf` is a multiline string that actually spans many lines."""
2528     value = leaf.value.lstrip("furbFURB")
2529     return value[:3] in {'"""', "'''"} and "\n" in value
2530
2531
2532 def is_stub_suite(node: Node) -> bool:
2533     """Return True if `node` is a suite with a stub body."""
2534     if (
2535         len(node.children) != 4
2536         or node.children[0].type != token.NEWLINE
2537         or node.children[1].type != token.INDENT
2538         or node.children[3].type != token.DEDENT
2539     ):
2540         return False
2541
2542     return is_stub_body(node.children[2])
2543
2544
2545 def is_stub_body(node: LN) -> bool:
2546     """Return True if `node` is a simple statement containing an ellipsis."""
2547     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2548         return False
2549
2550     if len(node.children) != 2:
2551         return False
2552
2553     child = node.children[0]
2554     return (
2555         child.type == syms.atom
2556         and len(child.children) == 3
2557         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2558     )
2559
2560
2561 def max_delimiter_priority_in_atom(node: LN) -> int:
2562     """Return maximum delimiter priority inside `node`.
2563
2564     This is specific to atoms with contents contained in a pair of parentheses.
2565     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2566     """
2567     if node.type != syms.atom:
2568         return 0
2569
2570     first = node.children[0]
2571     last = node.children[-1]
2572     if not (first.type == token.LPAR and last.type == token.RPAR):
2573         return 0
2574
2575     bt = BracketTracker()
2576     for c in node.children[1:-1]:
2577         if isinstance(c, Leaf):
2578             bt.mark(c)
2579         else:
2580             for leaf in c.leaves():
2581                 bt.mark(leaf)
2582     try:
2583         return bt.max_delimiter_priority()
2584
2585     except ValueError:
2586         return 0
2587
2588
2589 def ensure_visible(leaf: Leaf) -> None:
2590     """Make sure parentheses are visible.
2591
2592     They could be invisible as part of some statements (see
2593     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2594     """
2595     if leaf.type == token.LPAR:
2596         leaf.value = "("
2597     elif leaf.type == token.RPAR:
2598         leaf.value = ")"
2599
2600
2601 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2602     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2603     if not (
2604         opening_bracket.parent
2605         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2606         and opening_bracket.value in "[{("
2607     ):
2608         return False
2609
2610     try:
2611         last_leaf = line.leaves[-1]
2612         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2613         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2614     except (IndexError, ValueError):
2615         return False
2616
2617     return max_priority == COMMA_PRIORITY
2618
2619
2620 def is_python36(node: Node) -> bool:
2621     """Return True if the current file is using Python 3.6+ features.
2622
2623     Currently looking for:
2624     - f-strings; and
2625     - trailing commas after * or ** in function signatures and calls.
2626     """
2627     for n in node.pre_order():
2628         if n.type == token.STRING:
2629             value_head = n.value[:2]  # type: ignore
2630             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2631                 return True
2632
2633         elif (
2634             n.type in {syms.typedargslist, syms.arglist}
2635             and n.children
2636             and n.children[-1].type == token.COMMA
2637         ):
2638             for ch in n.children:
2639                 if ch.type in STARS:
2640                     return True
2641
2642                 if ch.type == syms.argument:
2643                     for argch in ch.children:
2644                         if argch.type in STARS:
2645                             return True
2646
2647     return False
2648
2649
2650 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2651     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2652
2653     Brackets can be omitted if the entire trailer up to and including
2654     a preceding closing bracket fits in one line.
2655
2656     Yielded sets are cumulative (contain results of previous yields, too).  First
2657     set is empty.
2658     """
2659
2660     omit: Set[LeafID] = set()
2661     yield omit
2662
2663     length = 4 * line.depth
2664     opening_bracket = None
2665     closing_bracket = None
2666     optional_brackets: Set[LeafID] = set()
2667     inner_brackets: Set[LeafID] = set()
2668     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2669         length += leaf_length
2670         if length > line_length:
2671             break
2672
2673         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2674         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2675             break
2676
2677         optional_brackets.discard(id(leaf))
2678         if opening_bracket:
2679             if leaf is opening_bracket:
2680                 opening_bracket = None
2681             elif leaf.type in CLOSING_BRACKETS:
2682                 inner_brackets.add(id(leaf))
2683         elif leaf.type in CLOSING_BRACKETS:
2684             if not leaf.value:
2685                 optional_brackets.add(id(opening_bracket))
2686                 continue
2687
2688             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2689                 # Empty brackets would fail a split so treat them as "inner"
2690                 # brackets (e.g. only add them to the `omit` set if another
2691                 # pair of brackets was good enough.
2692                 inner_brackets.add(id(leaf))
2693                 continue
2694
2695             opening_bracket = leaf.opening_bracket
2696             if closing_bracket:
2697                 omit.add(id(closing_bracket))
2698                 omit.update(inner_brackets)
2699                 inner_brackets.clear()
2700                 yield omit
2701             closing_bracket = leaf
2702
2703
2704 def get_future_imports(node: Node) -> Set[str]:
2705     """Return a set of __future__ imports in the file."""
2706     imports = set()
2707     for child in node.children:
2708         if child.type != syms.simple_stmt:
2709             break
2710         first_child = child.children[0]
2711         if isinstance(first_child, Leaf):
2712             # Continue looking if we see a docstring; otherwise stop.
2713             if (
2714                 len(child.children) == 2
2715                 and first_child.type == token.STRING
2716                 and child.children[1].type == token.NEWLINE
2717             ):
2718                 continue
2719             else:
2720                 break
2721         elif first_child.type == syms.import_from:
2722             module_name = first_child.children[1]
2723             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2724                 break
2725             for import_from_child in first_child.children[3:]:
2726                 if isinstance(import_from_child, Leaf):
2727                     if import_from_child.type == token.NAME:
2728                         imports.add(import_from_child.value)
2729                 else:
2730                     assert import_from_child.type == syms.import_as_names
2731                     for leaf in import_from_child.children:
2732                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2733                             imports.add(leaf.value)
2734         else:
2735             break
2736     return imports
2737
2738
2739 PYTHON_EXTENSIONS = {".py", ".pyi"}
2740 BLACKLISTED_DIRECTORIES = {
2741     "build",
2742     "buck-out",
2743     "dist",
2744     "_build",
2745     ".git",
2746     ".hg",
2747     ".mypy_cache",
2748     ".tox",
2749     ".venv",
2750 }
2751
2752
2753 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2754     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2755     and have one of the PYTHON_EXTENSIONS.
2756     """
2757     for child in path.iterdir():
2758         if child.is_dir():
2759             if child.name in BLACKLISTED_DIRECTORIES:
2760                 continue
2761
2762             yield from gen_python_files_in_dir(child)
2763
2764         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2765             yield child
2766
2767
2768 @dataclass
2769 class Report:
2770     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2771     check: bool = False
2772     quiet: bool = False
2773     change_count: int = 0
2774     same_count: int = 0
2775     failure_count: int = 0
2776
2777     def done(self, src: Path, changed: Changed) -> None:
2778         """Increment the counter for successful reformatting. Write out a message."""
2779         if changed is Changed.YES:
2780             reformatted = "would reformat" if self.check else "reformatted"
2781             if not self.quiet:
2782                 out(f"{reformatted} {src}")
2783             self.change_count += 1
2784         else:
2785             if not self.quiet:
2786                 if changed is Changed.NO:
2787                     msg = f"{src} already well formatted, good job."
2788                 else:
2789                     msg = f"{src} wasn't modified on disk since last run."
2790                 out(msg, bold=False)
2791             self.same_count += 1
2792
2793     def failed(self, src: Path, message: str) -> None:
2794         """Increment the counter for failed reformatting. Write out a message."""
2795         err(f"error: cannot format {src}: {message}")
2796         self.failure_count += 1
2797
2798     @property
2799     def return_code(self) -> int:
2800         """Return the exit code that the app should use.
2801
2802         This considers the current state of changed files and failures:
2803         - if there were any failures, return 123;
2804         - if any files were changed and --check is being used, return 1;
2805         - otherwise return 0.
2806         """
2807         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2808         # 126 we have special returncodes reserved by the shell.
2809         if self.failure_count:
2810             return 123
2811
2812         elif self.change_count and self.check:
2813             return 1
2814
2815         return 0
2816
2817     def __str__(self) -> str:
2818         """Render a color report of the current state.
2819
2820         Use `click.unstyle` to remove colors.
2821         """
2822         if self.check:
2823             reformatted = "would be reformatted"
2824             unchanged = "would be left unchanged"
2825             failed = "would fail to reformat"
2826         else:
2827             reformatted = "reformatted"
2828             unchanged = "left unchanged"
2829             failed = "failed to reformat"
2830         report = []
2831         if self.change_count:
2832             s = "s" if self.change_count > 1 else ""
2833             report.append(
2834                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2835             )
2836         if self.same_count:
2837             s = "s" if self.same_count > 1 else ""
2838             report.append(f"{self.same_count} file{s} {unchanged}")
2839         if self.failure_count:
2840             s = "s" if self.failure_count > 1 else ""
2841             report.append(
2842                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2843             )
2844         return ", ".join(report) + "."
2845
2846
2847 def assert_equivalent(src: str, dst: str) -> None:
2848     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2849
2850     import ast
2851     import traceback
2852
2853     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2854         """Simple visitor generating strings to compare ASTs by content."""
2855         yield f"{'  ' * depth}{node.__class__.__name__}("
2856
2857         for field in sorted(node._fields):
2858             try:
2859                 value = getattr(node, field)
2860             except AttributeError:
2861                 continue
2862
2863             yield f"{'  ' * (depth+1)}{field}="
2864
2865             if isinstance(value, list):
2866                 for item in value:
2867                     if isinstance(item, ast.AST):
2868                         yield from _v(item, depth + 2)
2869
2870             elif isinstance(value, ast.AST):
2871                 yield from _v(value, depth + 2)
2872
2873             else:
2874                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2875
2876         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2877
2878     try:
2879         src_ast = ast.parse(src)
2880     except Exception as exc:
2881         major, minor = sys.version_info[:2]
2882         raise AssertionError(
2883             f"cannot use --safe with this file; failed to parse source file "
2884             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2885             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2886         )
2887
2888     try:
2889         dst_ast = ast.parse(dst)
2890     except Exception as exc:
2891         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2892         raise AssertionError(
2893             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2894             f"Please report a bug on https://github.com/ambv/black/issues.  "
2895             f"This invalid output might be helpful: {log}"
2896         ) from None
2897
2898     src_ast_str = "\n".join(_v(src_ast))
2899     dst_ast_str = "\n".join(_v(dst_ast))
2900     if src_ast_str != dst_ast_str:
2901         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2902         raise AssertionError(
2903             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2904             f"the source.  "
2905             f"Please report a bug on https://github.com/ambv/black/issues.  "
2906             f"This diff might be helpful: {log}"
2907         ) from None
2908
2909
2910 def assert_stable(
2911     src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
2912 ) -> None:
2913     """Raise AssertionError if `dst` reformats differently the second time."""
2914     newdst = format_str(
2915         dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
2916     )
2917     if dst != newdst:
2918         log = dump_to_file(
2919             diff(src, dst, "source", "first pass"),
2920             diff(dst, newdst, "first pass", "second pass"),
2921         )
2922         raise AssertionError(
2923             f"INTERNAL ERROR: Black produced different code on the second pass "
2924             f"of the formatter.  "
2925             f"Please report a bug on https://github.com/ambv/black/issues.  "
2926             f"This diff might be helpful: {log}"
2927         ) from None
2928
2929
2930 def dump_to_file(*output: str) -> str:
2931     """Dump `output` to a temporary file. Return path to the file."""
2932     import tempfile
2933
2934     with tempfile.NamedTemporaryFile(
2935         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2936     ) as f:
2937         for lines in output:
2938             f.write(lines)
2939             if lines and lines[-1] != "\n":
2940                 f.write("\n")
2941     return f.name
2942
2943
2944 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2945     """Return a unified diff string between strings `a` and `b`."""
2946     import difflib
2947
2948     a_lines = [line + "\n" for line in a.split("\n")]
2949     b_lines = [line + "\n" for line in b.split("\n")]
2950     return "".join(
2951         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2952     )
2953
2954
2955 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2956     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2957     err("Aborted!")
2958     for task in tasks:
2959         task.cancel()
2960
2961
2962 def shutdown(loop: BaseEventLoop) -> None:
2963     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2964     try:
2965         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2966         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2967         if not to_cancel:
2968             return
2969
2970         for task in to_cancel:
2971             task.cancel()
2972         loop.run_until_complete(
2973             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2974         )
2975     finally:
2976         # `concurrent.futures.Future` objects cannot be cancelled once they
2977         # are already running. There might be some when the `shutdown()` happened.
2978         # Silence their logger's spew about the event loop being closed.
2979         cf_logger = logging.getLogger("concurrent.futures")
2980         cf_logger.setLevel(logging.CRITICAL)
2981         loop.close()
2982
2983
2984 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2985     """Replace `regex` with `replacement` twice on `original`.
2986
2987     This is used by string normalization to perform replaces on
2988     overlapping matches.
2989     """
2990     return regex.sub(replacement, regex.sub(replacement, original))
2991
2992
2993 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2994     """Like `reversed(enumerate(sequence))` if that were possible."""
2995     index = len(sequence) - 1
2996     for element in reversed(sequence):
2997         yield (index, element)
2998         index -= 1
2999
3000
3001 def enumerate_with_length(
3002     line: Line, reversed: bool = False
3003 ) -> Iterator[Tuple[Index, Leaf, int]]:
3004     """Return an enumeration of leaves with their length.
3005
3006     Stops prematurely on multiline strings and standalone comments.
3007     """
3008     op = cast(
3009         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3010         enumerate_reversed if reversed else enumerate,
3011     )
3012     for index, leaf in op(line.leaves):
3013         length = len(leaf.prefix) + len(leaf.value)
3014         if "\n" in leaf.value:
3015             return  # Multiline strings, we can't continue.
3016
3017         comment: Optional[Leaf]
3018         for comment in line.comments_after(leaf, index):
3019             length += len(comment.value)
3020
3021         yield index, leaf, length
3022
3023
3024 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3025     """Return True if `line` is no longer than `line_length`.
3026
3027     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3028     """
3029     if not line_str:
3030         line_str = str(line).strip("\n")
3031     return (
3032         len(line_str) <= line_length
3033         and "\n" not in line_str  # multiline strings
3034         and not line.contains_standalone_comments()
3035     )
3036
3037
3038 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3039     """Does `line` have a shape safe to reformat without optional parens around it?
3040
3041     Returns True for only a subset of potentially nice looking formattings but
3042     the point is to not return false positives that end up producing lines that
3043     are too long.
3044     """
3045     bt = line.bracket_tracker
3046     if not bt.delimiters:
3047         # Without delimiters the optional parentheses are useless.
3048         return True
3049
3050     max_priority = bt.max_delimiter_priority()
3051     if bt.delimiter_count_with_priority(max_priority) > 1:
3052         # With more than one delimiter of a kind the optional parentheses read better.
3053         return False
3054
3055     if max_priority == DOT_PRIORITY:
3056         # A single stranded method call doesn't require optional parentheses.
3057         return True
3058
3059     assert len(line.leaves) >= 2, "Stranded delimiter"
3060
3061     first = line.leaves[0]
3062     second = line.leaves[1]
3063     penultimate = line.leaves[-2]
3064     last = line.leaves[-1]
3065
3066     # With a single delimiter, omit if the expression starts or ends with
3067     # a bracket.
3068     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3069         remainder = False
3070         length = 4 * line.depth
3071         for _index, leaf, leaf_length in enumerate_with_length(line):
3072             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3073                 remainder = True
3074             if remainder:
3075                 length += leaf_length
3076                 if length > line_length:
3077                     break
3078
3079                 if leaf.type in OPENING_BRACKETS:
3080                     # There are brackets we can further split on.
3081                     remainder = False
3082
3083         else:
3084             # checked the entire string and line length wasn't exceeded
3085             if len(line.leaves) == _index + 1:
3086                 return True
3087
3088         # Note: we are not returning False here because a line might have *both*
3089         # a leading opening bracket and a trailing closing bracket.  If the
3090         # opening bracket doesn't match our rule, maybe the closing will.
3091
3092     if (
3093         last.type == token.RPAR
3094         or last.type == token.RBRACE
3095         or (
3096             # don't use indexing for omitting optional parentheses;
3097             # it looks weird
3098             last.type == token.RSQB
3099             and last.parent
3100             and last.parent.type != syms.trailer
3101         )
3102     ):
3103         if penultimate.type in OPENING_BRACKETS:
3104             # Empty brackets don't help.
3105             return False
3106
3107         if is_multiline_string(first):
3108             # Additional wrapping of a multiline string in this situation is
3109             # unnecessary.
3110             return True
3111
3112         length = 4 * line.depth
3113         seen_other_brackets = False
3114         for _index, leaf, leaf_length in enumerate_with_length(line):
3115             length += leaf_length
3116             if leaf is last.opening_bracket:
3117                 if seen_other_brackets or length <= line_length:
3118                     return True
3119
3120             elif leaf.type in OPENING_BRACKETS:
3121                 # There are brackets we can further split on.
3122                 seen_other_brackets = True
3123
3124     return False
3125
3126
3127 def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
3128     return (
3129         CACHE_DIR
3130         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3131     )
3132
3133
3134 def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
3135     """Read the cache if it exists and is well formed.
3136
3137     If it is not well formed, the call to write_cache later should resolve the issue.
3138     """
3139     cache_file = get_cache_file(line_length, pyi, py36)
3140     if not cache_file.exists():
3141         return {}
3142
3143     with cache_file.open("rb") as fobj:
3144         try:
3145             cache: Cache = pickle.load(fobj)
3146         except pickle.UnpicklingError:
3147             return {}
3148
3149     return cache
3150
3151
3152 def get_cache_info(path: Path) -> CacheInfo:
3153     """Return the information used to check if a file is already formatted or not."""
3154     stat = path.stat()
3155     return stat.st_mtime, stat.st_size
3156
3157
3158 def filter_cached(
3159     cache: Cache, sources: Iterable[Path]
3160 ) -> Tuple[List[Path], List[Path]]:
3161     """Split a list of paths into two.
3162
3163     The first list contains paths of files that modified on disk or are not in the
3164     cache. The other list contains paths to non-modified files.
3165     """
3166     todo, done = [], []
3167     for src in sources:
3168         src = src.resolve()
3169         if cache.get(src) != get_cache_info(src):
3170             todo.append(src)
3171         else:
3172             done.append(src)
3173     return todo, done
3174
3175
3176 def write_cache(
3177     cache: Cache,
3178     sources: List[Path],
3179     line_length: int,
3180     pyi: bool = False,
3181     py36: bool = False,
3182 ) -> None:
3183     """Update the cache file."""
3184     cache_file = get_cache_file(line_length, pyi, py36)
3185     try:
3186         if not CACHE_DIR.exists():
3187             CACHE_DIR.mkdir(parents=True)
3188         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3189         with cache_file.open("wb") as fobj:
3190             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3191     except OSError:
3192         pass
3193
3194
3195 if __name__ == "__main__":
3196     main()