]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Mention fix for #196 in the README
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b0"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 @click.command()
126 @click.option(
127     "-l",
128     "--line-length",
129     type=int,
130     default=DEFAULT_LINE_LENGTH,
131     help="How many character per line to allow.",
132     show_default=True,
133 )
134 @click.option(
135     "--check",
136     is_flag=True,
137     help=(
138         "Don't write the files back, just return the status.  Return code 0 "
139         "means nothing would change.  Return code 1 means some files would be "
140         "reformatted.  Return code 123 means there was an internal error."
141     ),
142 )
143 @click.option(
144     "--diff",
145     is_flag=True,
146     help="Don't write the files back, just output a diff for each file on stdout.",
147 )
148 @click.option(
149     "--fast/--safe",
150     is_flag=True,
151     help="If --fast given, skip temporary sanity checks. [default: --safe]",
152 )
153 @click.option(
154     "-q",
155     "--quiet",
156     is_flag=True,
157     help=(
158         "Don't emit non-error messages to stderr. Errors are still emitted, "
159         "silence those with 2>/dev/null."
160     ),
161 )
162 @click.option(
163     "--pyi",
164     is_flag=True,
165     help=(
166         "Consider all input files typing stubs regardless of file extension "
167         "(useful when piping source on standard input)."
168     ),
169 )
170 @click.option(
171     "--py36",
172     is_flag=True,
173     help=(
174         "Allow using Python 3.6-only syntax on all input files.  This will put "
175         "trailing commas in function signatures and calls also after *args and "
176         "**kwargs.  [default: per-file auto-detection]"
177     ),
178 )
179 @click.version_option(version=__version__)
180 @click.argument(
181     "src",
182     nargs=-1,
183     type=click.Path(
184         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
185     ),
186 )
187 @click.pass_context
188 def main(
189     ctx: click.Context,
190     line_length: int,
191     check: bool,
192     diff: bool,
193     fast: bool,
194     pyi: bool,
195     py36: bool,
196     quiet: bool,
197     src: List[str],
198 ) -> None:
199     """The uncompromising code formatter."""
200     sources: List[Path] = []
201     for s in src:
202         p = Path(s)
203         if p.is_dir():
204             sources.extend(gen_python_files_in_dir(p))
205         elif p.is_file():
206             # if a file was explicitly given, we don't care about its extension
207             sources.append(p)
208         elif s == "-":
209             sources.append(Path("-"))
210         else:
211             err(f"invalid path: {s}")
212
213     if check and not diff:
214         write_back = WriteBack.NO
215     elif diff:
216         write_back = WriteBack.DIFF
217     else:
218         write_back = WriteBack.YES
219     report = Report(check=check, quiet=quiet)
220     if len(sources) == 0:
221         out("No paths given. Nothing to do 😴")
222         ctx.exit(0)
223         return
224
225     elif len(sources) == 1:
226         reformat_one(
227             src=sources[0],
228             line_length=line_length,
229             fast=fast,
230             pyi=pyi,
231             py36=py36,
232             write_back=write_back,
233             report=report,
234         )
235     else:
236         loop = asyncio.get_event_loop()
237         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
238         try:
239             loop.run_until_complete(
240                 schedule_formatting(
241                     sources=sources,
242                     line_length=line_length,
243                     fast=fast,
244                     pyi=pyi,
245                     py36=py36,
246                     write_back=write_back,
247                     report=report,
248                     loop=loop,
249                     executor=executor,
250                 )
251             )
252         finally:
253             shutdown(loop)
254         if not quiet:
255             out("All done! ✨ 🍰 ✨")
256             click.echo(str(report))
257     ctx.exit(report.return_code)
258
259
260 def reformat_one(
261     src: Path,
262     line_length: int,
263     fast: bool,
264     pyi: bool,
265     py36: bool,
266     write_back: WriteBack,
267     report: "Report",
268 ) -> None:
269     """Reformat a single file under `src` without spawning child processes.
270
271     If `quiet` is True, non-error messages are not output. `line_length`,
272     `write_back`, `fast` and `pyi` options are passed to
273     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
274     """
275     try:
276         changed = Changed.NO
277         if not src.is_file() and str(src) == "-":
278             if format_stdin_to_stdout(
279                 line_length=line_length,
280                 fast=fast,
281                 is_pyi=pyi,
282                 force_py36=py36,
283                 write_back=write_back,
284             ):
285                 changed = Changed.YES
286         else:
287             cache: Cache = {}
288             if write_back != WriteBack.DIFF:
289                 cache = read_cache(line_length, pyi, py36)
290                 src = src.resolve()
291                 if src in cache and cache[src] == get_cache_info(src):
292                     changed = Changed.CACHED
293             if changed is not Changed.CACHED and format_file_in_place(
294                 src,
295                 line_length=line_length,
296                 fast=fast,
297                 force_pyi=pyi,
298                 force_py36=py36,
299                 write_back=write_back,
300             ):
301                 changed = Changed.YES
302             if write_back == WriteBack.YES and changed is not Changed.NO:
303                 write_cache(cache, [src], line_length, pyi, py36)
304         report.done(src, changed)
305     except Exception as exc:
306         report.failed(src, str(exc))
307
308
309 async def schedule_formatting(
310     sources: List[Path],
311     line_length: int,
312     fast: bool,
313     pyi: bool,
314     py36: bool,
315     write_back: WriteBack,
316     report: "Report",
317     loop: BaseEventLoop,
318     executor: Executor,
319 ) -> None:
320     """Run formatting of `sources` in parallel using the provided `executor`.
321
322     (Use ProcessPoolExecutors for actual parallelism.)
323
324     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
325     :func:`format_file_in_place`.
326     """
327     cache: Cache = {}
328     if write_back != WriteBack.DIFF:
329         cache = read_cache(line_length, pyi, py36)
330         sources, cached = filter_cached(cache, sources)
331         for src in cached:
332             report.done(src, Changed.CACHED)
333     cancelled = []
334     formatted = []
335     if sources:
336         lock = None
337         if write_back == WriteBack.DIFF:
338             # For diff output, we need locks to ensure we don't interleave output
339             # from different processes.
340             manager = Manager()
341             lock = manager.Lock()
342         tasks = {
343             loop.run_in_executor(
344                 executor,
345                 format_file_in_place,
346                 src,
347                 line_length,
348                 fast,
349                 pyi,
350                 py36,
351                 write_back,
352                 lock,
353             ): src
354             for src in sorted(sources)
355         }
356         pending: Iterable[asyncio.Task] = tasks.keys()
357         try:
358             loop.add_signal_handler(signal.SIGINT, cancel, pending)
359             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
360         except NotImplementedError:
361             # There are no good alternatives for these on Windows
362             pass
363         while pending:
364             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
365             for task in done:
366                 src = tasks.pop(task)
367                 if task.cancelled():
368                     cancelled.append(task)
369                 elif task.exception():
370                     report.failed(src, str(task.exception()))
371                 else:
372                     formatted.append(src)
373                     report.done(src, Changed.YES if task.result() else Changed.NO)
374     if cancelled:
375         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
376     if write_back == WriteBack.YES and formatted:
377         write_cache(cache, formatted, line_length, pyi, py36)
378
379
380 def format_file_in_place(
381     src: Path,
382     line_length: int,
383     fast: bool,
384     force_pyi: bool = False,
385     force_py36: bool = False,
386     write_back: WriteBack = WriteBack.NO,
387     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
388 ) -> bool:
389     """Format file under `src` path. Return True if changed.
390
391     If `write_back` is True, write reformatted code back to stdout.
392     `line_length` and `fast` options are passed to :func:`format_file_contents`.
393     """
394     is_pyi = force_pyi or src.suffix == ".pyi"
395
396     with tokenize.open(src) as src_buffer:
397         src_contents = src_buffer.read()
398     try:
399         dst_contents = format_file_contents(
400             src_contents,
401             line_length=line_length,
402             fast=fast,
403             is_pyi=is_pyi,
404             force_py36=force_py36,
405         )
406     except NothingChanged:
407         return False
408
409     if write_back == write_back.YES:
410         with open(src, "w", encoding=src_buffer.encoding) as f:
411             f.write(dst_contents)
412     elif write_back == write_back.DIFF:
413         src_name = f"{src}  (original)"
414         dst_name = f"{src}  (formatted)"
415         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
416         if lock:
417             lock.acquire()
418         try:
419             sys.stdout.write(diff_contents)
420         finally:
421             if lock:
422                 lock.release()
423     return True
424
425
426 def format_stdin_to_stdout(
427     line_length: int,
428     fast: bool,
429     is_pyi: bool = False,
430     force_py36: bool = False,
431     write_back: WriteBack = WriteBack.NO,
432 ) -> bool:
433     """Format file on stdin. Return True if changed.
434
435     If `write_back` is True, write reformatted code back to stdout.
436     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
437     :func:`format_file_contents`.
438     """
439     src = sys.stdin.read()
440     dst = src
441     try:
442         dst = format_file_contents(
443             src,
444             line_length=line_length,
445             fast=fast,
446             is_pyi=is_pyi,
447             force_py36=force_py36,
448         )
449         return True
450
451     except NothingChanged:
452         return False
453
454     finally:
455         if write_back == WriteBack.YES:
456             sys.stdout.write(dst)
457         elif write_back == WriteBack.DIFF:
458             src_name = "<stdin>  (original)"
459             dst_name = "<stdin>  (formatted)"
460             sys.stdout.write(diff(src, dst, src_name, dst_name))
461
462
463 def format_file_contents(
464     src_contents: str,
465     *,
466     line_length: int,
467     fast: bool,
468     is_pyi: bool = False,
469     force_py36: bool = False,
470 ) -> FileContent:
471     """Reformat contents a file and return new contents.
472
473     If `fast` is False, additionally confirm that the reformatted code is
474     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
475     `line_length` is passed to :func:`format_str`.
476     """
477     if src_contents.strip() == "":
478         raise NothingChanged
479
480     dst_contents = format_str(
481         src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
482     )
483     if src_contents == dst_contents:
484         raise NothingChanged
485
486     if not fast:
487         assert_equivalent(src_contents, dst_contents)
488         assert_stable(
489             src_contents,
490             dst_contents,
491             line_length=line_length,
492             is_pyi=is_pyi,
493             force_py36=force_py36,
494         )
495     return dst_contents
496
497
498 def format_str(
499     src_contents: str,
500     line_length: int,
501     *,
502     is_pyi: bool = False,
503     force_py36: bool = False,
504 ) -> FileContent:
505     """Reformat a string and return new contents.
506
507     `line_length` determines how many characters per line are allowed.
508     """
509     src_node = lib2to3_parse(src_contents)
510     dst_contents = ""
511     future_imports = get_future_imports(src_node)
512     elt = EmptyLineTracker(is_pyi=is_pyi)
513     py36 = force_py36 or is_python36(src_node)
514     lines = LineGenerator(
515         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
516     )
517     empty_line = Line()
518     after = 0
519     for current_line in lines.visit(src_node):
520         for _ in range(after):
521             dst_contents += str(empty_line)
522         before, after = elt.maybe_empty_lines(current_line)
523         for _ in range(before):
524             dst_contents += str(empty_line)
525         for line in split_line(current_line, line_length=line_length, py36=py36):
526             dst_contents += str(line)
527     return dst_contents
528
529
530 GRAMMARS = [
531     pygram.python_grammar_no_print_statement_no_exec_statement,
532     pygram.python_grammar_no_print_statement,
533     pygram.python_grammar,
534 ]
535
536
537 def lib2to3_parse(src_txt: str) -> Node:
538     """Given a string with source, return the lib2to3 Node."""
539     grammar = pygram.python_grammar_no_print_statement
540     if src_txt[-1] != "\n":
541         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
542         src_txt += nl
543     for grammar in GRAMMARS:
544         drv = driver.Driver(grammar, pytree.convert)
545         try:
546             result = drv.parse_string(src_txt, True)
547             break
548
549         except ParseError as pe:
550             lineno, column = pe.context[1]
551             lines = src_txt.splitlines()
552             try:
553                 faulty_line = lines[lineno - 1]
554             except IndexError:
555                 faulty_line = "<line number missing in source>"
556             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
557     else:
558         raise exc from None
559
560     if isinstance(result, Leaf):
561         result = Node(syms.file_input, [result])
562     return result
563
564
565 def lib2to3_unparse(node: Node) -> str:
566     """Given a lib2to3 node, return its string representation."""
567     code = str(node)
568     return code
569
570
571 T = TypeVar("T")
572
573
574 class Visitor(Generic[T]):
575     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
576
577     def visit(self, node: LN) -> Iterator[T]:
578         """Main method to visit `node` and its children.
579
580         It tries to find a `visit_*()` method for the given `node.type`, like
581         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
582         If no dedicated `visit_*()` method is found, chooses `visit_default()`
583         instead.
584
585         Then yields objects of type `T` from the selected visitor.
586         """
587         if node.type < 256:
588             name = token.tok_name[node.type]
589         else:
590             name = type_repr(node.type)
591         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
592
593     def visit_default(self, node: LN) -> Iterator[T]:
594         """Default `visit_*()` implementation. Recurses to children of `node`."""
595         if isinstance(node, Node):
596             for child in node.children:
597                 yield from self.visit(child)
598
599
600 @dataclass
601 class DebugVisitor(Visitor[T]):
602     tree_depth: int = 0
603
604     def visit_default(self, node: LN) -> Iterator[T]:
605         indent = " " * (2 * self.tree_depth)
606         if isinstance(node, Node):
607             _type = type_repr(node.type)
608             out(f"{indent}{_type}", fg="yellow")
609             self.tree_depth += 1
610             for child in node.children:
611                 yield from self.visit(child)
612
613             self.tree_depth -= 1
614             out(f"{indent}/{_type}", fg="yellow", bold=False)
615         else:
616             _type = token.tok_name.get(node.type, str(node.type))
617             out(f"{indent}{_type}", fg="blue", nl=False)
618             if node.prefix:
619                 # We don't have to handle prefixes for `Node` objects since
620                 # that delegates to the first child anyway.
621                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
622             out(f" {node.value!r}", fg="blue", bold=False)
623
624     @classmethod
625     def show(cls, code: str) -> None:
626         """Pretty-print the lib2to3 AST of a given string of `code`.
627
628         Convenience method for debugging.
629         """
630         v: DebugVisitor[None] = DebugVisitor()
631         list(v.visit(lib2to3_parse(code)))
632
633
634 KEYWORDS = set(keyword.kwlist)
635 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
636 FLOW_CONTROL = {"return", "raise", "break", "continue"}
637 STATEMENT = {
638     syms.if_stmt,
639     syms.while_stmt,
640     syms.for_stmt,
641     syms.try_stmt,
642     syms.except_clause,
643     syms.with_stmt,
644     syms.funcdef,
645     syms.classdef,
646 }
647 STANDALONE_COMMENT = 153
648 LOGIC_OPERATORS = {"and", "or"}
649 COMPARATORS = {
650     token.LESS,
651     token.GREATER,
652     token.EQEQUAL,
653     token.NOTEQUAL,
654     token.LESSEQUAL,
655     token.GREATEREQUAL,
656 }
657 MATH_OPERATORS = {
658     token.VBAR,
659     token.CIRCUMFLEX,
660     token.AMPER,
661     token.LEFTSHIFT,
662     token.RIGHTSHIFT,
663     token.PLUS,
664     token.MINUS,
665     token.STAR,
666     token.SLASH,
667     token.DOUBLESLASH,
668     token.PERCENT,
669     token.AT,
670     token.TILDE,
671     token.DOUBLESTAR,
672 }
673 STARS = {token.STAR, token.DOUBLESTAR}
674 VARARGS_PARENTS = {
675     syms.arglist,
676     syms.argument,  # double star in arglist
677     syms.trailer,  # single argument to call
678     syms.typedargslist,
679     syms.varargslist,  # lambdas
680 }
681 UNPACKING_PARENTS = {
682     syms.atom,  # single element of a list or set literal
683     syms.dictsetmaker,
684     syms.listmaker,
685     syms.testlist_gexp,
686 }
687 TEST_DESCENDANTS = {
688     syms.test,
689     syms.lambdef,
690     syms.or_test,
691     syms.and_test,
692     syms.not_test,
693     syms.comparison,
694     syms.star_expr,
695     syms.expr,
696     syms.xor_expr,
697     syms.and_expr,
698     syms.shift_expr,
699     syms.arith_expr,
700     syms.trailer,
701     syms.term,
702     syms.power,
703 }
704 ASSIGNMENTS = {
705     "=",
706     "+=",
707     "-=",
708     "*=",
709     "@=",
710     "/=",
711     "%=",
712     "&=",
713     "|=",
714     "^=",
715     "<<=",
716     ">>=",
717     "**=",
718     "//=",
719 }
720 COMPREHENSION_PRIORITY = 20
721 COMMA_PRIORITY = 18
722 TERNARY_PRIORITY = 16
723 LOGIC_PRIORITY = 14
724 STRING_PRIORITY = 12
725 COMPARATOR_PRIORITY = 10
726 MATH_PRIORITIES = {
727     token.VBAR: 9,
728     token.CIRCUMFLEX: 8,
729     token.AMPER: 7,
730     token.LEFTSHIFT: 6,
731     token.RIGHTSHIFT: 6,
732     token.PLUS: 5,
733     token.MINUS: 5,
734     token.STAR: 4,
735     token.SLASH: 4,
736     token.DOUBLESLASH: 4,
737     token.PERCENT: 4,
738     token.AT: 4,
739     token.TILDE: 3,
740     token.DOUBLESTAR: 2,
741 }
742 DOT_PRIORITY = 1
743
744
745 @dataclass
746 class BracketTracker:
747     """Keeps track of brackets on a line."""
748
749     depth: int = 0
750     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
751     delimiters: Dict[LeafID, Priority] = Factory(dict)
752     previous: Optional[Leaf] = None
753     _for_loop_variable: int = 0
754     _lambda_arguments: int = 0
755
756     def mark(self, leaf: Leaf) -> None:
757         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
758
759         All leaves receive an int `bracket_depth` field that stores how deep
760         within brackets a given leaf is. 0 means there are no enclosing brackets
761         that started on this line.
762
763         If a leaf is itself a closing bracket, it receives an `opening_bracket`
764         field that it forms a pair with. This is a one-directional link to
765         avoid reference cycles.
766
767         If a leaf is a delimiter (a token on which Black can split the line if
768         needed) and it's on depth 0, its `id()` is stored in the tracker's
769         `delimiters` field.
770         """
771         if leaf.type == token.COMMENT:
772             return
773
774         self.maybe_decrement_after_for_loop_variable(leaf)
775         self.maybe_decrement_after_lambda_arguments(leaf)
776         if leaf.type in CLOSING_BRACKETS:
777             self.depth -= 1
778             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
779             leaf.opening_bracket = opening_bracket
780         leaf.bracket_depth = self.depth
781         if self.depth == 0:
782             delim = is_split_before_delimiter(leaf, self.previous)
783             if delim and self.previous is not None:
784                 self.delimiters[id(self.previous)] = delim
785             else:
786                 delim = is_split_after_delimiter(leaf, self.previous)
787                 if delim:
788                     self.delimiters[id(leaf)] = delim
789         if leaf.type in OPENING_BRACKETS:
790             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
791             self.depth += 1
792         self.previous = leaf
793         self.maybe_increment_lambda_arguments(leaf)
794         self.maybe_increment_for_loop_variable(leaf)
795
796     def any_open_brackets(self) -> bool:
797         """Return True if there is an yet unmatched open bracket on the line."""
798         return bool(self.bracket_match)
799
800     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
801         """Return the highest priority of a delimiter found on the line.
802
803         Values are consistent with what `is_split_*_delimiter()` return.
804         Raises ValueError on no delimiters.
805         """
806         return max(v for k, v in self.delimiters.items() if k not in exclude)
807
808     def delimiter_count_with_priority(self, priority: int = 0) -> int:
809         """Return the number of delimiters with the given `priority`.
810
811         If no `priority` is passed, defaults to max priority on the line.
812         """
813         if not self.delimiters:
814             return 0
815
816         priority = priority or self.max_delimiter_priority()
817         return sum(1 for p in self.delimiters.values() if p == priority)
818
819     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
820         """In a for loop, or comprehension, the variables are often unpacks.
821
822         To avoid splitting on the comma in this situation, increase the depth of
823         tokens between `for` and `in`.
824         """
825         if leaf.type == token.NAME and leaf.value == "for":
826             self.depth += 1
827             self._for_loop_variable += 1
828             return True
829
830         return False
831
832     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
833         """See `maybe_increment_for_loop_variable` above for explanation."""
834         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
835             self.depth -= 1
836             self._for_loop_variable -= 1
837             return True
838
839         return False
840
841     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
842         """In a lambda expression, there might be more than one argument.
843
844         To avoid splitting on the comma in this situation, increase the depth of
845         tokens between `lambda` and `:`.
846         """
847         if leaf.type == token.NAME and leaf.value == "lambda":
848             self.depth += 1
849             self._lambda_arguments += 1
850             return True
851
852         return False
853
854     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
855         """See `maybe_increment_lambda_arguments` above for explanation."""
856         if self._lambda_arguments and leaf.type == token.COLON:
857             self.depth -= 1
858             self._lambda_arguments -= 1
859             return True
860
861         return False
862
863     def get_open_lsqb(self) -> Optional[Leaf]:
864         """Return the most recent opening square bracket (if any)."""
865         return self.bracket_match.get((self.depth - 1, token.RSQB))
866
867
868 @dataclass
869 class Line:
870     """Holds leaves and comments. Can be printed with `str(line)`."""
871
872     depth: int = 0
873     leaves: List[Leaf] = Factory(list)
874     comments: List[Tuple[Index, Leaf]] = Factory(list)
875     bracket_tracker: BracketTracker = Factory(BracketTracker)
876     inside_brackets: bool = False
877     should_explode: bool = False
878
879     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
880         """Add a new `leaf` to the end of the line.
881
882         Unless `preformatted` is True, the `leaf` will receive a new consistent
883         whitespace prefix and metadata applied by :class:`BracketTracker`.
884         Trailing commas are maybe removed, unpacked for loop variables are
885         demoted from being delimiters.
886
887         Inline comments are put aside.
888         """
889         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
890         if not has_value:
891             return
892
893         if token.COLON == leaf.type and self.is_class_paren_empty:
894             del self.leaves[-2:]
895         if self.leaves and not preformatted:
896             # Note: at this point leaf.prefix should be empty except for
897             # imports, for which we only preserve newlines.
898             leaf.prefix += whitespace(
899                 leaf, complex_subscript=self.is_complex_subscript(leaf)
900             )
901         if self.inside_brackets or not preformatted:
902             self.bracket_tracker.mark(leaf)
903             self.maybe_remove_trailing_comma(leaf)
904         if not self.append_comment(leaf):
905             self.leaves.append(leaf)
906
907     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
908         """Like :func:`append()` but disallow invalid standalone comment structure.
909
910         Raises ValueError when any `leaf` is appended after a standalone comment
911         or when a standalone comment is not the first leaf on the line.
912         """
913         if self.bracket_tracker.depth == 0:
914             if self.is_comment:
915                 raise ValueError("cannot append to standalone comments")
916
917             if self.leaves and leaf.type == STANDALONE_COMMENT:
918                 raise ValueError(
919                     "cannot append standalone comments to a populated line"
920                 )
921
922         self.append(leaf, preformatted=preformatted)
923
924     @property
925     def is_comment(self) -> bool:
926         """Is this line a standalone comment?"""
927         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
928
929     @property
930     def is_decorator(self) -> bool:
931         """Is this line a decorator?"""
932         return bool(self) and self.leaves[0].type == token.AT
933
934     @property
935     def is_import(self) -> bool:
936         """Is this an import line?"""
937         return bool(self) and is_import(self.leaves[0])
938
939     @property
940     def is_class(self) -> bool:
941         """Is this line a class definition?"""
942         return (
943             bool(self)
944             and self.leaves[0].type == token.NAME
945             and self.leaves[0].value == "class"
946         )
947
948     @property
949     def is_stub_class(self) -> bool:
950         """Is this line a class definition with a body consisting only of "..."?"""
951         return self.is_class and self.leaves[-3:] == [
952             Leaf(token.DOT, ".") for _ in range(3)
953         ]
954
955     @property
956     def is_def(self) -> bool:
957         """Is this a function definition? (Also returns True for async defs.)"""
958         try:
959             first_leaf = self.leaves[0]
960         except IndexError:
961             return False
962
963         try:
964             second_leaf: Optional[Leaf] = self.leaves[1]
965         except IndexError:
966             second_leaf = None
967         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
968             first_leaf.type == token.ASYNC
969             and second_leaf is not None
970             and second_leaf.type == token.NAME
971             and second_leaf.value == "def"
972         )
973
974     @property
975     def is_class_paren_empty(self) -> bool:
976         """Is this a class with no base classes but using parentheses?
977
978         Those are unnecessary and should be removed.
979         """
980         return (
981             bool(self)
982             and len(self.leaves) == 4
983             and self.is_class
984             and self.leaves[2].type == token.LPAR
985             and self.leaves[2].value == "("
986             and self.leaves[3].type == token.RPAR
987             and self.leaves[3].value == ")"
988         )
989
990     @property
991     def is_triple_quoted_string(self) -> bool:
992         """Is the line a triple quoted string?"""
993         return (
994             bool(self)
995             and self.leaves[0].type == token.STRING
996             and self.leaves[0].value.startswith(('"""', "'''"))
997         )
998
999     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1000         """If so, needs to be split before emitting."""
1001         for leaf in self.leaves:
1002             if leaf.type == STANDALONE_COMMENT:
1003                 if leaf.bracket_depth <= depth_limit:
1004                     return True
1005
1006         return False
1007
1008     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1009         """Remove trailing comma if there is one and it's safe."""
1010         if not (
1011             self.leaves
1012             and self.leaves[-1].type == token.COMMA
1013             and closing.type in CLOSING_BRACKETS
1014         ):
1015             return False
1016
1017         if closing.type == token.RBRACE:
1018             self.remove_trailing_comma()
1019             return True
1020
1021         if closing.type == token.RSQB:
1022             comma = self.leaves[-1]
1023             if comma.parent and comma.parent.type == syms.listmaker:
1024                 self.remove_trailing_comma()
1025                 return True
1026
1027         # For parens let's check if it's safe to remove the comma.
1028         # Imports are always safe.
1029         if self.is_import:
1030             self.remove_trailing_comma()
1031             return True
1032
1033         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1034         # change a tuple into a different type by removing the comma.
1035         depth = closing.bracket_depth + 1
1036         commas = 0
1037         opening = closing.opening_bracket
1038         for _opening_index, leaf in enumerate(self.leaves):
1039             if leaf is opening:
1040                 break
1041
1042         else:
1043             return False
1044
1045         for leaf in self.leaves[_opening_index + 1 :]:
1046             if leaf is closing:
1047                 break
1048
1049             bracket_depth = leaf.bracket_depth
1050             if bracket_depth == depth and leaf.type == token.COMMA:
1051                 commas += 1
1052                 if leaf.parent and leaf.parent.type == syms.arglist:
1053                     commas += 1
1054                     break
1055
1056         if commas > 1:
1057             self.remove_trailing_comma()
1058             return True
1059
1060         return False
1061
1062     def append_comment(self, comment: Leaf) -> bool:
1063         """Add an inline or standalone comment to the line."""
1064         if (
1065             comment.type == STANDALONE_COMMENT
1066             and self.bracket_tracker.any_open_brackets()
1067         ):
1068             comment.prefix = ""
1069             return False
1070
1071         if comment.type != token.COMMENT:
1072             return False
1073
1074         after = len(self.leaves) - 1
1075         if after == -1:
1076             comment.type = STANDALONE_COMMENT
1077             comment.prefix = ""
1078             return False
1079
1080         else:
1081             self.comments.append((after, comment))
1082             return True
1083
1084     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1085         """Generate comments that should appear directly after `leaf`.
1086
1087         Provide a non-negative leaf `_index` to speed up the function.
1088         """
1089         if _index == -1:
1090             for _index, _leaf in enumerate(self.leaves):
1091                 if leaf is _leaf:
1092                     break
1093
1094             else:
1095                 return
1096
1097         for index, comment_after in self.comments:
1098             if _index == index:
1099                 yield comment_after
1100
1101     def remove_trailing_comma(self) -> None:
1102         """Remove the trailing comma and moves the comments attached to it."""
1103         comma_index = len(self.leaves) - 1
1104         for i in range(len(self.comments)):
1105             comment_index, comment = self.comments[i]
1106             if comment_index == comma_index:
1107                 self.comments[i] = (comma_index - 1, comment)
1108         self.leaves.pop()
1109
1110     def is_complex_subscript(self, leaf: Leaf) -> bool:
1111         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1112         open_lsqb = (
1113             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1114         )
1115         if open_lsqb is None:
1116             return False
1117
1118         subscript_start = open_lsqb.next_sibling
1119         if (
1120             isinstance(subscript_start, Node)
1121             and subscript_start.type == syms.subscriptlist
1122         ):
1123             subscript_start = child_towards(subscript_start, leaf)
1124         return subscript_start is not None and any(
1125             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1126         )
1127
1128     def __str__(self) -> str:
1129         """Render the line."""
1130         if not self:
1131             return "\n"
1132
1133         indent = "    " * self.depth
1134         leaves = iter(self.leaves)
1135         first = next(leaves)
1136         res = f"{first.prefix}{indent}{first.value}"
1137         for leaf in leaves:
1138             res += str(leaf)
1139         for _, comment in self.comments:
1140             res += str(comment)
1141         return res + "\n"
1142
1143     def __bool__(self) -> bool:
1144         """Return True if the line has leaves or comments."""
1145         return bool(self.leaves or self.comments)
1146
1147
1148 class UnformattedLines(Line):
1149     """Just like :class:`Line` but stores lines which aren't reformatted."""
1150
1151     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1152         """Just add a new `leaf` to the end of the lines.
1153
1154         The `preformatted` argument is ignored.
1155
1156         Keeps track of indentation `depth`, which is useful when the user
1157         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1158         """
1159         try:
1160             list(generate_comments(leaf))
1161         except FormatOn as f_on:
1162             self.leaves.append(f_on.leaf_from_consumed(leaf))
1163             raise
1164
1165         self.leaves.append(leaf)
1166         if leaf.type == token.INDENT:
1167             self.depth += 1
1168         elif leaf.type == token.DEDENT:
1169             self.depth -= 1
1170
1171     def __str__(self) -> str:
1172         """Render unformatted lines from leaves which were added with `append()`.
1173
1174         `depth` is not used for indentation in this case.
1175         """
1176         if not self:
1177             return "\n"
1178
1179         res = ""
1180         for leaf in self.leaves:
1181             res += str(leaf)
1182         return res
1183
1184     def append_comment(self, comment: Leaf) -> bool:
1185         """Not implemented in this class. Raises `NotImplementedError`."""
1186         raise NotImplementedError("Unformatted lines don't store comments separately.")
1187
1188     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1189         """Does nothing and returns False."""
1190         return False
1191
1192     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1193         """Does nothing and returns False."""
1194         return False
1195
1196
1197 @dataclass
1198 class EmptyLineTracker:
1199     """Provides a stateful method that returns the number of potential extra
1200     empty lines needed before and after the currently processed line.
1201
1202     Note: this tracker works on lines that haven't been split yet.  It assumes
1203     the prefix of the first leaf consists of optional newlines.  Those newlines
1204     are consumed by `maybe_empty_lines()` and included in the computation.
1205     """
1206
1207     is_pyi: bool = False
1208     previous_line: Optional[Line] = None
1209     previous_after: int = 0
1210     previous_defs: List[int] = Factory(list)
1211
1212     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1213         """Return the number of extra empty lines before and after the `current_line`.
1214
1215         This is for separating `def`, `async def` and `class` with extra empty
1216         lines (two on module-level).
1217         """
1218         if isinstance(current_line, UnformattedLines):
1219             return 0, 0
1220
1221         before, after = self._maybe_empty_lines(current_line)
1222         before -= self.previous_after
1223         self.previous_after = after
1224         self.previous_line = current_line
1225         return before, after
1226
1227     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1228         max_allowed = 1
1229         if current_line.depth == 0:
1230             max_allowed = 1 if self.is_pyi else 2
1231         if current_line.leaves:
1232             # Consume the first leaf's extra newlines.
1233             first_leaf = current_line.leaves[0]
1234             before = first_leaf.prefix.count("\n")
1235             before = min(before, max_allowed)
1236             first_leaf.prefix = ""
1237         else:
1238             before = 0
1239         depth = current_line.depth
1240         while self.previous_defs and self.previous_defs[-1] >= depth:
1241             self.previous_defs.pop()
1242             if self.is_pyi:
1243                 before = 0 if depth else 1
1244             else:
1245                 before = 1 if depth else 2
1246         is_decorator = current_line.is_decorator
1247         if is_decorator or current_line.is_def or current_line.is_class:
1248             if not is_decorator:
1249                 self.previous_defs.append(depth)
1250             if self.previous_line is None:
1251                 # Don't insert empty lines before the first line in the file.
1252                 return 0, 0
1253
1254             if self.previous_line.is_decorator:
1255                 return 0, 0
1256
1257             if self.previous_line.depth < current_line.depth and (
1258                 self.previous_line.is_class or self.previous_line.is_def
1259             ):
1260                 return 0, 0
1261
1262             if (
1263                 self.previous_line.is_comment
1264                 and self.previous_line.depth == current_line.depth
1265                 and before == 0
1266             ):
1267                 return 0, 0
1268
1269             if self.is_pyi:
1270                 if self.previous_line.depth > current_line.depth:
1271                     newlines = 1
1272                 elif current_line.is_class or self.previous_line.is_class:
1273                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1274                         newlines = 0
1275                     else:
1276                         newlines = 1
1277                 else:
1278                     newlines = 0
1279             else:
1280                 newlines = 2
1281             if current_line.depth and newlines:
1282                 newlines -= 1
1283             return newlines, 0
1284
1285         if (
1286             self.previous_line
1287             and self.previous_line.is_import
1288             and not current_line.is_import
1289             and depth == self.previous_line.depth
1290         ):
1291             return (before or 1), 0
1292
1293         if (
1294             self.previous_line
1295             and self.previous_line.is_class
1296             and current_line.is_triple_quoted_string
1297         ):
1298             return before, 1
1299
1300         return before, 0
1301
1302
1303 @dataclass
1304 class LineGenerator(Visitor[Line]):
1305     """Generates reformatted Line objects.  Empty lines are not emitted.
1306
1307     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1308     in ways that will no longer stringify to valid Python code on the tree.
1309     """
1310
1311     is_pyi: bool = False
1312     current_line: Line = Factory(Line)
1313     remove_u_prefix: bool = False
1314
1315     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1316         """Generate a line.
1317
1318         If the line is empty, only emit if it makes sense.
1319         If the line is too long, split it first and then generate.
1320
1321         If any lines were generated, set up a new current_line.
1322         """
1323         if not self.current_line:
1324             if self.current_line.__class__ == type:
1325                 self.current_line.depth += indent
1326             else:
1327                 self.current_line = type(depth=self.current_line.depth + indent)
1328             return  # Line is empty, don't emit. Creating a new one unnecessary.
1329
1330         complete_line = self.current_line
1331         self.current_line = type(depth=complete_line.depth + indent)
1332         yield complete_line
1333
1334     def visit(self, node: LN) -> Iterator[Line]:
1335         """Main method to visit `node` and its children.
1336
1337         Yields :class:`Line` objects.
1338         """
1339         if isinstance(self.current_line, UnformattedLines):
1340             # File contained `# fmt: off`
1341             yield from self.visit_unformatted(node)
1342
1343         else:
1344             yield from super().visit(node)
1345
1346     def visit_default(self, node: LN) -> Iterator[Line]:
1347         """Default `visit_*()` implementation. Recurses to children of `node`."""
1348         if isinstance(node, Leaf):
1349             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1350             try:
1351                 for comment in generate_comments(node):
1352                     if any_open_brackets:
1353                         # any comment within brackets is subject to splitting
1354                         self.current_line.append(comment)
1355                     elif comment.type == token.COMMENT:
1356                         # regular trailing comment
1357                         self.current_line.append(comment)
1358                         yield from self.line()
1359
1360                     else:
1361                         # regular standalone comment
1362                         yield from self.line()
1363
1364                         self.current_line.append(comment)
1365                         yield from self.line()
1366
1367             except FormatOff as f_off:
1368                 f_off.trim_prefix(node)
1369                 yield from self.line(type=UnformattedLines)
1370                 yield from self.visit(node)
1371
1372             except FormatOn as f_on:
1373                 # This only happens here if somebody says "fmt: on" multiple
1374                 # times in a row.
1375                 f_on.trim_prefix(node)
1376                 yield from self.visit_default(node)
1377
1378             else:
1379                 normalize_prefix(node, inside_brackets=any_open_brackets)
1380                 if node.type == token.STRING:
1381                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1382                     normalize_string_quotes(node)
1383                 if node.type not in WHITESPACE:
1384                     self.current_line.append(node)
1385         yield from super().visit_default(node)
1386
1387     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1388         """Increase indentation level, maybe yield a line."""
1389         # In blib2to3 INDENT never holds comments.
1390         yield from self.line(+1)
1391         yield from self.visit_default(node)
1392
1393     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1394         """Decrease indentation level, maybe yield a line."""
1395         # The current line might still wait for trailing comments.  At DEDENT time
1396         # there won't be any (they would be prefixes on the preceding NEWLINE).
1397         # Emit the line then.
1398         yield from self.line()
1399
1400         # While DEDENT has no value, its prefix may contain standalone comments
1401         # that belong to the current indentation level.  Get 'em.
1402         yield from self.visit_default(node)
1403
1404         # Finally, emit the dedent.
1405         yield from self.line(-1)
1406
1407     def visit_stmt(
1408         self, node: Node, keywords: Set[str], parens: Set[str]
1409     ) -> Iterator[Line]:
1410         """Visit a statement.
1411
1412         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1413         `def`, `with`, `class`, `assert` and assignments.
1414
1415         The relevant Python language `keywords` for a given statement will be
1416         NAME leaves within it. This methods puts those on a separate line.
1417
1418         `parens` holds a set of string leaf values immediately after which
1419         invisible parens should be put.
1420         """
1421         normalize_invisible_parens(node, parens_after=parens)
1422         for child in node.children:
1423             if child.type == token.NAME and child.value in keywords:  # type: ignore
1424                 yield from self.line()
1425
1426             yield from self.visit(child)
1427
1428     def visit_suite(self, node: Node) -> Iterator[Line]:
1429         """Visit a suite."""
1430         if self.is_pyi and is_stub_suite(node):
1431             yield from self.visit(node.children[2])
1432         else:
1433             yield from self.visit_default(node)
1434
1435     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1436         """Visit a statement without nested statements."""
1437         is_suite_like = node.parent and node.parent.type in STATEMENT
1438         if is_suite_like:
1439             if self.is_pyi and is_stub_body(node):
1440                 yield from self.visit_default(node)
1441             else:
1442                 yield from self.line(+1)
1443                 yield from self.visit_default(node)
1444                 yield from self.line(-1)
1445
1446         else:
1447             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1448                 yield from self.line()
1449             yield from self.visit_default(node)
1450
1451     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1452         """Visit `async def`, `async for`, `async with`."""
1453         yield from self.line()
1454
1455         children = iter(node.children)
1456         for child in children:
1457             yield from self.visit(child)
1458
1459             if child.type == token.ASYNC:
1460                 break
1461
1462         internal_stmt = next(children)
1463         for child in internal_stmt.children:
1464             yield from self.visit(child)
1465
1466     def visit_decorators(self, node: Node) -> Iterator[Line]:
1467         """Visit decorators."""
1468         for child in node.children:
1469             yield from self.line()
1470             yield from self.visit(child)
1471
1472     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1473         """Remove a semicolon and put the other statement on a separate line."""
1474         yield from self.line()
1475
1476     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1477         """End of file. Process outstanding comments and end with a newline."""
1478         yield from self.visit_default(leaf)
1479         yield from self.line()
1480
1481     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1482         """Used when file contained a `# fmt: off`."""
1483         if isinstance(node, Node):
1484             for child in node.children:
1485                 yield from self.visit(child)
1486
1487         else:
1488             try:
1489                 self.current_line.append(node)
1490             except FormatOn as f_on:
1491                 f_on.trim_prefix(node)
1492                 yield from self.line()
1493                 yield from self.visit(node)
1494
1495             if node.type == token.ENDMARKER:
1496                 # somebody decided not to put a final `# fmt: on`
1497                 yield from self.line()
1498
1499     def __attrs_post_init__(self) -> None:
1500         """You are in a twisty little maze of passages."""
1501         v = self.visit_stmt
1502         Ø: Set[str] = set()
1503         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1504         self.visit_if_stmt = partial(
1505             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1506         )
1507         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1508         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1509         self.visit_try_stmt = partial(
1510             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1511         )
1512         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1513         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1514         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1515         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1516         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1517         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1518         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1519         self.visit_async_funcdef = self.visit_async_stmt
1520         self.visit_decorated = self.visit_decorators
1521
1522
1523 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1524 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1525 OPENING_BRACKETS = set(BRACKET.keys())
1526 CLOSING_BRACKETS = set(BRACKET.values())
1527 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1528 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1529
1530
1531 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1532     """Return whitespace prefix if needed for the given `leaf`.
1533
1534     `complex_subscript` signals whether the given leaf is part of a subscription
1535     which has non-trivial arguments, like arithmetic expressions or function calls.
1536     """
1537     NO = ""
1538     SPACE = " "
1539     DOUBLESPACE = "  "
1540     t = leaf.type
1541     p = leaf.parent
1542     v = leaf.value
1543     if t in ALWAYS_NO_SPACE:
1544         return NO
1545
1546     if t == token.COMMENT:
1547         return DOUBLESPACE
1548
1549     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1550     if t == token.COLON and p.type not in {
1551         syms.subscript,
1552         syms.subscriptlist,
1553         syms.sliceop,
1554     }:
1555         return NO
1556
1557     prev = leaf.prev_sibling
1558     if not prev:
1559         prevp = preceding_leaf(p)
1560         if not prevp or prevp.type in OPENING_BRACKETS:
1561             return NO
1562
1563         if t == token.COLON:
1564             if prevp.type == token.COLON:
1565                 return NO
1566
1567             elif prevp.type != token.COMMA and not complex_subscript:
1568                 return NO
1569
1570             return SPACE
1571
1572         if prevp.type == token.EQUAL:
1573             if prevp.parent:
1574                 if prevp.parent.type in {
1575                     syms.arglist,
1576                     syms.argument,
1577                     syms.parameters,
1578                     syms.varargslist,
1579                 }:
1580                     return NO
1581
1582                 elif prevp.parent.type == syms.typedargslist:
1583                     # A bit hacky: if the equal sign has whitespace, it means we
1584                     # previously found it's a typed argument.  So, we're using
1585                     # that, too.
1586                     return prevp.prefix
1587
1588         elif prevp.type in STARS:
1589             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1590                 return NO
1591
1592         elif prevp.type == token.COLON:
1593             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1594                 return SPACE if complex_subscript else NO
1595
1596         elif (
1597             prevp.parent
1598             and prevp.parent.type == syms.factor
1599             and prevp.type in MATH_OPERATORS
1600         ):
1601             return NO
1602
1603         elif (
1604             prevp.type == token.RIGHTSHIFT
1605             and prevp.parent
1606             and prevp.parent.type == syms.shift_expr
1607             and prevp.prev_sibling
1608             and prevp.prev_sibling.type == token.NAME
1609             and prevp.prev_sibling.value == "print"  # type: ignore
1610         ):
1611             # Python 2 print chevron
1612             return NO
1613
1614     elif prev.type in OPENING_BRACKETS:
1615         return NO
1616
1617     if p.type in {syms.parameters, syms.arglist}:
1618         # untyped function signatures or calls
1619         if not prev or prev.type != token.COMMA:
1620             return NO
1621
1622     elif p.type == syms.varargslist:
1623         # lambdas
1624         if prev and prev.type != token.COMMA:
1625             return NO
1626
1627     elif p.type == syms.typedargslist:
1628         # typed function signatures
1629         if not prev:
1630             return NO
1631
1632         if t == token.EQUAL:
1633             if prev.type != syms.tname:
1634                 return NO
1635
1636         elif prev.type == token.EQUAL:
1637             # A bit hacky: if the equal sign has whitespace, it means we
1638             # previously found it's a typed argument.  So, we're using that, too.
1639             return prev.prefix
1640
1641         elif prev.type != token.COMMA:
1642             return NO
1643
1644     elif p.type == syms.tname:
1645         # type names
1646         if not prev:
1647             prevp = preceding_leaf(p)
1648             if not prevp or prevp.type != token.COMMA:
1649                 return NO
1650
1651     elif p.type == syms.trailer:
1652         # attributes and calls
1653         if t == token.LPAR or t == token.RPAR:
1654             return NO
1655
1656         if not prev:
1657             if t == token.DOT:
1658                 prevp = preceding_leaf(p)
1659                 if not prevp or prevp.type != token.NUMBER:
1660                     return NO
1661
1662             elif t == token.LSQB:
1663                 return NO
1664
1665         elif prev.type != token.COMMA:
1666             return NO
1667
1668     elif p.type == syms.argument:
1669         # single argument
1670         if t == token.EQUAL:
1671             return NO
1672
1673         if not prev:
1674             prevp = preceding_leaf(p)
1675             if not prevp or prevp.type == token.LPAR:
1676                 return NO
1677
1678         elif prev.type in {token.EQUAL} | STARS:
1679             return NO
1680
1681     elif p.type == syms.decorator:
1682         # decorators
1683         return NO
1684
1685     elif p.type == syms.dotted_name:
1686         if prev:
1687             return NO
1688
1689         prevp = preceding_leaf(p)
1690         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1691             return NO
1692
1693     elif p.type == syms.classdef:
1694         if t == token.LPAR:
1695             return NO
1696
1697         if prev and prev.type == token.LPAR:
1698             return NO
1699
1700     elif p.type in {syms.subscript, syms.sliceop}:
1701         # indexing
1702         if not prev:
1703             assert p.parent is not None, "subscripts are always parented"
1704             if p.parent.type == syms.subscriptlist:
1705                 return SPACE
1706
1707             return NO
1708
1709         elif not complex_subscript:
1710             return NO
1711
1712     elif p.type == syms.atom:
1713         if prev and t == token.DOT:
1714             # dots, but not the first one.
1715             return NO
1716
1717     elif p.type == syms.dictsetmaker:
1718         # dict unpacking
1719         if prev and prev.type == token.DOUBLESTAR:
1720             return NO
1721
1722     elif p.type in {syms.factor, syms.star_expr}:
1723         # unary ops
1724         if not prev:
1725             prevp = preceding_leaf(p)
1726             if not prevp or prevp.type in OPENING_BRACKETS:
1727                 return NO
1728
1729             prevp_parent = prevp.parent
1730             assert prevp_parent is not None
1731             if prevp.type == token.COLON and prevp_parent.type in {
1732                 syms.subscript,
1733                 syms.sliceop,
1734             }:
1735                 return NO
1736
1737             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1738                 return NO
1739
1740         elif t == token.NAME or t == token.NUMBER:
1741             return NO
1742
1743     elif p.type == syms.import_from:
1744         if t == token.DOT:
1745             if prev and prev.type == token.DOT:
1746                 return NO
1747
1748         elif t == token.NAME:
1749             if v == "import":
1750                 return SPACE
1751
1752             if prev and prev.type == token.DOT:
1753                 return NO
1754
1755     elif p.type == syms.sliceop:
1756         return NO
1757
1758     return SPACE
1759
1760
1761 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1762     """Return the first leaf that precedes `node`, if any."""
1763     while node:
1764         res = node.prev_sibling
1765         if res:
1766             if isinstance(res, Leaf):
1767                 return res
1768
1769             try:
1770                 return list(res.leaves())[-1]
1771
1772             except IndexError:
1773                 return None
1774
1775         node = node.parent
1776     return None
1777
1778
1779 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1780     """Return the child of `ancestor` that contains `descendant`."""
1781     node: Optional[LN] = descendant
1782     while node and node.parent != ancestor:
1783         node = node.parent
1784     return node
1785
1786
1787 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1788     """Return the priority of the `leaf` delimiter, given a line break after it.
1789
1790     The delimiter priorities returned here are from those delimiters that would
1791     cause a line break after themselves.
1792
1793     Higher numbers are higher priority.
1794     """
1795     if leaf.type == token.COMMA:
1796         return COMMA_PRIORITY
1797
1798     return 0
1799
1800
1801 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1802     """Return the priority of the `leaf` delimiter, given a line before after it.
1803
1804     The delimiter priorities returned here are from those delimiters that would
1805     cause a line break before themselves.
1806
1807     Higher numbers are higher priority.
1808     """
1809     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1810         # * and ** might also be MATH_OPERATORS but in this case they are not.
1811         # Don't treat them as a delimiter.
1812         return 0
1813
1814     if (
1815         leaf.type == token.DOT
1816         and leaf.parent
1817         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1818         and (previous is None or previous.type in CLOSING_BRACKETS)
1819     ):
1820         return DOT_PRIORITY
1821
1822     if (
1823         leaf.type in MATH_OPERATORS
1824         and leaf.parent
1825         and leaf.parent.type not in {syms.factor, syms.star_expr}
1826     ):
1827         return MATH_PRIORITIES[leaf.type]
1828
1829     if leaf.type in COMPARATORS:
1830         return COMPARATOR_PRIORITY
1831
1832     if (
1833         leaf.type == token.STRING
1834         and previous is not None
1835         and previous.type == token.STRING
1836     ):
1837         return STRING_PRIORITY
1838
1839     if leaf.type != token.NAME:
1840         return 0
1841
1842     if (
1843         leaf.value == "for"
1844         and leaf.parent
1845         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1846     ):
1847         return COMPREHENSION_PRIORITY
1848
1849     if (
1850         leaf.value == "if"
1851         and leaf.parent
1852         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1853     ):
1854         return COMPREHENSION_PRIORITY
1855
1856     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1857         return TERNARY_PRIORITY
1858
1859     if leaf.value == "is":
1860         return COMPARATOR_PRIORITY
1861
1862     if (
1863         leaf.value == "in"
1864         and leaf.parent
1865         and leaf.parent.type in {syms.comp_op, syms.comparison}
1866         and not (
1867             previous is not None
1868             and previous.type == token.NAME
1869             and previous.value == "not"
1870         )
1871     ):
1872         return COMPARATOR_PRIORITY
1873
1874     if (
1875         leaf.value == "not"
1876         and leaf.parent
1877         and leaf.parent.type == syms.comp_op
1878         and not (
1879             previous is not None
1880             and previous.type == token.NAME
1881             and previous.value == "is"
1882         )
1883     ):
1884         return COMPARATOR_PRIORITY
1885
1886     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1887         return LOGIC_PRIORITY
1888
1889     return 0
1890
1891
1892 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1893     """Clean the prefix of the `leaf` and generate comments from it, if any.
1894
1895     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1896     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1897     move because it does away with modifying the grammar to include all the
1898     possible places in which comments can be placed.
1899
1900     The sad consequence for us though is that comments don't "belong" anywhere.
1901     This is why this function generates simple parentless Leaf objects for
1902     comments.  We simply don't know what the correct parent should be.
1903
1904     No matter though, we can live without this.  We really only need to
1905     differentiate between inline and standalone comments.  The latter don't
1906     share the line with any code.
1907
1908     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1909     are emitted with a fake STANDALONE_COMMENT token identifier.
1910     """
1911     p = leaf.prefix
1912     if not p:
1913         return
1914
1915     if "#" not in p:
1916         return
1917
1918     consumed = 0
1919     nlines = 0
1920     for index, line in enumerate(p.split("\n")):
1921         consumed += len(line) + 1  # adding the length of the split '\n'
1922         line = line.lstrip()
1923         if not line:
1924             nlines += 1
1925         if not line.startswith("#"):
1926             continue
1927
1928         if index == 0 and leaf.type != token.ENDMARKER:
1929             comment_type = token.COMMENT  # simple trailing comment
1930         else:
1931             comment_type = STANDALONE_COMMENT
1932         comment = make_comment(line)
1933         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1934
1935         if comment in {"# fmt: on", "# yapf: enable"}:
1936             raise FormatOn(consumed)
1937
1938         if comment in {"# fmt: off", "# yapf: disable"}:
1939             if comment_type == STANDALONE_COMMENT:
1940                 raise FormatOff(consumed)
1941
1942             prev = preceding_leaf(leaf)
1943             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1944                 raise FormatOff(consumed)
1945
1946         nlines = 0
1947
1948
1949 def make_comment(content: str) -> str:
1950     """Return a consistently formatted comment from the given `content` string.
1951
1952     All comments (except for "##", "#!", "#:") should have a single space between
1953     the hash sign and the content.
1954
1955     If `content` didn't start with a hash sign, one is provided.
1956     """
1957     content = content.rstrip()
1958     if not content:
1959         return "#"
1960
1961     if content[0] == "#":
1962         content = content[1:]
1963     if content and content[0] not in " !:#":
1964         content = " " + content
1965     return "#" + content
1966
1967
1968 def split_line(
1969     line: Line, line_length: int, inner: bool = False, py36: bool = False
1970 ) -> Iterator[Line]:
1971     """Split a `line` into potentially many lines.
1972
1973     They should fit in the allotted `line_length` but might not be able to.
1974     `inner` signifies that there were a pair of brackets somewhere around the
1975     current `line`, possibly transitively. This means we can fallback to splitting
1976     by delimiters if the LHS/RHS don't yield any results.
1977
1978     If `py36` is True, splitting may generate syntax that is only compatible
1979     with Python 3.6 and later.
1980     """
1981     if isinstance(line, UnformattedLines) or line.is_comment:
1982         yield line
1983         return
1984
1985     line_str = str(line).strip("\n")
1986     if not line.should_explode and is_line_short_enough(
1987         line, line_length=line_length, line_str=line_str
1988     ):
1989         yield line
1990         return
1991
1992     split_funcs: List[SplitFunc]
1993     if line.is_def:
1994         split_funcs = [left_hand_split]
1995     else:
1996
1997         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1998             for omit in generate_trailers_to_omit(line, line_length):
1999                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2000                 if is_line_short_enough(lines[0], line_length=line_length):
2001                     yield from lines
2002                     return
2003
2004             # All splits failed, best effort split with no omits.
2005             # This mostly happens to multiline strings that are by definition
2006             # reported as not fitting a single line.
2007             yield from right_hand_split(line, py36)
2008
2009         if line.inside_brackets:
2010             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2011         else:
2012             split_funcs = [rhs]
2013     for split_func in split_funcs:
2014         # We are accumulating lines in `result` because we might want to abort
2015         # mission and return the original line in the end, or attempt a different
2016         # split altogether.
2017         result: List[Line] = []
2018         try:
2019             for l in split_func(line, py36):
2020                 if str(l).strip("\n") == line_str:
2021                     raise CannotSplit("Split function returned an unchanged result")
2022
2023                 result.extend(
2024                     split_line(l, line_length=line_length, inner=True, py36=py36)
2025                 )
2026         except CannotSplit as cs:
2027             continue
2028
2029         else:
2030             yield from result
2031             break
2032
2033     else:
2034         yield line
2035
2036
2037 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2038     """Split line into many lines, starting with the first matching bracket pair.
2039
2040     Note: this usually looks weird, only use this for function definitions.
2041     Prefer RHS otherwise.  This is why this function is not symmetrical with
2042     :func:`right_hand_split` which also handles optional parentheses.
2043     """
2044     head = Line(depth=line.depth)
2045     body = Line(depth=line.depth + 1, inside_brackets=True)
2046     tail = Line(depth=line.depth)
2047     tail_leaves: List[Leaf] = []
2048     body_leaves: List[Leaf] = []
2049     head_leaves: List[Leaf] = []
2050     current_leaves = head_leaves
2051     matching_bracket = None
2052     for leaf in line.leaves:
2053         if (
2054             current_leaves is body_leaves
2055             and leaf.type in CLOSING_BRACKETS
2056             and leaf.opening_bracket is matching_bracket
2057         ):
2058             current_leaves = tail_leaves if body_leaves else head_leaves
2059         current_leaves.append(leaf)
2060         if current_leaves is head_leaves:
2061             if leaf.type in OPENING_BRACKETS:
2062                 matching_bracket = leaf
2063                 current_leaves = body_leaves
2064     # Since body is a new indent level, remove spurious leading whitespace.
2065     if body_leaves:
2066         normalize_prefix(body_leaves[0], inside_brackets=True)
2067     # Build the new lines.
2068     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2069         for leaf in leaves:
2070             result.append(leaf, preformatted=True)
2071             for comment_after in line.comments_after(leaf):
2072                 result.append(comment_after, preformatted=True)
2073     bracket_split_succeeded_or_raise(head, body, tail)
2074     for result in (head, body, tail):
2075         if result:
2076             yield result
2077
2078
2079 def right_hand_split(
2080     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2081 ) -> Iterator[Line]:
2082     """Split line into many lines, starting with the last matching bracket pair.
2083
2084     If the split was by optional parentheses, attempt splitting without them, too.
2085     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2086     this split.
2087
2088     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2089     """
2090     head = Line(depth=line.depth)
2091     body = Line(depth=line.depth + 1, inside_brackets=True)
2092     tail = Line(depth=line.depth)
2093     tail_leaves: List[Leaf] = []
2094     body_leaves: List[Leaf] = []
2095     head_leaves: List[Leaf] = []
2096     current_leaves = tail_leaves
2097     opening_bracket = None
2098     closing_bracket = None
2099     for leaf in reversed(line.leaves):
2100         if current_leaves is body_leaves:
2101             if leaf is opening_bracket:
2102                 current_leaves = head_leaves if body_leaves else tail_leaves
2103         current_leaves.append(leaf)
2104         if current_leaves is tail_leaves:
2105             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2106                 opening_bracket = leaf.opening_bracket
2107                 closing_bracket = leaf
2108                 current_leaves = body_leaves
2109     tail_leaves.reverse()
2110     body_leaves.reverse()
2111     head_leaves.reverse()
2112     # Since body is a new indent level, remove spurious leading whitespace.
2113     if body_leaves:
2114         normalize_prefix(body_leaves[0], inside_brackets=True)
2115     if not head_leaves:
2116         # No `head` means the split failed. Either `tail` has all content or
2117         # the matching `opening_bracket` wasn't available on `line` anymore.
2118         raise CannotSplit("No brackets found")
2119
2120     # Build the new lines.
2121     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2122         for leaf in leaves:
2123             result.append(leaf, preformatted=True)
2124             for comment_after in line.comments_after(leaf):
2125                 result.append(comment_after, preformatted=True)
2126     bracket_split_succeeded_or_raise(head, body, tail)
2127     assert opening_bracket and closing_bracket
2128     if (
2129         # the opening bracket is an optional paren
2130         opening_bracket.type == token.LPAR
2131         and not opening_bracket.value
2132         # the closing bracket is an optional paren
2133         and closing_bracket.type == token.RPAR
2134         and not closing_bracket.value
2135         # there are no standalone comments in the body
2136         and not line.contains_standalone_comments(0)
2137         # and it's not an import (optional parens are the only thing we can split
2138         # on in this case; attempting a split without them is a waste of time)
2139         and not line.is_import
2140     ):
2141         omit = {id(closing_bracket), *omit}
2142         if can_omit_invisible_parens(body, line_length):
2143             try:
2144                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2145                 return
2146             except CannotSplit:
2147                 pass
2148
2149     ensure_visible(opening_bracket)
2150     ensure_visible(closing_bracket)
2151     body.should_explode = should_explode(body, opening_bracket)
2152     for result in (head, body, tail):
2153         if result:
2154             yield result
2155
2156
2157 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2158     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2159
2160     Do nothing otherwise.
2161
2162     A left- or right-hand split is based on a pair of brackets. Content before
2163     (and including) the opening bracket is left on one line, content inside the
2164     brackets is put on a separate line, and finally content starting with and
2165     following the closing bracket is put on a separate line.
2166
2167     Those are called `head`, `body`, and `tail`, respectively. If the split
2168     produced the same line (all content in `head`) or ended up with an empty `body`
2169     and the `tail` is just the closing bracket, then it's considered failed.
2170     """
2171     tail_len = len(str(tail).strip())
2172     if not body:
2173         if tail_len == 0:
2174             raise CannotSplit("Splitting brackets produced the same line")
2175
2176         elif tail_len < 3:
2177             raise CannotSplit(
2178                 f"Splitting brackets on an empty body to save "
2179                 f"{tail_len} characters is not worth it"
2180             )
2181
2182
2183 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2184     """Normalize prefix of the first leaf in every line returned by `split_func`.
2185
2186     This is a decorator over relevant split functions.
2187     """
2188
2189     @wraps(split_func)
2190     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2191         for l in split_func(line, py36):
2192             normalize_prefix(l.leaves[0], inside_brackets=True)
2193             yield l
2194
2195     return split_wrapper
2196
2197
2198 @dont_increase_indentation
2199 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2200     """Split according to delimiters of the highest priority.
2201
2202     If `py36` is True, the split will add trailing commas also in function
2203     signatures that contain `*` and `**`.
2204     """
2205     try:
2206         last_leaf = line.leaves[-1]
2207     except IndexError:
2208         raise CannotSplit("Line empty")
2209
2210     bt = line.bracket_tracker
2211     try:
2212         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2213     except ValueError:
2214         raise CannotSplit("No delimiters found")
2215
2216     if delimiter_priority == DOT_PRIORITY:
2217         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2218             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2219
2220     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2221     lowest_depth = sys.maxsize
2222     trailing_comma_safe = True
2223
2224     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2225         """Append `leaf` to current line or to new line if appending impossible."""
2226         nonlocal current_line
2227         try:
2228             current_line.append_safe(leaf, preformatted=True)
2229         except ValueError as ve:
2230             yield current_line
2231
2232             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2233             current_line.append(leaf)
2234
2235     for index, leaf in enumerate(line.leaves):
2236         yield from append_to_line(leaf)
2237
2238         for comment_after in line.comments_after(leaf, index):
2239             yield from append_to_line(comment_after)
2240
2241         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2242         if leaf.bracket_depth == lowest_depth and is_vararg(
2243             leaf, within=VARARGS_PARENTS
2244         ):
2245             trailing_comma_safe = trailing_comma_safe and py36
2246         leaf_priority = bt.delimiters.get(id(leaf))
2247         if leaf_priority == delimiter_priority:
2248             yield current_line
2249
2250             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2251     if current_line:
2252         if (
2253             trailing_comma_safe
2254             and delimiter_priority == COMMA_PRIORITY
2255             and current_line.leaves[-1].type != token.COMMA
2256             and current_line.leaves[-1].type != STANDALONE_COMMENT
2257         ):
2258             current_line.append(Leaf(token.COMMA, ","))
2259         yield current_line
2260
2261
2262 @dont_increase_indentation
2263 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2264     """Split standalone comments from the rest of the line."""
2265     if not line.contains_standalone_comments(0):
2266         raise CannotSplit("Line does not have any standalone comments")
2267
2268     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2269
2270     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2271         """Append `leaf` to current line or to new line if appending impossible."""
2272         nonlocal current_line
2273         try:
2274             current_line.append_safe(leaf, preformatted=True)
2275         except ValueError as ve:
2276             yield current_line
2277
2278             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2279             current_line.append(leaf)
2280
2281     for index, leaf in enumerate(line.leaves):
2282         yield from append_to_line(leaf)
2283
2284         for comment_after in line.comments_after(leaf, index):
2285             yield from append_to_line(comment_after)
2286
2287     if current_line:
2288         yield current_line
2289
2290
2291 def is_import(leaf: Leaf) -> bool:
2292     """Return True if the given leaf starts an import statement."""
2293     p = leaf.parent
2294     t = leaf.type
2295     v = leaf.value
2296     return bool(
2297         t == token.NAME
2298         and (
2299             (v == "import" and p and p.type == syms.import_name)
2300             or (v == "from" and p and p.type == syms.import_from)
2301         )
2302     )
2303
2304
2305 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2306     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2307     else.
2308
2309     Note: don't use backslashes for formatting or you'll lose your voting rights.
2310     """
2311     if not inside_brackets:
2312         spl = leaf.prefix.split("#")
2313         if "\\" not in spl[0]:
2314             nl_count = spl[-1].count("\n")
2315             if len(spl) > 1:
2316                 nl_count -= 1
2317             leaf.prefix = "\n" * nl_count
2318             return
2319
2320     leaf.prefix = ""
2321
2322
2323 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2324     """Make all string prefixes lowercase.
2325
2326     If remove_u_prefix is given, also removes any u prefix from the string.
2327
2328     Note: Mutates its argument.
2329     """
2330     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2331     assert match is not None, f"failed to match string {leaf.value!r}"
2332     orig_prefix = match.group(1)
2333     new_prefix = orig_prefix.lower()
2334     if remove_u_prefix:
2335         new_prefix = new_prefix.replace("u", "")
2336     leaf.value = f"{new_prefix}{match.group(2)}"
2337
2338
2339 def normalize_string_quotes(leaf: Leaf) -> None:
2340     """Prefer double quotes but only if it doesn't cause more escaping.
2341
2342     Adds or removes backslashes as appropriate. Doesn't parse and fix
2343     strings nested in f-strings (yet).
2344
2345     Note: Mutates its argument.
2346     """
2347     value = leaf.value.lstrip("furbFURB")
2348     if value[:3] == '"""':
2349         return
2350
2351     elif value[:3] == "'''":
2352         orig_quote = "'''"
2353         new_quote = '"""'
2354     elif value[0] == '"':
2355         orig_quote = '"'
2356         new_quote = "'"
2357     else:
2358         orig_quote = "'"
2359         new_quote = '"'
2360     first_quote_pos = leaf.value.find(orig_quote)
2361     if first_quote_pos == -1:
2362         return  # There's an internal error
2363
2364     prefix = leaf.value[:first_quote_pos]
2365     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2366     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2367     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2368     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2369     if "r" in prefix.casefold():
2370         if unescaped_new_quote.search(body):
2371             # There's at least one unescaped new_quote in this raw string
2372             # so converting is impossible
2373             return
2374
2375         # Do not introduce or remove backslashes in raw strings
2376         new_body = body
2377     else:
2378         # remove unnecessary quotes
2379         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2380         if body != new_body:
2381             # Consider the string without unnecessary quotes as the original
2382             body = new_body
2383             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2384         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2385         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2386     if new_quote == '"""' and new_body[-1] == '"':
2387         # edge case:
2388         new_body = new_body[:-1] + '\\"'
2389     orig_escape_count = body.count("\\")
2390     new_escape_count = new_body.count("\\")
2391     if new_escape_count > orig_escape_count:
2392         return  # Do not introduce more escaping
2393
2394     if new_escape_count == orig_escape_count and orig_quote == '"':
2395         return  # Prefer double quotes
2396
2397     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2398
2399
2400 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2401     """Make existing optional parentheses invisible or create new ones.
2402
2403     `parens_after` is a set of string leaf values immeditely after which parens
2404     should be put.
2405
2406     Standardizes on visible parentheses for single-element tuples, and keeps
2407     existing visible parentheses for other tuples and generator expressions.
2408     """
2409     try:
2410         list(generate_comments(node))
2411     except FormatOff:
2412         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2413
2414     check_lpar = False
2415     for index, child in enumerate(list(node.children)):
2416         if check_lpar:
2417             if child.type == syms.atom:
2418                 maybe_make_parens_invisible_in_atom(child)
2419             elif is_one_tuple(child):
2420                 # wrap child in visible parentheses
2421                 lpar = Leaf(token.LPAR, "(")
2422                 rpar = Leaf(token.RPAR, ")")
2423                 child.remove()
2424                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2425             elif node.type == syms.import_from:
2426                 # "import from" nodes store parentheses directly as part of
2427                 # the statement
2428                 if child.type == token.LPAR:
2429                     # make parentheses invisible
2430                     child.value = ""  # type: ignore
2431                     node.children[-1].value = ""  # type: ignore
2432                 elif child.type != token.STAR:
2433                     # insert invisible parentheses
2434                     node.insert_child(index, Leaf(token.LPAR, ""))
2435                     node.append_child(Leaf(token.RPAR, ""))
2436                 break
2437
2438             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2439                 # wrap child in invisible parentheses
2440                 lpar = Leaf(token.LPAR, "")
2441                 rpar = Leaf(token.RPAR, "")
2442                 index = child.remove() or 0
2443                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2444
2445         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2446
2447
2448 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2449     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2450     if (
2451         node.type != syms.atom
2452         or is_empty_tuple(node)
2453         or is_one_tuple(node)
2454         or is_yield(node)
2455         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2456     ):
2457         return False
2458
2459     first = node.children[0]
2460     last = node.children[-1]
2461     if first.type == token.LPAR and last.type == token.RPAR:
2462         # make parentheses invisible
2463         first.value = ""  # type: ignore
2464         last.value = ""  # type: ignore
2465         if len(node.children) > 1:
2466             maybe_make_parens_invisible_in_atom(node.children[1])
2467         return True
2468
2469     return False
2470
2471
2472 def is_empty_tuple(node: LN) -> bool:
2473     """Return True if `node` holds an empty tuple."""
2474     return (
2475         node.type == syms.atom
2476         and len(node.children) == 2
2477         and node.children[0].type == token.LPAR
2478         and node.children[1].type == token.RPAR
2479     )
2480
2481
2482 def is_one_tuple(node: LN) -> bool:
2483     """Return True if `node` holds a tuple with one element, with or without parens."""
2484     if node.type == syms.atom:
2485         if len(node.children) != 3:
2486             return False
2487
2488         lpar, gexp, rpar = node.children
2489         if not (
2490             lpar.type == token.LPAR
2491             and gexp.type == syms.testlist_gexp
2492             and rpar.type == token.RPAR
2493         ):
2494             return False
2495
2496         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2497
2498     return (
2499         node.type in IMPLICIT_TUPLE
2500         and len(node.children) == 2
2501         and node.children[1].type == token.COMMA
2502     )
2503
2504
2505 def is_yield(node: LN) -> bool:
2506     """Return True if `node` holds a `yield` or `yield from` expression."""
2507     if node.type == syms.yield_expr:
2508         return True
2509
2510     if node.type == token.NAME and node.value == "yield":  # type: ignore
2511         return True
2512
2513     if node.type != syms.atom:
2514         return False
2515
2516     if len(node.children) != 3:
2517         return False
2518
2519     lpar, expr, rpar = node.children
2520     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2521         return is_yield(expr)
2522
2523     return False
2524
2525
2526 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2527     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2528
2529     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2530     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2531     extended iterable unpacking (PEP 3132) and additional unpacking
2532     generalizations (PEP 448).
2533     """
2534     if leaf.type not in STARS or not leaf.parent:
2535         return False
2536
2537     p = leaf.parent
2538     if p.type == syms.star_expr:
2539         # Star expressions are also used as assignment targets in extended
2540         # iterable unpacking (PEP 3132).  See what its parent is instead.
2541         if not p.parent:
2542             return False
2543
2544         p = p.parent
2545
2546     return p.type in within
2547
2548
2549 def is_multiline_string(leaf: Leaf) -> bool:
2550     """Return True if `leaf` is a multiline string that actually spans many lines."""
2551     value = leaf.value.lstrip("furbFURB")
2552     return value[:3] in {'"""', "'''"} and "\n" in value
2553
2554
2555 def is_stub_suite(node: Node) -> bool:
2556     """Return True if `node` is a suite with a stub body."""
2557     if (
2558         len(node.children) != 4
2559         or node.children[0].type != token.NEWLINE
2560         or node.children[1].type != token.INDENT
2561         or node.children[3].type != token.DEDENT
2562     ):
2563         return False
2564
2565     return is_stub_body(node.children[2])
2566
2567
2568 def is_stub_body(node: LN) -> bool:
2569     """Return True if `node` is a simple statement containing an ellipsis."""
2570     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2571         return False
2572
2573     if len(node.children) != 2:
2574         return False
2575
2576     child = node.children[0]
2577     return (
2578         child.type == syms.atom
2579         and len(child.children) == 3
2580         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2581     )
2582
2583
2584 def max_delimiter_priority_in_atom(node: LN) -> int:
2585     """Return maximum delimiter priority inside `node`.
2586
2587     This is specific to atoms with contents contained in a pair of parentheses.
2588     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2589     """
2590     if node.type != syms.atom:
2591         return 0
2592
2593     first = node.children[0]
2594     last = node.children[-1]
2595     if not (first.type == token.LPAR and last.type == token.RPAR):
2596         return 0
2597
2598     bt = BracketTracker()
2599     for c in node.children[1:-1]:
2600         if isinstance(c, Leaf):
2601             bt.mark(c)
2602         else:
2603             for leaf in c.leaves():
2604                 bt.mark(leaf)
2605     try:
2606         return bt.max_delimiter_priority()
2607
2608     except ValueError:
2609         return 0
2610
2611
2612 def ensure_visible(leaf: Leaf) -> None:
2613     """Make sure parentheses are visible.
2614
2615     They could be invisible as part of some statements (see
2616     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2617     """
2618     if leaf.type == token.LPAR:
2619         leaf.value = "("
2620     elif leaf.type == token.RPAR:
2621         leaf.value = ")"
2622
2623
2624 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2625     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2626     if not (
2627         opening_bracket.parent
2628         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2629         and opening_bracket.value in "[{("
2630     ):
2631         return False
2632
2633     try:
2634         last_leaf = line.leaves[-1]
2635         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2636         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2637     except (IndexError, ValueError):
2638         return False
2639
2640     return max_priority == COMMA_PRIORITY
2641
2642
2643 def is_python36(node: Node) -> bool:
2644     """Return True if the current file is using Python 3.6+ features.
2645
2646     Currently looking for:
2647     - f-strings; and
2648     - trailing commas after * or ** in function signatures and calls.
2649     """
2650     for n in node.pre_order():
2651         if n.type == token.STRING:
2652             value_head = n.value[:2]  # type: ignore
2653             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2654                 return True
2655
2656         elif (
2657             n.type in {syms.typedargslist, syms.arglist}
2658             and n.children
2659             and n.children[-1].type == token.COMMA
2660         ):
2661             for ch in n.children:
2662                 if ch.type in STARS:
2663                     return True
2664
2665                 if ch.type == syms.argument:
2666                     for argch in ch.children:
2667                         if argch.type in STARS:
2668                             return True
2669
2670     return False
2671
2672
2673 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2674     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2675
2676     Brackets can be omitted if the entire trailer up to and including
2677     a preceding closing bracket fits in one line.
2678
2679     Yielded sets are cumulative (contain results of previous yields, too).  First
2680     set is empty.
2681     """
2682
2683     omit: Set[LeafID] = set()
2684     yield omit
2685
2686     length = 4 * line.depth
2687     opening_bracket = None
2688     closing_bracket = None
2689     optional_brackets: Set[LeafID] = set()
2690     inner_brackets: Set[LeafID] = set()
2691     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2692         length += leaf_length
2693         if length > line_length:
2694             break
2695
2696         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2697         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2698             break
2699
2700         optional_brackets.discard(id(leaf))
2701         if opening_bracket:
2702             if leaf is opening_bracket:
2703                 opening_bracket = None
2704             elif leaf.type in CLOSING_BRACKETS:
2705                 inner_brackets.add(id(leaf))
2706         elif leaf.type in CLOSING_BRACKETS:
2707             if not leaf.value:
2708                 optional_brackets.add(id(opening_bracket))
2709                 continue
2710
2711             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2712                 # Empty brackets would fail a split so treat them as "inner"
2713                 # brackets (e.g. only add them to the `omit` set if another
2714                 # pair of brackets was good enough.
2715                 inner_brackets.add(id(leaf))
2716                 continue
2717
2718             opening_bracket = leaf.opening_bracket
2719             if closing_bracket:
2720                 omit.add(id(closing_bracket))
2721                 omit.update(inner_brackets)
2722                 inner_brackets.clear()
2723                 yield omit
2724             closing_bracket = leaf
2725
2726
2727 def get_future_imports(node: Node) -> Set[str]:
2728     """Return a set of __future__ imports in the file."""
2729     imports = set()
2730     for child in node.children:
2731         if child.type != syms.simple_stmt:
2732             break
2733         first_child = child.children[0]
2734         if isinstance(first_child, Leaf):
2735             # Continue looking if we see a docstring; otherwise stop.
2736             if (
2737                 len(child.children) == 2
2738                 and first_child.type == token.STRING
2739                 and child.children[1].type == token.NEWLINE
2740             ):
2741                 continue
2742             else:
2743                 break
2744         elif first_child.type == syms.import_from:
2745             module_name = first_child.children[1]
2746             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2747                 break
2748             for import_from_child in first_child.children[3:]:
2749                 if isinstance(import_from_child, Leaf):
2750                     if import_from_child.type == token.NAME:
2751                         imports.add(import_from_child.value)
2752                 else:
2753                     assert import_from_child.type == syms.import_as_names
2754                     for leaf in import_from_child.children:
2755                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2756                             imports.add(leaf.value)
2757         else:
2758             break
2759     return imports
2760
2761
2762 PYTHON_EXTENSIONS = {".py", ".pyi"}
2763 BLACKLISTED_DIRECTORIES = {
2764     "build",
2765     "buck-out",
2766     "dist",
2767     "_build",
2768     ".git",
2769     ".hg",
2770     ".mypy_cache",
2771     ".tox",
2772     ".venv",
2773 }
2774
2775
2776 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2777     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2778     and have one of the PYTHON_EXTENSIONS.
2779     """
2780     for child in path.iterdir():
2781         if child.is_dir():
2782             if child.name in BLACKLISTED_DIRECTORIES:
2783                 continue
2784
2785             yield from gen_python_files_in_dir(child)
2786
2787         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2788             yield child
2789
2790
2791 @dataclass
2792 class Report:
2793     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2794
2795     check: bool = False
2796     quiet: bool = False
2797     change_count: int = 0
2798     same_count: int = 0
2799     failure_count: int = 0
2800
2801     def done(self, src: Path, changed: Changed) -> None:
2802         """Increment the counter for successful reformatting. Write out a message."""
2803         if changed is Changed.YES:
2804             reformatted = "would reformat" if self.check else "reformatted"
2805             if not self.quiet:
2806                 out(f"{reformatted} {src}")
2807             self.change_count += 1
2808         else:
2809             if not self.quiet:
2810                 if changed is Changed.NO:
2811                     msg = f"{src} already well formatted, good job."
2812                 else:
2813                     msg = f"{src} wasn't modified on disk since last run."
2814                 out(msg, bold=False)
2815             self.same_count += 1
2816
2817     def failed(self, src: Path, message: str) -> None:
2818         """Increment the counter for failed reformatting. Write out a message."""
2819         err(f"error: cannot format {src}: {message}")
2820         self.failure_count += 1
2821
2822     @property
2823     def return_code(self) -> int:
2824         """Return the exit code that the app should use.
2825
2826         This considers the current state of changed files and failures:
2827         - if there were any failures, return 123;
2828         - if any files were changed and --check is being used, return 1;
2829         - otherwise return 0.
2830         """
2831         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2832         # 126 we have special returncodes reserved by the shell.
2833         if self.failure_count:
2834             return 123
2835
2836         elif self.change_count and self.check:
2837             return 1
2838
2839         return 0
2840
2841     def __str__(self) -> str:
2842         """Render a color report of the current state.
2843
2844         Use `click.unstyle` to remove colors.
2845         """
2846         if self.check:
2847             reformatted = "would be reformatted"
2848             unchanged = "would be left unchanged"
2849             failed = "would fail to reformat"
2850         else:
2851             reformatted = "reformatted"
2852             unchanged = "left unchanged"
2853             failed = "failed to reformat"
2854         report = []
2855         if self.change_count:
2856             s = "s" if self.change_count > 1 else ""
2857             report.append(
2858                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2859             )
2860         if self.same_count:
2861             s = "s" if self.same_count > 1 else ""
2862             report.append(f"{self.same_count} file{s} {unchanged}")
2863         if self.failure_count:
2864             s = "s" if self.failure_count > 1 else ""
2865             report.append(
2866                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2867             )
2868         return ", ".join(report) + "."
2869
2870
2871 def assert_equivalent(src: str, dst: str) -> None:
2872     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2873
2874     import ast
2875     import traceback
2876
2877     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2878         """Simple visitor generating strings to compare ASTs by content."""
2879         yield f"{'  ' * depth}{node.__class__.__name__}("
2880
2881         for field in sorted(node._fields):
2882             try:
2883                 value = getattr(node, field)
2884             except AttributeError:
2885                 continue
2886
2887             yield f"{'  ' * (depth+1)}{field}="
2888
2889             if isinstance(value, list):
2890                 for item in value:
2891                     if isinstance(item, ast.AST):
2892                         yield from _v(item, depth + 2)
2893
2894             elif isinstance(value, ast.AST):
2895                 yield from _v(value, depth + 2)
2896
2897             else:
2898                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2899
2900         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2901
2902     try:
2903         src_ast = ast.parse(src)
2904     except Exception as exc:
2905         major, minor = sys.version_info[:2]
2906         raise AssertionError(
2907             f"cannot use --safe with this file; failed to parse source file "
2908             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2909             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2910         )
2911
2912     try:
2913         dst_ast = ast.parse(dst)
2914     except Exception as exc:
2915         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2916         raise AssertionError(
2917             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2918             f"Please report a bug on https://github.com/ambv/black/issues.  "
2919             f"This invalid output might be helpful: {log}"
2920         ) from None
2921
2922     src_ast_str = "\n".join(_v(src_ast))
2923     dst_ast_str = "\n".join(_v(dst_ast))
2924     if src_ast_str != dst_ast_str:
2925         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2926         raise AssertionError(
2927             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2928             f"the source.  "
2929             f"Please report a bug on https://github.com/ambv/black/issues.  "
2930             f"This diff might be helpful: {log}"
2931         ) from None
2932
2933
2934 def assert_stable(
2935     src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
2936 ) -> None:
2937     """Raise AssertionError if `dst` reformats differently the second time."""
2938     newdst = format_str(
2939         dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
2940     )
2941     if dst != newdst:
2942         log = dump_to_file(
2943             diff(src, dst, "source", "first pass"),
2944             diff(dst, newdst, "first pass", "second pass"),
2945         )
2946         raise AssertionError(
2947             f"INTERNAL ERROR: Black produced different code on the second pass "
2948             f"of the formatter.  "
2949             f"Please report a bug on https://github.com/ambv/black/issues.  "
2950             f"This diff might be helpful: {log}"
2951         ) from None
2952
2953
2954 def dump_to_file(*output: str) -> str:
2955     """Dump `output` to a temporary file. Return path to the file."""
2956     import tempfile
2957
2958     with tempfile.NamedTemporaryFile(
2959         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2960     ) as f:
2961         for lines in output:
2962             f.write(lines)
2963             if lines and lines[-1] != "\n":
2964                 f.write("\n")
2965     return f.name
2966
2967
2968 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2969     """Return a unified diff string between strings `a` and `b`."""
2970     import difflib
2971
2972     a_lines = [line + "\n" for line in a.split("\n")]
2973     b_lines = [line + "\n" for line in b.split("\n")]
2974     return "".join(
2975         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2976     )
2977
2978
2979 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2980     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2981     err("Aborted!")
2982     for task in tasks:
2983         task.cancel()
2984
2985
2986 def shutdown(loop: BaseEventLoop) -> None:
2987     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2988     try:
2989         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2990         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2991         if not to_cancel:
2992             return
2993
2994         for task in to_cancel:
2995             task.cancel()
2996         loop.run_until_complete(
2997             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2998         )
2999     finally:
3000         # `concurrent.futures.Future` objects cannot be cancelled once they
3001         # are already running. There might be some when the `shutdown()` happened.
3002         # Silence their logger's spew about the event loop being closed.
3003         cf_logger = logging.getLogger("concurrent.futures")
3004         cf_logger.setLevel(logging.CRITICAL)
3005         loop.close()
3006
3007
3008 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3009     """Replace `regex` with `replacement` twice on `original`.
3010
3011     This is used by string normalization to perform replaces on
3012     overlapping matches.
3013     """
3014     return regex.sub(replacement, regex.sub(replacement, original))
3015
3016
3017 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3018     """Like `reversed(enumerate(sequence))` if that were possible."""
3019     index = len(sequence) - 1
3020     for element in reversed(sequence):
3021         yield (index, element)
3022         index -= 1
3023
3024
3025 def enumerate_with_length(
3026     line: Line, reversed: bool = False
3027 ) -> Iterator[Tuple[Index, Leaf, int]]:
3028     """Return an enumeration of leaves with their length.
3029
3030     Stops prematurely on multiline strings and standalone comments.
3031     """
3032     op = cast(
3033         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3034         enumerate_reversed if reversed else enumerate,
3035     )
3036     for index, leaf in op(line.leaves):
3037         length = len(leaf.prefix) + len(leaf.value)
3038         if "\n" in leaf.value:
3039             return  # Multiline strings, we can't continue.
3040
3041         comment: Optional[Leaf]
3042         for comment in line.comments_after(leaf, index):
3043             length += len(comment.value)
3044
3045         yield index, leaf, length
3046
3047
3048 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3049     """Return True if `line` is no longer than `line_length`.
3050
3051     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3052     """
3053     if not line_str:
3054         line_str = str(line).strip("\n")
3055     return (
3056         len(line_str) <= line_length
3057         and "\n" not in line_str  # multiline strings
3058         and not line.contains_standalone_comments()
3059     )
3060
3061
3062 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3063     """Does `line` have a shape safe to reformat without optional parens around it?
3064
3065     Returns True for only a subset of potentially nice looking formattings but
3066     the point is to not return false positives that end up producing lines that
3067     are too long.
3068     """
3069     bt = line.bracket_tracker
3070     if not bt.delimiters:
3071         # Without delimiters the optional parentheses are useless.
3072         return True
3073
3074     max_priority = bt.max_delimiter_priority()
3075     if bt.delimiter_count_with_priority(max_priority) > 1:
3076         # With more than one delimiter of a kind the optional parentheses read better.
3077         return False
3078
3079     if max_priority == DOT_PRIORITY:
3080         # A single stranded method call doesn't require optional parentheses.
3081         return True
3082
3083     assert len(line.leaves) >= 2, "Stranded delimiter"
3084
3085     first = line.leaves[0]
3086     second = line.leaves[1]
3087     penultimate = line.leaves[-2]
3088     last = line.leaves[-1]
3089
3090     # With a single delimiter, omit if the expression starts or ends with
3091     # a bracket.
3092     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3093         remainder = False
3094         length = 4 * line.depth
3095         for _index, leaf, leaf_length in enumerate_with_length(line):
3096             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3097                 remainder = True
3098             if remainder:
3099                 length += leaf_length
3100                 if length > line_length:
3101                     break
3102
3103                 if leaf.type in OPENING_BRACKETS:
3104                     # There are brackets we can further split on.
3105                     remainder = False
3106
3107         else:
3108             # checked the entire string and line length wasn't exceeded
3109             if len(line.leaves) == _index + 1:
3110                 return True
3111
3112         # Note: we are not returning False here because a line might have *both*
3113         # a leading opening bracket and a trailing closing bracket.  If the
3114         # opening bracket doesn't match our rule, maybe the closing will.
3115
3116     if (
3117         last.type == token.RPAR
3118         or last.type == token.RBRACE
3119         or (
3120             # don't use indexing for omitting optional parentheses;
3121             # it looks weird
3122             last.type == token.RSQB
3123             and last.parent
3124             and last.parent.type != syms.trailer
3125         )
3126     ):
3127         if penultimate.type in OPENING_BRACKETS:
3128             # Empty brackets don't help.
3129             return False
3130
3131         if is_multiline_string(first):
3132             # Additional wrapping of a multiline string in this situation is
3133             # unnecessary.
3134             return True
3135
3136         length = 4 * line.depth
3137         seen_other_brackets = False
3138         for _index, leaf, leaf_length in enumerate_with_length(line):
3139             length += leaf_length
3140             if leaf is last.opening_bracket:
3141                 if seen_other_brackets or length <= line_length:
3142                     return True
3143
3144             elif leaf.type in OPENING_BRACKETS:
3145                 # There are brackets we can further split on.
3146                 seen_other_brackets = True
3147
3148     return False
3149
3150
3151 def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
3152     return (
3153         CACHE_DIR
3154         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3155     )
3156
3157
3158 def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
3159     """Read the cache if it exists and is well formed.
3160
3161     If it is not well formed, the call to write_cache later should resolve the issue.
3162     """
3163     cache_file = get_cache_file(line_length, pyi, py36)
3164     if not cache_file.exists():
3165         return {}
3166
3167     with cache_file.open("rb") as fobj:
3168         try:
3169             cache: Cache = pickle.load(fobj)
3170         except pickle.UnpicklingError:
3171             return {}
3172
3173     return cache
3174
3175
3176 def get_cache_info(path: Path) -> CacheInfo:
3177     """Return the information used to check if a file is already formatted or not."""
3178     stat = path.stat()
3179     return stat.st_mtime, stat.st_size
3180
3181
3182 def filter_cached(
3183     cache: Cache, sources: Iterable[Path]
3184 ) -> Tuple[List[Path], List[Path]]:
3185     """Split a list of paths into two.
3186
3187     The first list contains paths of files that modified on disk or are not in the
3188     cache. The other list contains paths to non-modified files.
3189     """
3190     todo, done = [], []
3191     for src in sources:
3192         src = src.resolve()
3193         if cache.get(src) != get_cache_info(src):
3194             todo.append(src)
3195         else:
3196             done.append(src)
3197     return todo, done
3198
3199
3200 def write_cache(
3201     cache: Cache,
3202     sources: List[Path],
3203     line_length: int,
3204     pyi: bool = False,
3205     py36: bool = False,
3206 ) -> None:
3207     """Update the cache file."""
3208     cache_file = get_cache_file(line_length, pyi, py36)
3209     try:
3210         if not CACHE_DIR.exists():
3211             CACHE_DIR.mkdir(parents=True)
3212         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3213         with cache_file.open("wb") as fobj:
3214             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3215     except OSError:
3216         pass
3217
3218
3219 if __name__ == "__main__":
3220     main()