]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

4599bdd98a9ae536bb25c27d4b3d1511e0523f17
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b1"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 class FileMode(Flag):
126     AUTO_DETECT = 0
127     PYTHON36 = 1
128     PYI = 2
129     NO_STRING_NORMALIZATION = 4
130
131
132 @click.command()
133 @click.option(
134     "-l",
135     "--line-length",
136     type=int,
137     default=DEFAULT_LINE_LENGTH,
138     help="How many character per line to allow.",
139     show_default=True,
140 )
141 @click.option(
142     "--check",
143     is_flag=True,
144     help=(
145         "Don't write the files back, just return the status.  Return code 0 "
146         "means nothing would change.  Return code 1 means some files would be "
147         "reformatted.  Return code 123 means there was an internal error."
148     ),
149 )
150 @click.option(
151     "--diff",
152     is_flag=True,
153     help="Don't write the files back, just output a diff for each file on stdout.",
154 )
155 @click.option(
156     "--fast/--safe",
157     is_flag=True,
158     help="If --fast given, skip temporary sanity checks. [default: --safe]",
159 )
160 @click.option(
161     "-q",
162     "--quiet",
163     is_flag=True,
164     help=(
165         "Don't emit non-error messages to stderr. Errors are still emitted, "
166         "silence those with 2>/dev/null."
167     ),
168 )
169 @click.option(
170     "--pyi",
171     is_flag=True,
172     help=(
173         "Consider all input files typing stubs regardless of file extension "
174         "(useful when piping source on standard input)."
175     ),
176 )
177 @click.option(
178     "--py36",
179     is_flag=True,
180     help=(
181         "Allow using Python 3.6-only syntax on all input files.  This will put "
182         "trailing commas in function signatures and calls also after *args and "
183         "**kwargs.  [default: per-file auto-detection]"
184     ),
185 )
186 @click.option(
187     "-S",
188     "--skip-string-normalization",
189     is_flag=True,
190     help="Don't normalize string quotes or prefixes.",
191 )
192 @click.version_option(version=__version__)
193 @click.argument(
194     "src",
195     nargs=-1,
196     type=click.Path(
197         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
198     ),
199 )
200 @click.pass_context
201 def main(
202     ctx: click.Context,
203     line_length: int,
204     check: bool,
205     diff: bool,
206     fast: bool,
207     pyi: bool,
208     py36: bool,
209     skip_string_normalization: bool,
210     quiet: bool,
211     src: List[str],
212 ) -> None:
213     """The uncompromising code formatter."""
214     sources: List[Path] = []
215     for s in src:
216         p = Path(s)
217         if p.is_dir():
218             sources.extend(gen_python_files_in_dir(p))
219         elif p.is_file():
220             # if a file was explicitly given, we don't care about its extension
221             sources.append(p)
222         elif s == "-":
223             sources.append(Path("-"))
224         else:
225             err(f"invalid path: {s}")
226
227     if check and not diff:
228         write_back = WriteBack.NO
229     elif diff:
230         write_back = WriteBack.DIFF
231     else:
232         write_back = WriteBack.YES
233     mode = FileMode.AUTO_DETECT
234     if py36:
235         mode |= FileMode.PYTHON36
236     if pyi:
237         mode |= FileMode.PYI
238     if skip_string_normalization:
239         mode |= FileMode.NO_STRING_NORMALIZATION
240     report = Report(check=check, quiet=quiet)
241     if len(sources) == 0:
242         out("No paths given. Nothing to do 😴")
243         ctx.exit(0)
244         return
245
246     elif len(sources) == 1:
247         reformat_one(
248             src=sources[0],
249             line_length=line_length,
250             fast=fast,
251             write_back=write_back,
252             mode=mode,
253             report=report,
254         )
255     else:
256         loop = asyncio.get_event_loop()
257         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
258         try:
259             loop.run_until_complete(
260                 schedule_formatting(
261                     sources=sources,
262                     line_length=line_length,
263                     fast=fast,
264                     write_back=write_back,
265                     mode=mode,
266                     report=report,
267                     loop=loop,
268                     executor=executor,
269                 )
270             )
271         finally:
272             shutdown(loop)
273         if not quiet:
274             out("All done! ✨ 🍰 ✨")
275             click.echo(str(report))
276     ctx.exit(report.return_code)
277
278
279 def reformat_one(
280     src: Path,
281     line_length: int,
282     fast: bool,
283     write_back: WriteBack,
284     mode: FileMode,
285     report: "Report",
286 ) -> None:
287     """Reformat a single file under `src` without spawning child processes.
288
289     If `quiet` is True, non-error messages are not output. `line_length`,
290     `write_back`, `fast` and `pyi` options are passed to
291     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
292     """
293     try:
294         changed = Changed.NO
295         if not src.is_file() and str(src) == "-":
296             if format_stdin_to_stdout(
297                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
298             ):
299                 changed = Changed.YES
300         else:
301             cache: Cache = {}
302             if write_back != WriteBack.DIFF:
303                 cache = read_cache(line_length, mode)
304                 src = src.resolve()
305                 if src in cache and cache[src] == get_cache_info(src):
306                     changed = Changed.CACHED
307             if changed is not Changed.CACHED and format_file_in_place(
308                 src,
309                 line_length=line_length,
310                 fast=fast,
311                 write_back=write_back,
312                 mode=mode,
313             ):
314                 changed = Changed.YES
315             if write_back == WriteBack.YES and changed is not Changed.NO:
316                 write_cache(cache, [src], line_length, mode)
317         report.done(src, changed)
318     except Exception as exc:
319         report.failed(src, str(exc))
320
321
322 async def schedule_formatting(
323     sources: List[Path],
324     line_length: int,
325     fast: bool,
326     write_back: WriteBack,
327     mode: FileMode,
328     report: "Report",
329     loop: BaseEventLoop,
330     executor: Executor,
331 ) -> None:
332     """Run formatting of `sources` in parallel using the provided `executor`.
333
334     (Use ProcessPoolExecutors for actual parallelism.)
335
336     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
337     :func:`format_file_in_place`.
338     """
339     cache: Cache = {}
340     if write_back != WriteBack.DIFF:
341         cache = read_cache(line_length, mode)
342         sources, cached = filter_cached(cache, sources)
343         for src in cached:
344             report.done(src, Changed.CACHED)
345     cancelled = []
346     formatted = []
347     if sources:
348         lock = None
349         if write_back == WriteBack.DIFF:
350             # For diff output, we need locks to ensure we don't interleave output
351             # from different processes.
352             manager = Manager()
353             lock = manager.Lock()
354         tasks = {
355             loop.run_in_executor(
356                 executor,
357                 format_file_in_place,
358                 src,
359                 line_length,
360                 fast,
361                 write_back,
362                 mode,
363                 lock,
364             ): src
365             for src in sorted(sources)
366         }
367         pending: Iterable[asyncio.Task] = tasks.keys()
368         try:
369             loop.add_signal_handler(signal.SIGINT, cancel, pending)
370             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
371         except NotImplementedError:
372             # There are no good alternatives for these on Windows
373             pass
374         while pending:
375             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
376             for task in done:
377                 src = tasks.pop(task)
378                 if task.cancelled():
379                     cancelled.append(task)
380                 elif task.exception():
381                     report.failed(src, str(task.exception()))
382                 else:
383                     formatted.append(src)
384                     report.done(src, Changed.YES if task.result() else Changed.NO)
385     if cancelled:
386         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
387     if write_back == WriteBack.YES and formatted:
388         write_cache(cache, formatted, line_length, mode)
389
390
391 def format_file_in_place(
392     src: Path,
393     line_length: int,
394     fast: bool,
395     write_back: WriteBack = WriteBack.NO,
396     mode: FileMode = FileMode.AUTO_DETECT,
397     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
398 ) -> bool:
399     """Format file under `src` path. Return True if changed.
400
401     If `write_back` is True, write reformatted code back to stdout.
402     `line_length` and `fast` options are passed to :func:`format_file_contents`.
403     """
404     if src.suffix == ".pyi":
405         mode |= FileMode.PYI
406     with tokenize.open(src) as src_buffer:
407         src_contents = src_buffer.read()
408     try:
409         dst_contents = format_file_contents(
410             src_contents, line_length=line_length, fast=fast, mode=mode
411         )
412     except NothingChanged:
413         return False
414
415     if write_back == write_back.YES:
416         with open(src, "w", encoding=src_buffer.encoding) as f:
417             f.write(dst_contents)
418     elif write_back == write_back.DIFF:
419         src_name = f"{src}  (original)"
420         dst_name = f"{src}  (formatted)"
421         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
422         if lock:
423             lock.acquire()
424         try:
425             sys.stdout.write(diff_contents)
426         finally:
427             if lock:
428                 lock.release()
429     return True
430
431
432 def format_stdin_to_stdout(
433     line_length: int,
434     fast: bool,
435     write_back: WriteBack = WriteBack.NO,
436     mode: FileMode = FileMode.AUTO_DETECT,
437 ) -> bool:
438     """Format file on stdin. Return True if changed.
439
440     If `write_back` is True, write reformatted code back to stdout.
441     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
442     :func:`format_file_contents`.
443     """
444     src = sys.stdin.read()
445     dst = src
446     try:
447         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
448         return True
449
450     except NothingChanged:
451         return False
452
453     finally:
454         if write_back == WriteBack.YES:
455             sys.stdout.write(dst)
456         elif write_back == WriteBack.DIFF:
457             src_name = "<stdin>  (original)"
458             dst_name = "<stdin>  (formatted)"
459             sys.stdout.write(diff(src, dst, src_name, dst_name))
460
461
462 def format_file_contents(
463     src_contents: str,
464     *,
465     line_length: int,
466     fast: bool,
467     mode: FileMode = FileMode.AUTO_DETECT,
468 ) -> FileContent:
469     """Reformat contents a file and return new contents.
470
471     If `fast` is False, additionally confirm that the reformatted code is
472     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
473     `line_length` is passed to :func:`format_str`.
474     """
475     if src_contents.strip() == "":
476         raise NothingChanged
477
478     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
479     if src_contents == dst_contents:
480         raise NothingChanged
481
482     if not fast:
483         assert_equivalent(src_contents, dst_contents)
484         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
485     return dst_contents
486
487
488 def format_str(
489     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
490 ) -> FileContent:
491     """Reformat a string and return new contents.
492
493     `line_length` determines how many characters per line are allowed.
494     """
495     src_node = lib2to3_parse(src_contents)
496     dst_contents = ""
497     future_imports = get_future_imports(src_node)
498     is_pyi = bool(mode & FileMode.PYI)
499     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
500     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
501     lines = LineGenerator(
502         remove_u_prefix=py36 or "unicode_literals" in future_imports,
503         is_pyi=is_pyi,
504         normalize_strings=normalize_strings,
505     )
506     elt = EmptyLineTracker(is_pyi=is_pyi)
507     empty_line = Line()
508     after = 0
509     for current_line in lines.visit(src_node):
510         for _ in range(after):
511             dst_contents += str(empty_line)
512         before, after = elt.maybe_empty_lines(current_line)
513         for _ in range(before):
514             dst_contents += str(empty_line)
515         for line in split_line(current_line, line_length=line_length, py36=py36):
516             dst_contents += str(line)
517     return dst_contents
518
519
520 GRAMMARS = [
521     pygram.python_grammar_no_print_statement_no_exec_statement,
522     pygram.python_grammar_no_print_statement,
523     pygram.python_grammar,
524 ]
525
526
527 def lib2to3_parse(src_txt: str) -> Node:
528     """Given a string with source, return the lib2to3 Node."""
529     grammar = pygram.python_grammar_no_print_statement
530     if src_txt[-1] != "\n":
531         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
532         src_txt += nl
533     for grammar in GRAMMARS:
534         drv = driver.Driver(grammar, pytree.convert)
535         try:
536             result = drv.parse_string(src_txt, True)
537             break
538
539         except ParseError as pe:
540             lineno, column = pe.context[1]
541             lines = src_txt.splitlines()
542             try:
543                 faulty_line = lines[lineno - 1]
544             except IndexError:
545                 faulty_line = "<line number missing in source>"
546             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
547     else:
548         raise exc from None
549
550     if isinstance(result, Leaf):
551         result = Node(syms.file_input, [result])
552     return result
553
554
555 def lib2to3_unparse(node: Node) -> str:
556     """Given a lib2to3 node, return its string representation."""
557     code = str(node)
558     return code
559
560
561 T = TypeVar("T")
562
563
564 class Visitor(Generic[T]):
565     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
566
567     def visit(self, node: LN) -> Iterator[T]:
568         """Main method to visit `node` and its children.
569
570         It tries to find a `visit_*()` method for the given `node.type`, like
571         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
572         If no dedicated `visit_*()` method is found, chooses `visit_default()`
573         instead.
574
575         Then yields objects of type `T` from the selected visitor.
576         """
577         if node.type < 256:
578             name = token.tok_name[node.type]
579         else:
580             name = type_repr(node.type)
581         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
582
583     def visit_default(self, node: LN) -> Iterator[T]:
584         """Default `visit_*()` implementation. Recurses to children of `node`."""
585         if isinstance(node, Node):
586             for child in node.children:
587                 yield from self.visit(child)
588
589
590 @dataclass
591 class DebugVisitor(Visitor[T]):
592     tree_depth: int = 0
593
594     def visit_default(self, node: LN) -> Iterator[T]:
595         indent = " " * (2 * self.tree_depth)
596         if isinstance(node, Node):
597             _type = type_repr(node.type)
598             out(f"{indent}{_type}", fg="yellow")
599             self.tree_depth += 1
600             for child in node.children:
601                 yield from self.visit(child)
602
603             self.tree_depth -= 1
604             out(f"{indent}/{_type}", fg="yellow", bold=False)
605         else:
606             _type = token.tok_name.get(node.type, str(node.type))
607             out(f"{indent}{_type}", fg="blue", nl=False)
608             if node.prefix:
609                 # We don't have to handle prefixes for `Node` objects since
610                 # that delegates to the first child anyway.
611                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
612             out(f" {node.value!r}", fg="blue", bold=False)
613
614     @classmethod
615     def show(cls, code: str) -> None:
616         """Pretty-print the lib2to3 AST of a given string of `code`.
617
618         Convenience method for debugging.
619         """
620         v: DebugVisitor[None] = DebugVisitor()
621         list(v.visit(lib2to3_parse(code)))
622
623
624 KEYWORDS = set(keyword.kwlist)
625 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
626 FLOW_CONTROL = {"return", "raise", "break", "continue"}
627 STATEMENT = {
628     syms.if_stmt,
629     syms.while_stmt,
630     syms.for_stmt,
631     syms.try_stmt,
632     syms.except_clause,
633     syms.with_stmt,
634     syms.funcdef,
635     syms.classdef,
636 }
637 STANDALONE_COMMENT = 153
638 LOGIC_OPERATORS = {"and", "or"}
639 COMPARATORS = {
640     token.LESS,
641     token.GREATER,
642     token.EQEQUAL,
643     token.NOTEQUAL,
644     token.LESSEQUAL,
645     token.GREATEREQUAL,
646 }
647 MATH_OPERATORS = {
648     token.VBAR,
649     token.CIRCUMFLEX,
650     token.AMPER,
651     token.LEFTSHIFT,
652     token.RIGHTSHIFT,
653     token.PLUS,
654     token.MINUS,
655     token.STAR,
656     token.SLASH,
657     token.DOUBLESLASH,
658     token.PERCENT,
659     token.AT,
660     token.TILDE,
661     token.DOUBLESTAR,
662 }
663 STARS = {token.STAR, token.DOUBLESTAR}
664 VARARGS_PARENTS = {
665     syms.arglist,
666     syms.argument,  # double star in arglist
667     syms.trailer,  # single argument to call
668     syms.typedargslist,
669     syms.varargslist,  # lambdas
670 }
671 UNPACKING_PARENTS = {
672     syms.atom,  # single element of a list or set literal
673     syms.dictsetmaker,
674     syms.listmaker,
675     syms.testlist_gexp,
676 }
677 TEST_DESCENDANTS = {
678     syms.test,
679     syms.lambdef,
680     syms.or_test,
681     syms.and_test,
682     syms.not_test,
683     syms.comparison,
684     syms.star_expr,
685     syms.expr,
686     syms.xor_expr,
687     syms.and_expr,
688     syms.shift_expr,
689     syms.arith_expr,
690     syms.trailer,
691     syms.term,
692     syms.power,
693 }
694 ASSIGNMENTS = {
695     "=",
696     "+=",
697     "-=",
698     "*=",
699     "@=",
700     "/=",
701     "%=",
702     "&=",
703     "|=",
704     "^=",
705     "<<=",
706     ">>=",
707     "**=",
708     "//=",
709 }
710 COMPREHENSION_PRIORITY = 20
711 COMMA_PRIORITY = 18
712 TERNARY_PRIORITY = 16
713 LOGIC_PRIORITY = 14
714 STRING_PRIORITY = 12
715 COMPARATOR_PRIORITY = 10
716 MATH_PRIORITIES = {
717     token.VBAR: 9,
718     token.CIRCUMFLEX: 8,
719     token.AMPER: 7,
720     token.LEFTSHIFT: 6,
721     token.RIGHTSHIFT: 6,
722     token.PLUS: 5,
723     token.MINUS: 5,
724     token.STAR: 4,
725     token.SLASH: 4,
726     token.DOUBLESLASH: 4,
727     token.PERCENT: 4,
728     token.AT: 4,
729     token.TILDE: 3,
730     token.DOUBLESTAR: 2,
731 }
732 DOT_PRIORITY = 1
733
734
735 @dataclass
736 class BracketTracker:
737     """Keeps track of brackets on a line."""
738
739     depth: int = 0
740     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
741     delimiters: Dict[LeafID, Priority] = Factory(dict)
742     previous: Optional[Leaf] = None
743     _for_loop_variable: int = 0
744     _lambda_arguments: int = 0
745
746     def mark(self, leaf: Leaf) -> None:
747         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
748
749         All leaves receive an int `bracket_depth` field that stores how deep
750         within brackets a given leaf is. 0 means there are no enclosing brackets
751         that started on this line.
752
753         If a leaf is itself a closing bracket, it receives an `opening_bracket`
754         field that it forms a pair with. This is a one-directional link to
755         avoid reference cycles.
756
757         If a leaf is a delimiter (a token on which Black can split the line if
758         needed) and it's on depth 0, its `id()` is stored in the tracker's
759         `delimiters` field.
760         """
761         if leaf.type == token.COMMENT:
762             return
763
764         self.maybe_decrement_after_for_loop_variable(leaf)
765         self.maybe_decrement_after_lambda_arguments(leaf)
766         if leaf.type in CLOSING_BRACKETS:
767             self.depth -= 1
768             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
769             leaf.opening_bracket = opening_bracket
770         leaf.bracket_depth = self.depth
771         if self.depth == 0:
772             delim = is_split_before_delimiter(leaf, self.previous)
773             if delim and self.previous is not None:
774                 self.delimiters[id(self.previous)] = delim
775             else:
776                 delim = is_split_after_delimiter(leaf, self.previous)
777                 if delim:
778                     self.delimiters[id(leaf)] = delim
779         if leaf.type in OPENING_BRACKETS:
780             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
781             self.depth += 1
782         self.previous = leaf
783         self.maybe_increment_lambda_arguments(leaf)
784         self.maybe_increment_for_loop_variable(leaf)
785
786     def any_open_brackets(self) -> bool:
787         """Return True if there is an yet unmatched open bracket on the line."""
788         return bool(self.bracket_match)
789
790     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
791         """Return the highest priority of a delimiter found on the line.
792
793         Values are consistent with what `is_split_*_delimiter()` return.
794         Raises ValueError on no delimiters.
795         """
796         return max(v for k, v in self.delimiters.items() if k not in exclude)
797
798     def delimiter_count_with_priority(self, priority: int = 0) -> int:
799         """Return the number of delimiters with the given `priority`.
800
801         If no `priority` is passed, defaults to max priority on the line.
802         """
803         if not self.delimiters:
804             return 0
805
806         priority = priority or self.max_delimiter_priority()
807         return sum(1 for p in self.delimiters.values() if p == priority)
808
809     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
810         """In a for loop, or comprehension, the variables are often unpacks.
811
812         To avoid splitting on the comma in this situation, increase the depth of
813         tokens between `for` and `in`.
814         """
815         if leaf.type == token.NAME and leaf.value == "for":
816             self.depth += 1
817             self._for_loop_variable += 1
818             return True
819
820         return False
821
822     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
823         """See `maybe_increment_for_loop_variable` above for explanation."""
824         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
825             self.depth -= 1
826             self._for_loop_variable -= 1
827             return True
828
829         return False
830
831     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
832         """In a lambda expression, there might be more than one argument.
833
834         To avoid splitting on the comma in this situation, increase the depth of
835         tokens between `lambda` and `:`.
836         """
837         if leaf.type == token.NAME and leaf.value == "lambda":
838             self.depth += 1
839             self._lambda_arguments += 1
840             return True
841
842         return False
843
844     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
845         """See `maybe_increment_lambda_arguments` above for explanation."""
846         if self._lambda_arguments and leaf.type == token.COLON:
847             self.depth -= 1
848             self._lambda_arguments -= 1
849             return True
850
851         return False
852
853     def get_open_lsqb(self) -> Optional[Leaf]:
854         """Return the most recent opening square bracket (if any)."""
855         return self.bracket_match.get((self.depth - 1, token.RSQB))
856
857
858 @dataclass
859 class Line:
860     """Holds leaves and comments. Can be printed with `str(line)`."""
861
862     depth: int = 0
863     leaves: List[Leaf] = Factory(list)
864     comments: List[Tuple[Index, Leaf]] = Factory(list)
865     bracket_tracker: BracketTracker = Factory(BracketTracker)
866     inside_brackets: bool = False
867     should_explode: bool = False
868
869     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
870         """Add a new `leaf` to the end of the line.
871
872         Unless `preformatted` is True, the `leaf` will receive a new consistent
873         whitespace prefix and metadata applied by :class:`BracketTracker`.
874         Trailing commas are maybe removed, unpacked for loop variables are
875         demoted from being delimiters.
876
877         Inline comments are put aside.
878         """
879         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
880         if not has_value:
881             return
882
883         if token.COLON == leaf.type and self.is_class_paren_empty:
884             del self.leaves[-2:]
885         if self.leaves and not preformatted:
886             # Note: at this point leaf.prefix should be empty except for
887             # imports, for which we only preserve newlines.
888             leaf.prefix += whitespace(
889                 leaf, complex_subscript=self.is_complex_subscript(leaf)
890             )
891         if self.inside_brackets or not preformatted:
892             self.bracket_tracker.mark(leaf)
893             self.maybe_remove_trailing_comma(leaf)
894         if not self.append_comment(leaf):
895             self.leaves.append(leaf)
896
897     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
898         """Like :func:`append()` but disallow invalid standalone comment structure.
899
900         Raises ValueError when any `leaf` is appended after a standalone comment
901         or when a standalone comment is not the first leaf on the line.
902         """
903         if self.bracket_tracker.depth == 0:
904             if self.is_comment:
905                 raise ValueError("cannot append to standalone comments")
906
907             if self.leaves and leaf.type == STANDALONE_COMMENT:
908                 raise ValueError(
909                     "cannot append standalone comments to a populated line"
910                 )
911
912         self.append(leaf, preformatted=preformatted)
913
914     @property
915     def is_comment(self) -> bool:
916         """Is this line a standalone comment?"""
917         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
918
919     @property
920     def is_decorator(self) -> bool:
921         """Is this line a decorator?"""
922         return bool(self) and self.leaves[0].type == token.AT
923
924     @property
925     def is_import(self) -> bool:
926         """Is this an import line?"""
927         return bool(self) and is_import(self.leaves[0])
928
929     @property
930     def is_class(self) -> bool:
931         """Is this line a class definition?"""
932         return (
933             bool(self)
934             and self.leaves[0].type == token.NAME
935             and self.leaves[0].value == "class"
936         )
937
938     @property
939     def is_stub_class(self) -> bool:
940         """Is this line a class definition with a body consisting only of "..."?"""
941         return self.is_class and self.leaves[-3:] == [
942             Leaf(token.DOT, ".") for _ in range(3)
943         ]
944
945     @property
946     def is_def(self) -> bool:
947         """Is this a function definition? (Also returns True for async defs.)"""
948         try:
949             first_leaf = self.leaves[0]
950         except IndexError:
951             return False
952
953         try:
954             second_leaf: Optional[Leaf] = self.leaves[1]
955         except IndexError:
956             second_leaf = None
957         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
958             first_leaf.type == token.ASYNC
959             and second_leaf is not None
960             and second_leaf.type == token.NAME
961             and second_leaf.value == "def"
962         )
963
964     @property
965     def is_class_paren_empty(self) -> bool:
966         """Is this a class with no base classes but using parentheses?
967
968         Those are unnecessary and should be removed.
969         """
970         return (
971             bool(self)
972             and len(self.leaves) == 4
973             and self.is_class
974             and self.leaves[2].type == token.LPAR
975             and self.leaves[2].value == "("
976             and self.leaves[3].type == token.RPAR
977             and self.leaves[3].value == ")"
978         )
979
980     @property
981     def is_triple_quoted_string(self) -> bool:
982         """Is the line a triple quoted string?"""
983         return (
984             bool(self)
985             and self.leaves[0].type == token.STRING
986             and self.leaves[0].value.startswith(('"""', "'''"))
987         )
988
989     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
990         """If so, needs to be split before emitting."""
991         for leaf in self.leaves:
992             if leaf.type == STANDALONE_COMMENT:
993                 if leaf.bracket_depth <= depth_limit:
994                     return True
995
996         return False
997
998     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
999         """Remove trailing comma if there is one and it's safe."""
1000         if not (
1001             self.leaves
1002             and self.leaves[-1].type == token.COMMA
1003             and closing.type in CLOSING_BRACKETS
1004         ):
1005             return False
1006
1007         if closing.type == token.RBRACE:
1008             self.remove_trailing_comma()
1009             return True
1010
1011         if closing.type == token.RSQB:
1012             comma = self.leaves[-1]
1013             if comma.parent and comma.parent.type == syms.listmaker:
1014                 self.remove_trailing_comma()
1015                 return True
1016
1017         # For parens let's check if it's safe to remove the comma.
1018         # Imports are always safe.
1019         if self.is_import:
1020             self.remove_trailing_comma()
1021             return True
1022
1023         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1024         # change a tuple into a different type by removing the comma.
1025         depth = closing.bracket_depth + 1
1026         commas = 0
1027         opening = closing.opening_bracket
1028         for _opening_index, leaf in enumerate(self.leaves):
1029             if leaf is opening:
1030                 break
1031
1032         else:
1033             return False
1034
1035         for leaf in self.leaves[_opening_index + 1 :]:
1036             if leaf is closing:
1037                 break
1038
1039             bracket_depth = leaf.bracket_depth
1040             if bracket_depth == depth and leaf.type == token.COMMA:
1041                 commas += 1
1042                 if leaf.parent and leaf.parent.type == syms.arglist:
1043                     commas += 1
1044                     break
1045
1046         if commas > 1:
1047             self.remove_trailing_comma()
1048             return True
1049
1050         return False
1051
1052     def append_comment(self, comment: Leaf) -> bool:
1053         """Add an inline or standalone comment to the line."""
1054         if (
1055             comment.type == STANDALONE_COMMENT
1056             and self.bracket_tracker.any_open_brackets()
1057         ):
1058             comment.prefix = ""
1059             return False
1060
1061         if comment.type != token.COMMENT:
1062             return False
1063
1064         after = len(self.leaves) - 1
1065         if after == -1:
1066             comment.type = STANDALONE_COMMENT
1067             comment.prefix = ""
1068             return False
1069
1070         else:
1071             self.comments.append((after, comment))
1072             return True
1073
1074     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1075         """Generate comments that should appear directly after `leaf`.
1076
1077         Provide a non-negative leaf `_index` to speed up the function.
1078         """
1079         if _index == -1:
1080             for _index, _leaf in enumerate(self.leaves):
1081                 if leaf is _leaf:
1082                     break
1083
1084             else:
1085                 return
1086
1087         for index, comment_after in self.comments:
1088             if _index == index:
1089                 yield comment_after
1090
1091     def remove_trailing_comma(self) -> None:
1092         """Remove the trailing comma and moves the comments attached to it."""
1093         comma_index = len(self.leaves) - 1
1094         for i in range(len(self.comments)):
1095             comment_index, comment = self.comments[i]
1096             if comment_index == comma_index:
1097                 self.comments[i] = (comma_index - 1, comment)
1098         self.leaves.pop()
1099
1100     def is_complex_subscript(self, leaf: Leaf) -> bool:
1101         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1102         open_lsqb = (
1103             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1104         )
1105         if open_lsqb is None:
1106             return False
1107
1108         subscript_start = open_lsqb.next_sibling
1109         if (
1110             isinstance(subscript_start, Node)
1111             and subscript_start.type == syms.subscriptlist
1112         ):
1113             subscript_start = child_towards(subscript_start, leaf)
1114         return subscript_start is not None and any(
1115             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1116         )
1117
1118     def __str__(self) -> str:
1119         """Render the line."""
1120         if not self:
1121             return "\n"
1122
1123         indent = "    " * self.depth
1124         leaves = iter(self.leaves)
1125         first = next(leaves)
1126         res = f"{first.prefix}{indent}{first.value}"
1127         for leaf in leaves:
1128             res += str(leaf)
1129         for _, comment in self.comments:
1130             res += str(comment)
1131         return res + "\n"
1132
1133     def __bool__(self) -> bool:
1134         """Return True if the line has leaves or comments."""
1135         return bool(self.leaves or self.comments)
1136
1137
1138 class UnformattedLines(Line):
1139     """Just like :class:`Line` but stores lines which aren't reformatted."""
1140
1141     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1142         """Just add a new `leaf` to the end of the lines.
1143
1144         The `preformatted` argument is ignored.
1145
1146         Keeps track of indentation `depth`, which is useful when the user
1147         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1148         """
1149         try:
1150             list(generate_comments(leaf))
1151         except FormatOn as f_on:
1152             self.leaves.append(f_on.leaf_from_consumed(leaf))
1153             raise
1154
1155         self.leaves.append(leaf)
1156         if leaf.type == token.INDENT:
1157             self.depth += 1
1158         elif leaf.type == token.DEDENT:
1159             self.depth -= 1
1160
1161     def __str__(self) -> str:
1162         """Render unformatted lines from leaves which were added with `append()`.
1163
1164         `depth` is not used for indentation in this case.
1165         """
1166         if not self:
1167             return "\n"
1168
1169         res = ""
1170         for leaf in self.leaves:
1171             res += str(leaf)
1172         return res
1173
1174     def append_comment(self, comment: Leaf) -> bool:
1175         """Not implemented in this class. Raises `NotImplementedError`."""
1176         raise NotImplementedError("Unformatted lines don't store comments separately.")
1177
1178     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1179         """Does nothing and returns False."""
1180         return False
1181
1182     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1183         """Does nothing and returns False."""
1184         return False
1185
1186
1187 @dataclass
1188 class EmptyLineTracker:
1189     """Provides a stateful method that returns the number of potential extra
1190     empty lines needed before and after the currently processed line.
1191
1192     Note: this tracker works on lines that haven't been split yet.  It assumes
1193     the prefix of the first leaf consists of optional newlines.  Those newlines
1194     are consumed by `maybe_empty_lines()` and included in the computation.
1195     """
1196
1197     is_pyi: bool = False
1198     previous_line: Optional[Line] = None
1199     previous_after: int = 0
1200     previous_defs: List[int] = Factory(list)
1201
1202     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1203         """Return the number of extra empty lines before and after the `current_line`.
1204
1205         This is for separating `def`, `async def` and `class` with extra empty
1206         lines (two on module-level).
1207         """
1208         if isinstance(current_line, UnformattedLines):
1209             return 0, 0
1210
1211         before, after = self._maybe_empty_lines(current_line)
1212         before -= self.previous_after
1213         self.previous_after = after
1214         self.previous_line = current_line
1215         return before, after
1216
1217     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1218         max_allowed = 1
1219         if current_line.depth == 0:
1220             max_allowed = 1 if self.is_pyi else 2
1221         if current_line.leaves:
1222             # Consume the first leaf's extra newlines.
1223             first_leaf = current_line.leaves[0]
1224             before = first_leaf.prefix.count("\n")
1225             before = min(before, max_allowed)
1226             first_leaf.prefix = ""
1227         else:
1228             before = 0
1229         depth = current_line.depth
1230         while self.previous_defs and self.previous_defs[-1] >= depth:
1231             self.previous_defs.pop()
1232             if self.is_pyi:
1233                 before = 0 if depth else 1
1234             else:
1235                 before = 1 if depth else 2
1236         is_decorator = current_line.is_decorator
1237         if is_decorator or current_line.is_def or current_line.is_class:
1238             if not is_decorator:
1239                 self.previous_defs.append(depth)
1240             if self.previous_line is None:
1241                 # Don't insert empty lines before the first line in the file.
1242                 return 0, 0
1243
1244             if self.previous_line.is_decorator:
1245                 return 0, 0
1246
1247             if self.previous_line.depth < current_line.depth and (
1248                 self.previous_line.is_class or self.previous_line.is_def
1249             ):
1250                 return 0, 0
1251
1252             if (
1253                 self.previous_line.is_comment
1254                 and self.previous_line.depth == current_line.depth
1255                 and before == 0
1256             ):
1257                 return 0, 0
1258
1259             if self.is_pyi:
1260                 if self.previous_line.depth > current_line.depth:
1261                     newlines = 1
1262                 elif current_line.is_class or self.previous_line.is_class:
1263                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1264                         newlines = 0
1265                     else:
1266                         newlines = 1
1267                 else:
1268                     newlines = 0
1269             else:
1270                 newlines = 2
1271             if current_line.depth and newlines:
1272                 newlines -= 1
1273             return newlines, 0
1274
1275         if (
1276             self.previous_line
1277             and self.previous_line.is_import
1278             and not current_line.is_import
1279             and depth == self.previous_line.depth
1280         ):
1281             return (before or 1), 0
1282
1283         if (
1284             self.previous_line
1285             and self.previous_line.is_class
1286             and current_line.is_triple_quoted_string
1287         ):
1288             return before, 1
1289
1290         return before, 0
1291
1292
1293 @dataclass
1294 class LineGenerator(Visitor[Line]):
1295     """Generates reformatted Line objects.  Empty lines are not emitted.
1296
1297     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1298     in ways that will no longer stringify to valid Python code on the tree.
1299     """
1300
1301     is_pyi: bool = False
1302     normalize_strings: bool = True
1303     current_line: Line = Factory(Line)
1304     remove_u_prefix: bool = False
1305
1306     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1307         """Generate a line.
1308
1309         If the line is empty, only emit if it makes sense.
1310         If the line is too long, split it first and then generate.
1311
1312         If any lines were generated, set up a new current_line.
1313         """
1314         if not self.current_line:
1315             if self.current_line.__class__ == type:
1316                 self.current_line.depth += indent
1317             else:
1318                 self.current_line = type(depth=self.current_line.depth + indent)
1319             return  # Line is empty, don't emit. Creating a new one unnecessary.
1320
1321         complete_line = self.current_line
1322         self.current_line = type(depth=complete_line.depth + indent)
1323         yield complete_line
1324
1325     def visit(self, node: LN) -> Iterator[Line]:
1326         """Main method to visit `node` and its children.
1327
1328         Yields :class:`Line` objects.
1329         """
1330         if isinstance(self.current_line, UnformattedLines):
1331             # File contained `# fmt: off`
1332             yield from self.visit_unformatted(node)
1333
1334         else:
1335             yield from super().visit(node)
1336
1337     def visit_default(self, node: LN) -> Iterator[Line]:
1338         """Default `visit_*()` implementation. Recurses to children of `node`."""
1339         if isinstance(node, Leaf):
1340             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1341             try:
1342                 for comment in generate_comments(node):
1343                     if any_open_brackets:
1344                         # any comment within brackets is subject to splitting
1345                         self.current_line.append(comment)
1346                     elif comment.type == token.COMMENT:
1347                         # regular trailing comment
1348                         self.current_line.append(comment)
1349                         yield from self.line()
1350
1351                     else:
1352                         # regular standalone comment
1353                         yield from self.line()
1354
1355                         self.current_line.append(comment)
1356                         yield from self.line()
1357
1358             except FormatOff as f_off:
1359                 f_off.trim_prefix(node)
1360                 yield from self.line(type=UnformattedLines)
1361                 yield from self.visit(node)
1362
1363             except FormatOn as f_on:
1364                 # This only happens here if somebody says "fmt: on" multiple
1365                 # times in a row.
1366                 f_on.trim_prefix(node)
1367                 yield from self.visit_default(node)
1368
1369             else:
1370                 normalize_prefix(node, inside_brackets=any_open_brackets)
1371                 if self.normalize_strings and node.type == token.STRING:
1372                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1373                     normalize_string_quotes(node)
1374                 if node.type not in WHITESPACE:
1375                     self.current_line.append(node)
1376         yield from super().visit_default(node)
1377
1378     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1379         """Increase indentation level, maybe yield a line."""
1380         # In blib2to3 INDENT never holds comments.
1381         yield from self.line(+1)
1382         yield from self.visit_default(node)
1383
1384     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1385         """Decrease indentation level, maybe yield a line."""
1386         # The current line might still wait for trailing comments.  At DEDENT time
1387         # there won't be any (they would be prefixes on the preceding NEWLINE).
1388         # Emit the line then.
1389         yield from self.line()
1390
1391         # While DEDENT has no value, its prefix may contain standalone comments
1392         # that belong to the current indentation level.  Get 'em.
1393         yield from self.visit_default(node)
1394
1395         # Finally, emit the dedent.
1396         yield from self.line(-1)
1397
1398     def visit_stmt(
1399         self, node: Node, keywords: Set[str], parens: Set[str]
1400     ) -> Iterator[Line]:
1401         """Visit a statement.
1402
1403         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1404         `def`, `with`, `class`, `assert` and assignments.
1405
1406         The relevant Python language `keywords` for a given statement will be
1407         NAME leaves within it. This methods puts those on a separate line.
1408
1409         `parens` holds a set of string leaf values immediately after which
1410         invisible parens should be put.
1411         """
1412         normalize_invisible_parens(node, parens_after=parens)
1413         for child in node.children:
1414             if child.type == token.NAME and child.value in keywords:  # type: ignore
1415                 yield from self.line()
1416
1417             yield from self.visit(child)
1418
1419     def visit_suite(self, node: Node) -> Iterator[Line]:
1420         """Visit a suite."""
1421         if self.is_pyi and is_stub_suite(node):
1422             yield from self.visit(node.children[2])
1423         else:
1424             yield from self.visit_default(node)
1425
1426     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1427         """Visit a statement without nested statements."""
1428         is_suite_like = node.parent and node.parent.type in STATEMENT
1429         if is_suite_like:
1430             if self.is_pyi and is_stub_body(node):
1431                 yield from self.visit_default(node)
1432             else:
1433                 yield from self.line(+1)
1434                 yield from self.visit_default(node)
1435                 yield from self.line(-1)
1436
1437         else:
1438             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1439                 yield from self.line()
1440             yield from self.visit_default(node)
1441
1442     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1443         """Visit `async def`, `async for`, `async with`."""
1444         yield from self.line()
1445
1446         children = iter(node.children)
1447         for child in children:
1448             yield from self.visit(child)
1449
1450             if child.type == token.ASYNC:
1451                 break
1452
1453         internal_stmt = next(children)
1454         for child in internal_stmt.children:
1455             yield from self.visit(child)
1456
1457     def visit_decorators(self, node: Node) -> Iterator[Line]:
1458         """Visit decorators."""
1459         for child in node.children:
1460             yield from self.line()
1461             yield from self.visit(child)
1462
1463     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1464         """Remove a semicolon and put the other statement on a separate line."""
1465         yield from self.line()
1466
1467     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1468         """End of file. Process outstanding comments and end with a newline."""
1469         yield from self.visit_default(leaf)
1470         yield from self.line()
1471
1472     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1473         """Used when file contained a `# fmt: off`."""
1474         if isinstance(node, Node):
1475             for child in node.children:
1476                 yield from self.visit(child)
1477
1478         else:
1479             try:
1480                 self.current_line.append(node)
1481             except FormatOn as f_on:
1482                 f_on.trim_prefix(node)
1483                 yield from self.line()
1484                 yield from self.visit(node)
1485
1486             if node.type == token.ENDMARKER:
1487                 # somebody decided not to put a final `# fmt: on`
1488                 yield from self.line()
1489
1490     def __attrs_post_init__(self) -> None:
1491         """You are in a twisty little maze of passages."""
1492         v = self.visit_stmt
1493         Ø: Set[str] = set()
1494         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1495         self.visit_if_stmt = partial(
1496             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1497         )
1498         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1499         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1500         self.visit_try_stmt = partial(
1501             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1502         )
1503         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1504         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1505         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1506         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1507         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1508         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1509         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1510         self.visit_async_funcdef = self.visit_async_stmt
1511         self.visit_decorated = self.visit_decorators
1512
1513
1514 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1515 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1516 OPENING_BRACKETS = set(BRACKET.keys())
1517 CLOSING_BRACKETS = set(BRACKET.values())
1518 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1519 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1520
1521
1522 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1523     """Return whitespace prefix if needed for the given `leaf`.
1524
1525     `complex_subscript` signals whether the given leaf is part of a subscription
1526     which has non-trivial arguments, like arithmetic expressions or function calls.
1527     """
1528     NO = ""
1529     SPACE = " "
1530     DOUBLESPACE = "  "
1531     t = leaf.type
1532     p = leaf.parent
1533     v = leaf.value
1534     if t in ALWAYS_NO_SPACE:
1535         return NO
1536
1537     if t == token.COMMENT:
1538         return DOUBLESPACE
1539
1540     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1541     if t == token.COLON and p.type not in {
1542         syms.subscript,
1543         syms.subscriptlist,
1544         syms.sliceop,
1545     }:
1546         return NO
1547
1548     prev = leaf.prev_sibling
1549     if not prev:
1550         prevp = preceding_leaf(p)
1551         if not prevp or prevp.type in OPENING_BRACKETS:
1552             return NO
1553
1554         if t == token.COLON:
1555             if prevp.type == token.COLON:
1556                 return NO
1557
1558             elif prevp.type != token.COMMA and not complex_subscript:
1559                 return NO
1560
1561             return SPACE
1562
1563         if prevp.type == token.EQUAL:
1564             if prevp.parent:
1565                 if prevp.parent.type in {
1566                     syms.arglist,
1567                     syms.argument,
1568                     syms.parameters,
1569                     syms.varargslist,
1570                 }:
1571                     return NO
1572
1573                 elif prevp.parent.type == syms.typedargslist:
1574                     # A bit hacky: if the equal sign has whitespace, it means we
1575                     # previously found it's a typed argument.  So, we're using
1576                     # that, too.
1577                     return prevp.prefix
1578
1579         elif prevp.type in STARS:
1580             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1581                 return NO
1582
1583         elif prevp.type == token.COLON:
1584             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1585                 return SPACE if complex_subscript else NO
1586
1587         elif (
1588             prevp.parent
1589             and prevp.parent.type == syms.factor
1590             and prevp.type in MATH_OPERATORS
1591         ):
1592             return NO
1593
1594         elif (
1595             prevp.type == token.RIGHTSHIFT
1596             and prevp.parent
1597             and prevp.parent.type == syms.shift_expr
1598             and prevp.prev_sibling
1599             and prevp.prev_sibling.type == token.NAME
1600             and prevp.prev_sibling.value == "print"  # type: ignore
1601         ):
1602             # Python 2 print chevron
1603             return NO
1604
1605     elif prev.type in OPENING_BRACKETS:
1606         return NO
1607
1608     if p.type in {syms.parameters, syms.arglist}:
1609         # untyped function signatures or calls
1610         if not prev or prev.type != token.COMMA:
1611             return NO
1612
1613     elif p.type == syms.varargslist:
1614         # lambdas
1615         if prev and prev.type != token.COMMA:
1616             return NO
1617
1618     elif p.type == syms.typedargslist:
1619         # typed function signatures
1620         if not prev:
1621             return NO
1622
1623         if t == token.EQUAL:
1624             if prev.type != syms.tname:
1625                 return NO
1626
1627         elif prev.type == token.EQUAL:
1628             # A bit hacky: if the equal sign has whitespace, it means we
1629             # previously found it's a typed argument.  So, we're using that, too.
1630             return prev.prefix
1631
1632         elif prev.type != token.COMMA:
1633             return NO
1634
1635     elif p.type == syms.tname:
1636         # type names
1637         if not prev:
1638             prevp = preceding_leaf(p)
1639             if not prevp or prevp.type != token.COMMA:
1640                 return NO
1641
1642     elif p.type == syms.trailer:
1643         # attributes and calls
1644         if t == token.LPAR or t == token.RPAR:
1645             return NO
1646
1647         if not prev:
1648             if t == token.DOT:
1649                 prevp = preceding_leaf(p)
1650                 if not prevp or prevp.type != token.NUMBER:
1651                     return NO
1652
1653             elif t == token.LSQB:
1654                 return NO
1655
1656         elif prev.type != token.COMMA:
1657             return NO
1658
1659     elif p.type == syms.argument:
1660         # single argument
1661         if t == token.EQUAL:
1662             return NO
1663
1664         if not prev:
1665             prevp = preceding_leaf(p)
1666             if not prevp or prevp.type == token.LPAR:
1667                 return NO
1668
1669         elif prev.type in {token.EQUAL} | STARS:
1670             return NO
1671
1672     elif p.type == syms.decorator:
1673         # decorators
1674         return NO
1675
1676     elif p.type == syms.dotted_name:
1677         if prev:
1678             return NO
1679
1680         prevp = preceding_leaf(p)
1681         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1682             return NO
1683
1684     elif p.type == syms.classdef:
1685         if t == token.LPAR:
1686             return NO
1687
1688         if prev and prev.type == token.LPAR:
1689             return NO
1690
1691     elif p.type in {syms.subscript, syms.sliceop}:
1692         # indexing
1693         if not prev:
1694             assert p.parent is not None, "subscripts are always parented"
1695             if p.parent.type == syms.subscriptlist:
1696                 return SPACE
1697
1698             return NO
1699
1700         elif not complex_subscript:
1701             return NO
1702
1703     elif p.type == syms.atom:
1704         if prev and t == token.DOT:
1705             # dots, but not the first one.
1706             return NO
1707
1708     elif p.type == syms.dictsetmaker:
1709         # dict unpacking
1710         if prev and prev.type == token.DOUBLESTAR:
1711             return NO
1712
1713     elif p.type in {syms.factor, syms.star_expr}:
1714         # unary ops
1715         if not prev:
1716             prevp = preceding_leaf(p)
1717             if not prevp or prevp.type in OPENING_BRACKETS:
1718                 return NO
1719
1720             prevp_parent = prevp.parent
1721             assert prevp_parent is not None
1722             if prevp.type == token.COLON and prevp_parent.type in {
1723                 syms.subscript,
1724                 syms.sliceop,
1725             }:
1726                 return NO
1727
1728             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1729                 return NO
1730
1731         elif t == token.NAME or t == token.NUMBER:
1732             return NO
1733
1734     elif p.type == syms.import_from:
1735         if t == token.DOT:
1736             if prev and prev.type == token.DOT:
1737                 return NO
1738
1739         elif t == token.NAME:
1740             if v == "import":
1741                 return SPACE
1742
1743             if prev and prev.type == token.DOT:
1744                 return NO
1745
1746     elif p.type == syms.sliceop:
1747         return NO
1748
1749     return SPACE
1750
1751
1752 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1753     """Return the first leaf that precedes `node`, if any."""
1754     while node:
1755         res = node.prev_sibling
1756         if res:
1757             if isinstance(res, Leaf):
1758                 return res
1759
1760             try:
1761                 return list(res.leaves())[-1]
1762
1763             except IndexError:
1764                 return None
1765
1766         node = node.parent
1767     return None
1768
1769
1770 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1771     """Return the child of `ancestor` that contains `descendant`."""
1772     node: Optional[LN] = descendant
1773     while node and node.parent != ancestor:
1774         node = node.parent
1775     return node
1776
1777
1778 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1779     """Return the priority of the `leaf` delimiter, given a line break after it.
1780
1781     The delimiter priorities returned here are from those delimiters that would
1782     cause a line break after themselves.
1783
1784     Higher numbers are higher priority.
1785     """
1786     if leaf.type == token.COMMA:
1787         return COMMA_PRIORITY
1788
1789     return 0
1790
1791
1792 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1793     """Return the priority of the `leaf` delimiter, given a line before after it.
1794
1795     The delimiter priorities returned here are from those delimiters that would
1796     cause a line break before themselves.
1797
1798     Higher numbers are higher priority.
1799     """
1800     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1801         # * and ** might also be MATH_OPERATORS but in this case they are not.
1802         # Don't treat them as a delimiter.
1803         return 0
1804
1805     if (
1806         leaf.type == token.DOT
1807         and leaf.parent
1808         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1809         and (previous is None or previous.type in CLOSING_BRACKETS)
1810     ):
1811         return DOT_PRIORITY
1812
1813     if (
1814         leaf.type in MATH_OPERATORS
1815         and leaf.parent
1816         and leaf.parent.type not in {syms.factor, syms.star_expr}
1817     ):
1818         return MATH_PRIORITIES[leaf.type]
1819
1820     if leaf.type in COMPARATORS:
1821         return COMPARATOR_PRIORITY
1822
1823     if (
1824         leaf.type == token.STRING
1825         and previous is not None
1826         and previous.type == token.STRING
1827     ):
1828         return STRING_PRIORITY
1829
1830     if leaf.type != token.NAME:
1831         return 0
1832
1833     if (
1834         leaf.value == "for"
1835         and leaf.parent
1836         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1837     ):
1838         return COMPREHENSION_PRIORITY
1839
1840     if (
1841         leaf.value == "if"
1842         and leaf.parent
1843         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1844     ):
1845         return COMPREHENSION_PRIORITY
1846
1847     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1848         return TERNARY_PRIORITY
1849
1850     if leaf.value == "is":
1851         return COMPARATOR_PRIORITY
1852
1853     if (
1854         leaf.value == "in"
1855         and leaf.parent
1856         and leaf.parent.type in {syms.comp_op, syms.comparison}
1857         and not (
1858             previous is not None
1859             and previous.type == token.NAME
1860             and previous.value == "not"
1861         )
1862     ):
1863         return COMPARATOR_PRIORITY
1864
1865     if (
1866         leaf.value == "not"
1867         and leaf.parent
1868         and leaf.parent.type == syms.comp_op
1869         and not (
1870             previous is not None
1871             and previous.type == token.NAME
1872             and previous.value == "is"
1873         )
1874     ):
1875         return COMPARATOR_PRIORITY
1876
1877     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1878         return LOGIC_PRIORITY
1879
1880     return 0
1881
1882
1883 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1884     """Clean the prefix of the `leaf` and generate comments from it, if any.
1885
1886     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1887     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1888     move because it does away with modifying the grammar to include all the
1889     possible places in which comments can be placed.
1890
1891     The sad consequence for us though is that comments don't "belong" anywhere.
1892     This is why this function generates simple parentless Leaf objects for
1893     comments.  We simply don't know what the correct parent should be.
1894
1895     No matter though, we can live without this.  We really only need to
1896     differentiate between inline and standalone comments.  The latter don't
1897     share the line with any code.
1898
1899     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1900     are emitted with a fake STANDALONE_COMMENT token identifier.
1901     """
1902     p = leaf.prefix
1903     if not p:
1904         return
1905
1906     if "#" not in p:
1907         return
1908
1909     consumed = 0
1910     nlines = 0
1911     for index, line in enumerate(p.split("\n")):
1912         consumed += len(line) + 1  # adding the length of the split '\n'
1913         line = line.lstrip()
1914         if not line:
1915             nlines += 1
1916         if not line.startswith("#"):
1917             continue
1918
1919         if index == 0 and leaf.type != token.ENDMARKER:
1920             comment_type = token.COMMENT  # simple trailing comment
1921         else:
1922             comment_type = STANDALONE_COMMENT
1923         comment = make_comment(line)
1924         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1925
1926         if comment in {"# fmt: on", "# yapf: enable"}:
1927             raise FormatOn(consumed)
1928
1929         if comment in {"# fmt: off", "# yapf: disable"}:
1930             if comment_type == STANDALONE_COMMENT:
1931                 raise FormatOff(consumed)
1932
1933             prev = preceding_leaf(leaf)
1934             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1935                 raise FormatOff(consumed)
1936
1937         nlines = 0
1938
1939
1940 def make_comment(content: str) -> str:
1941     """Return a consistently formatted comment from the given `content` string.
1942
1943     All comments (except for "##", "#!", "#:") should have a single space between
1944     the hash sign and the content.
1945
1946     If `content` didn't start with a hash sign, one is provided.
1947     """
1948     content = content.rstrip()
1949     if not content:
1950         return "#"
1951
1952     if content[0] == "#":
1953         content = content[1:]
1954     if content and content[0] not in " !:#":
1955         content = " " + content
1956     return "#" + content
1957
1958
1959 def split_line(
1960     line: Line, line_length: int, inner: bool = False, py36: bool = False
1961 ) -> Iterator[Line]:
1962     """Split a `line` into potentially many lines.
1963
1964     They should fit in the allotted `line_length` but might not be able to.
1965     `inner` signifies that there were a pair of brackets somewhere around the
1966     current `line`, possibly transitively. This means we can fallback to splitting
1967     by delimiters if the LHS/RHS don't yield any results.
1968
1969     If `py36` is True, splitting may generate syntax that is only compatible
1970     with Python 3.6 and later.
1971     """
1972     if isinstance(line, UnformattedLines) or line.is_comment:
1973         yield line
1974         return
1975
1976     line_str = str(line).strip("\n")
1977     if not line.should_explode and is_line_short_enough(
1978         line, line_length=line_length, line_str=line_str
1979     ):
1980         yield line
1981         return
1982
1983     split_funcs: List[SplitFunc]
1984     if line.is_def:
1985         split_funcs = [left_hand_split]
1986     else:
1987
1988         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1989             for omit in generate_trailers_to_omit(line, line_length):
1990                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
1991                 if is_line_short_enough(lines[0], line_length=line_length):
1992                     yield from lines
1993                     return
1994
1995             # All splits failed, best effort split with no omits.
1996             # This mostly happens to multiline strings that are by definition
1997             # reported as not fitting a single line.
1998             yield from right_hand_split(line, py36)
1999
2000         if line.inside_brackets:
2001             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2002         else:
2003             split_funcs = [rhs]
2004     for split_func in split_funcs:
2005         # We are accumulating lines in `result` because we might want to abort
2006         # mission and return the original line in the end, or attempt a different
2007         # split altogether.
2008         result: List[Line] = []
2009         try:
2010             for l in split_func(line, py36):
2011                 if str(l).strip("\n") == line_str:
2012                     raise CannotSplit("Split function returned an unchanged result")
2013
2014                 result.extend(
2015                     split_line(l, line_length=line_length, inner=True, py36=py36)
2016                 )
2017         except CannotSplit as cs:
2018             continue
2019
2020         else:
2021             yield from result
2022             break
2023
2024     else:
2025         yield line
2026
2027
2028 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2029     """Split line into many lines, starting with the first matching bracket pair.
2030
2031     Note: this usually looks weird, only use this for function definitions.
2032     Prefer RHS otherwise.  This is why this function is not symmetrical with
2033     :func:`right_hand_split` which also handles optional parentheses.
2034     """
2035     head = Line(depth=line.depth)
2036     body = Line(depth=line.depth + 1, inside_brackets=True)
2037     tail = Line(depth=line.depth)
2038     tail_leaves: List[Leaf] = []
2039     body_leaves: List[Leaf] = []
2040     head_leaves: List[Leaf] = []
2041     current_leaves = head_leaves
2042     matching_bracket = None
2043     for leaf in line.leaves:
2044         if (
2045             current_leaves is body_leaves
2046             and leaf.type in CLOSING_BRACKETS
2047             and leaf.opening_bracket is matching_bracket
2048         ):
2049             current_leaves = tail_leaves if body_leaves else head_leaves
2050         current_leaves.append(leaf)
2051         if current_leaves is head_leaves:
2052             if leaf.type in OPENING_BRACKETS:
2053                 matching_bracket = leaf
2054                 current_leaves = body_leaves
2055     # Since body is a new indent level, remove spurious leading whitespace.
2056     if body_leaves:
2057         normalize_prefix(body_leaves[0], inside_brackets=True)
2058     # Build the new lines.
2059     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2060         for leaf in leaves:
2061             result.append(leaf, preformatted=True)
2062             for comment_after in line.comments_after(leaf):
2063                 result.append(comment_after, preformatted=True)
2064     bracket_split_succeeded_or_raise(head, body, tail)
2065     for result in (head, body, tail):
2066         if result:
2067             yield result
2068
2069
2070 def right_hand_split(
2071     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2072 ) -> Iterator[Line]:
2073     """Split line into many lines, starting with the last matching bracket pair.
2074
2075     If the split was by optional parentheses, attempt splitting without them, too.
2076     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2077     this split.
2078
2079     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2080     """
2081     head = Line(depth=line.depth)
2082     body = Line(depth=line.depth + 1, inside_brackets=True)
2083     tail = Line(depth=line.depth)
2084     tail_leaves: List[Leaf] = []
2085     body_leaves: List[Leaf] = []
2086     head_leaves: List[Leaf] = []
2087     current_leaves = tail_leaves
2088     opening_bracket = None
2089     closing_bracket = None
2090     for leaf in reversed(line.leaves):
2091         if current_leaves is body_leaves:
2092             if leaf is opening_bracket:
2093                 current_leaves = head_leaves if body_leaves else tail_leaves
2094         current_leaves.append(leaf)
2095         if current_leaves is tail_leaves:
2096             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2097                 opening_bracket = leaf.opening_bracket
2098                 closing_bracket = leaf
2099                 current_leaves = body_leaves
2100     tail_leaves.reverse()
2101     body_leaves.reverse()
2102     head_leaves.reverse()
2103     # Since body is a new indent level, remove spurious leading whitespace.
2104     if body_leaves:
2105         normalize_prefix(body_leaves[0], inside_brackets=True)
2106     if not head_leaves:
2107         # No `head` means the split failed. Either `tail` has all content or
2108         # the matching `opening_bracket` wasn't available on `line` anymore.
2109         raise CannotSplit("No brackets found")
2110
2111     # Build the new lines.
2112     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2113         for leaf in leaves:
2114             result.append(leaf, preformatted=True)
2115             for comment_after in line.comments_after(leaf):
2116                 result.append(comment_after, preformatted=True)
2117     bracket_split_succeeded_or_raise(head, body, tail)
2118     assert opening_bracket and closing_bracket
2119     if (
2120         # the opening bracket is an optional paren
2121         opening_bracket.type == token.LPAR
2122         and not opening_bracket.value
2123         # the closing bracket is an optional paren
2124         and closing_bracket.type == token.RPAR
2125         and not closing_bracket.value
2126         # there are no standalone comments in the body
2127         and not line.contains_standalone_comments(0)
2128         # and it's not an import (optional parens are the only thing we can split
2129         # on in this case; attempting a split without them is a waste of time)
2130         and not line.is_import
2131     ):
2132         omit = {id(closing_bracket), *omit}
2133         if can_omit_invisible_parens(body, line_length):
2134             try:
2135                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2136                 return
2137             except CannotSplit:
2138                 pass
2139
2140     ensure_visible(opening_bracket)
2141     ensure_visible(closing_bracket)
2142     body.should_explode = should_explode(body, opening_bracket)
2143     for result in (head, body, tail):
2144         if result:
2145             yield result
2146
2147
2148 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2149     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2150
2151     Do nothing otherwise.
2152
2153     A left- or right-hand split is based on a pair of brackets. Content before
2154     (and including) the opening bracket is left on one line, content inside the
2155     brackets is put on a separate line, and finally content starting with and
2156     following the closing bracket is put on a separate line.
2157
2158     Those are called `head`, `body`, and `tail`, respectively. If the split
2159     produced the same line (all content in `head`) or ended up with an empty `body`
2160     and the `tail` is just the closing bracket, then it's considered failed.
2161     """
2162     tail_len = len(str(tail).strip())
2163     if not body:
2164         if tail_len == 0:
2165             raise CannotSplit("Splitting brackets produced the same line")
2166
2167         elif tail_len < 3:
2168             raise CannotSplit(
2169                 f"Splitting brackets on an empty body to save "
2170                 f"{tail_len} characters is not worth it"
2171             )
2172
2173
2174 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2175     """Normalize prefix of the first leaf in every line returned by `split_func`.
2176
2177     This is a decorator over relevant split functions.
2178     """
2179
2180     @wraps(split_func)
2181     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2182         for l in split_func(line, py36):
2183             normalize_prefix(l.leaves[0], inside_brackets=True)
2184             yield l
2185
2186     return split_wrapper
2187
2188
2189 @dont_increase_indentation
2190 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2191     """Split according to delimiters of the highest priority.
2192
2193     If `py36` is True, the split will add trailing commas also in function
2194     signatures that contain `*` and `**`.
2195     """
2196     try:
2197         last_leaf = line.leaves[-1]
2198     except IndexError:
2199         raise CannotSplit("Line empty")
2200
2201     bt = line.bracket_tracker
2202     try:
2203         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2204     except ValueError:
2205         raise CannotSplit("No delimiters found")
2206
2207     if delimiter_priority == DOT_PRIORITY:
2208         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2209             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2210
2211     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2212     lowest_depth = sys.maxsize
2213     trailing_comma_safe = True
2214
2215     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2216         """Append `leaf` to current line or to new line if appending impossible."""
2217         nonlocal current_line
2218         try:
2219             current_line.append_safe(leaf, preformatted=True)
2220         except ValueError as ve:
2221             yield current_line
2222
2223             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2224             current_line.append(leaf)
2225
2226     for index, leaf in enumerate(line.leaves):
2227         yield from append_to_line(leaf)
2228
2229         for comment_after in line.comments_after(leaf, index):
2230             yield from append_to_line(comment_after)
2231
2232         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2233         if leaf.bracket_depth == lowest_depth and is_vararg(
2234             leaf, within=VARARGS_PARENTS
2235         ):
2236             trailing_comma_safe = trailing_comma_safe and py36
2237         leaf_priority = bt.delimiters.get(id(leaf))
2238         if leaf_priority == delimiter_priority:
2239             yield current_line
2240
2241             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2242     if current_line:
2243         if (
2244             trailing_comma_safe
2245             and delimiter_priority == COMMA_PRIORITY
2246             and current_line.leaves[-1].type != token.COMMA
2247             and current_line.leaves[-1].type != STANDALONE_COMMENT
2248         ):
2249             current_line.append(Leaf(token.COMMA, ","))
2250         yield current_line
2251
2252
2253 @dont_increase_indentation
2254 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2255     """Split standalone comments from the rest of the line."""
2256     if not line.contains_standalone_comments(0):
2257         raise CannotSplit("Line does not have any standalone comments")
2258
2259     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2260
2261     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2262         """Append `leaf` to current line or to new line if appending impossible."""
2263         nonlocal current_line
2264         try:
2265             current_line.append_safe(leaf, preformatted=True)
2266         except ValueError as ve:
2267             yield current_line
2268
2269             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2270             current_line.append(leaf)
2271
2272     for index, leaf in enumerate(line.leaves):
2273         yield from append_to_line(leaf)
2274
2275         for comment_after in line.comments_after(leaf, index):
2276             yield from append_to_line(comment_after)
2277
2278     if current_line:
2279         yield current_line
2280
2281
2282 def is_import(leaf: Leaf) -> bool:
2283     """Return True if the given leaf starts an import statement."""
2284     p = leaf.parent
2285     t = leaf.type
2286     v = leaf.value
2287     return bool(
2288         t == token.NAME
2289         and (
2290             (v == "import" and p and p.type == syms.import_name)
2291             or (v == "from" and p and p.type == syms.import_from)
2292         )
2293     )
2294
2295
2296 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2297     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2298     else.
2299
2300     Note: don't use backslashes for formatting or you'll lose your voting rights.
2301     """
2302     if not inside_brackets:
2303         spl = leaf.prefix.split("#")
2304         if "\\" not in spl[0]:
2305             nl_count = spl[-1].count("\n")
2306             if len(spl) > 1:
2307                 nl_count -= 1
2308             leaf.prefix = "\n" * nl_count
2309             return
2310
2311     leaf.prefix = ""
2312
2313
2314 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2315     """Make all string prefixes lowercase.
2316
2317     If remove_u_prefix is given, also removes any u prefix from the string.
2318
2319     Note: Mutates its argument.
2320     """
2321     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2322     assert match is not None, f"failed to match string {leaf.value!r}"
2323     orig_prefix = match.group(1)
2324     new_prefix = orig_prefix.lower()
2325     if remove_u_prefix:
2326         new_prefix = new_prefix.replace("u", "")
2327     leaf.value = f"{new_prefix}{match.group(2)}"
2328
2329
2330 def normalize_string_quotes(leaf: Leaf) -> None:
2331     """Prefer double quotes but only if it doesn't cause more escaping.
2332
2333     Adds or removes backslashes as appropriate. Doesn't parse and fix
2334     strings nested in f-strings (yet).
2335
2336     Note: Mutates its argument.
2337     """
2338     value = leaf.value.lstrip("furbFURB")
2339     if value[:3] == '"""':
2340         return
2341
2342     elif value[:3] == "'''":
2343         orig_quote = "'''"
2344         new_quote = '"""'
2345     elif value[0] == '"':
2346         orig_quote = '"'
2347         new_quote = "'"
2348     else:
2349         orig_quote = "'"
2350         new_quote = '"'
2351     first_quote_pos = leaf.value.find(orig_quote)
2352     if first_quote_pos == -1:
2353         return  # There's an internal error
2354
2355     prefix = leaf.value[:first_quote_pos]
2356     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2357     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2358     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2359     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2360     if "r" in prefix.casefold():
2361         if unescaped_new_quote.search(body):
2362             # There's at least one unescaped new_quote in this raw string
2363             # so converting is impossible
2364             return
2365
2366         # Do not introduce or remove backslashes in raw strings
2367         new_body = body
2368     else:
2369         # remove unnecessary quotes
2370         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2371         if body != new_body:
2372             # Consider the string without unnecessary quotes as the original
2373             body = new_body
2374             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2375         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2376         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2377     if new_quote == '"""' and new_body[-1] == '"':
2378         # edge case:
2379         new_body = new_body[:-1] + '\\"'
2380     orig_escape_count = body.count("\\")
2381     new_escape_count = new_body.count("\\")
2382     if new_escape_count > orig_escape_count:
2383         return  # Do not introduce more escaping
2384
2385     if new_escape_count == orig_escape_count and orig_quote == '"':
2386         return  # Prefer double quotes
2387
2388     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2389
2390
2391 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2392     """Make existing optional parentheses invisible or create new ones.
2393
2394     `parens_after` is a set of string leaf values immeditely after which parens
2395     should be put.
2396
2397     Standardizes on visible parentheses for single-element tuples, and keeps
2398     existing visible parentheses for other tuples and generator expressions.
2399     """
2400     try:
2401         list(generate_comments(node))
2402     except FormatOff:
2403         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2404
2405     check_lpar = False
2406     for index, child in enumerate(list(node.children)):
2407         if check_lpar:
2408             if child.type == syms.atom:
2409                 maybe_make_parens_invisible_in_atom(child)
2410             elif is_one_tuple(child):
2411                 # wrap child in visible parentheses
2412                 lpar = Leaf(token.LPAR, "(")
2413                 rpar = Leaf(token.RPAR, ")")
2414                 child.remove()
2415                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2416             elif node.type == syms.import_from:
2417                 # "import from" nodes store parentheses directly as part of
2418                 # the statement
2419                 if child.type == token.LPAR:
2420                     # make parentheses invisible
2421                     child.value = ""  # type: ignore
2422                     node.children[-1].value = ""  # type: ignore
2423                 elif child.type != token.STAR:
2424                     # insert invisible parentheses
2425                     node.insert_child(index, Leaf(token.LPAR, ""))
2426                     node.append_child(Leaf(token.RPAR, ""))
2427                 break
2428
2429             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2430                 # wrap child in invisible parentheses
2431                 lpar = Leaf(token.LPAR, "")
2432                 rpar = Leaf(token.RPAR, "")
2433                 index = child.remove() or 0
2434                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2435
2436         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2437
2438
2439 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2440     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2441     if (
2442         node.type != syms.atom
2443         or is_empty_tuple(node)
2444         or is_one_tuple(node)
2445         or is_yield(node)
2446         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2447     ):
2448         return False
2449
2450     first = node.children[0]
2451     last = node.children[-1]
2452     if first.type == token.LPAR and last.type == token.RPAR:
2453         # make parentheses invisible
2454         first.value = ""  # type: ignore
2455         last.value = ""  # type: ignore
2456         if len(node.children) > 1:
2457             maybe_make_parens_invisible_in_atom(node.children[1])
2458         return True
2459
2460     return False
2461
2462
2463 def is_empty_tuple(node: LN) -> bool:
2464     """Return True if `node` holds an empty tuple."""
2465     return (
2466         node.type == syms.atom
2467         and len(node.children) == 2
2468         and node.children[0].type == token.LPAR
2469         and node.children[1].type == token.RPAR
2470     )
2471
2472
2473 def is_one_tuple(node: LN) -> bool:
2474     """Return True if `node` holds a tuple with one element, with or without parens."""
2475     if node.type == syms.atom:
2476         if len(node.children) != 3:
2477             return False
2478
2479         lpar, gexp, rpar = node.children
2480         if not (
2481             lpar.type == token.LPAR
2482             and gexp.type == syms.testlist_gexp
2483             and rpar.type == token.RPAR
2484         ):
2485             return False
2486
2487         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2488
2489     return (
2490         node.type in IMPLICIT_TUPLE
2491         and len(node.children) == 2
2492         and node.children[1].type == token.COMMA
2493     )
2494
2495
2496 def is_yield(node: LN) -> bool:
2497     """Return True if `node` holds a `yield` or `yield from` expression."""
2498     if node.type == syms.yield_expr:
2499         return True
2500
2501     if node.type == token.NAME and node.value == "yield":  # type: ignore
2502         return True
2503
2504     if node.type != syms.atom:
2505         return False
2506
2507     if len(node.children) != 3:
2508         return False
2509
2510     lpar, expr, rpar = node.children
2511     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2512         return is_yield(expr)
2513
2514     return False
2515
2516
2517 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2518     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2519
2520     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2521     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2522     extended iterable unpacking (PEP 3132) and additional unpacking
2523     generalizations (PEP 448).
2524     """
2525     if leaf.type not in STARS or not leaf.parent:
2526         return False
2527
2528     p = leaf.parent
2529     if p.type == syms.star_expr:
2530         # Star expressions are also used as assignment targets in extended
2531         # iterable unpacking (PEP 3132).  See what its parent is instead.
2532         if not p.parent:
2533             return False
2534
2535         p = p.parent
2536
2537     return p.type in within
2538
2539
2540 def is_multiline_string(leaf: Leaf) -> bool:
2541     """Return True if `leaf` is a multiline string that actually spans many lines."""
2542     value = leaf.value.lstrip("furbFURB")
2543     return value[:3] in {'"""', "'''"} and "\n" in value
2544
2545
2546 def is_stub_suite(node: Node) -> bool:
2547     """Return True if `node` is a suite with a stub body."""
2548     if (
2549         len(node.children) != 4
2550         or node.children[0].type != token.NEWLINE
2551         or node.children[1].type != token.INDENT
2552         or node.children[3].type != token.DEDENT
2553     ):
2554         return False
2555
2556     return is_stub_body(node.children[2])
2557
2558
2559 def is_stub_body(node: LN) -> bool:
2560     """Return True if `node` is a simple statement containing an ellipsis."""
2561     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2562         return False
2563
2564     if len(node.children) != 2:
2565         return False
2566
2567     child = node.children[0]
2568     return (
2569         child.type == syms.atom
2570         and len(child.children) == 3
2571         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2572     )
2573
2574
2575 def max_delimiter_priority_in_atom(node: LN) -> int:
2576     """Return maximum delimiter priority inside `node`.
2577
2578     This is specific to atoms with contents contained in a pair of parentheses.
2579     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2580     """
2581     if node.type != syms.atom:
2582         return 0
2583
2584     first = node.children[0]
2585     last = node.children[-1]
2586     if not (first.type == token.LPAR and last.type == token.RPAR):
2587         return 0
2588
2589     bt = BracketTracker()
2590     for c in node.children[1:-1]:
2591         if isinstance(c, Leaf):
2592             bt.mark(c)
2593         else:
2594             for leaf in c.leaves():
2595                 bt.mark(leaf)
2596     try:
2597         return bt.max_delimiter_priority()
2598
2599     except ValueError:
2600         return 0
2601
2602
2603 def ensure_visible(leaf: Leaf) -> None:
2604     """Make sure parentheses are visible.
2605
2606     They could be invisible as part of some statements (see
2607     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2608     """
2609     if leaf.type == token.LPAR:
2610         leaf.value = "("
2611     elif leaf.type == token.RPAR:
2612         leaf.value = ")"
2613
2614
2615 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2616     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2617     if not (
2618         opening_bracket.parent
2619         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2620         and opening_bracket.value in "[{("
2621     ):
2622         return False
2623
2624     try:
2625         last_leaf = line.leaves[-1]
2626         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2627         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2628     except (IndexError, ValueError):
2629         return False
2630
2631     return max_priority == COMMA_PRIORITY
2632
2633
2634 def is_python36(node: Node) -> bool:
2635     """Return True if the current file is using Python 3.6+ features.
2636
2637     Currently looking for:
2638     - f-strings; and
2639     - trailing commas after * or ** in function signatures and calls.
2640     """
2641     for n in node.pre_order():
2642         if n.type == token.STRING:
2643             value_head = n.value[:2]  # type: ignore
2644             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2645                 return True
2646
2647         elif (
2648             n.type in {syms.typedargslist, syms.arglist}
2649             and n.children
2650             and n.children[-1].type == token.COMMA
2651         ):
2652             for ch in n.children:
2653                 if ch.type in STARS:
2654                     return True
2655
2656                 if ch.type == syms.argument:
2657                     for argch in ch.children:
2658                         if argch.type in STARS:
2659                             return True
2660
2661     return False
2662
2663
2664 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2665     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2666
2667     Brackets can be omitted if the entire trailer up to and including
2668     a preceding closing bracket fits in one line.
2669
2670     Yielded sets are cumulative (contain results of previous yields, too).  First
2671     set is empty.
2672     """
2673
2674     omit: Set[LeafID] = set()
2675     yield omit
2676
2677     length = 4 * line.depth
2678     opening_bracket = None
2679     closing_bracket = None
2680     optional_brackets: Set[LeafID] = set()
2681     inner_brackets: Set[LeafID] = set()
2682     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2683         length += leaf_length
2684         if length > line_length:
2685             break
2686
2687         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2688         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2689             break
2690
2691         optional_brackets.discard(id(leaf))
2692         if opening_bracket:
2693             if leaf is opening_bracket:
2694                 opening_bracket = None
2695             elif leaf.type in CLOSING_BRACKETS:
2696                 inner_brackets.add(id(leaf))
2697         elif leaf.type in CLOSING_BRACKETS:
2698             if not leaf.value:
2699                 optional_brackets.add(id(opening_bracket))
2700                 continue
2701
2702             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2703                 # Empty brackets would fail a split so treat them as "inner"
2704                 # brackets (e.g. only add them to the `omit` set if another
2705                 # pair of brackets was good enough.
2706                 inner_brackets.add(id(leaf))
2707                 continue
2708
2709             opening_bracket = leaf.opening_bracket
2710             if closing_bracket:
2711                 omit.add(id(closing_bracket))
2712                 omit.update(inner_brackets)
2713                 inner_brackets.clear()
2714                 yield omit
2715             closing_bracket = leaf
2716
2717
2718 def get_future_imports(node: Node) -> Set[str]:
2719     """Return a set of __future__ imports in the file."""
2720     imports = set()
2721     for child in node.children:
2722         if child.type != syms.simple_stmt:
2723             break
2724         first_child = child.children[0]
2725         if isinstance(first_child, Leaf):
2726             # Continue looking if we see a docstring; otherwise stop.
2727             if (
2728                 len(child.children) == 2
2729                 and first_child.type == token.STRING
2730                 and child.children[1].type == token.NEWLINE
2731             ):
2732                 continue
2733             else:
2734                 break
2735         elif first_child.type == syms.import_from:
2736             module_name = first_child.children[1]
2737             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2738                 break
2739             for import_from_child in first_child.children[3:]:
2740                 if isinstance(import_from_child, Leaf):
2741                     if import_from_child.type == token.NAME:
2742                         imports.add(import_from_child.value)
2743                 else:
2744                     assert import_from_child.type == syms.import_as_names
2745                     for leaf in import_from_child.children:
2746                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2747                             imports.add(leaf.value)
2748         else:
2749             break
2750     return imports
2751
2752
2753 PYTHON_EXTENSIONS = {".py", ".pyi"}
2754 BLACKLISTED_DIRECTORIES = {
2755     "build",
2756     "buck-out",
2757     "dist",
2758     "_build",
2759     ".git",
2760     ".hg",
2761     ".mypy_cache",
2762     ".tox",
2763     ".venv",
2764 }
2765
2766
2767 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2768     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2769     and have one of the PYTHON_EXTENSIONS.
2770     """
2771     for child in path.iterdir():
2772         if child.is_dir():
2773             if child.name in BLACKLISTED_DIRECTORIES:
2774                 continue
2775
2776             yield from gen_python_files_in_dir(child)
2777
2778         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2779             yield child
2780
2781
2782 @dataclass
2783 class Report:
2784     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2785
2786     check: bool = False
2787     quiet: bool = False
2788     change_count: int = 0
2789     same_count: int = 0
2790     failure_count: int = 0
2791
2792     def done(self, src: Path, changed: Changed) -> None:
2793         """Increment the counter for successful reformatting. Write out a message."""
2794         if changed is Changed.YES:
2795             reformatted = "would reformat" if self.check else "reformatted"
2796             if not self.quiet:
2797                 out(f"{reformatted} {src}")
2798             self.change_count += 1
2799         else:
2800             if not self.quiet:
2801                 if changed is Changed.NO:
2802                     msg = f"{src} already well formatted, good job."
2803                 else:
2804                     msg = f"{src} wasn't modified on disk since last run."
2805                 out(msg, bold=False)
2806             self.same_count += 1
2807
2808     def failed(self, src: Path, message: str) -> None:
2809         """Increment the counter for failed reformatting. Write out a message."""
2810         err(f"error: cannot format {src}: {message}")
2811         self.failure_count += 1
2812
2813     @property
2814     def return_code(self) -> int:
2815         """Return the exit code that the app should use.
2816
2817         This considers the current state of changed files and failures:
2818         - if there were any failures, return 123;
2819         - if any files were changed and --check is being used, return 1;
2820         - otherwise return 0.
2821         """
2822         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2823         # 126 we have special returncodes reserved by the shell.
2824         if self.failure_count:
2825             return 123
2826
2827         elif self.change_count and self.check:
2828             return 1
2829
2830         return 0
2831
2832     def __str__(self) -> str:
2833         """Render a color report of the current state.
2834
2835         Use `click.unstyle` to remove colors.
2836         """
2837         if self.check:
2838             reformatted = "would be reformatted"
2839             unchanged = "would be left unchanged"
2840             failed = "would fail to reformat"
2841         else:
2842             reformatted = "reformatted"
2843             unchanged = "left unchanged"
2844             failed = "failed to reformat"
2845         report = []
2846         if self.change_count:
2847             s = "s" if self.change_count > 1 else ""
2848             report.append(
2849                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2850             )
2851         if self.same_count:
2852             s = "s" if self.same_count > 1 else ""
2853             report.append(f"{self.same_count} file{s} {unchanged}")
2854         if self.failure_count:
2855             s = "s" if self.failure_count > 1 else ""
2856             report.append(
2857                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2858             )
2859         return ", ".join(report) + "."
2860
2861
2862 def assert_equivalent(src: str, dst: str) -> None:
2863     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2864
2865     import ast
2866     import traceback
2867
2868     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2869         """Simple visitor generating strings to compare ASTs by content."""
2870         yield f"{'  ' * depth}{node.__class__.__name__}("
2871
2872         for field in sorted(node._fields):
2873             try:
2874                 value = getattr(node, field)
2875             except AttributeError:
2876                 continue
2877
2878             yield f"{'  ' * (depth+1)}{field}="
2879
2880             if isinstance(value, list):
2881                 for item in value:
2882                     if isinstance(item, ast.AST):
2883                         yield from _v(item, depth + 2)
2884
2885             elif isinstance(value, ast.AST):
2886                 yield from _v(value, depth + 2)
2887
2888             else:
2889                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2890
2891         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2892
2893     try:
2894         src_ast = ast.parse(src)
2895     except Exception as exc:
2896         major, minor = sys.version_info[:2]
2897         raise AssertionError(
2898             f"cannot use --safe with this file; failed to parse source file "
2899             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2900             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2901         )
2902
2903     try:
2904         dst_ast = ast.parse(dst)
2905     except Exception as exc:
2906         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2907         raise AssertionError(
2908             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2909             f"Please report a bug on https://github.com/ambv/black/issues.  "
2910             f"This invalid output might be helpful: {log}"
2911         ) from None
2912
2913     src_ast_str = "\n".join(_v(src_ast))
2914     dst_ast_str = "\n".join(_v(dst_ast))
2915     if src_ast_str != dst_ast_str:
2916         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2917         raise AssertionError(
2918             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2919             f"the source.  "
2920             f"Please report a bug on https://github.com/ambv/black/issues.  "
2921             f"This diff might be helpful: {log}"
2922         ) from None
2923
2924
2925 def assert_stable(
2926     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
2927 ) -> None:
2928     """Raise AssertionError if `dst` reformats differently the second time."""
2929     newdst = format_str(dst, line_length=line_length, mode=mode)
2930     if dst != newdst:
2931         log = dump_to_file(
2932             diff(src, dst, "source", "first pass"),
2933             diff(dst, newdst, "first pass", "second pass"),
2934         )
2935         raise AssertionError(
2936             f"INTERNAL ERROR: Black produced different code on the second pass "
2937             f"of the formatter.  "
2938             f"Please report a bug on https://github.com/ambv/black/issues.  "
2939             f"This diff might be helpful: {log}"
2940         ) from None
2941
2942
2943 def dump_to_file(*output: str) -> str:
2944     """Dump `output` to a temporary file. Return path to the file."""
2945     import tempfile
2946
2947     with tempfile.NamedTemporaryFile(
2948         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2949     ) as f:
2950         for lines in output:
2951             f.write(lines)
2952             if lines and lines[-1] != "\n":
2953                 f.write("\n")
2954     return f.name
2955
2956
2957 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2958     """Return a unified diff string between strings `a` and `b`."""
2959     import difflib
2960
2961     a_lines = [line + "\n" for line in a.split("\n")]
2962     b_lines = [line + "\n" for line in b.split("\n")]
2963     return "".join(
2964         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2965     )
2966
2967
2968 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2969     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2970     err("Aborted!")
2971     for task in tasks:
2972         task.cancel()
2973
2974
2975 def shutdown(loop: BaseEventLoop) -> None:
2976     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2977     try:
2978         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2979         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2980         if not to_cancel:
2981             return
2982
2983         for task in to_cancel:
2984             task.cancel()
2985         loop.run_until_complete(
2986             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2987         )
2988     finally:
2989         # `concurrent.futures.Future` objects cannot be cancelled once they
2990         # are already running. There might be some when the `shutdown()` happened.
2991         # Silence their logger's spew about the event loop being closed.
2992         cf_logger = logging.getLogger("concurrent.futures")
2993         cf_logger.setLevel(logging.CRITICAL)
2994         loop.close()
2995
2996
2997 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2998     """Replace `regex` with `replacement` twice on `original`.
2999
3000     This is used by string normalization to perform replaces on
3001     overlapping matches.
3002     """
3003     return regex.sub(replacement, regex.sub(replacement, original))
3004
3005
3006 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3007     """Like `reversed(enumerate(sequence))` if that were possible."""
3008     index = len(sequence) - 1
3009     for element in reversed(sequence):
3010         yield (index, element)
3011         index -= 1
3012
3013
3014 def enumerate_with_length(
3015     line: Line, reversed: bool = False
3016 ) -> Iterator[Tuple[Index, Leaf, int]]:
3017     """Return an enumeration of leaves with their length.
3018
3019     Stops prematurely on multiline strings and standalone comments.
3020     """
3021     op = cast(
3022         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3023         enumerate_reversed if reversed else enumerate,
3024     )
3025     for index, leaf in op(line.leaves):
3026         length = len(leaf.prefix) + len(leaf.value)
3027         if "\n" in leaf.value:
3028             return  # Multiline strings, we can't continue.
3029
3030         comment: Optional[Leaf]
3031         for comment in line.comments_after(leaf, index):
3032             length += len(comment.value)
3033
3034         yield index, leaf, length
3035
3036
3037 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3038     """Return True if `line` is no longer than `line_length`.
3039
3040     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3041     """
3042     if not line_str:
3043         line_str = str(line).strip("\n")
3044     return (
3045         len(line_str) <= line_length
3046         and "\n" not in line_str  # multiline strings
3047         and not line.contains_standalone_comments()
3048     )
3049
3050
3051 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3052     """Does `line` have a shape safe to reformat without optional parens around it?
3053
3054     Returns True for only a subset of potentially nice looking formattings but
3055     the point is to not return false positives that end up producing lines that
3056     are too long.
3057     """
3058     bt = line.bracket_tracker
3059     if not bt.delimiters:
3060         # Without delimiters the optional parentheses are useless.
3061         return True
3062
3063     max_priority = bt.max_delimiter_priority()
3064     if bt.delimiter_count_with_priority(max_priority) > 1:
3065         # With more than one delimiter of a kind the optional parentheses read better.
3066         return False
3067
3068     if max_priority == DOT_PRIORITY:
3069         # A single stranded method call doesn't require optional parentheses.
3070         return True
3071
3072     assert len(line.leaves) >= 2, "Stranded delimiter"
3073
3074     first = line.leaves[0]
3075     second = line.leaves[1]
3076     penultimate = line.leaves[-2]
3077     last = line.leaves[-1]
3078
3079     # With a single delimiter, omit if the expression starts or ends with
3080     # a bracket.
3081     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3082         remainder = False
3083         length = 4 * line.depth
3084         for _index, leaf, leaf_length in enumerate_with_length(line):
3085             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3086                 remainder = True
3087             if remainder:
3088                 length += leaf_length
3089                 if length > line_length:
3090                     break
3091
3092                 if leaf.type in OPENING_BRACKETS:
3093                     # There are brackets we can further split on.
3094                     remainder = False
3095
3096         else:
3097             # checked the entire string and line length wasn't exceeded
3098             if len(line.leaves) == _index + 1:
3099                 return True
3100
3101         # Note: we are not returning False here because a line might have *both*
3102         # a leading opening bracket and a trailing closing bracket.  If the
3103         # opening bracket doesn't match our rule, maybe the closing will.
3104
3105     if (
3106         last.type == token.RPAR
3107         or last.type == token.RBRACE
3108         or (
3109             # don't use indexing for omitting optional parentheses;
3110             # it looks weird
3111             last.type == token.RSQB
3112             and last.parent
3113             and last.parent.type != syms.trailer
3114         )
3115     ):
3116         if penultimate.type in OPENING_BRACKETS:
3117             # Empty brackets don't help.
3118             return False
3119
3120         if is_multiline_string(first):
3121             # Additional wrapping of a multiline string in this situation is
3122             # unnecessary.
3123             return True
3124
3125         length = 4 * line.depth
3126         seen_other_brackets = False
3127         for _index, leaf, leaf_length in enumerate_with_length(line):
3128             length += leaf_length
3129             if leaf is last.opening_bracket:
3130                 if seen_other_brackets or length <= line_length:
3131                     return True
3132
3133             elif leaf.type in OPENING_BRACKETS:
3134                 # There are brackets we can further split on.
3135                 seen_other_brackets = True
3136
3137     return False
3138
3139
3140 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3141     pyi = bool(mode & FileMode.PYI)
3142     py36 = bool(mode & FileMode.PYTHON36)
3143     return (
3144         CACHE_DIR
3145         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3146     )
3147
3148
3149 def read_cache(line_length: int, mode: FileMode) -> Cache:
3150     """Read the cache if it exists and is well formed.
3151
3152     If it is not well formed, the call to write_cache later should resolve the issue.
3153     """
3154     cache_file = get_cache_file(line_length, mode)
3155     if not cache_file.exists():
3156         return {}
3157
3158     with cache_file.open("rb") as fobj:
3159         try:
3160             cache: Cache = pickle.load(fobj)
3161         except pickle.UnpicklingError:
3162             return {}
3163
3164     return cache
3165
3166
3167 def get_cache_info(path: Path) -> CacheInfo:
3168     """Return the information used to check if a file is already formatted or not."""
3169     stat = path.stat()
3170     return stat.st_mtime, stat.st_size
3171
3172
3173 def filter_cached(
3174     cache: Cache, sources: Iterable[Path]
3175 ) -> Tuple[List[Path], List[Path]]:
3176     """Split a list of paths into two.
3177
3178     The first list contains paths of files that modified on disk or are not in the
3179     cache. The other list contains paths to non-modified files.
3180     """
3181     todo, done = [], []
3182     for src in sources:
3183         src = src.resolve()
3184         if cache.get(src) != get_cache_info(src):
3185             todo.append(src)
3186         else:
3187             done.append(src)
3188     return todo, done
3189
3190
3191 def write_cache(
3192     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3193 ) -> None:
3194     """Update the cache file."""
3195     cache_file = get_cache_file(line_length, mode)
3196     try:
3197         if not CACHE_DIR.exists():
3198             CACHE_DIR.mkdir(parents=True)
3199         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3200         with cache_file.open("wb") as fobj:
3201             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3202     except OSError:
3203         pass
3204
3205
3206 if __name__ == "__main__":
3207     main()