]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Class new line between docstrings / vars / methods (#219)
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b0"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 @click.command()
126 @click.option(
127     "-l",
128     "--line-length",
129     type=int,
130     default=DEFAULT_LINE_LENGTH,
131     help="How many character per line to allow.",
132     show_default=True,
133 )
134 @click.option(
135     "--check",
136     is_flag=True,
137     help=(
138         "Don't write the files back, just return the status.  Return code 0 "
139         "means nothing would change.  Return code 1 means some files would be "
140         "reformatted.  Return code 123 means there was an internal error."
141     ),
142 )
143 @click.option(
144     "--diff",
145     is_flag=True,
146     help="Don't write the files back, just output a diff for each file on stdout.",
147 )
148 @click.option(
149     "--fast/--safe",
150     is_flag=True,
151     help="If --fast given, skip temporary sanity checks. [default: --safe]",
152 )
153 @click.option(
154     "-q",
155     "--quiet",
156     is_flag=True,
157     help=(
158         "Don't emit non-error messages to stderr. Errors are still emitted, "
159         "silence those with 2>/dev/null."
160     ),
161 )
162 @click.option(
163     "--pyi",
164     is_flag=True,
165     help=(
166         "Consider all input files typing stubs regardless of file extension "
167         "(useful when piping source on standard input)."
168     ),
169 )
170 @click.option(
171     "--py36",
172     is_flag=True,
173     help=(
174         "Allow using Python 3.6-only syntax on all input files.  This will put "
175         "trailing commas in function signatures and calls also after *args and "
176         "**kwargs.  [default: per-file auto-detection]"
177     ),
178 )
179 @click.version_option(version=__version__)
180 @click.argument(
181     "src",
182     nargs=-1,
183     type=click.Path(
184         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
185     ),
186 )
187 @click.pass_context
188 def main(
189     ctx: click.Context,
190     line_length: int,
191     check: bool,
192     diff: bool,
193     fast: bool,
194     pyi: bool,
195     py36: bool,
196     quiet: bool,
197     src: List[str],
198 ) -> None:
199     """The uncompromising code formatter."""
200     sources: List[Path] = []
201     for s in src:
202         p = Path(s)
203         if p.is_dir():
204             sources.extend(gen_python_files_in_dir(p))
205         elif p.is_file():
206             # if a file was explicitly given, we don't care about its extension
207             sources.append(p)
208         elif s == "-":
209             sources.append(Path("-"))
210         else:
211             err(f"invalid path: {s}")
212
213     if check and not diff:
214         write_back = WriteBack.NO
215     elif diff:
216         write_back = WriteBack.DIFF
217     else:
218         write_back = WriteBack.YES
219     report = Report(check=check, quiet=quiet)
220     if len(sources) == 0:
221         out("No paths given. Nothing to do 😴")
222         ctx.exit(0)
223         return
224
225     elif len(sources) == 1:
226         reformat_one(
227             src=sources[0],
228             line_length=line_length,
229             fast=fast,
230             pyi=pyi,
231             py36=py36,
232             write_back=write_back,
233             report=report,
234         )
235     else:
236         loop = asyncio.get_event_loop()
237         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
238         try:
239             loop.run_until_complete(
240                 schedule_formatting(
241                     sources=sources,
242                     line_length=line_length,
243                     fast=fast,
244                     pyi=pyi,
245                     py36=py36,
246                     write_back=write_back,
247                     report=report,
248                     loop=loop,
249                     executor=executor,
250                 )
251             )
252         finally:
253             shutdown(loop)
254         if not quiet:
255             out("All done! ✨ 🍰 ✨")
256             click.echo(str(report))
257     ctx.exit(report.return_code)
258
259
260 def reformat_one(
261     src: Path,
262     line_length: int,
263     fast: bool,
264     pyi: bool,
265     py36: bool,
266     write_back: WriteBack,
267     report: "Report",
268 ) -> None:
269     """Reformat a single file under `src` without spawning child processes.
270
271     If `quiet` is True, non-error messages are not output. `line_length`,
272     `write_back`, `fast` and `pyi` options are passed to
273     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
274     """
275     try:
276         changed = Changed.NO
277         if not src.is_file() and str(src) == "-":
278             if format_stdin_to_stdout(
279                 line_length=line_length,
280                 fast=fast,
281                 is_pyi=pyi,
282                 force_py36=py36,
283                 write_back=write_back,
284             ):
285                 changed = Changed.YES
286         else:
287             cache: Cache = {}
288             if write_back != WriteBack.DIFF:
289                 cache = read_cache(line_length, pyi, py36)
290                 src = src.resolve()
291                 if src in cache and cache[src] == get_cache_info(src):
292                     changed = Changed.CACHED
293             if changed is not Changed.CACHED and format_file_in_place(
294                 src,
295                 line_length=line_length,
296                 fast=fast,
297                 force_pyi=pyi,
298                 force_py36=py36,
299                 write_back=write_back,
300             ):
301                 changed = Changed.YES
302             if write_back == WriteBack.YES and changed is not Changed.NO:
303                 write_cache(cache, [src], line_length, pyi, py36)
304         report.done(src, changed)
305     except Exception as exc:
306         report.failed(src, str(exc))
307
308
309 async def schedule_formatting(
310     sources: List[Path],
311     line_length: int,
312     fast: bool,
313     pyi: bool,
314     py36: bool,
315     write_back: WriteBack,
316     report: "Report",
317     loop: BaseEventLoop,
318     executor: Executor,
319 ) -> None:
320     """Run formatting of `sources` in parallel using the provided `executor`.
321
322     (Use ProcessPoolExecutors for actual parallelism.)
323
324     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
325     :func:`format_file_in_place`.
326     """
327     cache: Cache = {}
328     if write_back != WriteBack.DIFF:
329         cache = read_cache(line_length, pyi, py36)
330         sources, cached = filter_cached(cache, sources)
331         for src in cached:
332             report.done(src, Changed.CACHED)
333     cancelled = []
334     formatted = []
335     if sources:
336         lock = None
337         if write_back == WriteBack.DIFF:
338             # For diff output, we need locks to ensure we don't interleave output
339             # from different processes.
340             manager = Manager()
341             lock = manager.Lock()
342         tasks = {
343             loop.run_in_executor(
344                 executor,
345                 format_file_in_place,
346                 src,
347                 line_length,
348                 fast,
349                 pyi,
350                 py36,
351                 write_back,
352                 lock,
353             ): src
354             for src in sorted(sources)
355         }
356         pending: Iterable[asyncio.Task] = tasks.keys()
357         try:
358             loop.add_signal_handler(signal.SIGINT, cancel, pending)
359             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
360         except NotImplementedError:
361             # There are no good alternatives for these on Windows
362             pass
363         while pending:
364             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
365             for task in done:
366                 src = tasks.pop(task)
367                 if task.cancelled():
368                     cancelled.append(task)
369                 elif task.exception():
370                     report.failed(src, str(task.exception()))
371                 else:
372                     formatted.append(src)
373                     report.done(src, Changed.YES if task.result() else Changed.NO)
374     if cancelled:
375         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
376     if write_back == WriteBack.YES and formatted:
377         write_cache(cache, formatted, line_length, pyi, py36)
378
379
380 def format_file_in_place(
381     src: Path,
382     line_length: int,
383     fast: bool,
384     force_pyi: bool = False,
385     force_py36: bool = False,
386     write_back: WriteBack = WriteBack.NO,
387     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
388 ) -> bool:
389     """Format file under `src` path. Return True if changed.
390
391     If `write_back` is True, write reformatted code back to stdout.
392     `line_length` and `fast` options are passed to :func:`format_file_contents`.
393     """
394     is_pyi = force_pyi or src.suffix == ".pyi"
395
396     with tokenize.open(src) as src_buffer:
397         src_contents = src_buffer.read()
398     try:
399         dst_contents = format_file_contents(
400             src_contents,
401             line_length=line_length,
402             fast=fast,
403             is_pyi=is_pyi,
404             force_py36=force_py36,
405         )
406     except NothingChanged:
407         return False
408
409     if write_back == write_back.YES:
410         with open(src, "w", encoding=src_buffer.encoding) as f:
411             f.write(dst_contents)
412     elif write_back == write_back.DIFF:
413         src_name = f"{src}  (original)"
414         dst_name = f"{src}  (formatted)"
415         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
416         if lock:
417             lock.acquire()
418         try:
419             sys.stdout.write(diff_contents)
420         finally:
421             if lock:
422                 lock.release()
423     return True
424
425
426 def format_stdin_to_stdout(
427     line_length: int,
428     fast: bool,
429     is_pyi: bool = False,
430     force_py36: bool = False,
431     write_back: WriteBack = WriteBack.NO,
432 ) -> bool:
433     """Format file on stdin. Return True if changed.
434
435     If `write_back` is True, write reformatted code back to stdout.
436     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
437     :func:`format_file_contents`.
438     """
439     src = sys.stdin.read()
440     dst = src
441     try:
442         dst = format_file_contents(
443             src,
444             line_length=line_length,
445             fast=fast,
446             is_pyi=is_pyi,
447             force_py36=force_py36,
448         )
449         return True
450
451     except NothingChanged:
452         return False
453
454     finally:
455         if write_back == WriteBack.YES:
456             sys.stdout.write(dst)
457         elif write_back == WriteBack.DIFF:
458             src_name = "<stdin>  (original)"
459             dst_name = "<stdin>  (formatted)"
460             sys.stdout.write(diff(src, dst, src_name, dst_name))
461
462
463 def format_file_contents(
464     src_contents: str,
465     *,
466     line_length: int,
467     fast: bool,
468     is_pyi: bool = False,
469     force_py36: bool = False,
470 ) -> FileContent:
471     """Reformat contents a file and return new contents.
472
473     If `fast` is False, additionally confirm that the reformatted code is
474     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
475     `line_length` is passed to :func:`format_str`.
476     """
477     if src_contents.strip() == "":
478         raise NothingChanged
479
480     dst_contents = format_str(
481         src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
482     )
483     if src_contents == dst_contents:
484         raise NothingChanged
485
486     if not fast:
487         assert_equivalent(src_contents, dst_contents)
488         assert_stable(
489             src_contents,
490             dst_contents,
491             line_length=line_length,
492             is_pyi=is_pyi,
493             force_py36=force_py36,
494         )
495     return dst_contents
496
497
498 def format_str(
499     src_contents: str,
500     line_length: int,
501     *,
502     is_pyi: bool = False,
503     force_py36: bool = False,
504 ) -> FileContent:
505     """Reformat a string and return new contents.
506
507     `line_length` determines how many characters per line are allowed.
508     """
509     src_node = lib2to3_parse(src_contents)
510     dst_contents = ""
511     future_imports = get_future_imports(src_node)
512     elt = EmptyLineTracker(is_pyi=is_pyi)
513     py36 = force_py36 or is_python36(src_node)
514     lines = LineGenerator(
515         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
516     )
517     empty_line = Line()
518     after = 0
519     for current_line in lines.visit(src_node):
520         for _ in range(after):
521             dst_contents += str(empty_line)
522         before, after = elt.maybe_empty_lines(current_line)
523         for _ in range(before):
524             dst_contents += str(empty_line)
525         for line in split_line(current_line, line_length=line_length, py36=py36):
526             dst_contents += str(line)
527     return dst_contents
528
529
530 GRAMMARS = [
531     pygram.python_grammar_no_print_statement_no_exec_statement,
532     pygram.python_grammar_no_print_statement,
533     pygram.python_grammar,
534 ]
535
536
537 def lib2to3_parse(src_txt: str) -> Node:
538     """Given a string with source, return the lib2to3 Node."""
539     grammar = pygram.python_grammar_no_print_statement
540     if src_txt[-1] != "\n":
541         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
542         src_txt += nl
543     for grammar in GRAMMARS:
544         drv = driver.Driver(grammar, pytree.convert)
545         try:
546             result = drv.parse_string(src_txt, True)
547             break
548
549         except ParseError as pe:
550             lineno, column = pe.context[1]
551             lines = src_txt.splitlines()
552             try:
553                 faulty_line = lines[lineno - 1]
554             except IndexError:
555                 faulty_line = "<line number missing in source>"
556             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
557     else:
558         raise exc from None
559
560     if isinstance(result, Leaf):
561         result = Node(syms.file_input, [result])
562     return result
563
564
565 def lib2to3_unparse(node: Node) -> str:
566     """Given a lib2to3 node, return its string representation."""
567     code = str(node)
568     return code
569
570
571 T = TypeVar("T")
572
573
574 class Visitor(Generic[T]):
575     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
576
577     def visit(self, node: LN) -> Iterator[T]:
578         """Main method to visit `node` and its children.
579
580         It tries to find a `visit_*()` method for the given `node.type`, like
581         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
582         If no dedicated `visit_*()` method is found, chooses `visit_default()`
583         instead.
584
585         Then yields objects of type `T` from the selected visitor.
586         """
587         if node.type < 256:
588             name = token.tok_name[node.type]
589         else:
590             name = type_repr(node.type)
591         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
592
593     def visit_default(self, node: LN) -> Iterator[T]:
594         """Default `visit_*()` implementation. Recurses to children of `node`."""
595         if isinstance(node, Node):
596             for child in node.children:
597                 yield from self.visit(child)
598
599
600 @dataclass
601 class DebugVisitor(Visitor[T]):
602     tree_depth: int = 0
603
604     def visit_default(self, node: LN) -> Iterator[T]:
605         indent = " " * (2 * self.tree_depth)
606         if isinstance(node, Node):
607             _type = type_repr(node.type)
608             out(f"{indent}{_type}", fg="yellow")
609             self.tree_depth += 1
610             for child in node.children:
611                 yield from self.visit(child)
612
613             self.tree_depth -= 1
614             out(f"{indent}/{_type}", fg="yellow", bold=False)
615         else:
616             _type = token.tok_name.get(node.type, str(node.type))
617             out(f"{indent}{_type}", fg="blue", nl=False)
618             if node.prefix:
619                 # We don't have to handle prefixes for `Node` objects since
620                 # that delegates to the first child anyway.
621                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
622             out(f" {node.value!r}", fg="blue", bold=False)
623
624     @classmethod
625     def show(cls, code: str) -> None:
626         """Pretty-print the lib2to3 AST of a given string of `code`.
627
628         Convenience method for debugging.
629         """
630         v: DebugVisitor[None] = DebugVisitor()
631         list(v.visit(lib2to3_parse(code)))
632
633
634 KEYWORDS = set(keyword.kwlist)
635 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
636 FLOW_CONTROL = {"return", "raise", "break", "continue"}
637 STATEMENT = {
638     syms.if_stmt,
639     syms.while_stmt,
640     syms.for_stmt,
641     syms.try_stmt,
642     syms.except_clause,
643     syms.with_stmt,
644     syms.funcdef,
645     syms.classdef,
646 }
647 STANDALONE_COMMENT = 153
648 LOGIC_OPERATORS = {"and", "or"}
649 COMPARATORS = {
650     token.LESS,
651     token.GREATER,
652     token.EQEQUAL,
653     token.NOTEQUAL,
654     token.LESSEQUAL,
655     token.GREATEREQUAL,
656 }
657 MATH_OPERATORS = {
658     token.VBAR,
659     token.CIRCUMFLEX,
660     token.AMPER,
661     token.LEFTSHIFT,
662     token.RIGHTSHIFT,
663     token.PLUS,
664     token.MINUS,
665     token.STAR,
666     token.SLASH,
667     token.DOUBLESLASH,
668     token.PERCENT,
669     token.AT,
670     token.TILDE,
671     token.DOUBLESTAR,
672 }
673 STARS = {token.STAR, token.DOUBLESTAR}
674 VARARGS_PARENTS = {
675     syms.arglist,
676     syms.argument,  # double star in arglist
677     syms.trailer,  # single argument to call
678     syms.typedargslist,
679     syms.varargslist,  # lambdas
680 }
681 UNPACKING_PARENTS = {
682     syms.atom,  # single element of a list or set literal
683     syms.dictsetmaker,
684     syms.listmaker,
685     syms.testlist_gexp,
686 }
687 TEST_DESCENDANTS = {
688     syms.test,
689     syms.lambdef,
690     syms.or_test,
691     syms.and_test,
692     syms.not_test,
693     syms.comparison,
694     syms.star_expr,
695     syms.expr,
696     syms.xor_expr,
697     syms.and_expr,
698     syms.shift_expr,
699     syms.arith_expr,
700     syms.trailer,
701     syms.term,
702     syms.power,
703 }
704 ASSIGNMENTS = {
705     "=",
706     "+=",
707     "-=",
708     "*=",
709     "@=",
710     "/=",
711     "%=",
712     "&=",
713     "|=",
714     "^=",
715     "<<=",
716     ">>=",
717     "**=",
718     "//=",
719 }
720 COMPREHENSION_PRIORITY = 20
721 COMMA_PRIORITY = 18
722 TERNARY_PRIORITY = 16
723 LOGIC_PRIORITY = 14
724 STRING_PRIORITY = 12
725 COMPARATOR_PRIORITY = 10
726 MATH_PRIORITIES = {
727     token.VBAR: 9,
728     token.CIRCUMFLEX: 8,
729     token.AMPER: 7,
730     token.LEFTSHIFT: 6,
731     token.RIGHTSHIFT: 6,
732     token.PLUS: 5,
733     token.MINUS: 5,
734     token.STAR: 4,
735     token.SLASH: 4,
736     token.DOUBLESLASH: 4,
737     token.PERCENT: 4,
738     token.AT: 4,
739     token.TILDE: 3,
740     token.DOUBLESTAR: 2,
741 }
742 DOT_PRIORITY = 1
743
744
745 @dataclass
746 class BracketTracker:
747     """Keeps track of brackets on a line."""
748
749     depth: int = 0
750     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
751     delimiters: Dict[LeafID, Priority] = Factory(dict)
752     previous: Optional[Leaf] = None
753     _for_loop_variable: int = 0
754     _lambda_arguments: int = 0
755
756     def mark(self, leaf: Leaf) -> None:
757         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
758
759         All leaves receive an int `bracket_depth` field that stores how deep
760         within brackets a given leaf is. 0 means there are no enclosing brackets
761         that started on this line.
762
763         If a leaf is itself a closing bracket, it receives an `opening_bracket`
764         field that it forms a pair with. This is a one-directional link to
765         avoid reference cycles.
766
767         If a leaf is a delimiter (a token on which Black can split the line if
768         needed) and it's on depth 0, its `id()` is stored in the tracker's
769         `delimiters` field.
770         """
771         if leaf.type == token.COMMENT:
772             return
773
774         self.maybe_decrement_after_for_loop_variable(leaf)
775         self.maybe_decrement_after_lambda_arguments(leaf)
776         if leaf.type in CLOSING_BRACKETS:
777             self.depth -= 1
778             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
779             leaf.opening_bracket = opening_bracket
780         leaf.bracket_depth = self.depth
781         if self.depth == 0:
782             delim = is_split_before_delimiter(leaf, self.previous)
783             if delim and self.previous is not None:
784                 self.delimiters[id(self.previous)] = delim
785             else:
786                 delim = is_split_after_delimiter(leaf, self.previous)
787                 if delim:
788                     self.delimiters[id(leaf)] = delim
789         if leaf.type in OPENING_BRACKETS:
790             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
791             self.depth += 1
792         self.previous = leaf
793         self.maybe_increment_lambda_arguments(leaf)
794         self.maybe_increment_for_loop_variable(leaf)
795
796     def any_open_brackets(self) -> bool:
797         """Return True if there is an yet unmatched open bracket on the line."""
798         return bool(self.bracket_match)
799
800     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
801         """Return the highest priority of a delimiter found on the line.
802
803         Values are consistent with what `is_split_*_delimiter()` return.
804         Raises ValueError on no delimiters.
805         """
806         return max(v for k, v in self.delimiters.items() if k not in exclude)
807
808     def delimiter_count_with_priority(self, priority: int = 0) -> int:
809         """Return the number of delimiters with the given `priority`.
810
811         If no `priority` is passed, defaults to max priority on the line.
812         """
813         if not self.delimiters:
814             return 0
815
816         priority = priority or self.max_delimiter_priority()
817         return sum(1 for p in self.delimiters.values() if p == priority)
818
819     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
820         """In a for loop, or comprehension, the variables are often unpacks.
821
822         To avoid splitting on the comma in this situation, increase the depth of
823         tokens between `for` and `in`.
824         """
825         if leaf.type == token.NAME and leaf.value == "for":
826             self.depth += 1
827             self._for_loop_variable += 1
828             return True
829
830         return False
831
832     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
833         """See `maybe_increment_for_loop_variable` above for explanation."""
834         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
835             self.depth -= 1
836             self._for_loop_variable -= 1
837             return True
838
839         return False
840
841     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
842         """In a lambda expression, there might be more than one argument.
843
844         To avoid splitting on the comma in this situation, increase the depth of
845         tokens between `lambda` and `:`.
846         """
847         if leaf.type == token.NAME and leaf.value == "lambda":
848             self.depth += 1
849             self._lambda_arguments += 1
850             return True
851
852         return False
853
854     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
855         """See `maybe_increment_lambda_arguments` above for explanation."""
856         if self._lambda_arguments and leaf.type == token.COLON:
857             self.depth -= 1
858             self._lambda_arguments -= 1
859             return True
860
861         return False
862
863     def get_open_lsqb(self) -> Optional[Leaf]:
864         """Return the most recent opening square bracket (if any)."""
865         return self.bracket_match.get((self.depth - 1, token.RSQB))
866
867
868 @dataclass
869 class Line:
870     """Holds leaves and comments. Can be printed with `str(line)`."""
871
872     depth: int = 0
873     leaves: List[Leaf] = Factory(list)
874     comments: List[Tuple[Index, Leaf]] = Factory(list)
875     bracket_tracker: BracketTracker = Factory(BracketTracker)
876     inside_brackets: bool = False
877     should_explode: bool = False
878
879     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
880         """Add a new `leaf` to the end of the line.
881
882         Unless `preformatted` is True, the `leaf` will receive a new consistent
883         whitespace prefix and metadata applied by :class:`BracketTracker`.
884         Trailing commas are maybe removed, unpacked for loop variables are
885         demoted from being delimiters.
886
887         Inline comments are put aside.
888         """
889         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
890         if not has_value:
891             return
892
893         if token.COLON == leaf.type and self.is_class_paren_empty:
894             del self.leaves[-2:]
895         if self.leaves and not preformatted:
896             # Note: at this point leaf.prefix should be empty except for
897             # imports, for which we only preserve newlines.
898             leaf.prefix += whitespace(
899                 leaf, complex_subscript=self.is_complex_subscript(leaf)
900             )
901         if self.inside_brackets or not preformatted:
902             self.bracket_tracker.mark(leaf)
903             self.maybe_remove_trailing_comma(leaf)
904         if not self.append_comment(leaf):
905             self.leaves.append(leaf)
906
907     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
908         """Like :func:`append()` but disallow invalid standalone comment structure.
909
910         Raises ValueError when any `leaf` is appended after a standalone comment
911         or when a standalone comment is not the first leaf on the line.
912         """
913         if self.bracket_tracker.depth == 0:
914             if self.is_comment:
915                 raise ValueError("cannot append to standalone comments")
916
917             if self.leaves and leaf.type == STANDALONE_COMMENT:
918                 raise ValueError(
919                     "cannot append standalone comments to a populated line"
920                 )
921
922         self.append(leaf, preformatted=preformatted)
923
924     @property
925     def is_comment(self) -> bool:
926         """Is this line a standalone comment?"""
927         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
928
929     @property
930     def is_decorator(self) -> bool:
931         """Is this line a decorator?"""
932         return bool(self) and self.leaves[0].type == token.AT
933
934     @property
935     def is_import(self) -> bool:
936         """Is this an import line?"""
937         return bool(self) and is_import(self.leaves[0])
938
939     @property
940     def is_class(self) -> bool:
941         """Is this line a class definition?"""
942         return (
943             bool(self)
944             and self.leaves[0].type == token.NAME
945             and self.leaves[0].value == "class"
946         )
947
948     @property
949     def is_stub_class(self) -> bool:
950         """Is this line a class definition with a body consisting only of "..."?"""
951         return self.is_class and self.leaves[-3:] == [
952             Leaf(token.DOT, ".") for _ in range(3)
953         ]
954
955     @property
956     def is_def(self) -> bool:
957         """Is this a function definition? (Also returns True for async defs.)"""
958         try:
959             first_leaf = self.leaves[0]
960         except IndexError:
961             return False
962
963         try:
964             second_leaf: Optional[Leaf] = self.leaves[1]
965         except IndexError:
966             second_leaf = None
967         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
968             first_leaf.type == token.ASYNC
969             and second_leaf is not None
970             and second_leaf.type == token.NAME
971             and second_leaf.value == "def"
972         )
973
974     @property
975     def is_class_paren_empty(self) -> bool:
976         """Is this a class with no base classes but using parentheses?
977
978         Those are unnecessary and should be removed.
979         """
980         return (
981             bool(self)
982             and len(self.leaves) == 4
983             and self.is_class
984             and self.leaves[2].type == token.LPAR
985             and self.leaves[2].value == "("
986             and self.leaves[3].type == token.RPAR
987             and self.leaves[3].value == ")"
988         )
989
990     @property
991     def is_triple_quoted_string(self) -> bool:
992         """Is the line a triple quoted docstring?"""
993         return (
994             bool(self)
995             and self.leaves[0].type == token.STRING
996             and (
997                 self.leaves[0].value.startswith('"""')
998                 or self.leaves[0].value.startswith("'''")
999             )
1000         )
1001
1002     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1003         """If so, needs to be split before emitting."""
1004         for leaf in self.leaves:
1005             if leaf.type == STANDALONE_COMMENT:
1006                 if leaf.bracket_depth <= depth_limit:
1007                     return True
1008
1009         return False
1010
1011     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1012         """Remove trailing comma if there is one and it's safe."""
1013         if not (
1014             self.leaves
1015             and self.leaves[-1].type == token.COMMA
1016             and closing.type in CLOSING_BRACKETS
1017         ):
1018             return False
1019
1020         if closing.type == token.RBRACE:
1021             self.remove_trailing_comma()
1022             return True
1023
1024         if closing.type == token.RSQB:
1025             comma = self.leaves[-1]
1026             if comma.parent and comma.parent.type == syms.listmaker:
1027                 self.remove_trailing_comma()
1028                 return True
1029
1030         # For parens let's check if it's safe to remove the comma.
1031         # Imports are always safe.
1032         if self.is_import:
1033             self.remove_trailing_comma()
1034             return True
1035
1036         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1037         # change a tuple into a different type by removing the comma.
1038         depth = closing.bracket_depth + 1
1039         commas = 0
1040         opening = closing.opening_bracket
1041         for _opening_index, leaf in enumerate(self.leaves):
1042             if leaf is opening:
1043                 break
1044
1045         else:
1046             return False
1047
1048         for leaf in self.leaves[_opening_index + 1 :]:
1049             if leaf is closing:
1050                 break
1051
1052             bracket_depth = leaf.bracket_depth
1053             if bracket_depth == depth and leaf.type == token.COMMA:
1054                 commas += 1
1055                 if leaf.parent and leaf.parent.type == syms.arglist:
1056                     commas += 1
1057                     break
1058
1059         if commas > 1:
1060             self.remove_trailing_comma()
1061             return True
1062
1063         return False
1064
1065     def append_comment(self, comment: Leaf) -> bool:
1066         """Add an inline or standalone comment to the line."""
1067         if (
1068             comment.type == STANDALONE_COMMENT
1069             and self.bracket_tracker.any_open_brackets()
1070         ):
1071             comment.prefix = ""
1072             return False
1073
1074         if comment.type != token.COMMENT:
1075             return False
1076
1077         after = len(self.leaves) - 1
1078         if after == -1:
1079             comment.type = STANDALONE_COMMENT
1080             comment.prefix = ""
1081             return False
1082
1083         else:
1084             self.comments.append((after, comment))
1085             return True
1086
1087     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1088         """Generate comments that should appear directly after `leaf`.
1089
1090         Provide a non-negative leaf `_index` to speed up the function.
1091         """
1092         if _index == -1:
1093             for _index, _leaf in enumerate(self.leaves):
1094                 if leaf is _leaf:
1095                     break
1096
1097             else:
1098                 return
1099
1100         for index, comment_after in self.comments:
1101             if _index == index:
1102                 yield comment_after
1103
1104     def remove_trailing_comma(self) -> None:
1105         """Remove the trailing comma and moves the comments attached to it."""
1106         comma_index = len(self.leaves) - 1
1107         for i in range(len(self.comments)):
1108             comment_index, comment = self.comments[i]
1109             if comment_index == comma_index:
1110                 self.comments[i] = (comma_index - 1, comment)
1111         self.leaves.pop()
1112
1113     def is_complex_subscript(self, leaf: Leaf) -> bool:
1114         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1115         open_lsqb = (
1116             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1117         )
1118         if open_lsqb is None:
1119             return False
1120
1121         subscript_start = open_lsqb.next_sibling
1122         if (
1123             isinstance(subscript_start, Node)
1124             and subscript_start.type == syms.subscriptlist
1125         ):
1126             subscript_start = child_towards(subscript_start, leaf)
1127         return subscript_start is not None and any(
1128             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1129         )
1130
1131     def __str__(self) -> str:
1132         """Render the line."""
1133         if not self:
1134             return "\n"
1135
1136         indent = "    " * self.depth
1137         leaves = iter(self.leaves)
1138         first = next(leaves)
1139         res = f"{first.prefix}{indent}{first.value}"
1140         for leaf in leaves:
1141             res += str(leaf)
1142         for _, comment in self.comments:
1143             res += str(comment)
1144         return res + "\n"
1145
1146     def __bool__(self) -> bool:
1147         """Return True if the line has leaves or comments."""
1148         return bool(self.leaves or self.comments)
1149
1150
1151 class UnformattedLines(Line):
1152     """Just like :class:`Line` but stores lines which aren't reformatted."""
1153
1154     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1155         """Just add a new `leaf` to the end of the lines.
1156
1157         The `preformatted` argument is ignored.
1158
1159         Keeps track of indentation `depth`, which is useful when the user
1160         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1161         """
1162         try:
1163             list(generate_comments(leaf))
1164         except FormatOn as f_on:
1165             self.leaves.append(f_on.leaf_from_consumed(leaf))
1166             raise
1167
1168         self.leaves.append(leaf)
1169         if leaf.type == token.INDENT:
1170             self.depth += 1
1171         elif leaf.type == token.DEDENT:
1172             self.depth -= 1
1173
1174     def __str__(self) -> str:
1175         """Render unformatted lines from leaves which were added with `append()`.
1176
1177         `depth` is not used for indentation in this case.
1178         """
1179         if not self:
1180             return "\n"
1181
1182         res = ""
1183         for leaf in self.leaves:
1184             res += str(leaf)
1185         return res
1186
1187     def append_comment(self, comment: Leaf) -> bool:
1188         """Not implemented in this class. Raises `NotImplementedError`."""
1189         raise NotImplementedError("Unformatted lines don't store comments separately.")
1190
1191     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1192         """Does nothing and returns False."""
1193         return False
1194
1195     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1196         """Does nothing and returns False."""
1197         return False
1198
1199
1200 @dataclass
1201 class EmptyLineTracker:
1202     """Provides a stateful method that returns the number of potential extra
1203     empty lines needed before and after the currently processed line.
1204
1205     Note: this tracker works on lines that haven't been split yet.  It assumes
1206     the prefix of the first leaf consists of optional newlines.  Those newlines
1207     are consumed by `maybe_empty_lines()` and included in the computation.
1208     """
1209
1210     is_pyi: bool = False
1211     previous_line: Optional[Line] = None
1212     previous_after: int = 0
1213     previous_defs: List[int] = Factory(list)
1214
1215     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1216         """Return the number of extra empty lines before and after the `current_line`.
1217
1218         This is for separating `def`, `async def` and `class` with extra empty
1219         lines (two on module-level).
1220         """
1221         if isinstance(current_line, UnformattedLines):
1222             return 0, 0
1223
1224         before, after = self._maybe_empty_lines(current_line)
1225         before -= self.previous_after
1226         self.previous_after = after
1227         self.previous_line = current_line
1228         return before, after
1229
1230     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1231         max_allowed = 1
1232         if current_line.depth == 0:
1233             max_allowed = 1 if self.is_pyi else 2
1234         if current_line.leaves:
1235             # Consume the first leaf's extra newlines.
1236             first_leaf = current_line.leaves[0]
1237             before = first_leaf.prefix.count("\n")
1238             before = min(before, max_allowed)
1239             first_leaf.prefix = ""
1240         else:
1241             before = 0
1242         depth = current_line.depth
1243         while self.previous_defs and self.previous_defs[-1] >= depth:
1244             self.previous_defs.pop()
1245             if self.is_pyi:
1246                 before = 0 if depth else 1
1247             else:
1248                 before = 1 if depth else 2
1249         is_decorator = current_line.is_decorator
1250         if is_decorator or current_line.is_def or current_line.is_class:
1251             if not is_decorator:
1252                 self.previous_defs.append(depth)
1253             if self.previous_line is None:
1254                 # Don't insert empty lines before the first line in the file.
1255                 return 0, 0
1256
1257             if self.previous_line.is_decorator:
1258                 return 0, 0
1259
1260             if (
1261                 self.previous_line.is_class
1262                 and self.previous_line.depth != current_line.depth
1263             ):
1264                 return 0, 0
1265
1266             if (
1267                 self.previous_line.is_comment
1268                 and self.previous_line.depth == current_line.depth
1269                 and before == 0
1270             ):
1271                 return 0, 0
1272
1273             if self.is_pyi:
1274                 if self.previous_line.depth > current_line.depth:
1275                     newlines = 1
1276                 elif current_line.is_class or self.previous_line.is_class:
1277                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1278                         newlines = 0
1279                     else:
1280                         newlines = 1
1281                 else:
1282                     newlines = 0
1283             else:
1284                 newlines = 2
1285             if current_line.depth and newlines:
1286                 newlines -= 1
1287             return newlines, 0
1288
1289         if (
1290             self.previous_line
1291             and self.previous_line.is_import
1292             and not current_line.is_import
1293             and depth == self.previous_line.depth
1294         ):
1295             return (before or 1), 0
1296
1297         if (
1298             self.previous_line
1299             and self.previous_line.is_class
1300             and current_line.is_triple_quoted_string
1301         ):
1302             return before, 1
1303
1304         return before, 0
1305
1306
1307 @dataclass
1308 class LineGenerator(Visitor[Line]):
1309     """Generates reformatted Line objects.  Empty lines are not emitted.
1310
1311     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1312     in ways that will no longer stringify to valid Python code on the tree.
1313     """
1314
1315     is_pyi: bool = False
1316     current_line: Line = Factory(Line)
1317     remove_u_prefix: bool = False
1318
1319     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1320         """Generate a line.
1321
1322         If the line is empty, only emit if it makes sense.
1323         If the line is too long, split it first and then generate.
1324
1325         If any lines were generated, set up a new current_line.
1326         """
1327         if not self.current_line:
1328             if self.current_line.__class__ == type:
1329                 self.current_line.depth += indent
1330             else:
1331                 self.current_line = type(depth=self.current_line.depth + indent)
1332             return  # Line is empty, don't emit. Creating a new one unnecessary.
1333
1334         complete_line = self.current_line
1335         self.current_line = type(depth=complete_line.depth + indent)
1336         yield complete_line
1337
1338     def visit(self, node: LN) -> Iterator[Line]:
1339         """Main method to visit `node` and its children.
1340
1341         Yields :class:`Line` objects.
1342         """
1343         if isinstance(self.current_line, UnformattedLines):
1344             # File contained `# fmt: off`
1345             yield from self.visit_unformatted(node)
1346
1347         else:
1348             yield from super().visit(node)
1349
1350     def visit_default(self, node: LN) -> Iterator[Line]:
1351         """Default `visit_*()` implementation. Recurses to children of `node`."""
1352         if isinstance(node, Leaf):
1353             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1354             try:
1355                 for comment in generate_comments(node):
1356                     if any_open_brackets:
1357                         # any comment within brackets is subject to splitting
1358                         self.current_line.append(comment)
1359                     elif comment.type == token.COMMENT:
1360                         # regular trailing comment
1361                         self.current_line.append(comment)
1362                         yield from self.line()
1363
1364                     else:
1365                         # regular standalone comment
1366                         yield from self.line()
1367
1368                         self.current_line.append(comment)
1369                         yield from self.line()
1370
1371             except FormatOff as f_off:
1372                 f_off.trim_prefix(node)
1373                 yield from self.line(type=UnformattedLines)
1374                 yield from self.visit(node)
1375
1376             except FormatOn as f_on:
1377                 # This only happens here if somebody says "fmt: on" multiple
1378                 # times in a row.
1379                 f_on.trim_prefix(node)
1380                 yield from self.visit_default(node)
1381
1382             else:
1383                 normalize_prefix(node, inside_brackets=any_open_brackets)
1384                 if node.type == token.STRING:
1385                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1386                     normalize_string_quotes(node)
1387                 if node.type not in WHITESPACE:
1388                     self.current_line.append(node)
1389         yield from super().visit_default(node)
1390
1391     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1392         """Increase indentation level, maybe yield a line."""
1393         # In blib2to3 INDENT never holds comments.
1394         yield from self.line(+1)
1395         yield from self.visit_default(node)
1396
1397     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1398         """Decrease indentation level, maybe yield a line."""
1399         # The current line might still wait for trailing comments.  At DEDENT time
1400         # there won't be any (they would be prefixes on the preceding NEWLINE).
1401         # Emit the line then.
1402         yield from self.line()
1403
1404         # While DEDENT has no value, its prefix may contain standalone comments
1405         # that belong to the current indentation level.  Get 'em.
1406         yield from self.visit_default(node)
1407
1408         # Finally, emit the dedent.
1409         yield from self.line(-1)
1410
1411     def visit_stmt(
1412         self, node: Node, keywords: Set[str], parens: Set[str]
1413     ) -> Iterator[Line]:
1414         """Visit a statement.
1415
1416         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1417         `def`, `with`, `class`, `assert` and assignments.
1418
1419         The relevant Python language `keywords` for a given statement will be
1420         NAME leaves within it. This methods puts those on a separate line.
1421
1422         `parens` holds a set of string leaf values immediately after which
1423         invisible parens should be put.
1424         """
1425         normalize_invisible_parens(node, parens_after=parens)
1426         for child in node.children:
1427             if child.type == token.NAME and child.value in keywords:  # type: ignore
1428                 yield from self.line()
1429
1430             yield from self.visit(child)
1431
1432     def visit_suite(self, node: Node) -> Iterator[Line]:
1433         """Visit a suite."""
1434         if self.is_pyi and is_stub_suite(node):
1435             yield from self.visit(node.children[2])
1436         else:
1437             yield from self.visit_default(node)
1438
1439     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1440         """Visit a statement without nested statements."""
1441         is_suite_like = node.parent and node.parent.type in STATEMENT
1442         if is_suite_like:
1443             if self.is_pyi and is_stub_body(node):
1444                 yield from self.visit_default(node)
1445             else:
1446                 yield from self.line(+1)
1447                 yield from self.visit_default(node)
1448                 yield from self.line(-1)
1449
1450         else:
1451             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1452                 yield from self.line()
1453             yield from self.visit_default(node)
1454
1455     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1456         """Visit `async def`, `async for`, `async with`."""
1457         yield from self.line()
1458
1459         children = iter(node.children)
1460         for child in children:
1461             yield from self.visit(child)
1462
1463             if child.type == token.ASYNC:
1464                 break
1465
1466         internal_stmt = next(children)
1467         for child in internal_stmt.children:
1468             yield from self.visit(child)
1469
1470     def visit_decorators(self, node: Node) -> Iterator[Line]:
1471         """Visit decorators."""
1472         for child in node.children:
1473             yield from self.line()
1474             yield from self.visit(child)
1475
1476     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1477         """Remove a semicolon and put the other statement on a separate line."""
1478         yield from self.line()
1479
1480     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1481         """End of file. Process outstanding comments and end with a newline."""
1482         yield from self.visit_default(leaf)
1483         yield from self.line()
1484
1485     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1486         """Used when file contained a `# fmt: off`."""
1487         if isinstance(node, Node):
1488             for child in node.children:
1489                 yield from self.visit(child)
1490
1491         else:
1492             try:
1493                 self.current_line.append(node)
1494             except FormatOn as f_on:
1495                 f_on.trim_prefix(node)
1496                 yield from self.line()
1497                 yield from self.visit(node)
1498
1499             if node.type == token.ENDMARKER:
1500                 # somebody decided not to put a final `# fmt: on`
1501                 yield from self.line()
1502
1503     def __attrs_post_init__(self) -> None:
1504         """You are in a twisty little maze of passages."""
1505         v = self.visit_stmt
1506         Ø: Set[str] = set()
1507         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1508         self.visit_if_stmt = partial(
1509             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1510         )
1511         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1512         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1513         self.visit_try_stmt = partial(
1514             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1515         )
1516         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1517         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1518         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1519         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1520         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1521         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1522         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1523         self.visit_async_funcdef = self.visit_async_stmt
1524         self.visit_decorated = self.visit_decorators
1525
1526
1527 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1528 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1529 OPENING_BRACKETS = set(BRACKET.keys())
1530 CLOSING_BRACKETS = set(BRACKET.values())
1531 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1532 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1533
1534
1535 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1536     """Return whitespace prefix if needed for the given `leaf`.
1537
1538     `complex_subscript` signals whether the given leaf is part of a subscription
1539     which has non-trivial arguments, like arithmetic expressions or function calls.
1540     """
1541     NO = ""
1542     SPACE = " "
1543     DOUBLESPACE = "  "
1544     t = leaf.type
1545     p = leaf.parent
1546     v = leaf.value
1547     if t in ALWAYS_NO_SPACE:
1548         return NO
1549
1550     if t == token.COMMENT:
1551         return DOUBLESPACE
1552
1553     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1554     if t == token.COLON and p.type not in {
1555         syms.subscript,
1556         syms.subscriptlist,
1557         syms.sliceop,
1558     }:
1559         return NO
1560
1561     prev = leaf.prev_sibling
1562     if not prev:
1563         prevp = preceding_leaf(p)
1564         if not prevp or prevp.type in OPENING_BRACKETS:
1565             return NO
1566
1567         if t == token.COLON:
1568             if prevp.type == token.COLON:
1569                 return NO
1570
1571             elif prevp.type != token.COMMA and not complex_subscript:
1572                 return NO
1573
1574             return SPACE
1575
1576         if prevp.type == token.EQUAL:
1577             if prevp.parent:
1578                 if prevp.parent.type in {
1579                     syms.arglist,
1580                     syms.argument,
1581                     syms.parameters,
1582                     syms.varargslist,
1583                 }:
1584                     return NO
1585
1586                 elif prevp.parent.type == syms.typedargslist:
1587                     # A bit hacky: if the equal sign has whitespace, it means we
1588                     # previously found it's a typed argument.  So, we're using
1589                     # that, too.
1590                     return prevp.prefix
1591
1592         elif prevp.type in STARS:
1593             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1594                 return NO
1595
1596         elif prevp.type == token.COLON:
1597             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1598                 return SPACE if complex_subscript else NO
1599
1600         elif (
1601             prevp.parent
1602             and prevp.parent.type == syms.factor
1603             and prevp.type in MATH_OPERATORS
1604         ):
1605             return NO
1606
1607         elif (
1608             prevp.type == token.RIGHTSHIFT
1609             and prevp.parent
1610             and prevp.parent.type == syms.shift_expr
1611             and prevp.prev_sibling
1612             and prevp.prev_sibling.type == token.NAME
1613             and prevp.prev_sibling.value == "print"  # type: ignore
1614         ):
1615             # Python 2 print chevron
1616             return NO
1617
1618     elif prev.type in OPENING_BRACKETS:
1619         return NO
1620
1621     if p.type in {syms.parameters, syms.arglist}:
1622         # untyped function signatures or calls
1623         if not prev or prev.type != token.COMMA:
1624             return NO
1625
1626     elif p.type == syms.varargslist:
1627         # lambdas
1628         if prev and prev.type != token.COMMA:
1629             return NO
1630
1631     elif p.type == syms.typedargslist:
1632         # typed function signatures
1633         if not prev:
1634             return NO
1635
1636         if t == token.EQUAL:
1637             if prev.type != syms.tname:
1638                 return NO
1639
1640         elif prev.type == token.EQUAL:
1641             # A bit hacky: if the equal sign has whitespace, it means we
1642             # previously found it's a typed argument.  So, we're using that, too.
1643             return prev.prefix
1644
1645         elif prev.type != token.COMMA:
1646             return NO
1647
1648     elif p.type == syms.tname:
1649         # type names
1650         if not prev:
1651             prevp = preceding_leaf(p)
1652             if not prevp or prevp.type != token.COMMA:
1653                 return NO
1654
1655     elif p.type == syms.trailer:
1656         # attributes and calls
1657         if t == token.LPAR or t == token.RPAR:
1658             return NO
1659
1660         if not prev:
1661             if t == token.DOT:
1662                 prevp = preceding_leaf(p)
1663                 if not prevp or prevp.type != token.NUMBER:
1664                     return NO
1665
1666             elif t == token.LSQB:
1667                 return NO
1668
1669         elif prev.type != token.COMMA:
1670             return NO
1671
1672     elif p.type == syms.argument:
1673         # single argument
1674         if t == token.EQUAL:
1675             return NO
1676
1677         if not prev:
1678             prevp = preceding_leaf(p)
1679             if not prevp or prevp.type == token.LPAR:
1680                 return NO
1681
1682         elif prev.type in {token.EQUAL} | STARS:
1683             return NO
1684
1685     elif p.type == syms.decorator:
1686         # decorators
1687         return NO
1688
1689     elif p.type == syms.dotted_name:
1690         if prev:
1691             return NO
1692
1693         prevp = preceding_leaf(p)
1694         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1695             return NO
1696
1697     elif p.type == syms.classdef:
1698         if t == token.LPAR:
1699             return NO
1700
1701         if prev and prev.type == token.LPAR:
1702             return NO
1703
1704     elif p.type in {syms.subscript, syms.sliceop}:
1705         # indexing
1706         if not prev:
1707             assert p.parent is not None, "subscripts are always parented"
1708             if p.parent.type == syms.subscriptlist:
1709                 return SPACE
1710
1711             return NO
1712
1713         elif not complex_subscript:
1714             return NO
1715
1716     elif p.type == syms.atom:
1717         if prev and t == token.DOT:
1718             # dots, but not the first one.
1719             return NO
1720
1721     elif p.type == syms.dictsetmaker:
1722         # dict unpacking
1723         if prev and prev.type == token.DOUBLESTAR:
1724             return NO
1725
1726     elif p.type in {syms.factor, syms.star_expr}:
1727         # unary ops
1728         if not prev:
1729             prevp = preceding_leaf(p)
1730             if not prevp or prevp.type in OPENING_BRACKETS:
1731                 return NO
1732
1733             prevp_parent = prevp.parent
1734             assert prevp_parent is not None
1735             if prevp.type == token.COLON and prevp_parent.type in {
1736                 syms.subscript,
1737                 syms.sliceop,
1738             }:
1739                 return NO
1740
1741             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1742                 return NO
1743
1744         elif t == token.NAME or t == token.NUMBER:
1745             return NO
1746
1747     elif p.type == syms.import_from:
1748         if t == token.DOT:
1749             if prev and prev.type == token.DOT:
1750                 return NO
1751
1752         elif t == token.NAME:
1753             if v == "import":
1754                 return SPACE
1755
1756             if prev and prev.type == token.DOT:
1757                 return NO
1758
1759     elif p.type == syms.sliceop:
1760         return NO
1761
1762     return SPACE
1763
1764
1765 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1766     """Return the first leaf that precedes `node`, if any."""
1767     while node:
1768         res = node.prev_sibling
1769         if res:
1770             if isinstance(res, Leaf):
1771                 return res
1772
1773             try:
1774                 return list(res.leaves())[-1]
1775
1776             except IndexError:
1777                 return None
1778
1779         node = node.parent
1780     return None
1781
1782
1783 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1784     """Return the child of `ancestor` that contains `descendant`."""
1785     node: Optional[LN] = descendant
1786     while node and node.parent != ancestor:
1787         node = node.parent
1788     return node
1789
1790
1791 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1792     """Return the priority of the `leaf` delimiter, given a line break after it.
1793
1794     The delimiter priorities returned here are from those delimiters that would
1795     cause a line break after themselves.
1796
1797     Higher numbers are higher priority.
1798     """
1799     if leaf.type == token.COMMA:
1800         return COMMA_PRIORITY
1801
1802     return 0
1803
1804
1805 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1806     """Return the priority of the `leaf` delimiter, given a line before after it.
1807
1808     The delimiter priorities returned here are from those delimiters that would
1809     cause a line break before themselves.
1810
1811     Higher numbers are higher priority.
1812     """
1813     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1814         # * and ** might also be MATH_OPERATORS but in this case they are not.
1815         # Don't treat them as a delimiter.
1816         return 0
1817
1818     if (
1819         leaf.type == token.DOT
1820         and leaf.parent
1821         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1822         and (previous is None or previous.type in CLOSING_BRACKETS)
1823     ):
1824         return DOT_PRIORITY
1825
1826     if (
1827         leaf.type in MATH_OPERATORS
1828         and leaf.parent
1829         and leaf.parent.type not in {syms.factor, syms.star_expr}
1830     ):
1831         return MATH_PRIORITIES[leaf.type]
1832
1833     if leaf.type in COMPARATORS:
1834         return COMPARATOR_PRIORITY
1835
1836     if (
1837         leaf.type == token.STRING
1838         and previous is not None
1839         and previous.type == token.STRING
1840     ):
1841         return STRING_PRIORITY
1842
1843     if leaf.type != token.NAME:
1844         return 0
1845
1846     if (
1847         leaf.value == "for"
1848         and leaf.parent
1849         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1850     ):
1851         return COMPREHENSION_PRIORITY
1852
1853     if (
1854         leaf.value == "if"
1855         and leaf.parent
1856         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1857     ):
1858         return COMPREHENSION_PRIORITY
1859
1860     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1861         return TERNARY_PRIORITY
1862
1863     if leaf.value == "is":
1864         return COMPARATOR_PRIORITY
1865
1866     if (
1867         leaf.value == "in"
1868         and leaf.parent
1869         and leaf.parent.type in {syms.comp_op, syms.comparison}
1870         and not (
1871             previous is not None
1872             and previous.type == token.NAME
1873             and previous.value == "not"
1874         )
1875     ):
1876         return COMPARATOR_PRIORITY
1877
1878     if (
1879         leaf.value == "not"
1880         and leaf.parent
1881         and leaf.parent.type == syms.comp_op
1882         and not (
1883             previous is not None
1884             and previous.type == token.NAME
1885             and previous.value == "is"
1886         )
1887     ):
1888         return COMPARATOR_PRIORITY
1889
1890     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1891         return LOGIC_PRIORITY
1892
1893     return 0
1894
1895
1896 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1897     """Clean the prefix of the `leaf` and generate comments from it, if any.
1898
1899     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1900     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1901     move because it does away with modifying the grammar to include all the
1902     possible places in which comments can be placed.
1903
1904     The sad consequence for us though is that comments don't "belong" anywhere.
1905     This is why this function generates simple parentless Leaf objects for
1906     comments.  We simply don't know what the correct parent should be.
1907
1908     No matter though, we can live without this.  We really only need to
1909     differentiate between inline and standalone comments.  The latter don't
1910     share the line with any code.
1911
1912     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1913     are emitted with a fake STANDALONE_COMMENT token identifier.
1914     """
1915     p = leaf.prefix
1916     if not p:
1917         return
1918
1919     if "#" not in p:
1920         return
1921
1922     consumed = 0
1923     nlines = 0
1924     for index, line in enumerate(p.split("\n")):
1925         consumed += len(line) + 1  # adding the length of the split '\n'
1926         line = line.lstrip()
1927         if not line:
1928             nlines += 1
1929         if not line.startswith("#"):
1930             continue
1931
1932         if index == 0 and leaf.type != token.ENDMARKER:
1933             comment_type = token.COMMENT  # simple trailing comment
1934         else:
1935             comment_type = STANDALONE_COMMENT
1936         comment = make_comment(line)
1937         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1938
1939         if comment in {"# fmt: on", "# yapf: enable"}:
1940             raise FormatOn(consumed)
1941
1942         if comment in {"# fmt: off", "# yapf: disable"}:
1943             if comment_type == STANDALONE_COMMENT:
1944                 raise FormatOff(consumed)
1945
1946             prev = preceding_leaf(leaf)
1947             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1948                 raise FormatOff(consumed)
1949
1950         nlines = 0
1951
1952
1953 def make_comment(content: str) -> str:
1954     """Return a consistently formatted comment from the given `content` string.
1955
1956     All comments (except for "##", "#!", "#:") should have a single space between
1957     the hash sign and the content.
1958
1959     If `content` didn't start with a hash sign, one is provided.
1960     """
1961     content = content.rstrip()
1962     if not content:
1963         return "#"
1964
1965     if content[0] == "#":
1966         content = content[1:]
1967     if content and content[0] not in " !:#":
1968         content = " " + content
1969     return "#" + content
1970
1971
1972 def split_line(
1973     line: Line, line_length: int, inner: bool = False, py36: bool = False
1974 ) -> Iterator[Line]:
1975     """Split a `line` into potentially many lines.
1976
1977     They should fit in the allotted `line_length` but might not be able to.
1978     `inner` signifies that there were a pair of brackets somewhere around the
1979     current `line`, possibly transitively. This means we can fallback to splitting
1980     by delimiters if the LHS/RHS don't yield any results.
1981
1982     If `py36` is True, splitting may generate syntax that is only compatible
1983     with Python 3.6 and later.
1984     """
1985     if isinstance(line, UnformattedLines) or line.is_comment:
1986         yield line
1987         return
1988
1989     line_str = str(line).strip("\n")
1990     if not line.should_explode and is_line_short_enough(
1991         line, line_length=line_length, line_str=line_str
1992     ):
1993         yield line
1994         return
1995
1996     split_funcs: List[SplitFunc]
1997     if line.is_def:
1998         split_funcs = [left_hand_split]
1999     else:
2000
2001         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2002             for omit in generate_trailers_to_omit(line, line_length):
2003                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2004                 if is_line_short_enough(lines[0], line_length=line_length):
2005                     yield from lines
2006                     return
2007
2008             # All splits failed, best effort split with no omits.
2009             # This mostly happens to multiline strings that are by definition
2010             # reported as not fitting a single line.
2011             yield from right_hand_split(line, py36)
2012
2013         if line.inside_brackets:
2014             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2015         else:
2016             split_funcs = [rhs]
2017     for split_func in split_funcs:
2018         # We are accumulating lines in `result` because we might want to abort
2019         # mission and return the original line in the end, or attempt a different
2020         # split altogether.
2021         result: List[Line] = []
2022         try:
2023             for l in split_func(line, py36):
2024                 if str(l).strip("\n") == line_str:
2025                     raise CannotSplit("Split function returned an unchanged result")
2026
2027                 result.extend(
2028                     split_line(l, line_length=line_length, inner=True, py36=py36)
2029                 )
2030         except CannotSplit as cs:
2031             continue
2032
2033         else:
2034             yield from result
2035             break
2036
2037     else:
2038         yield line
2039
2040
2041 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2042     """Split line into many lines, starting with the first matching bracket pair.
2043
2044     Note: this usually looks weird, only use this for function definitions.
2045     Prefer RHS otherwise.  This is why this function is not symmetrical with
2046     :func:`right_hand_split` which also handles optional parentheses.
2047     """
2048     head = Line(depth=line.depth)
2049     body = Line(depth=line.depth + 1, inside_brackets=True)
2050     tail = Line(depth=line.depth)
2051     tail_leaves: List[Leaf] = []
2052     body_leaves: List[Leaf] = []
2053     head_leaves: List[Leaf] = []
2054     current_leaves = head_leaves
2055     matching_bracket = None
2056     for leaf in line.leaves:
2057         if (
2058             current_leaves is body_leaves
2059             and leaf.type in CLOSING_BRACKETS
2060             and leaf.opening_bracket is matching_bracket
2061         ):
2062             current_leaves = tail_leaves if body_leaves else head_leaves
2063         current_leaves.append(leaf)
2064         if current_leaves is head_leaves:
2065             if leaf.type in OPENING_BRACKETS:
2066                 matching_bracket = leaf
2067                 current_leaves = body_leaves
2068     # Since body is a new indent level, remove spurious leading whitespace.
2069     if body_leaves:
2070         normalize_prefix(body_leaves[0], inside_brackets=True)
2071     # Build the new lines.
2072     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2073         for leaf in leaves:
2074             result.append(leaf, preformatted=True)
2075             for comment_after in line.comments_after(leaf):
2076                 result.append(comment_after, preformatted=True)
2077     bracket_split_succeeded_or_raise(head, body, tail)
2078     for result in (head, body, tail):
2079         if result:
2080             yield result
2081
2082
2083 def right_hand_split(
2084     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2085 ) -> Iterator[Line]:
2086     """Split line into many lines, starting with the last matching bracket pair.
2087
2088     If the split was by optional parentheses, attempt splitting without them, too.
2089     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2090     this split.
2091
2092     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2093     """
2094     head = Line(depth=line.depth)
2095     body = Line(depth=line.depth + 1, inside_brackets=True)
2096     tail = Line(depth=line.depth)
2097     tail_leaves: List[Leaf] = []
2098     body_leaves: List[Leaf] = []
2099     head_leaves: List[Leaf] = []
2100     current_leaves = tail_leaves
2101     opening_bracket = None
2102     closing_bracket = None
2103     for leaf in reversed(line.leaves):
2104         if current_leaves is body_leaves:
2105             if leaf is opening_bracket:
2106                 current_leaves = head_leaves if body_leaves else tail_leaves
2107         current_leaves.append(leaf)
2108         if current_leaves is tail_leaves:
2109             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2110                 opening_bracket = leaf.opening_bracket
2111                 closing_bracket = leaf
2112                 current_leaves = body_leaves
2113     tail_leaves.reverse()
2114     body_leaves.reverse()
2115     head_leaves.reverse()
2116     # Since body is a new indent level, remove spurious leading whitespace.
2117     if body_leaves:
2118         normalize_prefix(body_leaves[0], inside_brackets=True)
2119     if not head_leaves:
2120         # No `head` means the split failed. Either `tail` has all content or
2121         # the matching `opening_bracket` wasn't available on `line` anymore.
2122         raise CannotSplit("No brackets found")
2123
2124     # Build the new lines.
2125     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2126         for leaf in leaves:
2127             result.append(leaf, preformatted=True)
2128             for comment_after in line.comments_after(leaf):
2129                 result.append(comment_after, preformatted=True)
2130     bracket_split_succeeded_or_raise(head, body, tail)
2131     assert opening_bracket and closing_bracket
2132     if (
2133         # the opening bracket is an optional paren
2134         opening_bracket.type == token.LPAR
2135         and not opening_bracket.value
2136         # the closing bracket is an optional paren
2137         and closing_bracket.type == token.RPAR
2138         and not closing_bracket.value
2139         # there are no standalone comments in the body
2140         and not line.contains_standalone_comments(0)
2141         # and it's not an import (optional parens are the only thing we can split
2142         # on in this case; attempting a split without them is a waste of time)
2143         and not line.is_import
2144     ):
2145         omit = {id(closing_bracket), *omit}
2146         if can_omit_invisible_parens(body, line_length):
2147             try:
2148                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2149                 return
2150             except CannotSplit:
2151                 pass
2152
2153     ensure_visible(opening_bracket)
2154     ensure_visible(closing_bracket)
2155     body.should_explode = should_explode(body, opening_bracket)
2156     for result in (head, body, tail):
2157         if result:
2158             yield result
2159
2160
2161 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2162     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2163
2164     Do nothing otherwise.
2165
2166     A left- or right-hand split is based on a pair of brackets. Content before
2167     (and including) the opening bracket is left on one line, content inside the
2168     brackets is put on a separate line, and finally content starting with and
2169     following the closing bracket is put on a separate line.
2170
2171     Those are called `head`, `body`, and `tail`, respectively. If the split
2172     produced the same line (all content in `head`) or ended up with an empty `body`
2173     and the `tail` is just the closing bracket, then it's considered failed.
2174     """
2175     tail_len = len(str(tail).strip())
2176     if not body:
2177         if tail_len == 0:
2178             raise CannotSplit("Splitting brackets produced the same line")
2179
2180         elif tail_len < 3:
2181             raise CannotSplit(
2182                 f"Splitting brackets on an empty body to save "
2183                 f"{tail_len} characters is not worth it"
2184             )
2185
2186
2187 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2188     """Normalize prefix of the first leaf in every line returned by `split_func`.
2189
2190     This is a decorator over relevant split functions.
2191     """
2192
2193     @wraps(split_func)
2194     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2195         for l in split_func(line, py36):
2196             normalize_prefix(l.leaves[0], inside_brackets=True)
2197             yield l
2198
2199     return split_wrapper
2200
2201
2202 @dont_increase_indentation
2203 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2204     """Split according to delimiters of the highest priority.
2205
2206     If `py36` is True, the split will add trailing commas also in function
2207     signatures that contain `*` and `**`.
2208     """
2209     try:
2210         last_leaf = line.leaves[-1]
2211     except IndexError:
2212         raise CannotSplit("Line empty")
2213
2214     bt = line.bracket_tracker
2215     try:
2216         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2217     except ValueError:
2218         raise CannotSplit("No delimiters found")
2219
2220     if delimiter_priority == DOT_PRIORITY:
2221         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2222             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2223
2224     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2225     lowest_depth = sys.maxsize
2226     trailing_comma_safe = True
2227
2228     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2229         """Append `leaf` to current line or to new line if appending impossible."""
2230         nonlocal current_line
2231         try:
2232             current_line.append_safe(leaf, preformatted=True)
2233         except ValueError as ve:
2234             yield current_line
2235
2236             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2237             current_line.append(leaf)
2238
2239     for index, leaf in enumerate(line.leaves):
2240         yield from append_to_line(leaf)
2241
2242         for comment_after in line.comments_after(leaf, index):
2243             yield from append_to_line(comment_after)
2244
2245         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2246         if leaf.bracket_depth == lowest_depth and is_vararg(
2247             leaf, within=VARARGS_PARENTS
2248         ):
2249             trailing_comma_safe = trailing_comma_safe and py36
2250         leaf_priority = bt.delimiters.get(id(leaf))
2251         if leaf_priority == delimiter_priority:
2252             yield current_line
2253
2254             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2255     if current_line:
2256         if (
2257             trailing_comma_safe
2258             and delimiter_priority == COMMA_PRIORITY
2259             and current_line.leaves[-1].type != token.COMMA
2260             and current_line.leaves[-1].type != STANDALONE_COMMENT
2261         ):
2262             current_line.append(Leaf(token.COMMA, ","))
2263         yield current_line
2264
2265
2266 @dont_increase_indentation
2267 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2268     """Split standalone comments from the rest of the line."""
2269     if not line.contains_standalone_comments(0):
2270         raise CannotSplit("Line does not have any standalone comments")
2271
2272     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2273
2274     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2275         """Append `leaf` to current line or to new line if appending impossible."""
2276         nonlocal current_line
2277         try:
2278             current_line.append_safe(leaf, preformatted=True)
2279         except ValueError as ve:
2280             yield current_line
2281
2282             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2283             current_line.append(leaf)
2284
2285     for index, leaf in enumerate(line.leaves):
2286         yield from append_to_line(leaf)
2287
2288         for comment_after in line.comments_after(leaf, index):
2289             yield from append_to_line(comment_after)
2290
2291     if current_line:
2292         yield current_line
2293
2294
2295 def is_import(leaf: Leaf) -> bool:
2296     """Return True if the given leaf starts an import statement."""
2297     p = leaf.parent
2298     t = leaf.type
2299     v = leaf.value
2300     return bool(
2301         t == token.NAME
2302         and (
2303             (v == "import" and p and p.type == syms.import_name)
2304             or (v == "from" and p and p.type == syms.import_from)
2305         )
2306     )
2307
2308
2309 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2310     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2311     else.
2312
2313     Note: don't use backslashes for formatting or you'll lose your voting rights.
2314     """
2315     if not inside_brackets:
2316         spl = leaf.prefix.split("#")
2317         if "\\" not in spl[0]:
2318             nl_count = spl[-1].count("\n")
2319             if len(spl) > 1:
2320                 nl_count -= 1
2321             leaf.prefix = "\n" * nl_count
2322             return
2323
2324     leaf.prefix = ""
2325
2326
2327 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2328     """Make all string prefixes lowercase.
2329
2330     If remove_u_prefix is given, also removes any u prefix from the string.
2331
2332     Note: Mutates its argument.
2333     """
2334     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2335     assert match is not None, f"failed to match string {leaf.value!r}"
2336     orig_prefix = match.group(1)
2337     new_prefix = orig_prefix.lower()
2338     if remove_u_prefix:
2339         new_prefix = new_prefix.replace("u", "")
2340     leaf.value = f"{new_prefix}{match.group(2)}"
2341
2342
2343 def normalize_string_quotes(leaf: Leaf) -> None:
2344     """Prefer double quotes but only if it doesn't cause more escaping.
2345
2346     Adds or removes backslashes as appropriate. Doesn't parse and fix
2347     strings nested in f-strings (yet).
2348
2349     Note: Mutates its argument.
2350     """
2351     value = leaf.value.lstrip("furbFURB")
2352     if value[:3] == '"""':
2353         return
2354
2355     elif value[:3] == "'''":
2356         orig_quote = "'''"
2357         new_quote = '"""'
2358     elif value[0] == '"':
2359         orig_quote = '"'
2360         new_quote = "'"
2361     else:
2362         orig_quote = "'"
2363         new_quote = '"'
2364     first_quote_pos = leaf.value.find(orig_quote)
2365     if first_quote_pos == -1:
2366         return  # There's an internal error
2367
2368     prefix = leaf.value[:first_quote_pos]
2369     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2370     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2371     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2372     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2373     if "r" in prefix.casefold():
2374         if unescaped_new_quote.search(body):
2375             # There's at least one unescaped new_quote in this raw string
2376             # so converting is impossible
2377             return
2378
2379         # Do not introduce or remove backslashes in raw strings
2380         new_body = body
2381     else:
2382         # remove unnecessary quotes
2383         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2384         if body != new_body:
2385             # Consider the string without unnecessary quotes as the original
2386             body = new_body
2387             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2388         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2389         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2390     if new_quote == '"""' and new_body[-1] == '"':
2391         # edge case:
2392         new_body = new_body[:-1] + '\\"'
2393     orig_escape_count = body.count("\\")
2394     new_escape_count = new_body.count("\\")
2395     if new_escape_count > orig_escape_count:
2396         return  # Do not introduce more escaping
2397
2398     if new_escape_count == orig_escape_count and orig_quote == '"':
2399         return  # Prefer double quotes
2400
2401     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2402
2403
2404 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2405     """Make existing optional parentheses invisible or create new ones.
2406
2407     `parens_after` is a set of string leaf values immeditely after which parens
2408     should be put.
2409
2410     Standardizes on visible parentheses for single-element tuples, and keeps
2411     existing visible parentheses for other tuples and generator expressions.
2412     """
2413     try:
2414         list(generate_comments(node))
2415     except FormatOff:
2416         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2417
2418     check_lpar = False
2419     for index, child in enumerate(list(node.children)):
2420         if check_lpar:
2421             if child.type == syms.atom:
2422                 maybe_make_parens_invisible_in_atom(child)
2423             elif is_one_tuple(child):
2424                 # wrap child in visible parentheses
2425                 lpar = Leaf(token.LPAR, "(")
2426                 rpar = Leaf(token.RPAR, ")")
2427                 child.remove()
2428                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2429             elif node.type == syms.import_from:
2430                 # "import from" nodes store parentheses directly as part of
2431                 # the statement
2432                 if child.type == token.LPAR:
2433                     # make parentheses invisible
2434                     child.value = ""  # type: ignore
2435                     node.children[-1].value = ""  # type: ignore
2436                 elif child.type != token.STAR:
2437                     # insert invisible parentheses
2438                     node.insert_child(index, Leaf(token.LPAR, ""))
2439                     node.append_child(Leaf(token.RPAR, ""))
2440                 break
2441
2442             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2443                 # wrap child in invisible parentheses
2444                 lpar = Leaf(token.LPAR, "")
2445                 rpar = Leaf(token.RPAR, "")
2446                 index = child.remove() or 0
2447                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2448
2449         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2450
2451
2452 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2453     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2454     if (
2455         node.type != syms.atom
2456         or is_empty_tuple(node)
2457         or is_one_tuple(node)
2458         or is_yield(node)
2459         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2460     ):
2461         return False
2462
2463     first = node.children[0]
2464     last = node.children[-1]
2465     if first.type == token.LPAR and last.type == token.RPAR:
2466         # make parentheses invisible
2467         first.value = ""  # type: ignore
2468         last.value = ""  # type: ignore
2469         if len(node.children) > 1:
2470             maybe_make_parens_invisible_in_atom(node.children[1])
2471         return True
2472
2473     return False
2474
2475
2476 def is_empty_tuple(node: LN) -> bool:
2477     """Return True if `node` holds an empty tuple."""
2478     return (
2479         node.type == syms.atom
2480         and len(node.children) == 2
2481         and node.children[0].type == token.LPAR
2482         and node.children[1].type == token.RPAR
2483     )
2484
2485
2486 def is_one_tuple(node: LN) -> bool:
2487     """Return True if `node` holds a tuple with one element, with or without parens."""
2488     if node.type == syms.atom:
2489         if len(node.children) != 3:
2490             return False
2491
2492         lpar, gexp, rpar = node.children
2493         if not (
2494             lpar.type == token.LPAR
2495             and gexp.type == syms.testlist_gexp
2496             and rpar.type == token.RPAR
2497         ):
2498             return False
2499
2500         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2501
2502     return (
2503         node.type in IMPLICIT_TUPLE
2504         and len(node.children) == 2
2505         and node.children[1].type == token.COMMA
2506     )
2507
2508
2509 def is_yield(node: LN) -> bool:
2510     """Return True if `node` holds a `yield` or `yield from` expression."""
2511     if node.type == syms.yield_expr:
2512         return True
2513
2514     if node.type == token.NAME and node.value == "yield":  # type: ignore
2515         return True
2516
2517     if node.type != syms.atom:
2518         return False
2519
2520     if len(node.children) != 3:
2521         return False
2522
2523     lpar, expr, rpar = node.children
2524     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2525         return is_yield(expr)
2526
2527     return False
2528
2529
2530 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2531     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2532
2533     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2534     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2535     extended iterable unpacking (PEP 3132) and additional unpacking
2536     generalizations (PEP 448).
2537     """
2538     if leaf.type not in STARS or not leaf.parent:
2539         return False
2540
2541     p = leaf.parent
2542     if p.type == syms.star_expr:
2543         # Star expressions are also used as assignment targets in extended
2544         # iterable unpacking (PEP 3132).  See what its parent is instead.
2545         if not p.parent:
2546             return False
2547
2548         p = p.parent
2549
2550     return p.type in within
2551
2552
2553 def is_multiline_string(leaf: Leaf) -> bool:
2554     """Return True if `leaf` is a multiline string that actually spans many lines."""
2555     value = leaf.value.lstrip("furbFURB")
2556     return value[:3] in {'"""', "'''"} and "\n" in value
2557
2558
2559 def is_stub_suite(node: Node) -> bool:
2560     """Return True if `node` is a suite with a stub body."""
2561     if (
2562         len(node.children) != 4
2563         or node.children[0].type != token.NEWLINE
2564         or node.children[1].type != token.INDENT
2565         or node.children[3].type != token.DEDENT
2566     ):
2567         return False
2568
2569     return is_stub_body(node.children[2])
2570
2571
2572 def is_stub_body(node: LN) -> bool:
2573     """Return True if `node` is a simple statement containing an ellipsis."""
2574     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2575         return False
2576
2577     if len(node.children) != 2:
2578         return False
2579
2580     child = node.children[0]
2581     return (
2582         child.type == syms.atom
2583         and len(child.children) == 3
2584         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2585     )
2586
2587
2588 def max_delimiter_priority_in_atom(node: LN) -> int:
2589     """Return maximum delimiter priority inside `node`.
2590
2591     This is specific to atoms with contents contained in a pair of parentheses.
2592     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2593     """
2594     if node.type != syms.atom:
2595         return 0
2596
2597     first = node.children[0]
2598     last = node.children[-1]
2599     if not (first.type == token.LPAR and last.type == token.RPAR):
2600         return 0
2601
2602     bt = BracketTracker()
2603     for c in node.children[1:-1]:
2604         if isinstance(c, Leaf):
2605             bt.mark(c)
2606         else:
2607             for leaf in c.leaves():
2608                 bt.mark(leaf)
2609     try:
2610         return bt.max_delimiter_priority()
2611
2612     except ValueError:
2613         return 0
2614
2615
2616 def ensure_visible(leaf: Leaf) -> None:
2617     """Make sure parentheses are visible.
2618
2619     They could be invisible as part of some statements (see
2620     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2621     """
2622     if leaf.type == token.LPAR:
2623         leaf.value = "("
2624     elif leaf.type == token.RPAR:
2625         leaf.value = ")"
2626
2627
2628 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2629     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2630     if not (
2631         opening_bracket.parent
2632         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2633         and opening_bracket.value in "[{("
2634     ):
2635         return False
2636
2637     try:
2638         last_leaf = line.leaves[-1]
2639         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2640         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2641     except (IndexError, ValueError):
2642         return False
2643
2644     return max_priority == COMMA_PRIORITY
2645
2646
2647 def is_python36(node: Node) -> bool:
2648     """Return True if the current file is using Python 3.6+ features.
2649
2650     Currently looking for:
2651     - f-strings; and
2652     - trailing commas after * or ** in function signatures and calls.
2653     """
2654     for n in node.pre_order():
2655         if n.type == token.STRING:
2656             value_head = n.value[:2]  # type: ignore
2657             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2658                 return True
2659
2660         elif (
2661             n.type in {syms.typedargslist, syms.arglist}
2662             and n.children
2663             and n.children[-1].type == token.COMMA
2664         ):
2665             for ch in n.children:
2666                 if ch.type in STARS:
2667                     return True
2668
2669                 if ch.type == syms.argument:
2670                     for argch in ch.children:
2671                         if argch.type in STARS:
2672                             return True
2673
2674     return False
2675
2676
2677 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2678     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2679
2680     Brackets can be omitted if the entire trailer up to and including
2681     a preceding closing bracket fits in one line.
2682
2683     Yielded sets are cumulative (contain results of previous yields, too).  First
2684     set is empty.
2685     """
2686
2687     omit: Set[LeafID] = set()
2688     yield omit
2689
2690     length = 4 * line.depth
2691     opening_bracket = None
2692     closing_bracket = None
2693     optional_brackets: Set[LeafID] = set()
2694     inner_brackets: Set[LeafID] = set()
2695     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2696         length += leaf_length
2697         if length > line_length:
2698             break
2699
2700         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2701         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2702             break
2703
2704         optional_brackets.discard(id(leaf))
2705         if opening_bracket:
2706             if leaf is opening_bracket:
2707                 opening_bracket = None
2708             elif leaf.type in CLOSING_BRACKETS:
2709                 inner_brackets.add(id(leaf))
2710         elif leaf.type in CLOSING_BRACKETS:
2711             if not leaf.value:
2712                 optional_brackets.add(id(opening_bracket))
2713                 continue
2714
2715             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2716                 # Empty brackets would fail a split so treat them as "inner"
2717                 # brackets (e.g. only add them to the `omit` set if another
2718                 # pair of brackets was good enough.
2719                 inner_brackets.add(id(leaf))
2720                 continue
2721
2722             opening_bracket = leaf.opening_bracket
2723             if closing_bracket:
2724                 omit.add(id(closing_bracket))
2725                 omit.update(inner_brackets)
2726                 inner_brackets.clear()
2727                 yield omit
2728             closing_bracket = leaf
2729
2730
2731 def get_future_imports(node: Node) -> Set[str]:
2732     """Return a set of __future__ imports in the file."""
2733     imports = set()
2734     for child in node.children:
2735         if child.type != syms.simple_stmt:
2736             break
2737         first_child = child.children[0]
2738         if isinstance(first_child, Leaf):
2739             # Continue looking if we see a docstring; otherwise stop.
2740             if (
2741                 len(child.children) == 2
2742                 and first_child.type == token.STRING
2743                 and child.children[1].type == token.NEWLINE
2744             ):
2745                 continue
2746             else:
2747                 break
2748         elif first_child.type == syms.import_from:
2749             module_name = first_child.children[1]
2750             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2751                 break
2752             for import_from_child in first_child.children[3:]:
2753                 if isinstance(import_from_child, Leaf):
2754                     if import_from_child.type == token.NAME:
2755                         imports.add(import_from_child.value)
2756                 else:
2757                     assert import_from_child.type == syms.import_as_names
2758                     for leaf in import_from_child.children:
2759                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2760                             imports.add(leaf.value)
2761         else:
2762             break
2763     return imports
2764
2765
2766 PYTHON_EXTENSIONS = {".py", ".pyi"}
2767 BLACKLISTED_DIRECTORIES = {
2768     "build",
2769     "buck-out",
2770     "dist",
2771     "_build",
2772     ".git",
2773     ".hg",
2774     ".mypy_cache",
2775     ".tox",
2776     ".venv",
2777 }
2778
2779
2780 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2781     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2782     and have one of the PYTHON_EXTENSIONS.
2783     """
2784     for child in path.iterdir():
2785         if child.is_dir():
2786             if child.name in BLACKLISTED_DIRECTORIES:
2787                 continue
2788
2789             yield from gen_python_files_in_dir(child)
2790
2791         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2792             yield child
2793
2794
2795 @dataclass
2796 class Report:
2797     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2798
2799     check: bool = False
2800     quiet: bool = False
2801     change_count: int = 0
2802     same_count: int = 0
2803     failure_count: int = 0
2804
2805     def done(self, src: Path, changed: Changed) -> None:
2806         """Increment the counter for successful reformatting. Write out a message."""
2807         if changed is Changed.YES:
2808             reformatted = "would reformat" if self.check else "reformatted"
2809             if not self.quiet:
2810                 out(f"{reformatted} {src}")
2811             self.change_count += 1
2812         else:
2813             if not self.quiet:
2814                 if changed is Changed.NO:
2815                     msg = f"{src} already well formatted, good job."
2816                 else:
2817                     msg = f"{src} wasn't modified on disk since last run."
2818                 out(msg, bold=False)
2819             self.same_count += 1
2820
2821     def failed(self, src: Path, message: str) -> None:
2822         """Increment the counter for failed reformatting. Write out a message."""
2823         err(f"error: cannot format {src}: {message}")
2824         self.failure_count += 1
2825
2826     @property
2827     def return_code(self) -> int:
2828         """Return the exit code that the app should use.
2829
2830         This considers the current state of changed files and failures:
2831         - if there were any failures, return 123;
2832         - if any files were changed and --check is being used, return 1;
2833         - otherwise return 0.
2834         """
2835         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2836         # 126 we have special returncodes reserved by the shell.
2837         if self.failure_count:
2838             return 123
2839
2840         elif self.change_count and self.check:
2841             return 1
2842
2843         return 0
2844
2845     def __str__(self) -> str:
2846         """Render a color report of the current state.
2847
2848         Use `click.unstyle` to remove colors.
2849         """
2850         if self.check:
2851             reformatted = "would be reformatted"
2852             unchanged = "would be left unchanged"
2853             failed = "would fail to reformat"
2854         else:
2855             reformatted = "reformatted"
2856             unchanged = "left unchanged"
2857             failed = "failed to reformat"
2858         report = []
2859         if self.change_count:
2860             s = "s" if self.change_count > 1 else ""
2861             report.append(
2862                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2863             )
2864         if self.same_count:
2865             s = "s" if self.same_count > 1 else ""
2866             report.append(f"{self.same_count} file{s} {unchanged}")
2867         if self.failure_count:
2868             s = "s" if self.failure_count > 1 else ""
2869             report.append(
2870                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2871             )
2872         return ", ".join(report) + "."
2873
2874
2875 def assert_equivalent(src: str, dst: str) -> None:
2876     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2877
2878     import ast
2879     import traceback
2880
2881     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2882         """Simple visitor generating strings to compare ASTs by content."""
2883         yield f"{'  ' * depth}{node.__class__.__name__}("
2884
2885         for field in sorted(node._fields):
2886             try:
2887                 value = getattr(node, field)
2888             except AttributeError:
2889                 continue
2890
2891             yield f"{'  ' * (depth+1)}{field}="
2892
2893             if isinstance(value, list):
2894                 for item in value:
2895                     if isinstance(item, ast.AST):
2896                         yield from _v(item, depth + 2)
2897
2898             elif isinstance(value, ast.AST):
2899                 yield from _v(value, depth + 2)
2900
2901             else:
2902                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2903
2904         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2905
2906     try:
2907         src_ast = ast.parse(src)
2908     except Exception as exc:
2909         major, minor = sys.version_info[:2]
2910         raise AssertionError(
2911             f"cannot use --safe with this file; failed to parse source file "
2912             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2913             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2914         )
2915
2916     try:
2917         dst_ast = ast.parse(dst)
2918     except Exception as exc:
2919         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2920         raise AssertionError(
2921             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2922             f"Please report a bug on https://github.com/ambv/black/issues.  "
2923             f"This invalid output might be helpful: {log}"
2924         ) from None
2925
2926     src_ast_str = "\n".join(_v(src_ast))
2927     dst_ast_str = "\n".join(_v(dst_ast))
2928     if src_ast_str != dst_ast_str:
2929         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2930         raise AssertionError(
2931             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2932             f"the source.  "
2933             f"Please report a bug on https://github.com/ambv/black/issues.  "
2934             f"This diff might be helpful: {log}"
2935         ) from None
2936
2937
2938 def assert_stable(
2939     src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
2940 ) -> None:
2941     """Raise AssertionError if `dst` reformats differently the second time."""
2942     newdst = format_str(
2943         dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
2944     )
2945     if dst != newdst:
2946         log = dump_to_file(
2947             diff(src, dst, "source", "first pass"),
2948             diff(dst, newdst, "first pass", "second pass"),
2949         )
2950         raise AssertionError(
2951             f"INTERNAL ERROR: Black produced different code on the second pass "
2952             f"of the formatter.  "
2953             f"Please report a bug on https://github.com/ambv/black/issues.  "
2954             f"This diff might be helpful: {log}"
2955         ) from None
2956
2957
2958 def dump_to_file(*output: str) -> str:
2959     """Dump `output` to a temporary file. Return path to the file."""
2960     import tempfile
2961
2962     with tempfile.NamedTemporaryFile(
2963         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2964     ) as f:
2965         for lines in output:
2966             f.write(lines)
2967             if lines and lines[-1] != "\n":
2968                 f.write("\n")
2969     return f.name
2970
2971
2972 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2973     """Return a unified diff string between strings `a` and `b`."""
2974     import difflib
2975
2976     a_lines = [line + "\n" for line in a.split("\n")]
2977     b_lines = [line + "\n" for line in b.split("\n")]
2978     return "".join(
2979         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2980     )
2981
2982
2983 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2984     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2985     err("Aborted!")
2986     for task in tasks:
2987         task.cancel()
2988
2989
2990 def shutdown(loop: BaseEventLoop) -> None:
2991     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2992     try:
2993         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2994         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2995         if not to_cancel:
2996             return
2997
2998         for task in to_cancel:
2999             task.cancel()
3000         loop.run_until_complete(
3001             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3002         )
3003     finally:
3004         # `concurrent.futures.Future` objects cannot be cancelled once they
3005         # are already running. There might be some when the `shutdown()` happened.
3006         # Silence their logger's spew about the event loop being closed.
3007         cf_logger = logging.getLogger("concurrent.futures")
3008         cf_logger.setLevel(logging.CRITICAL)
3009         loop.close()
3010
3011
3012 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3013     """Replace `regex` with `replacement` twice on `original`.
3014
3015     This is used by string normalization to perform replaces on
3016     overlapping matches.
3017     """
3018     return regex.sub(replacement, regex.sub(replacement, original))
3019
3020
3021 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3022     """Like `reversed(enumerate(sequence))` if that were possible."""
3023     index = len(sequence) - 1
3024     for element in reversed(sequence):
3025         yield (index, element)
3026         index -= 1
3027
3028
3029 def enumerate_with_length(
3030     line: Line, reversed: bool = False
3031 ) -> Iterator[Tuple[Index, Leaf, int]]:
3032     """Return an enumeration of leaves with their length.
3033
3034     Stops prematurely on multiline strings and standalone comments.
3035     """
3036     op = cast(
3037         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3038         enumerate_reversed if reversed else enumerate,
3039     )
3040     for index, leaf in op(line.leaves):
3041         length = len(leaf.prefix) + len(leaf.value)
3042         if "\n" in leaf.value:
3043             return  # Multiline strings, we can't continue.
3044
3045         comment: Optional[Leaf]
3046         for comment in line.comments_after(leaf, index):
3047             length += len(comment.value)
3048
3049         yield index, leaf, length
3050
3051
3052 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3053     """Return True if `line` is no longer than `line_length`.
3054
3055     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3056     """
3057     if not line_str:
3058         line_str = str(line).strip("\n")
3059     return (
3060         len(line_str) <= line_length
3061         and "\n" not in line_str  # multiline strings
3062         and not line.contains_standalone_comments()
3063     )
3064
3065
3066 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3067     """Does `line` have a shape safe to reformat without optional parens around it?
3068
3069     Returns True for only a subset of potentially nice looking formattings but
3070     the point is to not return false positives that end up producing lines that
3071     are too long.
3072     """
3073     bt = line.bracket_tracker
3074     if not bt.delimiters:
3075         # Without delimiters the optional parentheses are useless.
3076         return True
3077
3078     max_priority = bt.max_delimiter_priority()
3079     if bt.delimiter_count_with_priority(max_priority) > 1:
3080         # With more than one delimiter of a kind the optional parentheses read better.
3081         return False
3082
3083     if max_priority == DOT_PRIORITY:
3084         # A single stranded method call doesn't require optional parentheses.
3085         return True
3086
3087     assert len(line.leaves) >= 2, "Stranded delimiter"
3088
3089     first = line.leaves[0]
3090     second = line.leaves[1]
3091     penultimate = line.leaves[-2]
3092     last = line.leaves[-1]
3093
3094     # With a single delimiter, omit if the expression starts or ends with
3095     # a bracket.
3096     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3097         remainder = False
3098         length = 4 * line.depth
3099         for _index, leaf, leaf_length in enumerate_with_length(line):
3100             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3101                 remainder = True
3102             if remainder:
3103                 length += leaf_length
3104                 if length > line_length:
3105                     break
3106
3107                 if leaf.type in OPENING_BRACKETS:
3108                     # There are brackets we can further split on.
3109                     remainder = False
3110
3111         else:
3112             # checked the entire string and line length wasn't exceeded
3113             if len(line.leaves) == _index + 1:
3114                 return True
3115
3116         # Note: we are not returning False here because a line might have *both*
3117         # a leading opening bracket and a trailing closing bracket.  If the
3118         # opening bracket doesn't match our rule, maybe the closing will.
3119
3120     if (
3121         last.type == token.RPAR
3122         or last.type == token.RBRACE
3123         or (
3124             # don't use indexing for omitting optional parentheses;
3125             # it looks weird
3126             last.type == token.RSQB
3127             and last.parent
3128             and last.parent.type != syms.trailer
3129         )
3130     ):
3131         if penultimate.type in OPENING_BRACKETS:
3132             # Empty brackets don't help.
3133             return False
3134
3135         if is_multiline_string(first):
3136             # Additional wrapping of a multiline string in this situation is
3137             # unnecessary.
3138             return True
3139
3140         length = 4 * line.depth
3141         seen_other_brackets = False
3142         for _index, leaf, leaf_length in enumerate_with_length(line):
3143             length += leaf_length
3144             if leaf is last.opening_bracket:
3145                 if seen_other_brackets or length <= line_length:
3146                     return True
3147
3148             elif leaf.type in OPENING_BRACKETS:
3149                 # There are brackets we can further split on.
3150                 seen_other_brackets = True
3151
3152     return False
3153
3154
3155 def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
3156     return (
3157         CACHE_DIR
3158         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3159     )
3160
3161
3162 def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
3163     """Read the cache if it exists and is well formed.
3164
3165     If it is not well formed, the call to write_cache later should resolve the issue.
3166     """
3167     cache_file = get_cache_file(line_length, pyi, py36)
3168     if not cache_file.exists():
3169         return {}
3170
3171     with cache_file.open("rb") as fobj:
3172         try:
3173             cache: Cache = pickle.load(fobj)
3174         except pickle.UnpicklingError:
3175             return {}
3176
3177     return cache
3178
3179
3180 def get_cache_info(path: Path) -> CacheInfo:
3181     """Return the information used to check if a file is already formatted or not."""
3182     stat = path.stat()
3183     return stat.st_mtime, stat.st_size
3184
3185
3186 def filter_cached(
3187     cache: Cache, sources: Iterable[Path]
3188 ) -> Tuple[List[Path], List[Path]]:
3189     """Split a list of paths into two.
3190
3191     The first list contains paths of files that modified on disk or are not in the
3192     cache. The other list contains paths to non-modified files.
3193     """
3194     todo, done = [], []
3195     for src in sources:
3196         src = src.resolve()
3197         if cache.get(src) != get_cache_info(src):
3198             todo.append(src)
3199         else:
3200             done.append(src)
3201     return todo, done
3202
3203
3204 def write_cache(
3205     cache: Cache,
3206     sources: List[Path],
3207     line_length: int,
3208     pyi: bool = False,
3209     py36: bool = False,
3210 ) -> None:
3211     """Update the cache file."""
3212     cache_file = get_cache_file(line_length, pyi, py36)
3213     try:
3214         if not CACHE_DIR.exists():
3215             CACHE_DIR.mkdir(parents=True)
3216         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3217         with cache_file.open("wb") as fobj:
3218             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3219     except OSError:
3220         pass
3221
3222
3223 if __name__ == "__main__":
3224     main()