]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

a31aa7ac4402812d3c54e98125b7e8e3e2a6feb6
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b0"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 @click.command()
126 @click.option(
127     "-l",
128     "--line-length",
129     type=int,
130     default=DEFAULT_LINE_LENGTH,
131     help="How many character per line to allow.",
132     show_default=True,
133 )
134 @click.option(
135     "--check",
136     is_flag=True,
137     help=(
138         "Don't write the files back, just return the status.  Return code 0 "
139         "means nothing would change.  Return code 1 means some files would be "
140         "reformatted.  Return code 123 means there was an internal error."
141     ),
142 )
143 @click.option(
144     "--diff",
145     is_flag=True,
146     help="Don't write the files back, just output a diff for each file on stdout.",
147 )
148 @click.option(
149     "--fast/--safe",
150     is_flag=True,
151     help="If --fast given, skip temporary sanity checks. [default: --safe]",
152 )
153 @click.option(
154     "-q",
155     "--quiet",
156     is_flag=True,
157     help=(
158         "Don't emit non-error messages to stderr. Errors are still emitted, "
159         "silence those with 2>/dev/null."
160     ),
161 )
162 @click.option(
163     "--pyi",
164     is_flag=True,
165     help=(
166         "Consider all input files typing stubs regardless of file extension "
167         "(useful when piping source on standard input)."
168     ),
169 )
170 @click.option(
171     "--py36",
172     is_flag=True,
173     help=(
174         "Allow using Python 3.6-only syntax on all input files.  This will put "
175         "trailing commas in function signatures and calls also after *args and "
176         "**kwargs.  [default: per-file auto-detection]"
177     ),
178 )
179 @click.version_option(version=__version__)
180 @click.argument(
181     "src",
182     nargs=-1,
183     type=click.Path(
184         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
185     ),
186 )
187 @click.pass_context
188 def main(
189     ctx: click.Context,
190     line_length: int,
191     check: bool,
192     diff: bool,
193     fast: bool,
194     pyi: bool,
195     py36: bool,
196     quiet: bool,
197     src: List[str],
198 ) -> None:
199     """The uncompromising code formatter."""
200     sources: List[Path] = []
201     for s in src:
202         p = Path(s)
203         if p.is_dir():
204             sources.extend(gen_python_files_in_dir(p))
205         elif p.is_file():
206             # if a file was explicitly given, we don't care about its extension
207             sources.append(p)
208         elif s == "-":
209             sources.append(Path("-"))
210         else:
211             err(f"invalid path: {s}")
212
213     if check and not diff:
214         write_back = WriteBack.NO
215     elif diff:
216         write_back = WriteBack.DIFF
217     else:
218         write_back = WriteBack.YES
219     report = Report(check=check, quiet=quiet)
220     if len(sources) == 0:
221         out("No paths given. Nothing to do 😴")
222         ctx.exit(0)
223         return
224
225     elif len(sources) == 1:
226         reformat_one(
227             src=sources[0],
228             line_length=line_length,
229             fast=fast,
230             pyi=pyi,
231             py36=py36,
232             write_back=write_back,
233             report=report,
234         )
235     else:
236         loop = asyncio.get_event_loop()
237         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
238         try:
239             loop.run_until_complete(
240                 schedule_formatting(
241                     sources=sources,
242                     line_length=line_length,
243                     fast=fast,
244                     pyi=pyi,
245                     py36=py36,
246                     write_back=write_back,
247                     report=report,
248                     loop=loop,
249                     executor=executor,
250                 )
251             )
252         finally:
253             shutdown(loop)
254         if not quiet:
255             out("All done! ✨ 🍰 ✨")
256             click.echo(str(report))
257     ctx.exit(report.return_code)
258
259
260 def reformat_one(
261     src: Path,
262     line_length: int,
263     fast: bool,
264     pyi: bool,
265     py36: bool,
266     write_back: WriteBack,
267     report: "Report",
268 ) -> None:
269     """Reformat a single file under `src` without spawning child processes.
270
271     If `quiet` is True, non-error messages are not output. `line_length`,
272     `write_back`, `fast` and `pyi` options are passed to
273     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
274     """
275     try:
276         changed = Changed.NO
277         if not src.is_file() and str(src) == "-":
278             if format_stdin_to_stdout(
279                 line_length=line_length,
280                 fast=fast,
281                 is_pyi=pyi,
282                 force_py36=py36,
283                 write_back=write_back,
284             ):
285                 changed = Changed.YES
286         else:
287             cache: Cache = {}
288             if write_back != WriteBack.DIFF:
289                 cache = read_cache(line_length, pyi, py36)
290                 src = src.resolve()
291                 if src in cache and cache[src] == get_cache_info(src):
292                     changed = Changed.CACHED
293             if changed is not Changed.CACHED and format_file_in_place(
294                 src,
295                 line_length=line_length,
296                 fast=fast,
297                 force_pyi=pyi,
298                 force_py36=py36,
299                 write_back=write_back,
300             ):
301                 changed = Changed.YES
302             if write_back == WriteBack.YES and changed is not Changed.NO:
303                 write_cache(cache, [src], line_length, pyi, py36)
304         report.done(src, changed)
305     except Exception as exc:
306         report.failed(src, str(exc))
307
308
309 async def schedule_formatting(
310     sources: List[Path],
311     line_length: int,
312     fast: bool,
313     pyi: bool,
314     py36: bool,
315     write_back: WriteBack,
316     report: "Report",
317     loop: BaseEventLoop,
318     executor: Executor,
319 ) -> None:
320     """Run formatting of `sources` in parallel using the provided `executor`.
321
322     (Use ProcessPoolExecutors for actual parallelism.)
323
324     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
325     :func:`format_file_in_place`.
326     """
327     cache: Cache = {}
328     if write_back != WriteBack.DIFF:
329         cache = read_cache(line_length, pyi, py36)
330         sources, cached = filter_cached(cache, sources)
331         for src in cached:
332             report.done(src, Changed.CACHED)
333     cancelled = []
334     formatted = []
335     if sources:
336         lock = None
337         if write_back == WriteBack.DIFF:
338             # For diff output, we need locks to ensure we don't interleave output
339             # from different processes.
340             manager = Manager()
341             lock = manager.Lock()
342         tasks = {
343             loop.run_in_executor(
344                 executor,
345                 format_file_in_place,
346                 src,
347                 line_length,
348                 fast,
349                 pyi,
350                 py36,
351                 write_back,
352                 lock,
353             ): src
354             for src in sorted(sources)
355         }
356         pending: Iterable[asyncio.Task] = tasks.keys()
357         try:
358             loop.add_signal_handler(signal.SIGINT, cancel, pending)
359             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
360         except NotImplementedError:
361             # There are no good alternatives for these on Windows
362             pass
363         while pending:
364             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
365             for task in done:
366                 src = tasks.pop(task)
367                 if task.cancelled():
368                     cancelled.append(task)
369                 elif task.exception():
370                     report.failed(src, str(task.exception()))
371                 else:
372                     formatted.append(src)
373                     report.done(src, Changed.YES if task.result() else Changed.NO)
374     if cancelled:
375         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
376     if write_back == WriteBack.YES and formatted:
377         write_cache(cache, formatted, line_length, pyi, py36)
378
379
380 def format_file_in_place(
381     src: Path,
382     line_length: int,
383     fast: bool,
384     force_pyi: bool = False,
385     force_py36: bool = False,
386     write_back: WriteBack = WriteBack.NO,
387     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
388 ) -> bool:
389     """Format file under `src` path. Return True if changed.
390
391     If `write_back` is True, write reformatted code back to stdout.
392     `line_length` and `fast` options are passed to :func:`format_file_contents`.
393     """
394     is_pyi = force_pyi or src.suffix == ".pyi"
395
396     with tokenize.open(src) as src_buffer:
397         src_contents = src_buffer.read()
398     try:
399         dst_contents = format_file_contents(
400             src_contents,
401             line_length=line_length,
402             fast=fast,
403             is_pyi=is_pyi,
404             force_py36=force_py36,
405         )
406     except NothingChanged:
407         return False
408
409     if write_back == write_back.YES:
410         with open(src, "w", encoding=src_buffer.encoding) as f:
411             f.write(dst_contents)
412     elif write_back == write_back.DIFF:
413         src_name = f"{src}  (original)"
414         dst_name = f"{src}  (formatted)"
415         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
416         if lock:
417             lock.acquire()
418         try:
419             sys.stdout.write(diff_contents)
420         finally:
421             if lock:
422                 lock.release()
423     return True
424
425
426 def format_stdin_to_stdout(
427     line_length: int,
428     fast: bool,
429     is_pyi: bool = False,
430     force_py36: bool = False,
431     write_back: WriteBack = WriteBack.NO,
432 ) -> bool:
433     """Format file on stdin. Return True if changed.
434
435     If `write_back` is True, write reformatted code back to stdout.
436     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
437     :func:`format_file_contents`.
438     """
439     src = sys.stdin.read()
440     dst = src
441     try:
442         dst = format_file_contents(
443             src,
444             line_length=line_length,
445             fast=fast,
446             is_pyi=is_pyi,
447             force_py36=force_py36,
448         )
449         return True
450
451     except NothingChanged:
452         return False
453
454     finally:
455         if write_back == WriteBack.YES:
456             sys.stdout.write(dst)
457         elif write_back == WriteBack.DIFF:
458             src_name = "<stdin>  (original)"
459             dst_name = "<stdin>  (formatted)"
460             sys.stdout.write(diff(src, dst, src_name, dst_name))
461
462
463 def format_file_contents(
464     src_contents: str,
465     *,
466     line_length: int,
467     fast: bool,
468     is_pyi: bool = False,
469     force_py36: bool = False,
470 ) -> FileContent:
471     """Reformat contents a file and return new contents.
472
473     If `fast` is False, additionally confirm that the reformatted code is
474     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
475     `line_length` is passed to :func:`format_str`.
476     """
477     if src_contents.strip() == "":
478         raise NothingChanged
479
480     dst_contents = format_str(
481         src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
482     )
483     if src_contents == dst_contents:
484         raise NothingChanged
485
486     if not fast:
487         assert_equivalent(src_contents, dst_contents)
488         assert_stable(
489             src_contents,
490             dst_contents,
491             line_length=line_length,
492             is_pyi=is_pyi,
493             force_py36=force_py36,
494         )
495     return dst_contents
496
497
498 def format_str(
499     src_contents: str,
500     line_length: int,
501     *,
502     is_pyi: bool = False,
503     force_py36: bool = False,
504 ) -> FileContent:
505     """Reformat a string and return new contents.
506
507     `line_length` determines how many characters per line are allowed.
508     """
509     src_node = lib2to3_parse(src_contents)
510     dst_contents = ""
511     future_imports = get_future_imports(src_node)
512     elt = EmptyLineTracker(is_pyi=is_pyi)
513     py36 = force_py36 or is_python36(src_node)
514     lines = LineGenerator(
515         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
516     )
517     empty_line = Line()
518     after = 0
519     for current_line in lines.visit(src_node):
520         for _ in range(after):
521             dst_contents += str(empty_line)
522         before, after = elt.maybe_empty_lines(current_line)
523         for _ in range(before):
524             dst_contents += str(empty_line)
525         for line in split_line(current_line, line_length=line_length, py36=py36):
526             dst_contents += str(line)
527     return dst_contents
528
529
530 GRAMMARS = [
531     pygram.python_grammar_no_print_statement_no_exec_statement,
532     pygram.python_grammar_no_print_statement,
533     pygram.python_grammar,
534 ]
535
536
537 def lib2to3_parse(src_txt: str) -> Node:
538     """Given a string with source, return the lib2to3 Node."""
539     grammar = pygram.python_grammar_no_print_statement
540     if src_txt[-1] != "\n":
541         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
542         src_txt += nl
543     for grammar in GRAMMARS:
544         drv = driver.Driver(grammar, pytree.convert)
545         try:
546             result = drv.parse_string(src_txt, True)
547             break
548
549         except ParseError as pe:
550             lineno, column = pe.context[1]
551             lines = src_txt.splitlines()
552             try:
553                 faulty_line = lines[lineno - 1]
554             except IndexError:
555                 faulty_line = "<line number missing in source>"
556             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
557     else:
558         raise exc from None
559
560     if isinstance(result, Leaf):
561         result = Node(syms.file_input, [result])
562     return result
563
564
565 def lib2to3_unparse(node: Node) -> str:
566     """Given a lib2to3 node, return its string representation."""
567     code = str(node)
568     return code
569
570
571 T = TypeVar("T")
572
573
574 class Visitor(Generic[T]):
575     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
576
577     def visit(self, node: LN) -> Iterator[T]:
578         """Main method to visit `node` and its children.
579
580         It tries to find a `visit_*()` method for the given `node.type`, like
581         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
582         If no dedicated `visit_*()` method is found, chooses `visit_default()`
583         instead.
584
585         Then yields objects of type `T` from the selected visitor.
586         """
587         if node.type < 256:
588             name = token.tok_name[node.type]
589         else:
590             name = type_repr(node.type)
591         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
592
593     def visit_default(self, node: LN) -> Iterator[T]:
594         """Default `visit_*()` implementation. Recurses to children of `node`."""
595         if isinstance(node, Node):
596             for child in node.children:
597                 yield from self.visit(child)
598
599
600 @dataclass
601 class DebugVisitor(Visitor[T]):
602     tree_depth: int = 0
603
604     def visit_default(self, node: LN) -> Iterator[T]:
605         indent = " " * (2 * self.tree_depth)
606         if isinstance(node, Node):
607             _type = type_repr(node.type)
608             out(f"{indent}{_type}", fg="yellow")
609             self.tree_depth += 1
610             for child in node.children:
611                 yield from self.visit(child)
612
613             self.tree_depth -= 1
614             out(f"{indent}/{_type}", fg="yellow", bold=False)
615         else:
616             _type = token.tok_name.get(node.type, str(node.type))
617             out(f"{indent}{_type}", fg="blue", nl=False)
618             if node.prefix:
619                 # We don't have to handle prefixes for `Node` objects since
620                 # that delegates to the first child anyway.
621                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
622             out(f" {node.value!r}", fg="blue", bold=False)
623
624     @classmethod
625     def show(cls, code: str) -> None:
626         """Pretty-print the lib2to3 AST of a given string of `code`.
627
628         Convenience method for debugging.
629         """
630         v: DebugVisitor[None] = DebugVisitor()
631         list(v.visit(lib2to3_parse(code)))
632
633
634 KEYWORDS = set(keyword.kwlist)
635 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
636 FLOW_CONTROL = {"return", "raise", "break", "continue"}
637 STATEMENT = {
638     syms.if_stmt,
639     syms.while_stmt,
640     syms.for_stmt,
641     syms.try_stmt,
642     syms.except_clause,
643     syms.with_stmt,
644     syms.funcdef,
645     syms.classdef,
646 }
647 STANDALONE_COMMENT = 153
648 LOGIC_OPERATORS = {"and", "or"}
649 COMPARATORS = {
650     token.LESS,
651     token.GREATER,
652     token.EQEQUAL,
653     token.NOTEQUAL,
654     token.LESSEQUAL,
655     token.GREATEREQUAL,
656 }
657 MATH_OPERATORS = {
658     token.VBAR,
659     token.CIRCUMFLEX,
660     token.AMPER,
661     token.LEFTSHIFT,
662     token.RIGHTSHIFT,
663     token.PLUS,
664     token.MINUS,
665     token.STAR,
666     token.SLASH,
667     token.DOUBLESLASH,
668     token.PERCENT,
669     token.AT,
670     token.TILDE,
671     token.DOUBLESTAR,
672 }
673 STARS = {token.STAR, token.DOUBLESTAR}
674 VARARGS_PARENTS = {
675     syms.arglist,
676     syms.argument,  # double star in arglist
677     syms.trailer,  # single argument to call
678     syms.typedargslist,
679     syms.varargslist,  # lambdas
680 }
681 UNPACKING_PARENTS = {
682     syms.atom,  # single element of a list or set literal
683     syms.dictsetmaker,
684     syms.listmaker,
685     syms.testlist_gexp,
686 }
687 TEST_DESCENDANTS = {
688     syms.test,
689     syms.lambdef,
690     syms.or_test,
691     syms.and_test,
692     syms.not_test,
693     syms.comparison,
694     syms.star_expr,
695     syms.expr,
696     syms.xor_expr,
697     syms.and_expr,
698     syms.shift_expr,
699     syms.arith_expr,
700     syms.trailer,
701     syms.term,
702     syms.power,
703 }
704 ASSIGNMENTS = {
705     "=",
706     "+=",
707     "-=",
708     "*=",
709     "@=",
710     "/=",
711     "%=",
712     "&=",
713     "|=",
714     "^=",
715     "<<=",
716     ">>=",
717     "**=",
718     "//=",
719 }
720 COMPREHENSION_PRIORITY = 20
721 COMMA_PRIORITY = 18
722 TERNARY_PRIORITY = 16
723 LOGIC_PRIORITY = 14
724 STRING_PRIORITY = 12
725 COMPARATOR_PRIORITY = 10
726 MATH_PRIORITIES = {
727     token.VBAR: 9,
728     token.CIRCUMFLEX: 8,
729     token.AMPER: 7,
730     token.LEFTSHIFT: 6,
731     token.RIGHTSHIFT: 6,
732     token.PLUS: 5,
733     token.MINUS: 5,
734     token.STAR: 4,
735     token.SLASH: 4,
736     token.DOUBLESLASH: 4,
737     token.PERCENT: 4,
738     token.AT: 4,
739     token.TILDE: 3,
740     token.DOUBLESTAR: 2,
741 }
742 DOT_PRIORITY = 1
743
744
745 @dataclass
746 class BracketTracker:
747     """Keeps track of brackets on a line."""
748
749     depth: int = 0
750     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
751     delimiters: Dict[LeafID, Priority] = Factory(dict)
752     previous: Optional[Leaf] = None
753     _for_loop_variable: int = 0
754     _lambda_arguments: int = 0
755
756     def mark(self, leaf: Leaf) -> None:
757         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
758
759         All leaves receive an int `bracket_depth` field that stores how deep
760         within brackets a given leaf is. 0 means there are no enclosing brackets
761         that started on this line.
762
763         If a leaf is itself a closing bracket, it receives an `opening_bracket`
764         field that it forms a pair with. This is a one-directional link to
765         avoid reference cycles.
766
767         If a leaf is a delimiter (a token on which Black can split the line if
768         needed) and it's on depth 0, its `id()` is stored in the tracker's
769         `delimiters` field.
770         """
771         if leaf.type == token.COMMENT:
772             return
773
774         self.maybe_decrement_after_for_loop_variable(leaf)
775         self.maybe_decrement_after_lambda_arguments(leaf)
776         if leaf.type in CLOSING_BRACKETS:
777             self.depth -= 1
778             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
779             leaf.opening_bracket = opening_bracket
780         leaf.bracket_depth = self.depth
781         if self.depth == 0:
782             delim = is_split_before_delimiter(leaf, self.previous)
783             if delim and self.previous is not None:
784                 self.delimiters[id(self.previous)] = delim
785             else:
786                 delim = is_split_after_delimiter(leaf, self.previous)
787                 if delim:
788                     self.delimiters[id(leaf)] = delim
789         if leaf.type in OPENING_BRACKETS:
790             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
791             self.depth += 1
792         self.previous = leaf
793         self.maybe_increment_lambda_arguments(leaf)
794         self.maybe_increment_for_loop_variable(leaf)
795
796     def any_open_brackets(self) -> bool:
797         """Return True if there is an yet unmatched open bracket on the line."""
798         return bool(self.bracket_match)
799
800     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
801         """Return the highest priority of a delimiter found on the line.
802
803         Values are consistent with what `is_split_*_delimiter()` return.
804         Raises ValueError on no delimiters.
805         """
806         return max(v for k, v in self.delimiters.items() if k not in exclude)
807
808     def delimiter_count_with_priority(self, priority: int = 0) -> int:
809         """Return the number of delimiters with the given `priority`.
810
811         If no `priority` is passed, defaults to max priority on the line.
812         """
813         if not self.delimiters:
814             return 0
815
816         priority = priority or self.max_delimiter_priority()
817         return sum(1 for p in self.delimiters.values() if p == priority)
818
819     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
820         """In a for loop, or comprehension, the variables are often unpacks.
821
822         To avoid splitting on the comma in this situation, increase the depth of
823         tokens between `for` and `in`.
824         """
825         if leaf.type == token.NAME and leaf.value == "for":
826             self.depth += 1
827             self._for_loop_variable += 1
828             return True
829
830         return False
831
832     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
833         """See `maybe_increment_for_loop_variable` above for explanation."""
834         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
835             self.depth -= 1
836             self._for_loop_variable -= 1
837             return True
838
839         return False
840
841     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
842         """In a lambda expression, there might be more than one argument.
843
844         To avoid splitting on the comma in this situation, increase the depth of
845         tokens between `lambda` and `:`.
846         """
847         if leaf.type == token.NAME and leaf.value == "lambda":
848             self.depth += 1
849             self._lambda_arguments += 1
850             return True
851
852         return False
853
854     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
855         """See `maybe_increment_lambda_arguments` above for explanation."""
856         if self._lambda_arguments and leaf.type == token.COLON:
857             self.depth -= 1
858             self._lambda_arguments -= 1
859             return True
860
861         return False
862
863     def get_open_lsqb(self) -> Optional[Leaf]:
864         """Return the most recent opening square bracket (if any)."""
865         return self.bracket_match.get((self.depth - 1, token.RSQB))
866
867
868 @dataclass
869 class Line:
870     """Holds leaves and comments. Can be printed with `str(line)`."""
871
872     depth: int = 0
873     leaves: List[Leaf] = Factory(list)
874     comments: List[Tuple[Index, Leaf]] = Factory(list)
875     bracket_tracker: BracketTracker = Factory(BracketTracker)
876     inside_brackets: bool = False
877     should_explode: bool = False
878
879     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
880         """Add a new `leaf` to the end of the line.
881
882         Unless `preformatted` is True, the `leaf` will receive a new consistent
883         whitespace prefix and metadata applied by :class:`BracketTracker`.
884         Trailing commas are maybe removed, unpacked for loop variables are
885         demoted from being delimiters.
886
887         Inline comments are put aside.
888         """
889         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
890         if not has_value:
891             return
892
893         if token.COLON == leaf.type and self.is_class_paren_empty:
894             del self.leaves[-2:]
895         if self.leaves and not preformatted:
896             # Note: at this point leaf.prefix should be empty except for
897             # imports, for which we only preserve newlines.
898             leaf.prefix += whitespace(
899                 leaf, complex_subscript=self.is_complex_subscript(leaf)
900             )
901         if self.inside_brackets or not preformatted:
902             self.bracket_tracker.mark(leaf)
903             self.maybe_remove_trailing_comma(leaf)
904         if not self.append_comment(leaf):
905             self.leaves.append(leaf)
906
907     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
908         """Like :func:`append()` but disallow invalid standalone comment structure.
909
910         Raises ValueError when any `leaf` is appended after a standalone comment
911         or when a standalone comment is not the first leaf on the line.
912         """
913         if self.bracket_tracker.depth == 0:
914             if self.is_comment:
915                 raise ValueError("cannot append to standalone comments")
916
917             if self.leaves and leaf.type == STANDALONE_COMMENT:
918                 raise ValueError(
919                     "cannot append standalone comments to a populated line"
920                 )
921
922         self.append(leaf, preformatted=preformatted)
923
924     @property
925     def is_comment(self) -> bool:
926         """Is this line a standalone comment?"""
927         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
928
929     @property
930     def is_decorator(self) -> bool:
931         """Is this line a decorator?"""
932         return bool(self) and self.leaves[0].type == token.AT
933
934     @property
935     def is_import(self) -> bool:
936         """Is this an import line?"""
937         return bool(self) and is_import(self.leaves[0])
938
939     @property
940     def is_class(self) -> bool:
941         """Is this line a class definition?"""
942         return (
943             bool(self)
944             and self.leaves[0].type == token.NAME
945             and self.leaves[0].value == "class"
946         )
947
948     @property
949     def is_stub_class(self) -> bool:
950         """Is this line a class definition with a body consisting only of "..."?"""
951         return self.is_class and self.leaves[-3:] == [
952             Leaf(token.DOT, ".") for _ in range(3)
953         ]
954
955     @property
956     def is_def(self) -> bool:
957         """Is this a function definition? (Also returns True for async defs.)"""
958         try:
959             first_leaf = self.leaves[0]
960         except IndexError:
961             return False
962
963         try:
964             second_leaf: Optional[Leaf] = self.leaves[1]
965         except IndexError:
966             second_leaf = None
967         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
968             first_leaf.type == token.ASYNC
969             and second_leaf is not None
970             and second_leaf.type == token.NAME
971             and second_leaf.value == "def"
972         )
973
974     @property
975     def is_flow_control(self) -> bool:
976         """Is this line a flow control statement?
977
978         Those are `return`, `raise`, `break`, and `continue`.
979         """
980         return (
981             bool(self)
982             and self.leaves[0].type == token.NAME
983             and self.leaves[0].value in FLOW_CONTROL
984         )
985
986     @property
987     def is_yield(self) -> bool:
988         """Is this line a yield statement?"""
989         return (
990             bool(self)
991             and self.leaves[0].type == token.NAME
992             and self.leaves[0].value == "yield"
993         )
994
995     @property
996     def is_class_paren_empty(self) -> bool:
997         """Is this a class with no base classes but using parentheses?
998
999         Those are unnecessary and should be removed.
1000         """
1001         return (
1002             bool(self)
1003             and len(self.leaves) == 4
1004             and self.is_class
1005             and self.leaves[2].type == token.LPAR
1006             and self.leaves[2].value == "("
1007             and self.leaves[3].type == token.RPAR
1008             and self.leaves[3].value == ")"
1009         )
1010
1011     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1012         """If so, needs to be split before emitting."""
1013         for leaf in self.leaves:
1014             if leaf.type == STANDALONE_COMMENT:
1015                 if leaf.bracket_depth <= depth_limit:
1016                     return True
1017
1018         return False
1019
1020     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1021         """Remove trailing comma if there is one and it's safe."""
1022         if not (
1023             self.leaves
1024             and self.leaves[-1].type == token.COMMA
1025             and closing.type in CLOSING_BRACKETS
1026         ):
1027             return False
1028
1029         if closing.type == token.RBRACE:
1030             self.remove_trailing_comma()
1031             return True
1032
1033         if closing.type == token.RSQB:
1034             comma = self.leaves[-1]
1035             if comma.parent and comma.parent.type == syms.listmaker:
1036                 self.remove_trailing_comma()
1037                 return True
1038
1039         # For parens let's check if it's safe to remove the comma.
1040         # Imports are always safe.
1041         if self.is_import:
1042             self.remove_trailing_comma()
1043             return True
1044
1045         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1046         # change a tuple into a different type by removing the comma.
1047         depth = closing.bracket_depth + 1
1048         commas = 0
1049         opening = closing.opening_bracket
1050         for _opening_index, leaf in enumerate(self.leaves):
1051             if leaf is opening:
1052                 break
1053
1054         else:
1055             return False
1056
1057         for leaf in self.leaves[_opening_index + 1 :]:
1058             if leaf is closing:
1059                 break
1060
1061             bracket_depth = leaf.bracket_depth
1062             if bracket_depth == depth and leaf.type == token.COMMA:
1063                 commas += 1
1064                 if leaf.parent and leaf.parent.type == syms.arglist:
1065                     commas += 1
1066                     break
1067
1068         if commas > 1:
1069             self.remove_trailing_comma()
1070             return True
1071
1072         return False
1073
1074     def append_comment(self, comment: Leaf) -> bool:
1075         """Add an inline or standalone comment to the line."""
1076         if (
1077             comment.type == STANDALONE_COMMENT
1078             and self.bracket_tracker.any_open_brackets()
1079         ):
1080             comment.prefix = ""
1081             return False
1082
1083         if comment.type != token.COMMENT:
1084             return False
1085
1086         after = len(self.leaves) - 1
1087         if after == -1:
1088             comment.type = STANDALONE_COMMENT
1089             comment.prefix = ""
1090             return False
1091
1092         else:
1093             self.comments.append((after, comment))
1094             return True
1095
1096     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1097         """Generate comments that should appear directly after `leaf`.
1098
1099         Provide a non-negative leaf `_index` to speed up the function.
1100         """
1101         if _index == -1:
1102             for _index, _leaf in enumerate(self.leaves):
1103                 if leaf is _leaf:
1104                     break
1105
1106             else:
1107                 return
1108
1109         for index, comment_after in self.comments:
1110             if _index == index:
1111                 yield comment_after
1112
1113     def remove_trailing_comma(self) -> None:
1114         """Remove the trailing comma and moves the comments attached to it."""
1115         comma_index = len(self.leaves) - 1
1116         for i in range(len(self.comments)):
1117             comment_index, comment = self.comments[i]
1118             if comment_index == comma_index:
1119                 self.comments[i] = (comma_index - 1, comment)
1120         self.leaves.pop()
1121
1122     def is_complex_subscript(self, leaf: Leaf) -> bool:
1123         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1124         open_lsqb = (
1125             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1126         )
1127         if open_lsqb is None:
1128             return False
1129
1130         subscript_start = open_lsqb.next_sibling
1131         if (
1132             isinstance(subscript_start, Node)
1133             and subscript_start.type == syms.subscriptlist
1134         ):
1135             subscript_start = child_towards(subscript_start, leaf)
1136         return subscript_start is not None and any(
1137             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1138         )
1139
1140     def __str__(self) -> str:
1141         """Render the line."""
1142         if not self:
1143             return "\n"
1144
1145         indent = "    " * self.depth
1146         leaves = iter(self.leaves)
1147         first = next(leaves)
1148         res = f"{first.prefix}{indent}{first.value}"
1149         for leaf in leaves:
1150             res += str(leaf)
1151         for _, comment in self.comments:
1152             res += str(comment)
1153         return res + "\n"
1154
1155     def __bool__(self) -> bool:
1156         """Return True if the line has leaves or comments."""
1157         return bool(self.leaves or self.comments)
1158
1159
1160 class UnformattedLines(Line):
1161     """Just like :class:`Line` but stores lines which aren't reformatted."""
1162
1163     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1164         """Just add a new `leaf` to the end of the lines.
1165
1166         The `preformatted` argument is ignored.
1167
1168         Keeps track of indentation `depth`, which is useful when the user
1169         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1170         """
1171         try:
1172             list(generate_comments(leaf))
1173         except FormatOn as f_on:
1174             self.leaves.append(f_on.leaf_from_consumed(leaf))
1175             raise
1176
1177         self.leaves.append(leaf)
1178         if leaf.type == token.INDENT:
1179             self.depth += 1
1180         elif leaf.type == token.DEDENT:
1181             self.depth -= 1
1182
1183     def __str__(self) -> str:
1184         """Render unformatted lines from leaves which were added with `append()`.
1185
1186         `depth` is not used for indentation in this case.
1187         """
1188         if not self:
1189             return "\n"
1190
1191         res = ""
1192         for leaf in self.leaves:
1193             res += str(leaf)
1194         return res
1195
1196     def append_comment(self, comment: Leaf) -> bool:
1197         """Not implemented in this class. Raises `NotImplementedError`."""
1198         raise NotImplementedError("Unformatted lines don't store comments separately.")
1199
1200     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1201         """Does nothing and returns False."""
1202         return False
1203
1204     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1205         """Does nothing and returns False."""
1206         return False
1207
1208
1209 @dataclass
1210 class EmptyLineTracker:
1211     """Provides a stateful method that returns the number of potential extra
1212     empty lines needed before and after the currently processed line.
1213
1214     Note: this tracker works on lines that haven't been split yet.  It assumes
1215     the prefix of the first leaf consists of optional newlines.  Those newlines
1216     are consumed by `maybe_empty_lines()` and included in the computation.
1217     """
1218     is_pyi: bool = False
1219     previous_line: Optional[Line] = None
1220     previous_after: int = 0
1221     previous_defs: List[int] = Factory(list)
1222
1223     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1224         """Return the number of extra empty lines before and after the `current_line`.
1225
1226         This is for separating `def`, `async def` and `class` with extra empty
1227         lines (two on module-level), as well as providing an extra empty line
1228         after flow control keywords to make them more prominent.
1229         """
1230         if isinstance(current_line, UnformattedLines):
1231             return 0, 0
1232
1233         before, after = self._maybe_empty_lines(current_line)
1234         before -= self.previous_after
1235         self.previous_after = after
1236         self.previous_line = current_line
1237         return before, after
1238
1239     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1240         max_allowed = 1
1241         if current_line.depth == 0:
1242             max_allowed = 1 if self.is_pyi else 2
1243         if current_line.leaves:
1244             # Consume the first leaf's extra newlines.
1245             first_leaf = current_line.leaves[0]
1246             before = first_leaf.prefix.count("\n")
1247             before = min(before, max_allowed)
1248             first_leaf.prefix = ""
1249         else:
1250             before = 0
1251         depth = current_line.depth
1252         while self.previous_defs and self.previous_defs[-1] >= depth:
1253             self.previous_defs.pop()
1254             if self.is_pyi:
1255                 before = 0 if depth else 1
1256             else:
1257                 before = 1 if depth else 2
1258         is_decorator = current_line.is_decorator
1259         if is_decorator or current_line.is_def or current_line.is_class:
1260             if not is_decorator:
1261                 self.previous_defs.append(depth)
1262             if self.previous_line is None:
1263                 # Don't insert empty lines before the first line in the file.
1264                 return 0, 0
1265
1266             if self.previous_line.is_decorator:
1267                 return 0, 0
1268
1269             if (
1270                 self.previous_line.is_comment
1271                 and self.previous_line.depth == current_line.depth
1272                 and before == 0
1273             ):
1274                 return 0, 0
1275
1276             if self.is_pyi:
1277                 if self.previous_line.depth > current_line.depth:
1278                     newlines = 1
1279                 elif current_line.is_class or self.previous_line.is_class:
1280                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1281                         newlines = 0
1282                     else:
1283                         newlines = 1
1284                 else:
1285                     newlines = 0
1286             else:
1287                 newlines = 2
1288             if current_line.depth and newlines:
1289                 newlines -= 1
1290             return newlines, 0
1291
1292         if (
1293             self.previous_line
1294             and self.previous_line.is_import
1295             and not current_line.is_import
1296             and depth == self.previous_line.depth
1297         ):
1298             return (before or 1), 0
1299
1300         return before, 0
1301
1302
1303 @dataclass
1304 class LineGenerator(Visitor[Line]):
1305     """Generates reformatted Line objects.  Empty lines are not emitted.
1306
1307     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1308     in ways that will no longer stringify to valid Python code on the tree.
1309     """
1310     is_pyi: bool = False
1311     current_line: Line = Factory(Line)
1312     remove_u_prefix: bool = False
1313
1314     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1315         """Generate a line.
1316
1317         If the line is empty, only emit if it makes sense.
1318         If the line is too long, split it first and then generate.
1319
1320         If any lines were generated, set up a new current_line.
1321         """
1322         if not self.current_line:
1323             if self.current_line.__class__ == type:
1324                 self.current_line.depth += indent
1325             else:
1326                 self.current_line = type(depth=self.current_line.depth + indent)
1327             return  # Line is empty, don't emit. Creating a new one unnecessary.
1328
1329         complete_line = self.current_line
1330         self.current_line = type(depth=complete_line.depth + indent)
1331         yield complete_line
1332
1333     def visit(self, node: LN) -> Iterator[Line]:
1334         """Main method to visit `node` and its children.
1335
1336         Yields :class:`Line` objects.
1337         """
1338         if isinstance(self.current_line, UnformattedLines):
1339             # File contained `# fmt: off`
1340             yield from self.visit_unformatted(node)
1341
1342         else:
1343             yield from super().visit(node)
1344
1345     def visit_default(self, node: LN) -> Iterator[Line]:
1346         """Default `visit_*()` implementation. Recurses to children of `node`."""
1347         if isinstance(node, Leaf):
1348             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1349             try:
1350                 for comment in generate_comments(node):
1351                     if any_open_brackets:
1352                         # any comment within brackets is subject to splitting
1353                         self.current_line.append(comment)
1354                     elif comment.type == token.COMMENT:
1355                         # regular trailing comment
1356                         self.current_line.append(comment)
1357                         yield from self.line()
1358
1359                     else:
1360                         # regular standalone comment
1361                         yield from self.line()
1362
1363                         self.current_line.append(comment)
1364                         yield from self.line()
1365
1366             except FormatOff as f_off:
1367                 f_off.trim_prefix(node)
1368                 yield from self.line(type=UnformattedLines)
1369                 yield from self.visit(node)
1370
1371             except FormatOn as f_on:
1372                 # This only happens here if somebody says "fmt: on" multiple
1373                 # times in a row.
1374                 f_on.trim_prefix(node)
1375                 yield from self.visit_default(node)
1376
1377             else:
1378                 normalize_prefix(node, inside_brackets=any_open_brackets)
1379                 if node.type == token.STRING:
1380                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1381                     normalize_string_quotes(node)
1382                 if node.type not in WHITESPACE:
1383                     self.current_line.append(node)
1384         yield from super().visit_default(node)
1385
1386     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1387         """Increase indentation level, maybe yield a line."""
1388         # In blib2to3 INDENT never holds comments.
1389         yield from self.line(+1)
1390         yield from self.visit_default(node)
1391
1392     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1393         """Decrease indentation level, maybe yield a line."""
1394         # The current line might still wait for trailing comments.  At DEDENT time
1395         # there won't be any (they would be prefixes on the preceding NEWLINE).
1396         # Emit the line then.
1397         yield from self.line()
1398
1399         # While DEDENT has no value, its prefix may contain standalone comments
1400         # that belong to the current indentation level.  Get 'em.
1401         yield from self.visit_default(node)
1402
1403         # Finally, emit the dedent.
1404         yield from self.line(-1)
1405
1406     def visit_stmt(
1407         self, node: Node, keywords: Set[str], parens: Set[str]
1408     ) -> Iterator[Line]:
1409         """Visit a statement.
1410
1411         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1412         `def`, `with`, `class`, `assert` and assignments.
1413
1414         The relevant Python language `keywords` for a given statement will be
1415         NAME leaves within it. This methods puts those on a separate line.
1416
1417         `parens` holds a set of string leaf values immediately after which
1418         invisible parens should be put.
1419         """
1420         normalize_invisible_parens(node, parens_after=parens)
1421         for child in node.children:
1422             if child.type == token.NAME and child.value in keywords:  # type: ignore
1423                 yield from self.line()
1424
1425             yield from self.visit(child)
1426
1427     def visit_suite(self, node: Node) -> Iterator[Line]:
1428         """Visit a suite."""
1429         if self.is_pyi and is_stub_suite(node):
1430             yield from self.visit(node.children[2])
1431         else:
1432             yield from self.visit_default(node)
1433
1434     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1435         """Visit a statement without nested statements."""
1436         is_suite_like = node.parent and node.parent.type in STATEMENT
1437         if is_suite_like:
1438             if self.is_pyi and is_stub_body(node):
1439                 yield from self.visit_default(node)
1440             else:
1441                 yield from self.line(+1)
1442                 yield from self.visit_default(node)
1443                 yield from self.line(-1)
1444
1445         else:
1446             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1447                 yield from self.line()
1448             yield from self.visit_default(node)
1449
1450     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1451         """Visit `async def`, `async for`, `async with`."""
1452         yield from self.line()
1453
1454         children = iter(node.children)
1455         for child in children:
1456             yield from self.visit(child)
1457
1458             if child.type == token.ASYNC:
1459                 break
1460
1461         internal_stmt = next(children)
1462         for child in internal_stmt.children:
1463             yield from self.visit(child)
1464
1465     def visit_decorators(self, node: Node) -> Iterator[Line]:
1466         """Visit decorators."""
1467         for child in node.children:
1468             yield from self.line()
1469             yield from self.visit(child)
1470
1471     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1472         """Remove a semicolon and put the other statement on a separate line."""
1473         yield from self.line()
1474
1475     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1476         """End of file. Process outstanding comments and end with a newline."""
1477         yield from self.visit_default(leaf)
1478         yield from self.line()
1479
1480     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1481         """Used when file contained a `# fmt: off`."""
1482         if isinstance(node, Node):
1483             for child in node.children:
1484                 yield from self.visit(child)
1485
1486         else:
1487             try:
1488                 self.current_line.append(node)
1489             except FormatOn as f_on:
1490                 f_on.trim_prefix(node)
1491                 yield from self.line()
1492                 yield from self.visit(node)
1493
1494             if node.type == token.ENDMARKER:
1495                 # somebody decided not to put a final `# fmt: on`
1496                 yield from self.line()
1497
1498     def __attrs_post_init__(self) -> None:
1499         """You are in a twisty little maze of passages."""
1500         v = self.visit_stmt
1501         Ø: Set[str] = set()
1502         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1503         self.visit_if_stmt = partial(
1504             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1505         )
1506         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1507         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1508         self.visit_try_stmt = partial(
1509             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1510         )
1511         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1512         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1513         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1514         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1515         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1516         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1517         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1518         self.visit_async_funcdef = self.visit_async_stmt
1519         self.visit_decorated = self.visit_decorators
1520
1521
1522 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1523 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1524 OPENING_BRACKETS = set(BRACKET.keys())
1525 CLOSING_BRACKETS = set(BRACKET.values())
1526 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1527 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1528
1529
1530 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1531     """Return whitespace prefix if needed for the given `leaf`.
1532
1533     `complex_subscript` signals whether the given leaf is part of a subscription
1534     which has non-trivial arguments, like arithmetic expressions or function calls.
1535     """
1536     NO = ""
1537     SPACE = " "
1538     DOUBLESPACE = "  "
1539     t = leaf.type
1540     p = leaf.parent
1541     v = leaf.value
1542     if t in ALWAYS_NO_SPACE:
1543         return NO
1544
1545     if t == token.COMMENT:
1546         return DOUBLESPACE
1547
1548     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1549     if t == token.COLON and p.type not in {
1550         syms.subscript,
1551         syms.subscriptlist,
1552         syms.sliceop,
1553     }:
1554         return NO
1555
1556     prev = leaf.prev_sibling
1557     if not prev:
1558         prevp = preceding_leaf(p)
1559         if not prevp or prevp.type in OPENING_BRACKETS:
1560             return NO
1561
1562         if t == token.COLON:
1563             if prevp.type == token.COLON:
1564                 return NO
1565
1566             elif prevp.type != token.COMMA and not complex_subscript:
1567                 return NO
1568
1569             return SPACE
1570
1571         if prevp.type == token.EQUAL:
1572             if prevp.parent:
1573                 if prevp.parent.type in {
1574                     syms.arglist,
1575                     syms.argument,
1576                     syms.parameters,
1577                     syms.varargslist,
1578                 }:
1579                     return NO
1580
1581                 elif prevp.parent.type == syms.typedargslist:
1582                     # A bit hacky: if the equal sign has whitespace, it means we
1583                     # previously found it's a typed argument.  So, we're using
1584                     # that, too.
1585                     return prevp.prefix
1586
1587         elif prevp.type in STARS:
1588             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1589                 return NO
1590
1591         elif prevp.type == token.COLON:
1592             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1593                 return SPACE if complex_subscript else NO
1594
1595         elif (
1596             prevp.parent
1597             and prevp.parent.type == syms.factor
1598             and prevp.type in MATH_OPERATORS
1599         ):
1600             return NO
1601
1602         elif (
1603             prevp.type == token.RIGHTSHIFT
1604             and prevp.parent
1605             and prevp.parent.type == syms.shift_expr
1606             and prevp.prev_sibling
1607             and prevp.prev_sibling.type == token.NAME
1608             and prevp.prev_sibling.value == "print"  # type: ignore
1609         ):
1610             # Python 2 print chevron
1611             return NO
1612
1613     elif prev.type in OPENING_BRACKETS:
1614         return NO
1615
1616     if p.type in {syms.parameters, syms.arglist}:
1617         # untyped function signatures or calls
1618         if not prev or prev.type != token.COMMA:
1619             return NO
1620
1621     elif p.type == syms.varargslist:
1622         # lambdas
1623         if prev and prev.type != token.COMMA:
1624             return NO
1625
1626     elif p.type == syms.typedargslist:
1627         # typed function signatures
1628         if not prev:
1629             return NO
1630
1631         if t == token.EQUAL:
1632             if prev.type != syms.tname:
1633                 return NO
1634
1635         elif prev.type == token.EQUAL:
1636             # A bit hacky: if the equal sign has whitespace, it means we
1637             # previously found it's a typed argument.  So, we're using that, too.
1638             return prev.prefix
1639
1640         elif prev.type != token.COMMA:
1641             return NO
1642
1643     elif p.type == syms.tname:
1644         # type names
1645         if not prev:
1646             prevp = preceding_leaf(p)
1647             if not prevp or prevp.type != token.COMMA:
1648                 return NO
1649
1650     elif p.type == syms.trailer:
1651         # attributes and calls
1652         if t == token.LPAR or t == token.RPAR:
1653             return NO
1654
1655         if not prev:
1656             if t == token.DOT:
1657                 prevp = preceding_leaf(p)
1658                 if not prevp or prevp.type != token.NUMBER:
1659                     return NO
1660
1661             elif t == token.LSQB:
1662                 return NO
1663
1664         elif prev.type != token.COMMA:
1665             return NO
1666
1667     elif p.type == syms.argument:
1668         # single argument
1669         if t == token.EQUAL:
1670             return NO
1671
1672         if not prev:
1673             prevp = preceding_leaf(p)
1674             if not prevp or prevp.type == token.LPAR:
1675                 return NO
1676
1677         elif prev.type in {token.EQUAL} | STARS:
1678             return NO
1679
1680     elif p.type == syms.decorator:
1681         # decorators
1682         return NO
1683
1684     elif p.type == syms.dotted_name:
1685         if prev:
1686             return NO
1687
1688         prevp = preceding_leaf(p)
1689         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1690             return NO
1691
1692     elif p.type == syms.classdef:
1693         if t == token.LPAR:
1694             return NO
1695
1696         if prev and prev.type == token.LPAR:
1697             return NO
1698
1699     elif p.type in {syms.subscript, syms.sliceop}:
1700         # indexing
1701         if not prev:
1702             assert p.parent is not None, "subscripts are always parented"
1703             if p.parent.type == syms.subscriptlist:
1704                 return SPACE
1705
1706             return NO
1707
1708         elif not complex_subscript:
1709             return NO
1710
1711     elif p.type == syms.atom:
1712         if prev and t == token.DOT:
1713             # dots, but not the first one.
1714             return NO
1715
1716     elif p.type == syms.dictsetmaker:
1717         # dict unpacking
1718         if prev and prev.type == token.DOUBLESTAR:
1719             return NO
1720
1721     elif p.type in {syms.factor, syms.star_expr}:
1722         # unary ops
1723         if not prev:
1724             prevp = preceding_leaf(p)
1725             if not prevp or prevp.type in OPENING_BRACKETS:
1726                 return NO
1727
1728             prevp_parent = prevp.parent
1729             assert prevp_parent is not None
1730             if prevp.type == token.COLON and prevp_parent.type in {
1731                 syms.subscript,
1732                 syms.sliceop,
1733             }:
1734                 return NO
1735
1736             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1737                 return NO
1738
1739         elif t == token.NAME or t == token.NUMBER:
1740             return NO
1741
1742     elif p.type == syms.import_from:
1743         if t == token.DOT:
1744             if prev and prev.type == token.DOT:
1745                 return NO
1746
1747         elif t == token.NAME:
1748             if v == "import":
1749                 return SPACE
1750
1751             if prev and prev.type == token.DOT:
1752                 return NO
1753
1754     elif p.type == syms.sliceop:
1755         return NO
1756
1757     return SPACE
1758
1759
1760 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1761     """Return the first leaf that precedes `node`, if any."""
1762     while node:
1763         res = node.prev_sibling
1764         if res:
1765             if isinstance(res, Leaf):
1766                 return res
1767
1768             try:
1769                 return list(res.leaves())[-1]
1770
1771             except IndexError:
1772                 return None
1773
1774         node = node.parent
1775     return None
1776
1777
1778 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1779     """Return the child of `ancestor` that contains `descendant`."""
1780     node: Optional[LN] = descendant
1781     while node and node.parent != ancestor:
1782         node = node.parent
1783     return node
1784
1785
1786 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1787     """Return the priority of the `leaf` delimiter, given a line break after it.
1788
1789     The delimiter priorities returned here are from those delimiters that would
1790     cause a line break after themselves.
1791
1792     Higher numbers are higher priority.
1793     """
1794     if leaf.type == token.COMMA:
1795         return COMMA_PRIORITY
1796
1797     return 0
1798
1799
1800 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1801     """Return the priority of the `leaf` delimiter, given a line before after it.
1802
1803     The delimiter priorities returned here are from those delimiters that would
1804     cause a line break before themselves.
1805
1806     Higher numbers are higher priority.
1807     """
1808     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1809         # * and ** might also be MATH_OPERATORS but in this case they are not.
1810         # Don't treat them as a delimiter.
1811         return 0
1812
1813     if (
1814         leaf.type == token.DOT
1815         and leaf.parent
1816         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1817         and (previous is None or previous.type in CLOSING_BRACKETS)
1818     ):
1819         return DOT_PRIORITY
1820
1821     if (
1822         leaf.type in MATH_OPERATORS
1823         and leaf.parent
1824         and leaf.parent.type not in {syms.factor, syms.star_expr}
1825     ):
1826         return MATH_PRIORITIES[leaf.type]
1827
1828     if leaf.type in COMPARATORS:
1829         return COMPARATOR_PRIORITY
1830
1831     if (
1832         leaf.type == token.STRING
1833         and previous is not None
1834         and previous.type == token.STRING
1835     ):
1836         return STRING_PRIORITY
1837
1838     if leaf.type != token.NAME:
1839         return 0
1840
1841     if (
1842         leaf.value == "for"
1843         and leaf.parent
1844         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1845     ):
1846         return COMPREHENSION_PRIORITY
1847
1848     if (
1849         leaf.value == "if"
1850         and leaf.parent
1851         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1852     ):
1853         return COMPREHENSION_PRIORITY
1854
1855     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1856         return TERNARY_PRIORITY
1857
1858     if leaf.value == "is":
1859         return COMPARATOR_PRIORITY
1860
1861     if (
1862         leaf.value == "in"
1863         and leaf.parent
1864         and leaf.parent.type in {syms.comp_op, syms.comparison}
1865         and not (
1866             previous is not None
1867             and previous.type == token.NAME
1868             and previous.value == "not"
1869         )
1870     ):
1871         return COMPARATOR_PRIORITY
1872
1873     if (
1874         leaf.value == "not"
1875         and leaf.parent
1876         and leaf.parent.type == syms.comp_op
1877         and not (
1878             previous is not None
1879             and previous.type == token.NAME
1880             and previous.value == "is"
1881         )
1882     ):
1883         return COMPARATOR_PRIORITY
1884
1885     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1886         return LOGIC_PRIORITY
1887
1888     return 0
1889
1890
1891 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1892     """Clean the prefix of the `leaf` and generate comments from it, if any.
1893
1894     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1895     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1896     move because it does away with modifying the grammar to include all the
1897     possible places in which comments can be placed.
1898
1899     The sad consequence for us though is that comments don't "belong" anywhere.
1900     This is why this function generates simple parentless Leaf objects for
1901     comments.  We simply don't know what the correct parent should be.
1902
1903     No matter though, we can live without this.  We really only need to
1904     differentiate between inline and standalone comments.  The latter don't
1905     share the line with any code.
1906
1907     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1908     are emitted with a fake STANDALONE_COMMENT token identifier.
1909     """
1910     p = leaf.prefix
1911     if not p:
1912         return
1913
1914     if "#" not in p:
1915         return
1916
1917     consumed = 0
1918     nlines = 0
1919     for index, line in enumerate(p.split("\n")):
1920         consumed += len(line) + 1  # adding the length of the split '\n'
1921         line = line.lstrip()
1922         if not line:
1923             nlines += 1
1924         if not line.startswith("#"):
1925             continue
1926
1927         if index == 0 and leaf.type != token.ENDMARKER:
1928             comment_type = token.COMMENT  # simple trailing comment
1929         else:
1930             comment_type = STANDALONE_COMMENT
1931         comment = make_comment(line)
1932         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1933
1934         if comment in {"# fmt: on", "# yapf: enable"}:
1935             raise FormatOn(consumed)
1936
1937         if comment in {"# fmt: off", "# yapf: disable"}:
1938             if comment_type == STANDALONE_COMMENT:
1939                 raise FormatOff(consumed)
1940
1941             prev = preceding_leaf(leaf)
1942             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1943                 raise FormatOff(consumed)
1944
1945         nlines = 0
1946
1947
1948 def make_comment(content: str) -> str:
1949     """Return a consistently formatted comment from the given `content` string.
1950
1951     All comments (except for "##", "#!", "#:") should have a single space between
1952     the hash sign and the content.
1953
1954     If `content` didn't start with a hash sign, one is provided.
1955     """
1956     content = content.rstrip()
1957     if not content:
1958         return "#"
1959
1960     if content[0] == "#":
1961         content = content[1:]
1962     if content and content[0] not in " !:#":
1963         content = " " + content
1964     return "#" + content
1965
1966
1967 def split_line(
1968     line: Line, line_length: int, inner: bool = False, py36: bool = False
1969 ) -> Iterator[Line]:
1970     """Split a `line` into potentially many lines.
1971
1972     They should fit in the allotted `line_length` but might not be able to.
1973     `inner` signifies that there were a pair of brackets somewhere around the
1974     current `line`, possibly transitively. This means we can fallback to splitting
1975     by delimiters if the LHS/RHS don't yield any results.
1976
1977     If `py36` is True, splitting may generate syntax that is only compatible
1978     with Python 3.6 and later.
1979     """
1980     if isinstance(line, UnformattedLines) or line.is_comment:
1981         yield line
1982         return
1983
1984     line_str = str(line).strip("\n")
1985     if not line.should_explode and is_line_short_enough(
1986         line, line_length=line_length, line_str=line_str
1987     ):
1988         yield line
1989         return
1990
1991     split_funcs: List[SplitFunc]
1992     if line.is_def:
1993         split_funcs = [left_hand_split]
1994     else:
1995
1996         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1997             for omit in generate_trailers_to_omit(line, line_length):
1998                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
1999                 if is_line_short_enough(lines[0], line_length=line_length):
2000                     yield from lines
2001                     return
2002
2003             # All splits failed, best effort split with no omits.
2004             # This mostly happens to multiline strings that are by definition
2005             # reported as not fitting a single line.
2006             yield from right_hand_split(line, py36)
2007
2008         if line.inside_brackets:
2009             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2010         else:
2011             split_funcs = [rhs]
2012     for split_func in split_funcs:
2013         # We are accumulating lines in `result` because we might want to abort
2014         # mission and return the original line in the end, or attempt a different
2015         # split altogether.
2016         result: List[Line] = []
2017         try:
2018             for l in split_func(line, py36):
2019                 if str(l).strip("\n") == line_str:
2020                     raise CannotSplit("Split function returned an unchanged result")
2021
2022                 result.extend(
2023                     split_line(l, line_length=line_length, inner=True, py36=py36)
2024                 )
2025         except CannotSplit as cs:
2026             continue
2027
2028         else:
2029             yield from result
2030             break
2031
2032     else:
2033         yield line
2034
2035
2036 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2037     """Split line into many lines, starting with the first matching bracket pair.
2038
2039     Note: this usually looks weird, only use this for function definitions.
2040     Prefer RHS otherwise.  This is why this function is not symmetrical with
2041     :func:`right_hand_split` which also handles optional parentheses.
2042     """
2043     head = Line(depth=line.depth)
2044     body = Line(depth=line.depth + 1, inside_brackets=True)
2045     tail = Line(depth=line.depth)
2046     tail_leaves: List[Leaf] = []
2047     body_leaves: List[Leaf] = []
2048     head_leaves: List[Leaf] = []
2049     current_leaves = head_leaves
2050     matching_bracket = None
2051     for leaf in line.leaves:
2052         if (
2053             current_leaves is body_leaves
2054             and leaf.type in CLOSING_BRACKETS
2055             and leaf.opening_bracket is matching_bracket
2056         ):
2057             current_leaves = tail_leaves if body_leaves else head_leaves
2058         current_leaves.append(leaf)
2059         if current_leaves is head_leaves:
2060             if leaf.type in OPENING_BRACKETS:
2061                 matching_bracket = leaf
2062                 current_leaves = body_leaves
2063     # Since body is a new indent level, remove spurious leading whitespace.
2064     if body_leaves:
2065         normalize_prefix(body_leaves[0], inside_brackets=True)
2066     # Build the new lines.
2067     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2068         for leaf in leaves:
2069             result.append(leaf, preformatted=True)
2070             for comment_after in line.comments_after(leaf):
2071                 result.append(comment_after, preformatted=True)
2072     bracket_split_succeeded_or_raise(head, body, tail)
2073     for result in (head, body, tail):
2074         if result:
2075             yield result
2076
2077
2078 def right_hand_split(
2079     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2080 ) -> Iterator[Line]:
2081     """Split line into many lines, starting with the last matching bracket pair.
2082
2083     If the split was by optional parentheses, attempt splitting without them, too.
2084     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2085     this split.
2086
2087     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2088     """
2089     head = Line(depth=line.depth)
2090     body = Line(depth=line.depth + 1, inside_brackets=True)
2091     tail = Line(depth=line.depth)
2092     tail_leaves: List[Leaf] = []
2093     body_leaves: List[Leaf] = []
2094     head_leaves: List[Leaf] = []
2095     current_leaves = tail_leaves
2096     opening_bracket = None
2097     closing_bracket = None
2098     for leaf in reversed(line.leaves):
2099         if current_leaves is body_leaves:
2100             if leaf is opening_bracket:
2101                 current_leaves = head_leaves if body_leaves else tail_leaves
2102         current_leaves.append(leaf)
2103         if current_leaves is tail_leaves:
2104             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2105                 opening_bracket = leaf.opening_bracket
2106                 closing_bracket = leaf
2107                 current_leaves = body_leaves
2108     tail_leaves.reverse()
2109     body_leaves.reverse()
2110     head_leaves.reverse()
2111     # Since body is a new indent level, remove spurious leading whitespace.
2112     if body_leaves:
2113         normalize_prefix(body_leaves[0], inside_brackets=True)
2114     if not head_leaves:
2115         # No `head` means the split failed. Either `tail` has all content or
2116         # the matching `opening_bracket` wasn't available on `line` anymore.
2117         raise CannotSplit("No brackets found")
2118
2119     # Build the new lines.
2120     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2121         for leaf in leaves:
2122             result.append(leaf, preformatted=True)
2123             for comment_after in line.comments_after(leaf):
2124                 result.append(comment_after, preformatted=True)
2125     bracket_split_succeeded_or_raise(head, body, tail)
2126     assert opening_bracket and closing_bracket
2127     if (
2128         # the opening bracket is an optional paren
2129         opening_bracket.type == token.LPAR
2130         and not opening_bracket.value
2131         # the closing bracket is an optional paren
2132         and closing_bracket.type == token.RPAR
2133         and not closing_bracket.value
2134         # there are no standalone comments in the body
2135         and not line.contains_standalone_comments(0)
2136         # and it's not an import (optional parens are the only thing we can split
2137         # on in this case; attempting a split without them is a waste of time)
2138         and not line.is_import
2139     ):
2140         omit = {id(closing_bracket), *omit}
2141         if can_omit_invisible_parens(body, line_length):
2142             try:
2143                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2144                 return
2145             except CannotSplit:
2146                 pass
2147
2148     ensure_visible(opening_bracket)
2149     ensure_visible(closing_bracket)
2150     body.should_explode = should_explode(body, opening_bracket)
2151     for result in (head, body, tail):
2152         if result:
2153             yield result
2154
2155
2156 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2157     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2158
2159     Do nothing otherwise.
2160
2161     A left- or right-hand split is based on a pair of brackets. Content before
2162     (and including) the opening bracket is left on one line, content inside the
2163     brackets is put on a separate line, and finally content starting with and
2164     following the closing bracket is put on a separate line.
2165
2166     Those are called `head`, `body`, and `tail`, respectively. If the split
2167     produced the same line (all content in `head`) or ended up with an empty `body`
2168     and the `tail` is just the closing bracket, then it's considered failed.
2169     """
2170     tail_len = len(str(tail).strip())
2171     if not body:
2172         if tail_len == 0:
2173             raise CannotSplit("Splitting brackets produced the same line")
2174
2175         elif tail_len < 3:
2176             raise CannotSplit(
2177                 f"Splitting brackets on an empty body to save "
2178                 f"{tail_len} characters is not worth it"
2179             )
2180
2181
2182 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2183     """Normalize prefix of the first leaf in every line returned by `split_func`.
2184
2185     This is a decorator over relevant split functions.
2186     """
2187
2188     @wraps(split_func)
2189     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2190         for l in split_func(line, py36):
2191             normalize_prefix(l.leaves[0], inside_brackets=True)
2192             yield l
2193
2194     return split_wrapper
2195
2196
2197 @dont_increase_indentation
2198 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2199     """Split according to delimiters of the highest priority.
2200
2201     If `py36` is True, the split will add trailing commas also in function
2202     signatures that contain `*` and `**`.
2203     """
2204     try:
2205         last_leaf = line.leaves[-1]
2206     except IndexError:
2207         raise CannotSplit("Line empty")
2208
2209     bt = line.bracket_tracker
2210     try:
2211         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2212     except ValueError:
2213         raise CannotSplit("No delimiters found")
2214
2215     if delimiter_priority == DOT_PRIORITY:
2216         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2217             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2218
2219     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2220     lowest_depth = sys.maxsize
2221     trailing_comma_safe = True
2222
2223     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2224         """Append `leaf` to current line or to new line if appending impossible."""
2225         nonlocal current_line
2226         try:
2227             current_line.append_safe(leaf, preformatted=True)
2228         except ValueError as ve:
2229             yield current_line
2230
2231             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2232             current_line.append(leaf)
2233
2234     for index, leaf in enumerate(line.leaves):
2235         yield from append_to_line(leaf)
2236
2237         for comment_after in line.comments_after(leaf, index):
2238             yield from append_to_line(comment_after)
2239
2240         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2241         if leaf.bracket_depth == lowest_depth and is_vararg(
2242             leaf, within=VARARGS_PARENTS
2243         ):
2244             trailing_comma_safe = trailing_comma_safe and py36
2245         leaf_priority = bt.delimiters.get(id(leaf))
2246         if leaf_priority == delimiter_priority:
2247             yield current_line
2248
2249             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2250     if current_line:
2251         if (
2252             trailing_comma_safe
2253             and delimiter_priority == COMMA_PRIORITY
2254             and current_line.leaves[-1].type != token.COMMA
2255             and current_line.leaves[-1].type != STANDALONE_COMMENT
2256         ):
2257             current_line.append(Leaf(token.COMMA, ","))
2258         yield current_line
2259
2260
2261 @dont_increase_indentation
2262 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2263     """Split standalone comments from the rest of the line."""
2264     if not line.contains_standalone_comments(0):
2265         raise CannotSplit("Line does not have any standalone comments")
2266
2267     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2268
2269     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2270         """Append `leaf` to current line or to new line if appending impossible."""
2271         nonlocal current_line
2272         try:
2273             current_line.append_safe(leaf, preformatted=True)
2274         except ValueError as ve:
2275             yield current_line
2276
2277             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2278             current_line.append(leaf)
2279
2280     for index, leaf in enumerate(line.leaves):
2281         yield from append_to_line(leaf)
2282
2283         for comment_after in line.comments_after(leaf, index):
2284             yield from append_to_line(comment_after)
2285
2286     if current_line:
2287         yield current_line
2288
2289
2290 def is_import(leaf: Leaf) -> bool:
2291     """Return True if the given leaf starts an import statement."""
2292     p = leaf.parent
2293     t = leaf.type
2294     v = leaf.value
2295     return bool(
2296         t == token.NAME
2297         and (
2298             (v == "import" and p and p.type == syms.import_name)
2299             or (v == "from" and p and p.type == syms.import_from)
2300         )
2301     )
2302
2303
2304 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2305     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2306     else.
2307
2308     Note: don't use backslashes for formatting or you'll lose your voting rights.
2309     """
2310     if not inside_brackets:
2311         spl = leaf.prefix.split("#")
2312         if "\\" not in spl[0]:
2313             nl_count = spl[-1].count("\n")
2314             if len(spl) > 1:
2315                 nl_count -= 1
2316             leaf.prefix = "\n" * nl_count
2317             return
2318
2319     leaf.prefix = ""
2320
2321
2322 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2323     """Make all string prefixes lowercase.
2324
2325     If remove_u_prefix is given, also removes any u prefix from the string.
2326
2327     Note: Mutates its argument.
2328     """
2329     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2330     assert match is not None, f"failed to match string {leaf.value!r}"
2331     orig_prefix = match.group(1)
2332     new_prefix = orig_prefix.lower()
2333     if remove_u_prefix:
2334         new_prefix = new_prefix.replace("u", "")
2335     leaf.value = f"{new_prefix}{match.group(2)}"
2336
2337
2338 def normalize_string_quotes(leaf: Leaf) -> None:
2339     """Prefer double quotes but only if it doesn't cause more escaping.
2340
2341     Adds or removes backslashes as appropriate. Doesn't parse and fix
2342     strings nested in f-strings (yet).
2343
2344     Note: Mutates its argument.
2345     """
2346     value = leaf.value.lstrip("furbFURB")
2347     if value[:3] == '"""':
2348         return
2349
2350     elif value[:3] == "'''":
2351         orig_quote = "'''"
2352         new_quote = '"""'
2353     elif value[0] == '"':
2354         orig_quote = '"'
2355         new_quote = "'"
2356     else:
2357         orig_quote = "'"
2358         new_quote = '"'
2359     first_quote_pos = leaf.value.find(orig_quote)
2360     if first_quote_pos == -1:
2361         return  # There's an internal error
2362
2363     prefix = leaf.value[:first_quote_pos]
2364     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2365     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2366     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2367     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2368     if "r" in prefix.casefold():
2369         if unescaped_new_quote.search(body):
2370             # There's at least one unescaped new_quote in this raw string
2371             # so converting is impossible
2372             return
2373
2374         # Do not introduce or remove backslashes in raw strings
2375         new_body = body
2376     else:
2377         # remove unnecessary quotes
2378         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2379         if body != new_body:
2380             # Consider the string without unnecessary quotes as the original
2381             body = new_body
2382             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2383         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2384         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2385     if new_quote == '"""' and new_body[-1] == '"':
2386         # edge case:
2387         new_body = new_body[:-1] + '\\"'
2388     orig_escape_count = body.count("\\")
2389     new_escape_count = new_body.count("\\")
2390     if new_escape_count > orig_escape_count:
2391         return  # Do not introduce more escaping
2392
2393     if new_escape_count == orig_escape_count and orig_quote == '"':
2394         return  # Prefer double quotes
2395
2396     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2397
2398
2399 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2400     """Make existing optional parentheses invisible or create new ones.
2401
2402     `parens_after` is a set of string leaf values immeditely after which parens
2403     should be put.
2404
2405     Standardizes on visible parentheses for single-element tuples, and keeps
2406     existing visible parentheses for other tuples and generator expressions.
2407     """
2408     try:
2409         list(generate_comments(node))
2410     except FormatOff:
2411         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2412
2413     check_lpar = False
2414     for index, child in enumerate(list(node.children)):
2415         if check_lpar:
2416             if child.type == syms.atom:
2417                 maybe_make_parens_invisible_in_atom(child)
2418             elif is_one_tuple(child):
2419                 # wrap child in visible parentheses
2420                 lpar = Leaf(token.LPAR, "(")
2421                 rpar = Leaf(token.RPAR, ")")
2422                 child.remove()
2423                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2424             elif node.type == syms.import_from:
2425                 # "import from" nodes store parentheses directly as part of
2426                 # the statement
2427                 if child.type == token.LPAR:
2428                     # make parentheses invisible
2429                     child.value = ""  # type: ignore
2430                     node.children[-1].value = ""  # type: ignore
2431                 elif child.type != token.STAR:
2432                     # insert invisible parentheses
2433                     node.insert_child(index, Leaf(token.LPAR, ""))
2434                     node.append_child(Leaf(token.RPAR, ""))
2435                 break
2436
2437             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2438                 # wrap child in invisible parentheses
2439                 lpar = Leaf(token.LPAR, "")
2440                 rpar = Leaf(token.RPAR, "")
2441                 index = child.remove() or 0
2442                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2443
2444         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2445
2446
2447 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2448     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2449     if (
2450         node.type != syms.atom
2451         or is_empty_tuple(node)
2452         or is_one_tuple(node)
2453         or is_yield(node)
2454         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2455     ):
2456         return False
2457
2458     first = node.children[0]
2459     last = node.children[-1]
2460     if first.type == token.LPAR and last.type == token.RPAR:
2461         # make parentheses invisible
2462         first.value = ""  # type: ignore
2463         last.value = ""  # type: ignore
2464         if len(node.children) > 1:
2465             maybe_make_parens_invisible_in_atom(node.children[1])
2466         return True
2467
2468     return False
2469
2470
2471 def is_empty_tuple(node: LN) -> bool:
2472     """Return True if `node` holds an empty tuple."""
2473     return (
2474         node.type == syms.atom
2475         and len(node.children) == 2
2476         and node.children[0].type == token.LPAR
2477         and node.children[1].type == token.RPAR
2478     )
2479
2480
2481 def is_one_tuple(node: LN) -> bool:
2482     """Return True if `node` holds a tuple with one element, with or without parens."""
2483     if node.type == syms.atom:
2484         if len(node.children) != 3:
2485             return False
2486
2487         lpar, gexp, rpar = node.children
2488         if not (
2489             lpar.type == token.LPAR
2490             and gexp.type == syms.testlist_gexp
2491             and rpar.type == token.RPAR
2492         ):
2493             return False
2494
2495         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2496
2497     return (
2498         node.type in IMPLICIT_TUPLE
2499         and len(node.children) == 2
2500         and node.children[1].type == token.COMMA
2501     )
2502
2503
2504 def is_yield(node: LN) -> bool:
2505     """Return True if `node` holds a `yield` or `yield from` expression."""
2506     if node.type == syms.yield_expr:
2507         return True
2508
2509     if node.type == token.NAME and node.value == "yield":  # type: ignore
2510         return True
2511
2512     if node.type != syms.atom:
2513         return False
2514
2515     if len(node.children) != 3:
2516         return False
2517
2518     lpar, expr, rpar = node.children
2519     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2520         return is_yield(expr)
2521
2522     return False
2523
2524
2525 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2526     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2527
2528     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2529     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2530     extended iterable unpacking (PEP 3132) and additional unpacking
2531     generalizations (PEP 448).
2532     """
2533     if leaf.type not in STARS or not leaf.parent:
2534         return False
2535
2536     p = leaf.parent
2537     if p.type == syms.star_expr:
2538         # Star expressions are also used as assignment targets in extended
2539         # iterable unpacking (PEP 3132).  See what its parent is instead.
2540         if not p.parent:
2541             return False
2542
2543         p = p.parent
2544
2545     return p.type in within
2546
2547
2548 def is_multiline_string(leaf: Leaf) -> bool:
2549     """Return True if `leaf` is a multiline string that actually spans many lines."""
2550     value = leaf.value.lstrip("furbFURB")
2551     return value[:3] in {'"""', "'''"} and "\n" in value
2552
2553
2554 def is_stub_suite(node: Node) -> bool:
2555     """Return True if `node` is a suite with a stub body."""
2556     if (
2557         len(node.children) != 4
2558         or node.children[0].type != token.NEWLINE
2559         or node.children[1].type != token.INDENT
2560         or node.children[3].type != token.DEDENT
2561     ):
2562         return False
2563
2564     return is_stub_body(node.children[2])
2565
2566
2567 def is_stub_body(node: LN) -> bool:
2568     """Return True if `node` is a simple statement containing an ellipsis."""
2569     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2570         return False
2571
2572     if len(node.children) != 2:
2573         return False
2574
2575     child = node.children[0]
2576     return (
2577         child.type == syms.atom
2578         and len(child.children) == 3
2579         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2580     )
2581
2582
2583 def max_delimiter_priority_in_atom(node: LN) -> int:
2584     """Return maximum delimiter priority inside `node`.
2585
2586     This is specific to atoms with contents contained in a pair of parentheses.
2587     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2588     """
2589     if node.type != syms.atom:
2590         return 0
2591
2592     first = node.children[0]
2593     last = node.children[-1]
2594     if not (first.type == token.LPAR and last.type == token.RPAR):
2595         return 0
2596
2597     bt = BracketTracker()
2598     for c in node.children[1:-1]:
2599         if isinstance(c, Leaf):
2600             bt.mark(c)
2601         else:
2602             for leaf in c.leaves():
2603                 bt.mark(leaf)
2604     try:
2605         return bt.max_delimiter_priority()
2606
2607     except ValueError:
2608         return 0
2609
2610
2611 def ensure_visible(leaf: Leaf) -> None:
2612     """Make sure parentheses are visible.
2613
2614     They could be invisible as part of some statements (see
2615     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2616     """
2617     if leaf.type == token.LPAR:
2618         leaf.value = "("
2619     elif leaf.type == token.RPAR:
2620         leaf.value = ")"
2621
2622
2623 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2624     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2625     if not (
2626         opening_bracket.parent
2627         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2628         and opening_bracket.value in "[{("
2629     ):
2630         return False
2631
2632     try:
2633         last_leaf = line.leaves[-1]
2634         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2635         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2636     except (IndexError, ValueError):
2637         return False
2638
2639     return max_priority == COMMA_PRIORITY
2640
2641
2642 def is_python36(node: Node) -> bool:
2643     """Return True if the current file is using Python 3.6+ features.
2644
2645     Currently looking for:
2646     - f-strings; and
2647     - trailing commas after * or ** in function signatures and calls.
2648     """
2649     for n in node.pre_order():
2650         if n.type == token.STRING:
2651             value_head = n.value[:2]  # type: ignore
2652             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2653                 return True
2654
2655         elif (
2656             n.type in {syms.typedargslist, syms.arglist}
2657             and n.children
2658             and n.children[-1].type == token.COMMA
2659         ):
2660             for ch in n.children:
2661                 if ch.type in STARS:
2662                     return True
2663
2664                 if ch.type == syms.argument:
2665                     for argch in ch.children:
2666                         if argch.type in STARS:
2667                             return True
2668
2669     return False
2670
2671
2672 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2673     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2674
2675     Brackets can be omitted if the entire trailer up to and including
2676     a preceding closing bracket fits in one line.
2677
2678     Yielded sets are cumulative (contain results of previous yields, too).  First
2679     set is empty.
2680     """
2681
2682     omit: Set[LeafID] = set()
2683     yield omit
2684
2685     length = 4 * line.depth
2686     opening_bracket = None
2687     closing_bracket = None
2688     optional_brackets: Set[LeafID] = set()
2689     inner_brackets: Set[LeafID] = set()
2690     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2691         length += leaf_length
2692         if length > line_length:
2693             break
2694
2695         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2696         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2697             break
2698
2699         optional_brackets.discard(id(leaf))
2700         if opening_bracket:
2701             if leaf is opening_bracket:
2702                 opening_bracket = None
2703             elif leaf.type in CLOSING_BRACKETS:
2704                 inner_brackets.add(id(leaf))
2705         elif leaf.type in CLOSING_BRACKETS:
2706             if not leaf.value:
2707                 optional_brackets.add(id(opening_bracket))
2708                 continue
2709
2710             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2711                 # Empty brackets would fail a split so treat them as "inner"
2712                 # brackets (e.g. only add them to the `omit` set if another
2713                 # pair of brackets was good enough.
2714                 inner_brackets.add(id(leaf))
2715                 continue
2716
2717             opening_bracket = leaf.opening_bracket
2718             if closing_bracket:
2719                 omit.add(id(closing_bracket))
2720                 omit.update(inner_brackets)
2721                 inner_brackets.clear()
2722                 yield omit
2723             closing_bracket = leaf
2724
2725
2726 def get_future_imports(node: Node) -> Set[str]:
2727     """Return a set of __future__ imports in the file."""
2728     imports = set()
2729     for child in node.children:
2730         if child.type != syms.simple_stmt:
2731             break
2732         first_child = child.children[0]
2733         if isinstance(first_child, Leaf):
2734             # Continue looking if we see a docstring; otherwise stop.
2735             if (
2736                 len(child.children) == 2
2737                 and first_child.type == token.STRING
2738                 and child.children[1].type == token.NEWLINE
2739             ):
2740                 continue
2741             else:
2742                 break
2743         elif first_child.type == syms.import_from:
2744             module_name = first_child.children[1]
2745             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2746                 break
2747             for import_from_child in first_child.children[3:]:
2748                 if isinstance(import_from_child, Leaf):
2749                     if import_from_child.type == token.NAME:
2750                         imports.add(import_from_child.value)
2751                 else:
2752                     assert import_from_child.type == syms.import_as_names
2753                     for leaf in import_from_child.children:
2754                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2755                             imports.add(leaf.value)
2756         else:
2757             break
2758     return imports
2759
2760
2761 PYTHON_EXTENSIONS = {".py", ".pyi"}
2762 BLACKLISTED_DIRECTORIES = {
2763     "build",
2764     "buck-out",
2765     "dist",
2766     "_build",
2767     ".git",
2768     ".hg",
2769     ".mypy_cache",
2770     ".tox",
2771     ".venv",
2772 }
2773
2774
2775 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2776     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2777     and have one of the PYTHON_EXTENSIONS.
2778     """
2779     for child in path.iterdir():
2780         if child.is_dir():
2781             if child.name in BLACKLISTED_DIRECTORIES:
2782                 continue
2783
2784             yield from gen_python_files_in_dir(child)
2785
2786         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2787             yield child
2788
2789
2790 @dataclass
2791 class Report:
2792     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2793     check: bool = False
2794     quiet: bool = False
2795     change_count: int = 0
2796     same_count: int = 0
2797     failure_count: int = 0
2798
2799     def done(self, src: Path, changed: Changed) -> None:
2800         """Increment the counter for successful reformatting. Write out a message."""
2801         if changed is Changed.YES:
2802             reformatted = "would reformat" if self.check else "reformatted"
2803             if not self.quiet:
2804                 out(f"{reformatted} {src}")
2805             self.change_count += 1
2806         else:
2807             if not self.quiet:
2808                 if changed is Changed.NO:
2809                     msg = f"{src} already well formatted, good job."
2810                 else:
2811                     msg = f"{src} wasn't modified on disk since last run."
2812                 out(msg, bold=False)
2813             self.same_count += 1
2814
2815     def failed(self, src: Path, message: str) -> None:
2816         """Increment the counter for failed reformatting. Write out a message."""
2817         err(f"error: cannot format {src}: {message}")
2818         self.failure_count += 1
2819
2820     @property
2821     def return_code(self) -> int:
2822         """Return the exit code that the app should use.
2823
2824         This considers the current state of changed files and failures:
2825         - if there were any failures, return 123;
2826         - if any files were changed and --check is being used, return 1;
2827         - otherwise return 0.
2828         """
2829         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2830         # 126 we have special returncodes reserved by the shell.
2831         if self.failure_count:
2832             return 123
2833
2834         elif self.change_count and self.check:
2835             return 1
2836
2837         return 0
2838
2839     def __str__(self) -> str:
2840         """Render a color report of the current state.
2841
2842         Use `click.unstyle` to remove colors.
2843         """
2844         if self.check:
2845             reformatted = "would be reformatted"
2846             unchanged = "would be left unchanged"
2847             failed = "would fail to reformat"
2848         else:
2849             reformatted = "reformatted"
2850             unchanged = "left unchanged"
2851             failed = "failed to reformat"
2852         report = []
2853         if self.change_count:
2854             s = "s" if self.change_count > 1 else ""
2855             report.append(
2856                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2857             )
2858         if self.same_count:
2859             s = "s" if self.same_count > 1 else ""
2860             report.append(f"{self.same_count} file{s} {unchanged}")
2861         if self.failure_count:
2862             s = "s" if self.failure_count > 1 else ""
2863             report.append(
2864                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2865             )
2866         return ", ".join(report) + "."
2867
2868
2869 def assert_equivalent(src: str, dst: str) -> None:
2870     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2871
2872     import ast
2873     import traceback
2874
2875     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2876         """Simple visitor generating strings to compare ASTs by content."""
2877         yield f"{'  ' * depth}{node.__class__.__name__}("
2878
2879         for field in sorted(node._fields):
2880             try:
2881                 value = getattr(node, field)
2882             except AttributeError:
2883                 continue
2884
2885             yield f"{'  ' * (depth+1)}{field}="
2886
2887             if isinstance(value, list):
2888                 for item in value:
2889                     if isinstance(item, ast.AST):
2890                         yield from _v(item, depth + 2)
2891
2892             elif isinstance(value, ast.AST):
2893                 yield from _v(value, depth + 2)
2894
2895             else:
2896                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2897
2898         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2899
2900     try:
2901         src_ast = ast.parse(src)
2902     except Exception as exc:
2903         major, minor = sys.version_info[:2]
2904         raise AssertionError(
2905             f"cannot use --safe with this file; failed to parse source file "
2906             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2907             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2908         )
2909
2910     try:
2911         dst_ast = ast.parse(dst)
2912     except Exception as exc:
2913         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2914         raise AssertionError(
2915             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2916             f"Please report a bug on https://github.com/ambv/black/issues.  "
2917             f"This invalid output might be helpful: {log}"
2918         ) from None
2919
2920     src_ast_str = "\n".join(_v(src_ast))
2921     dst_ast_str = "\n".join(_v(dst_ast))
2922     if src_ast_str != dst_ast_str:
2923         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2924         raise AssertionError(
2925             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2926             f"the source.  "
2927             f"Please report a bug on https://github.com/ambv/black/issues.  "
2928             f"This diff might be helpful: {log}"
2929         ) from None
2930
2931
2932 def assert_stable(
2933     src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
2934 ) -> None:
2935     """Raise AssertionError if `dst` reformats differently the second time."""
2936     newdst = format_str(
2937         dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
2938     )
2939     if dst != newdst:
2940         log = dump_to_file(
2941             diff(src, dst, "source", "first pass"),
2942             diff(dst, newdst, "first pass", "second pass"),
2943         )
2944         raise AssertionError(
2945             f"INTERNAL ERROR: Black produced different code on the second pass "
2946             f"of the formatter.  "
2947             f"Please report a bug on https://github.com/ambv/black/issues.  "
2948             f"This diff might be helpful: {log}"
2949         ) from None
2950
2951
2952 def dump_to_file(*output: str) -> str:
2953     """Dump `output` to a temporary file. Return path to the file."""
2954     import tempfile
2955
2956     with tempfile.NamedTemporaryFile(
2957         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2958     ) as f:
2959         for lines in output:
2960             f.write(lines)
2961             if lines and lines[-1] != "\n":
2962                 f.write("\n")
2963     return f.name
2964
2965
2966 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2967     """Return a unified diff string between strings `a` and `b`."""
2968     import difflib
2969
2970     a_lines = [line + "\n" for line in a.split("\n")]
2971     b_lines = [line + "\n" for line in b.split("\n")]
2972     return "".join(
2973         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2974     )
2975
2976
2977 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2978     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2979     err("Aborted!")
2980     for task in tasks:
2981         task.cancel()
2982
2983
2984 def shutdown(loop: BaseEventLoop) -> None:
2985     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2986     try:
2987         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2988         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2989         if not to_cancel:
2990             return
2991
2992         for task in to_cancel:
2993             task.cancel()
2994         loop.run_until_complete(
2995             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2996         )
2997     finally:
2998         # `concurrent.futures.Future` objects cannot be cancelled once they
2999         # are already running. There might be some when the `shutdown()` happened.
3000         # Silence their logger's spew about the event loop being closed.
3001         cf_logger = logging.getLogger("concurrent.futures")
3002         cf_logger.setLevel(logging.CRITICAL)
3003         loop.close()
3004
3005
3006 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3007     """Replace `regex` with `replacement` twice on `original`.
3008
3009     This is used by string normalization to perform replaces on
3010     overlapping matches.
3011     """
3012     return regex.sub(replacement, regex.sub(replacement, original))
3013
3014
3015 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3016     """Like `reversed(enumerate(sequence))` if that were possible."""
3017     index = len(sequence) - 1
3018     for element in reversed(sequence):
3019         yield (index, element)
3020         index -= 1
3021
3022
3023 def enumerate_with_length(
3024     line: Line, reversed: bool = False
3025 ) -> Iterator[Tuple[Index, Leaf, int]]:
3026     """Return an enumeration of leaves with their length.
3027
3028     Stops prematurely on multiline strings and standalone comments.
3029     """
3030     op = cast(
3031         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3032         enumerate_reversed if reversed else enumerate,
3033     )
3034     for index, leaf in op(line.leaves):
3035         length = len(leaf.prefix) + len(leaf.value)
3036         if "\n" in leaf.value:
3037             return  # Multiline strings, we can't continue.
3038
3039         comment: Optional[Leaf]
3040         for comment in line.comments_after(leaf, index):
3041             length += len(comment.value)
3042
3043         yield index, leaf, length
3044
3045
3046 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3047     """Return True if `line` is no longer than `line_length`.
3048
3049     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3050     """
3051     if not line_str:
3052         line_str = str(line).strip("\n")
3053     return (
3054         len(line_str) <= line_length
3055         and "\n" not in line_str  # multiline strings
3056         and not line.contains_standalone_comments()
3057     )
3058
3059
3060 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3061     """Does `line` have a shape safe to reformat without optional parens around it?
3062
3063     Returns True for only a subset of potentially nice looking formattings but
3064     the point is to not return false positives that end up producing lines that
3065     are too long.
3066     """
3067     bt = line.bracket_tracker
3068     if not bt.delimiters:
3069         # Without delimiters the optional parentheses are useless.
3070         return True
3071
3072     max_priority = bt.max_delimiter_priority()
3073     if bt.delimiter_count_with_priority(max_priority) > 1:
3074         # With more than one delimiter of a kind the optional parentheses read better.
3075         return False
3076
3077     if max_priority == DOT_PRIORITY:
3078         # A single stranded method call doesn't require optional parentheses.
3079         return True
3080
3081     assert len(line.leaves) >= 2, "Stranded delimiter"
3082
3083     first = line.leaves[0]
3084     second = line.leaves[1]
3085     penultimate = line.leaves[-2]
3086     last = line.leaves[-1]
3087
3088     # With a single delimiter, omit if the expression starts or ends with
3089     # a bracket.
3090     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3091         remainder = False
3092         length = 4 * line.depth
3093         for _index, leaf, leaf_length in enumerate_with_length(line):
3094             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3095                 remainder = True
3096             if remainder:
3097                 length += leaf_length
3098                 if length > line_length:
3099                     break
3100
3101                 if leaf.type in OPENING_BRACKETS:
3102                     # There are brackets we can further split on.
3103                     remainder = False
3104
3105         else:
3106             # checked the entire string and line length wasn't exceeded
3107             if len(line.leaves) == _index + 1:
3108                 return True
3109
3110         # Note: we are not returning False here because a line might have *both*
3111         # a leading opening bracket and a trailing closing bracket.  If the
3112         # opening bracket doesn't match our rule, maybe the closing will.
3113
3114     if (
3115         last.type == token.RPAR
3116         or last.type == token.RBRACE
3117         or (
3118             # don't use indexing for omitting optional parentheses;
3119             # it looks weird
3120             last.type == token.RSQB
3121             and last.parent
3122             and last.parent.type != syms.trailer
3123         )
3124     ):
3125         if penultimate.type in OPENING_BRACKETS:
3126             # Empty brackets don't help.
3127             return False
3128
3129         if is_multiline_string(first):
3130             # Additional wrapping of a multiline string in this situation is
3131             # unnecessary.
3132             return True
3133
3134         length = 4 * line.depth
3135         seen_other_brackets = False
3136         for _index, leaf, leaf_length in enumerate_with_length(line):
3137             length += leaf_length
3138             if leaf is last.opening_bracket:
3139                 if seen_other_brackets or length <= line_length:
3140                     return True
3141
3142             elif leaf.type in OPENING_BRACKETS:
3143                 # There are brackets we can further split on.
3144                 seen_other_brackets = True
3145
3146     return False
3147
3148
3149 def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
3150     return (
3151         CACHE_DIR
3152         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3153     )
3154
3155
3156 def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
3157     """Read the cache if it exists and is well formed.
3158
3159     If it is not well formed, the call to write_cache later should resolve the issue.
3160     """
3161     cache_file = get_cache_file(line_length, pyi, py36)
3162     if not cache_file.exists():
3163         return {}
3164
3165     with cache_file.open("rb") as fobj:
3166         try:
3167             cache: Cache = pickle.load(fobj)
3168         except pickle.UnpicklingError:
3169             return {}
3170
3171     return cache
3172
3173
3174 def get_cache_info(path: Path) -> CacheInfo:
3175     """Return the information used to check if a file is already formatted or not."""
3176     stat = path.stat()
3177     return stat.st_mtime, stat.st_size
3178
3179
3180 def filter_cached(
3181     cache: Cache, sources: Iterable[Path]
3182 ) -> Tuple[List[Path], List[Path]]:
3183     """Split a list of paths into two.
3184
3185     The first list contains paths of files that modified on disk or are not in the
3186     cache. The other list contains paths to non-modified files.
3187     """
3188     todo, done = [], []
3189     for src in sources:
3190         src = src.resolve()
3191         if cache.get(src) != get_cache_info(src):
3192             todo.append(src)
3193         else:
3194             done.append(src)
3195     return todo, done
3196
3197
3198 def write_cache(
3199     cache: Cache,
3200     sources: List[Path],
3201     line_length: int,
3202     pyi: bool = False,
3203     py36: bool = False,
3204 ) -> None:
3205     """Update the cache file."""
3206     cache_file = get_cache_file(line_length, pyi, py36)
3207     try:
3208         if not CACHE_DIR.exists():
3209             CACHE_DIR.mkdir(parents=True)
3210         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3211         with cache_file.open("wb") as fobj:
3212             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3213     except OSError:
3214         pass
3215
3216
3217 if __name__ == "__main__":
3218     main()