]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Improve doc regarding PyCharm keyboard shortcut (#271)
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b1"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 class FileMode(Flag):
126     AUTO_DETECT = 0
127     PYTHON36 = 1
128     PYI = 2
129
130
131 @click.command()
132 @click.option(
133     "-l",
134     "--line-length",
135     type=int,
136     default=DEFAULT_LINE_LENGTH,
137     help="How many character per line to allow.",
138     show_default=True,
139 )
140 @click.option(
141     "--check",
142     is_flag=True,
143     help=(
144         "Don't write the files back, just return the status.  Return code 0 "
145         "means nothing would change.  Return code 1 means some files would be "
146         "reformatted.  Return code 123 means there was an internal error."
147     ),
148 )
149 @click.option(
150     "--diff",
151     is_flag=True,
152     help="Don't write the files back, just output a diff for each file on stdout.",
153 )
154 @click.option(
155     "--fast/--safe",
156     is_flag=True,
157     help="If --fast given, skip temporary sanity checks. [default: --safe]",
158 )
159 @click.option(
160     "-q",
161     "--quiet",
162     is_flag=True,
163     help=(
164         "Don't emit non-error messages to stderr. Errors are still emitted, "
165         "silence those with 2>/dev/null."
166     ),
167 )
168 @click.option(
169     "--pyi",
170     is_flag=True,
171     help=(
172         "Consider all input files typing stubs regardless of file extension "
173         "(useful when piping source on standard input)."
174     ),
175 )
176 @click.option(
177     "--py36",
178     is_flag=True,
179     help=(
180         "Allow using Python 3.6-only syntax on all input files.  This will put "
181         "trailing commas in function signatures and calls also after *args and "
182         "**kwargs.  [default: per-file auto-detection]"
183     ),
184 )
185 @click.version_option(version=__version__)
186 @click.argument(
187     "src",
188     nargs=-1,
189     type=click.Path(
190         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
191     ),
192 )
193 @click.pass_context
194 def main(
195     ctx: click.Context,
196     line_length: int,
197     check: bool,
198     diff: bool,
199     fast: bool,
200     pyi: bool,
201     py36: bool,
202     quiet: bool,
203     src: List[str],
204 ) -> None:
205     """The uncompromising code formatter."""
206     sources: List[Path] = []
207     for s in src:
208         p = Path(s)
209         if p.is_dir():
210             sources.extend(gen_python_files_in_dir(p))
211         elif p.is_file():
212             # if a file was explicitly given, we don't care about its extension
213             sources.append(p)
214         elif s == "-":
215             sources.append(Path("-"))
216         else:
217             err(f"invalid path: {s}")
218
219     if check and not diff:
220         write_back = WriteBack.NO
221     elif diff:
222         write_back = WriteBack.DIFF
223     else:
224         write_back = WriteBack.YES
225     mode = FileMode.AUTO_DETECT
226     if py36:
227         mode |= FileMode.PYTHON36
228     if pyi:
229         mode |= FileMode.PYI
230     report = Report(check=check, quiet=quiet)
231     if len(sources) == 0:
232         out("No paths given. Nothing to do 😴")
233         ctx.exit(0)
234         return
235
236     elif len(sources) == 1:
237         reformat_one(
238             src=sources[0],
239             line_length=line_length,
240             fast=fast,
241             write_back=write_back,
242             mode=mode,
243             report=report,
244         )
245     else:
246         loop = asyncio.get_event_loop()
247         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
248         try:
249             loop.run_until_complete(
250                 schedule_formatting(
251                     sources=sources,
252                     line_length=line_length,
253                     fast=fast,
254                     write_back=write_back,
255                     mode=mode,
256                     report=report,
257                     loop=loop,
258                     executor=executor,
259                 )
260             )
261         finally:
262             shutdown(loop)
263         if not quiet:
264             out("All done! ✨ 🍰 ✨")
265             click.echo(str(report))
266     ctx.exit(report.return_code)
267
268
269 def reformat_one(
270     src: Path,
271     line_length: int,
272     fast: bool,
273     write_back: WriteBack,
274     mode: FileMode,
275     report: "Report",
276 ) -> None:
277     """Reformat a single file under `src` without spawning child processes.
278
279     If `quiet` is True, non-error messages are not output. `line_length`,
280     `write_back`, `fast` and `pyi` options are passed to
281     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
282     """
283     try:
284         changed = Changed.NO
285         if not src.is_file() and str(src) == "-":
286             if format_stdin_to_stdout(
287                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
288             ):
289                 changed = Changed.YES
290         else:
291             cache: Cache = {}
292             if write_back != WriteBack.DIFF:
293                 cache = read_cache(line_length, mode)
294                 src = src.resolve()
295                 if src in cache and cache[src] == get_cache_info(src):
296                     changed = Changed.CACHED
297             if changed is not Changed.CACHED and format_file_in_place(
298                 src,
299                 line_length=line_length,
300                 fast=fast,
301                 write_back=write_back,
302                 mode=mode,
303             ):
304                 changed = Changed.YES
305             if write_back == WriteBack.YES and changed is not Changed.NO:
306                 write_cache(cache, [src], line_length, mode)
307         report.done(src, changed)
308     except Exception as exc:
309         report.failed(src, str(exc))
310
311
312 async def schedule_formatting(
313     sources: List[Path],
314     line_length: int,
315     fast: bool,
316     write_back: WriteBack,
317     mode: FileMode,
318     report: "Report",
319     loop: BaseEventLoop,
320     executor: Executor,
321 ) -> None:
322     """Run formatting of `sources` in parallel using the provided `executor`.
323
324     (Use ProcessPoolExecutors for actual parallelism.)
325
326     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
327     :func:`format_file_in_place`.
328     """
329     cache: Cache = {}
330     if write_back != WriteBack.DIFF:
331         cache = read_cache(line_length, mode)
332         sources, cached = filter_cached(cache, sources)
333         for src in cached:
334             report.done(src, Changed.CACHED)
335     cancelled = []
336     formatted = []
337     if sources:
338         lock = None
339         if write_back == WriteBack.DIFF:
340             # For diff output, we need locks to ensure we don't interleave output
341             # from different processes.
342             manager = Manager()
343             lock = manager.Lock()
344         tasks = {
345             loop.run_in_executor(
346                 executor,
347                 format_file_in_place,
348                 src,
349                 line_length,
350                 fast,
351                 write_back,
352                 mode,
353                 lock,
354             ): src
355             for src in sorted(sources)
356         }
357         pending: Iterable[asyncio.Task] = tasks.keys()
358         try:
359             loop.add_signal_handler(signal.SIGINT, cancel, pending)
360             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
361         except NotImplementedError:
362             # There are no good alternatives for these on Windows
363             pass
364         while pending:
365             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
366             for task in done:
367                 src = tasks.pop(task)
368                 if task.cancelled():
369                     cancelled.append(task)
370                 elif task.exception():
371                     report.failed(src, str(task.exception()))
372                 else:
373                     formatted.append(src)
374                     report.done(src, Changed.YES if task.result() else Changed.NO)
375     if cancelled:
376         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
377     if write_back == WriteBack.YES and formatted:
378         write_cache(cache, formatted, line_length, mode)
379
380
381 def format_file_in_place(
382     src: Path,
383     line_length: int,
384     fast: bool,
385     write_back: WriteBack = WriteBack.NO,
386     mode: FileMode = FileMode.AUTO_DETECT,
387     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
388 ) -> bool:
389     """Format file under `src` path. Return True if changed.
390
391     If `write_back` is True, write reformatted code back to stdout.
392     `line_length` and `fast` options are passed to :func:`format_file_contents`.
393     """
394     if src.suffix == ".pyi":
395         mode |= FileMode.PYI
396     with tokenize.open(src) as src_buffer:
397         src_contents = src_buffer.read()
398     try:
399         dst_contents = format_file_contents(
400             src_contents, line_length=line_length, fast=fast, mode=mode
401         )
402     except NothingChanged:
403         return False
404
405     if write_back == write_back.YES:
406         with open(src, "w", encoding=src_buffer.encoding) as f:
407             f.write(dst_contents)
408     elif write_back == write_back.DIFF:
409         src_name = f"{src}  (original)"
410         dst_name = f"{src}  (formatted)"
411         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
412         if lock:
413             lock.acquire()
414         try:
415             sys.stdout.write(diff_contents)
416         finally:
417             if lock:
418                 lock.release()
419     return True
420
421
422 def format_stdin_to_stdout(
423     line_length: int,
424     fast: bool,
425     write_back: WriteBack = WriteBack.NO,
426     mode: FileMode = FileMode.AUTO_DETECT,
427 ) -> bool:
428     """Format file on stdin. Return True if changed.
429
430     If `write_back` is True, write reformatted code back to stdout.
431     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
432     :func:`format_file_contents`.
433     """
434     src = sys.stdin.read()
435     dst = src
436     try:
437         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
438         return True
439
440     except NothingChanged:
441         return False
442
443     finally:
444         if write_back == WriteBack.YES:
445             sys.stdout.write(dst)
446         elif write_back == WriteBack.DIFF:
447             src_name = "<stdin>  (original)"
448             dst_name = "<stdin>  (formatted)"
449             sys.stdout.write(diff(src, dst, src_name, dst_name))
450
451
452 def format_file_contents(
453     src_contents: str,
454     *,
455     line_length: int,
456     fast: bool,
457     mode: FileMode = FileMode.AUTO_DETECT,
458 ) -> FileContent:
459     """Reformat contents a file and return new contents.
460
461     If `fast` is False, additionally confirm that the reformatted code is
462     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
463     `line_length` is passed to :func:`format_str`.
464     """
465     if src_contents.strip() == "":
466         raise NothingChanged
467
468     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
469     if src_contents == dst_contents:
470         raise NothingChanged
471
472     if not fast:
473         assert_equivalent(src_contents, dst_contents)
474         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
475     return dst_contents
476
477
478 def format_str(
479     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
480 ) -> FileContent:
481     """Reformat a string and return new contents.
482
483     `line_length` determines how many characters per line are allowed.
484     """
485     src_node = lib2to3_parse(src_contents)
486     dst_contents = ""
487     future_imports = get_future_imports(src_node)
488     is_pyi = bool(mode & FileMode.PYI)
489     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
490     lines = LineGenerator(
491         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
492     )
493     elt = EmptyLineTracker(is_pyi=is_pyi)
494     empty_line = Line()
495     after = 0
496     for current_line in lines.visit(src_node):
497         for _ in range(after):
498             dst_contents += str(empty_line)
499         before, after = elt.maybe_empty_lines(current_line)
500         for _ in range(before):
501             dst_contents += str(empty_line)
502         for line in split_line(current_line, line_length=line_length, py36=py36):
503             dst_contents += str(line)
504     return dst_contents
505
506
507 GRAMMARS = [
508     pygram.python_grammar_no_print_statement_no_exec_statement,
509     pygram.python_grammar_no_print_statement,
510     pygram.python_grammar,
511 ]
512
513
514 def lib2to3_parse(src_txt: str) -> Node:
515     """Given a string with source, return the lib2to3 Node."""
516     grammar = pygram.python_grammar_no_print_statement
517     if src_txt[-1] != "\n":
518         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
519         src_txt += nl
520     for grammar in GRAMMARS:
521         drv = driver.Driver(grammar, pytree.convert)
522         try:
523             result = drv.parse_string(src_txt, True)
524             break
525
526         except ParseError as pe:
527             lineno, column = pe.context[1]
528             lines = src_txt.splitlines()
529             try:
530                 faulty_line = lines[lineno - 1]
531             except IndexError:
532                 faulty_line = "<line number missing in source>"
533             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
534     else:
535         raise exc from None
536
537     if isinstance(result, Leaf):
538         result = Node(syms.file_input, [result])
539     return result
540
541
542 def lib2to3_unparse(node: Node) -> str:
543     """Given a lib2to3 node, return its string representation."""
544     code = str(node)
545     return code
546
547
548 T = TypeVar("T")
549
550
551 class Visitor(Generic[T]):
552     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
553
554     def visit(self, node: LN) -> Iterator[T]:
555         """Main method to visit `node` and its children.
556
557         It tries to find a `visit_*()` method for the given `node.type`, like
558         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
559         If no dedicated `visit_*()` method is found, chooses `visit_default()`
560         instead.
561
562         Then yields objects of type `T` from the selected visitor.
563         """
564         if node.type < 256:
565             name = token.tok_name[node.type]
566         else:
567             name = type_repr(node.type)
568         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
569
570     def visit_default(self, node: LN) -> Iterator[T]:
571         """Default `visit_*()` implementation. Recurses to children of `node`."""
572         if isinstance(node, Node):
573             for child in node.children:
574                 yield from self.visit(child)
575
576
577 @dataclass
578 class DebugVisitor(Visitor[T]):
579     tree_depth: int = 0
580
581     def visit_default(self, node: LN) -> Iterator[T]:
582         indent = " " * (2 * self.tree_depth)
583         if isinstance(node, Node):
584             _type = type_repr(node.type)
585             out(f"{indent}{_type}", fg="yellow")
586             self.tree_depth += 1
587             for child in node.children:
588                 yield from self.visit(child)
589
590             self.tree_depth -= 1
591             out(f"{indent}/{_type}", fg="yellow", bold=False)
592         else:
593             _type = token.tok_name.get(node.type, str(node.type))
594             out(f"{indent}{_type}", fg="blue", nl=False)
595             if node.prefix:
596                 # We don't have to handle prefixes for `Node` objects since
597                 # that delegates to the first child anyway.
598                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
599             out(f" {node.value!r}", fg="blue", bold=False)
600
601     @classmethod
602     def show(cls, code: str) -> None:
603         """Pretty-print the lib2to3 AST of a given string of `code`.
604
605         Convenience method for debugging.
606         """
607         v: DebugVisitor[None] = DebugVisitor()
608         list(v.visit(lib2to3_parse(code)))
609
610
611 KEYWORDS = set(keyword.kwlist)
612 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
613 FLOW_CONTROL = {"return", "raise", "break", "continue"}
614 STATEMENT = {
615     syms.if_stmt,
616     syms.while_stmt,
617     syms.for_stmt,
618     syms.try_stmt,
619     syms.except_clause,
620     syms.with_stmt,
621     syms.funcdef,
622     syms.classdef,
623 }
624 STANDALONE_COMMENT = 153
625 LOGIC_OPERATORS = {"and", "or"}
626 COMPARATORS = {
627     token.LESS,
628     token.GREATER,
629     token.EQEQUAL,
630     token.NOTEQUAL,
631     token.LESSEQUAL,
632     token.GREATEREQUAL,
633 }
634 MATH_OPERATORS = {
635     token.VBAR,
636     token.CIRCUMFLEX,
637     token.AMPER,
638     token.LEFTSHIFT,
639     token.RIGHTSHIFT,
640     token.PLUS,
641     token.MINUS,
642     token.STAR,
643     token.SLASH,
644     token.DOUBLESLASH,
645     token.PERCENT,
646     token.AT,
647     token.TILDE,
648     token.DOUBLESTAR,
649 }
650 STARS = {token.STAR, token.DOUBLESTAR}
651 VARARGS_PARENTS = {
652     syms.arglist,
653     syms.argument,  # double star in arglist
654     syms.trailer,  # single argument to call
655     syms.typedargslist,
656     syms.varargslist,  # lambdas
657 }
658 UNPACKING_PARENTS = {
659     syms.atom,  # single element of a list or set literal
660     syms.dictsetmaker,
661     syms.listmaker,
662     syms.testlist_gexp,
663 }
664 TEST_DESCENDANTS = {
665     syms.test,
666     syms.lambdef,
667     syms.or_test,
668     syms.and_test,
669     syms.not_test,
670     syms.comparison,
671     syms.star_expr,
672     syms.expr,
673     syms.xor_expr,
674     syms.and_expr,
675     syms.shift_expr,
676     syms.arith_expr,
677     syms.trailer,
678     syms.term,
679     syms.power,
680 }
681 ASSIGNMENTS = {
682     "=",
683     "+=",
684     "-=",
685     "*=",
686     "@=",
687     "/=",
688     "%=",
689     "&=",
690     "|=",
691     "^=",
692     "<<=",
693     ">>=",
694     "**=",
695     "//=",
696 }
697 COMPREHENSION_PRIORITY = 20
698 COMMA_PRIORITY = 18
699 TERNARY_PRIORITY = 16
700 LOGIC_PRIORITY = 14
701 STRING_PRIORITY = 12
702 COMPARATOR_PRIORITY = 10
703 MATH_PRIORITIES = {
704     token.VBAR: 9,
705     token.CIRCUMFLEX: 8,
706     token.AMPER: 7,
707     token.LEFTSHIFT: 6,
708     token.RIGHTSHIFT: 6,
709     token.PLUS: 5,
710     token.MINUS: 5,
711     token.STAR: 4,
712     token.SLASH: 4,
713     token.DOUBLESLASH: 4,
714     token.PERCENT: 4,
715     token.AT: 4,
716     token.TILDE: 3,
717     token.DOUBLESTAR: 2,
718 }
719 DOT_PRIORITY = 1
720
721
722 @dataclass
723 class BracketTracker:
724     """Keeps track of brackets on a line."""
725
726     depth: int = 0
727     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
728     delimiters: Dict[LeafID, Priority] = Factory(dict)
729     previous: Optional[Leaf] = None
730     _for_loop_variable: int = 0
731     _lambda_arguments: int = 0
732
733     def mark(self, leaf: Leaf) -> None:
734         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
735
736         All leaves receive an int `bracket_depth` field that stores how deep
737         within brackets a given leaf is. 0 means there are no enclosing brackets
738         that started on this line.
739
740         If a leaf is itself a closing bracket, it receives an `opening_bracket`
741         field that it forms a pair with. This is a one-directional link to
742         avoid reference cycles.
743
744         If a leaf is a delimiter (a token on which Black can split the line if
745         needed) and it's on depth 0, its `id()` is stored in the tracker's
746         `delimiters` field.
747         """
748         if leaf.type == token.COMMENT:
749             return
750
751         self.maybe_decrement_after_for_loop_variable(leaf)
752         self.maybe_decrement_after_lambda_arguments(leaf)
753         if leaf.type in CLOSING_BRACKETS:
754             self.depth -= 1
755             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
756             leaf.opening_bracket = opening_bracket
757         leaf.bracket_depth = self.depth
758         if self.depth == 0:
759             delim = is_split_before_delimiter(leaf, self.previous)
760             if delim and self.previous is not None:
761                 self.delimiters[id(self.previous)] = delim
762             else:
763                 delim = is_split_after_delimiter(leaf, self.previous)
764                 if delim:
765                     self.delimiters[id(leaf)] = delim
766         if leaf.type in OPENING_BRACKETS:
767             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
768             self.depth += 1
769         self.previous = leaf
770         self.maybe_increment_lambda_arguments(leaf)
771         self.maybe_increment_for_loop_variable(leaf)
772
773     def any_open_brackets(self) -> bool:
774         """Return True if there is an yet unmatched open bracket on the line."""
775         return bool(self.bracket_match)
776
777     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
778         """Return the highest priority of a delimiter found on the line.
779
780         Values are consistent with what `is_split_*_delimiter()` return.
781         Raises ValueError on no delimiters.
782         """
783         return max(v for k, v in self.delimiters.items() if k not in exclude)
784
785     def delimiter_count_with_priority(self, priority: int = 0) -> int:
786         """Return the number of delimiters with the given `priority`.
787
788         If no `priority` is passed, defaults to max priority on the line.
789         """
790         if not self.delimiters:
791             return 0
792
793         priority = priority or self.max_delimiter_priority()
794         return sum(1 for p in self.delimiters.values() if p == priority)
795
796     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
797         """In a for loop, or comprehension, the variables are often unpacks.
798
799         To avoid splitting on the comma in this situation, increase the depth of
800         tokens between `for` and `in`.
801         """
802         if leaf.type == token.NAME and leaf.value == "for":
803             self.depth += 1
804             self._for_loop_variable += 1
805             return True
806
807         return False
808
809     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
810         """See `maybe_increment_for_loop_variable` above for explanation."""
811         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
812             self.depth -= 1
813             self._for_loop_variable -= 1
814             return True
815
816         return False
817
818     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
819         """In a lambda expression, there might be more than one argument.
820
821         To avoid splitting on the comma in this situation, increase the depth of
822         tokens between `lambda` and `:`.
823         """
824         if leaf.type == token.NAME and leaf.value == "lambda":
825             self.depth += 1
826             self._lambda_arguments += 1
827             return True
828
829         return False
830
831     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
832         """See `maybe_increment_lambda_arguments` above for explanation."""
833         if self._lambda_arguments and leaf.type == token.COLON:
834             self.depth -= 1
835             self._lambda_arguments -= 1
836             return True
837
838         return False
839
840     def get_open_lsqb(self) -> Optional[Leaf]:
841         """Return the most recent opening square bracket (if any)."""
842         return self.bracket_match.get((self.depth - 1, token.RSQB))
843
844
845 @dataclass
846 class Line:
847     """Holds leaves and comments. Can be printed with `str(line)`."""
848
849     depth: int = 0
850     leaves: List[Leaf] = Factory(list)
851     comments: List[Tuple[Index, Leaf]] = Factory(list)
852     bracket_tracker: BracketTracker = Factory(BracketTracker)
853     inside_brackets: bool = False
854     should_explode: bool = False
855
856     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
857         """Add a new `leaf` to the end of the line.
858
859         Unless `preformatted` is True, the `leaf` will receive a new consistent
860         whitespace prefix and metadata applied by :class:`BracketTracker`.
861         Trailing commas are maybe removed, unpacked for loop variables are
862         demoted from being delimiters.
863
864         Inline comments are put aside.
865         """
866         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
867         if not has_value:
868             return
869
870         if token.COLON == leaf.type and self.is_class_paren_empty:
871             del self.leaves[-2:]
872         if self.leaves and not preformatted:
873             # Note: at this point leaf.prefix should be empty except for
874             # imports, for which we only preserve newlines.
875             leaf.prefix += whitespace(
876                 leaf, complex_subscript=self.is_complex_subscript(leaf)
877             )
878         if self.inside_brackets or not preformatted:
879             self.bracket_tracker.mark(leaf)
880             self.maybe_remove_trailing_comma(leaf)
881         if not self.append_comment(leaf):
882             self.leaves.append(leaf)
883
884     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
885         """Like :func:`append()` but disallow invalid standalone comment structure.
886
887         Raises ValueError when any `leaf` is appended after a standalone comment
888         or when a standalone comment is not the first leaf on the line.
889         """
890         if self.bracket_tracker.depth == 0:
891             if self.is_comment:
892                 raise ValueError("cannot append to standalone comments")
893
894             if self.leaves and leaf.type == STANDALONE_COMMENT:
895                 raise ValueError(
896                     "cannot append standalone comments to a populated line"
897                 )
898
899         self.append(leaf, preformatted=preformatted)
900
901     @property
902     def is_comment(self) -> bool:
903         """Is this line a standalone comment?"""
904         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
905
906     @property
907     def is_decorator(self) -> bool:
908         """Is this line a decorator?"""
909         return bool(self) and self.leaves[0].type == token.AT
910
911     @property
912     def is_import(self) -> bool:
913         """Is this an import line?"""
914         return bool(self) and is_import(self.leaves[0])
915
916     @property
917     def is_class(self) -> bool:
918         """Is this line a class definition?"""
919         return (
920             bool(self)
921             and self.leaves[0].type == token.NAME
922             and self.leaves[0].value == "class"
923         )
924
925     @property
926     def is_stub_class(self) -> bool:
927         """Is this line a class definition with a body consisting only of "..."?"""
928         return self.is_class and self.leaves[-3:] == [
929             Leaf(token.DOT, ".") for _ in range(3)
930         ]
931
932     @property
933     def is_def(self) -> bool:
934         """Is this a function definition? (Also returns True for async defs.)"""
935         try:
936             first_leaf = self.leaves[0]
937         except IndexError:
938             return False
939
940         try:
941             second_leaf: Optional[Leaf] = self.leaves[1]
942         except IndexError:
943             second_leaf = None
944         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
945             first_leaf.type == token.ASYNC
946             and second_leaf is not None
947             and second_leaf.type == token.NAME
948             and second_leaf.value == "def"
949         )
950
951     @property
952     def is_class_paren_empty(self) -> bool:
953         """Is this a class with no base classes but using parentheses?
954
955         Those are unnecessary and should be removed.
956         """
957         return (
958             bool(self)
959             and len(self.leaves) == 4
960             and self.is_class
961             and self.leaves[2].type == token.LPAR
962             and self.leaves[2].value == "("
963             and self.leaves[3].type == token.RPAR
964             and self.leaves[3].value == ")"
965         )
966
967     @property
968     def is_triple_quoted_string(self) -> bool:
969         """Is the line a triple quoted string?"""
970         return (
971             bool(self)
972             and self.leaves[0].type == token.STRING
973             and self.leaves[0].value.startswith(('"""', "'''"))
974         )
975
976     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
977         """If so, needs to be split before emitting."""
978         for leaf in self.leaves:
979             if leaf.type == STANDALONE_COMMENT:
980                 if leaf.bracket_depth <= depth_limit:
981                     return True
982
983         return False
984
985     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
986         """Remove trailing comma if there is one and it's safe."""
987         if not (
988             self.leaves
989             and self.leaves[-1].type == token.COMMA
990             and closing.type in CLOSING_BRACKETS
991         ):
992             return False
993
994         if closing.type == token.RBRACE:
995             self.remove_trailing_comma()
996             return True
997
998         if closing.type == token.RSQB:
999             comma = self.leaves[-1]
1000             if comma.parent and comma.parent.type == syms.listmaker:
1001                 self.remove_trailing_comma()
1002                 return True
1003
1004         # For parens let's check if it's safe to remove the comma.
1005         # Imports are always safe.
1006         if self.is_import:
1007             self.remove_trailing_comma()
1008             return True
1009
1010         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1011         # change a tuple into a different type by removing the comma.
1012         depth = closing.bracket_depth + 1
1013         commas = 0
1014         opening = closing.opening_bracket
1015         for _opening_index, leaf in enumerate(self.leaves):
1016             if leaf is opening:
1017                 break
1018
1019         else:
1020             return False
1021
1022         for leaf in self.leaves[_opening_index + 1 :]:
1023             if leaf is closing:
1024                 break
1025
1026             bracket_depth = leaf.bracket_depth
1027             if bracket_depth == depth and leaf.type == token.COMMA:
1028                 commas += 1
1029                 if leaf.parent and leaf.parent.type == syms.arglist:
1030                     commas += 1
1031                     break
1032
1033         if commas > 1:
1034             self.remove_trailing_comma()
1035             return True
1036
1037         return False
1038
1039     def append_comment(self, comment: Leaf) -> bool:
1040         """Add an inline or standalone comment to the line."""
1041         if (
1042             comment.type == STANDALONE_COMMENT
1043             and self.bracket_tracker.any_open_brackets()
1044         ):
1045             comment.prefix = ""
1046             return False
1047
1048         if comment.type != token.COMMENT:
1049             return False
1050
1051         after = len(self.leaves) - 1
1052         if after == -1:
1053             comment.type = STANDALONE_COMMENT
1054             comment.prefix = ""
1055             return False
1056
1057         else:
1058             self.comments.append((after, comment))
1059             return True
1060
1061     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1062         """Generate comments that should appear directly after `leaf`.
1063
1064         Provide a non-negative leaf `_index` to speed up the function.
1065         """
1066         if _index == -1:
1067             for _index, _leaf in enumerate(self.leaves):
1068                 if leaf is _leaf:
1069                     break
1070
1071             else:
1072                 return
1073
1074         for index, comment_after in self.comments:
1075             if _index == index:
1076                 yield comment_after
1077
1078     def remove_trailing_comma(self) -> None:
1079         """Remove the trailing comma and moves the comments attached to it."""
1080         comma_index = len(self.leaves) - 1
1081         for i in range(len(self.comments)):
1082             comment_index, comment = self.comments[i]
1083             if comment_index == comma_index:
1084                 self.comments[i] = (comma_index - 1, comment)
1085         self.leaves.pop()
1086
1087     def is_complex_subscript(self, leaf: Leaf) -> bool:
1088         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1089         open_lsqb = (
1090             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1091         )
1092         if open_lsqb is None:
1093             return False
1094
1095         subscript_start = open_lsqb.next_sibling
1096         if (
1097             isinstance(subscript_start, Node)
1098             and subscript_start.type == syms.subscriptlist
1099         ):
1100             subscript_start = child_towards(subscript_start, leaf)
1101         return subscript_start is not None and any(
1102             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1103         )
1104
1105     def __str__(self) -> str:
1106         """Render the line."""
1107         if not self:
1108             return "\n"
1109
1110         indent = "    " * self.depth
1111         leaves = iter(self.leaves)
1112         first = next(leaves)
1113         res = f"{first.prefix}{indent}{first.value}"
1114         for leaf in leaves:
1115             res += str(leaf)
1116         for _, comment in self.comments:
1117             res += str(comment)
1118         return res + "\n"
1119
1120     def __bool__(self) -> bool:
1121         """Return True if the line has leaves or comments."""
1122         return bool(self.leaves or self.comments)
1123
1124
1125 class UnformattedLines(Line):
1126     """Just like :class:`Line` but stores lines which aren't reformatted."""
1127
1128     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1129         """Just add a new `leaf` to the end of the lines.
1130
1131         The `preformatted` argument is ignored.
1132
1133         Keeps track of indentation `depth`, which is useful when the user
1134         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1135         """
1136         try:
1137             list(generate_comments(leaf))
1138         except FormatOn as f_on:
1139             self.leaves.append(f_on.leaf_from_consumed(leaf))
1140             raise
1141
1142         self.leaves.append(leaf)
1143         if leaf.type == token.INDENT:
1144             self.depth += 1
1145         elif leaf.type == token.DEDENT:
1146             self.depth -= 1
1147
1148     def __str__(self) -> str:
1149         """Render unformatted lines from leaves which were added with `append()`.
1150
1151         `depth` is not used for indentation in this case.
1152         """
1153         if not self:
1154             return "\n"
1155
1156         res = ""
1157         for leaf in self.leaves:
1158             res += str(leaf)
1159         return res
1160
1161     def append_comment(self, comment: Leaf) -> bool:
1162         """Not implemented in this class. Raises `NotImplementedError`."""
1163         raise NotImplementedError("Unformatted lines don't store comments separately.")
1164
1165     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1166         """Does nothing and returns False."""
1167         return False
1168
1169     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1170         """Does nothing and returns False."""
1171         return False
1172
1173
1174 @dataclass
1175 class EmptyLineTracker:
1176     """Provides a stateful method that returns the number of potential extra
1177     empty lines needed before and after the currently processed line.
1178
1179     Note: this tracker works on lines that haven't been split yet.  It assumes
1180     the prefix of the first leaf consists of optional newlines.  Those newlines
1181     are consumed by `maybe_empty_lines()` and included in the computation.
1182     """
1183
1184     is_pyi: bool = False
1185     previous_line: Optional[Line] = None
1186     previous_after: int = 0
1187     previous_defs: List[int] = Factory(list)
1188
1189     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1190         """Return the number of extra empty lines before and after the `current_line`.
1191
1192         This is for separating `def`, `async def` and `class` with extra empty
1193         lines (two on module-level).
1194         """
1195         if isinstance(current_line, UnformattedLines):
1196             return 0, 0
1197
1198         before, after = self._maybe_empty_lines(current_line)
1199         before -= self.previous_after
1200         self.previous_after = after
1201         self.previous_line = current_line
1202         return before, after
1203
1204     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1205         max_allowed = 1
1206         if current_line.depth == 0:
1207             max_allowed = 1 if self.is_pyi else 2
1208         if current_line.leaves:
1209             # Consume the first leaf's extra newlines.
1210             first_leaf = current_line.leaves[0]
1211             before = first_leaf.prefix.count("\n")
1212             before = min(before, max_allowed)
1213             first_leaf.prefix = ""
1214         else:
1215             before = 0
1216         depth = current_line.depth
1217         while self.previous_defs and self.previous_defs[-1] >= depth:
1218             self.previous_defs.pop()
1219             if self.is_pyi:
1220                 before = 0 if depth else 1
1221             else:
1222                 before = 1 if depth else 2
1223         is_decorator = current_line.is_decorator
1224         if is_decorator or current_line.is_def or current_line.is_class:
1225             if not is_decorator:
1226                 self.previous_defs.append(depth)
1227             if self.previous_line is None:
1228                 # Don't insert empty lines before the first line in the file.
1229                 return 0, 0
1230
1231             if self.previous_line.is_decorator:
1232                 return 0, 0
1233
1234             if self.previous_line.depth < current_line.depth and (
1235                 self.previous_line.is_class or self.previous_line.is_def
1236             ):
1237                 return 0, 0
1238
1239             if (
1240                 self.previous_line.is_comment
1241                 and self.previous_line.depth == current_line.depth
1242                 and before == 0
1243             ):
1244                 return 0, 0
1245
1246             if self.is_pyi:
1247                 if self.previous_line.depth > current_line.depth:
1248                     newlines = 1
1249                 elif current_line.is_class or self.previous_line.is_class:
1250                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1251                         newlines = 0
1252                     else:
1253                         newlines = 1
1254                 else:
1255                     newlines = 0
1256             else:
1257                 newlines = 2
1258             if current_line.depth and newlines:
1259                 newlines -= 1
1260             return newlines, 0
1261
1262         if (
1263             self.previous_line
1264             and self.previous_line.is_import
1265             and not current_line.is_import
1266             and depth == self.previous_line.depth
1267         ):
1268             return (before or 1), 0
1269
1270         if (
1271             self.previous_line
1272             and self.previous_line.is_class
1273             and current_line.is_triple_quoted_string
1274         ):
1275             return before, 1
1276
1277         return before, 0
1278
1279
1280 @dataclass
1281 class LineGenerator(Visitor[Line]):
1282     """Generates reformatted Line objects.  Empty lines are not emitted.
1283
1284     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1285     in ways that will no longer stringify to valid Python code on the tree.
1286     """
1287
1288     is_pyi: bool = False
1289     current_line: Line = Factory(Line)
1290     remove_u_prefix: bool = False
1291
1292     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1293         """Generate a line.
1294
1295         If the line is empty, only emit if it makes sense.
1296         If the line is too long, split it first and then generate.
1297
1298         If any lines were generated, set up a new current_line.
1299         """
1300         if not self.current_line:
1301             if self.current_line.__class__ == type:
1302                 self.current_line.depth += indent
1303             else:
1304                 self.current_line = type(depth=self.current_line.depth + indent)
1305             return  # Line is empty, don't emit. Creating a new one unnecessary.
1306
1307         complete_line = self.current_line
1308         self.current_line = type(depth=complete_line.depth + indent)
1309         yield complete_line
1310
1311     def visit(self, node: LN) -> Iterator[Line]:
1312         """Main method to visit `node` and its children.
1313
1314         Yields :class:`Line` objects.
1315         """
1316         if isinstance(self.current_line, UnformattedLines):
1317             # File contained `# fmt: off`
1318             yield from self.visit_unformatted(node)
1319
1320         else:
1321             yield from super().visit(node)
1322
1323     def visit_default(self, node: LN) -> Iterator[Line]:
1324         """Default `visit_*()` implementation. Recurses to children of `node`."""
1325         if isinstance(node, Leaf):
1326             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1327             try:
1328                 for comment in generate_comments(node):
1329                     if any_open_brackets:
1330                         # any comment within brackets is subject to splitting
1331                         self.current_line.append(comment)
1332                     elif comment.type == token.COMMENT:
1333                         # regular trailing comment
1334                         self.current_line.append(comment)
1335                         yield from self.line()
1336
1337                     else:
1338                         # regular standalone comment
1339                         yield from self.line()
1340
1341                         self.current_line.append(comment)
1342                         yield from self.line()
1343
1344             except FormatOff as f_off:
1345                 f_off.trim_prefix(node)
1346                 yield from self.line(type=UnformattedLines)
1347                 yield from self.visit(node)
1348
1349             except FormatOn as f_on:
1350                 # This only happens here if somebody says "fmt: on" multiple
1351                 # times in a row.
1352                 f_on.trim_prefix(node)
1353                 yield from self.visit_default(node)
1354
1355             else:
1356                 normalize_prefix(node, inside_brackets=any_open_brackets)
1357                 if node.type == token.STRING:
1358                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1359                     normalize_string_quotes(node)
1360                 if node.type not in WHITESPACE:
1361                     self.current_line.append(node)
1362         yield from super().visit_default(node)
1363
1364     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1365         """Increase indentation level, maybe yield a line."""
1366         # In blib2to3 INDENT never holds comments.
1367         yield from self.line(+1)
1368         yield from self.visit_default(node)
1369
1370     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1371         """Decrease indentation level, maybe yield a line."""
1372         # The current line might still wait for trailing comments.  At DEDENT time
1373         # there won't be any (they would be prefixes on the preceding NEWLINE).
1374         # Emit the line then.
1375         yield from self.line()
1376
1377         # While DEDENT has no value, its prefix may contain standalone comments
1378         # that belong to the current indentation level.  Get 'em.
1379         yield from self.visit_default(node)
1380
1381         # Finally, emit the dedent.
1382         yield from self.line(-1)
1383
1384     def visit_stmt(
1385         self, node: Node, keywords: Set[str], parens: Set[str]
1386     ) -> Iterator[Line]:
1387         """Visit a statement.
1388
1389         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1390         `def`, `with`, `class`, `assert` and assignments.
1391
1392         The relevant Python language `keywords` for a given statement will be
1393         NAME leaves within it. This methods puts those on a separate line.
1394
1395         `parens` holds a set of string leaf values immediately after which
1396         invisible parens should be put.
1397         """
1398         normalize_invisible_parens(node, parens_after=parens)
1399         for child in node.children:
1400             if child.type == token.NAME and child.value in keywords:  # type: ignore
1401                 yield from self.line()
1402
1403             yield from self.visit(child)
1404
1405     def visit_suite(self, node: Node) -> Iterator[Line]:
1406         """Visit a suite."""
1407         if self.is_pyi and is_stub_suite(node):
1408             yield from self.visit(node.children[2])
1409         else:
1410             yield from self.visit_default(node)
1411
1412     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1413         """Visit a statement without nested statements."""
1414         is_suite_like = node.parent and node.parent.type in STATEMENT
1415         if is_suite_like:
1416             if self.is_pyi and is_stub_body(node):
1417                 yield from self.visit_default(node)
1418             else:
1419                 yield from self.line(+1)
1420                 yield from self.visit_default(node)
1421                 yield from self.line(-1)
1422
1423         else:
1424             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1425                 yield from self.line()
1426             yield from self.visit_default(node)
1427
1428     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1429         """Visit `async def`, `async for`, `async with`."""
1430         yield from self.line()
1431
1432         children = iter(node.children)
1433         for child in children:
1434             yield from self.visit(child)
1435
1436             if child.type == token.ASYNC:
1437                 break
1438
1439         internal_stmt = next(children)
1440         for child in internal_stmt.children:
1441             yield from self.visit(child)
1442
1443     def visit_decorators(self, node: Node) -> Iterator[Line]:
1444         """Visit decorators."""
1445         for child in node.children:
1446             yield from self.line()
1447             yield from self.visit(child)
1448
1449     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1450         """Remove a semicolon and put the other statement on a separate line."""
1451         yield from self.line()
1452
1453     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1454         """End of file. Process outstanding comments and end with a newline."""
1455         yield from self.visit_default(leaf)
1456         yield from self.line()
1457
1458     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1459         """Used when file contained a `# fmt: off`."""
1460         if isinstance(node, Node):
1461             for child in node.children:
1462                 yield from self.visit(child)
1463
1464         else:
1465             try:
1466                 self.current_line.append(node)
1467             except FormatOn as f_on:
1468                 f_on.trim_prefix(node)
1469                 yield from self.line()
1470                 yield from self.visit(node)
1471
1472             if node.type == token.ENDMARKER:
1473                 # somebody decided not to put a final `# fmt: on`
1474                 yield from self.line()
1475
1476     def __attrs_post_init__(self) -> None:
1477         """You are in a twisty little maze of passages."""
1478         v = self.visit_stmt
1479         Ø: Set[str] = set()
1480         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1481         self.visit_if_stmt = partial(
1482             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1483         )
1484         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1485         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1486         self.visit_try_stmt = partial(
1487             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1488         )
1489         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1490         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1491         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1492         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1493         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1494         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1495         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1496         self.visit_async_funcdef = self.visit_async_stmt
1497         self.visit_decorated = self.visit_decorators
1498
1499
1500 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1501 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1502 OPENING_BRACKETS = set(BRACKET.keys())
1503 CLOSING_BRACKETS = set(BRACKET.values())
1504 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1505 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1506
1507
1508 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1509     """Return whitespace prefix if needed for the given `leaf`.
1510
1511     `complex_subscript` signals whether the given leaf is part of a subscription
1512     which has non-trivial arguments, like arithmetic expressions or function calls.
1513     """
1514     NO = ""
1515     SPACE = " "
1516     DOUBLESPACE = "  "
1517     t = leaf.type
1518     p = leaf.parent
1519     v = leaf.value
1520     if t in ALWAYS_NO_SPACE:
1521         return NO
1522
1523     if t == token.COMMENT:
1524         return DOUBLESPACE
1525
1526     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1527     if t == token.COLON and p.type not in {
1528         syms.subscript,
1529         syms.subscriptlist,
1530         syms.sliceop,
1531     }:
1532         return NO
1533
1534     prev = leaf.prev_sibling
1535     if not prev:
1536         prevp = preceding_leaf(p)
1537         if not prevp or prevp.type in OPENING_BRACKETS:
1538             return NO
1539
1540         if t == token.COLON:
1541             if prevp.type == token.COLON:
1542                 return NO
1543
1544             elif prevp.type != token.COMMA and not complex_subscript:
1545                 return NO
1546
1547             return SPACE
1548
1549         if prevp.type == token.EQUAL:
1550             if prevp.parent:
1551                 if prevp.parent.type in {
1552                     syms.arglist,
1553                     syms.argument,
1554                     syms.parameters,
1555                     syms.varargslist,
1556                 }:
1557                     return NO
1558
1559                 elif prevp.parent.type == syms.typedargslist:
1560                     # A bit hacky: if the equal sign has whitespace, it means we
1561                     # previously found it's a typed argument.  So, we're using
1562                     # that, too.
1563                     return prevp.prefix
1564
1565         elif prevp.type in STARS:
1566             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1567                 return NO
1568
1569         elif prevp.type == token.COLON:
1570             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1571                 return SPACE if complex_subscript else NO
1572
1573         elif (
1574             prevp.parent
1575             and prevp.parent.type == syms.factor
1576             and prevp.type in MATH_OPERATORS
1577         ):
1578             return NO
1579
1580         elif (
1581             prevp.type == token.RIGHTSHIFT
1582             and prevp.parent
1583             and prevp.parent.type == syms.shift_expr
1584             and prevp.prev_sibling
1585             and prevp.prev_sibling.type == token.NAME
1586             and prevp.prev_sibling.value == "print"  # type: ignore
1587         ):
1588             # Python 2 print chevron
1589             return NO
1590
1591     elif prev.type in OPENING_BRACKETS:
1592         return NO
1593
1594     if p.type in {syms.parameters, syms.arglist}:
1595         # untyped function signatures or calls
1596         if not prev or prev.type != token.COMMA:
1597             return NO
1598
1599     elif p.type == syms.varargslist:
1600         # lambdas
1601         if prev and prev.type != token.COMMA:
1602             return NO
1603
1604     elif p.type == syms.typedargslist:
1605         # typed function signatures
1606         if not prev:
1607             return NO
1608
1609         if t == token.EQUAL:
1610             if prev.type != syms.tname:
1611                 return NO
1612
1613         elif prev.type == token.EQUAL:
1614             # A bit hacky: if the equal sign has whitespace, it means we
1615             # previously found it's a typed argument.  So, we're using that, too.
1616             return prev.prefix
1617
1618         elif prev.type != token.COMMA:
1619             return NO
1620
1621     elif p.type == syms.tname:
1622         # type names
1623         if not prev:
1624             prevp = preceding_leaf(p)
1625             if not prevp or prevp.type != token.COMMA:
1626                 return NO
1627
1628     elif p.type == syms.trailer:
1629         # attributes and calls
1630         if t == token.LPAR or t == token.RPAR:
1631             return NO
1632
1633         if not prev:
1634             if t == token.DOT:
1635                 prevp = preceding_leaf(p)
1636                 if not prevp or prevp.type != token.NUMBER:
1637                     return NO
1638
1639             elif t == token.LSQB:
1640                 return NO
1641
1642         elif prev.type != token.COMMA:
1643             return NO
1644
1645     elif p.type == syms.argument:
1646         # single argument
1647         if t == token.EQUAL:
1648             return NO
1649
1650         if not prev:
1651             prevp = preceding_leaf(p)
1652             if not prevp or prevp.type == token.LPAR:
1653                 return NO
1654
1655         elif prev.type in {token.EQUAL} | STARS:
1656             return NO
1657
1658     elif p.type == syms.decorator:
1659         # decorators
1660         return NO
1661
1662     elif p.type == syms.dotted_name:
1663         if prev:
1664             return NO
1665
1666         prevp = preceding_leaf(p)
1667         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1668             return NO
1669
1670     elif p.type == syms.classdef:
1671         if t == token.LPAR:
1672             return NO
1673
1674         if prev and prev.type == token.LPAR:
1675             return NO
1676
1677     elif p.type in {syms.subscript, syms.sliceop}:
1678         # indexing
1679         if not prev:
1680             assert p.parent is not None, "subscripts are always parented"
1681             if p.parent.type == syms.subscriptlist:
1682                 return SPACE
1683
1684             return NO
1685
1686         elif not complex_subscript:
1687             return NO
1688
1689     elif p.type == syms.atom:
1690         if prev and t == token.DOT:
1691             # dots, but not the first one.
1692             return NO
1693
1694     elif p.type == syms.dictsetmaker:
1695         # dict unpacking
1696         if prev and prev.type == token.DOUBLESTAR:
1697             return NO
1698
1699     elif p.type in {syms.factor, syms.star_expr}:
1700         # unary ops
1701         if not prev:
1702             prevp = preceding_leaf(p)
1703             if not prevp or prevp.type in OPENING_BRACKETS:
1704                 return NO
1705
1706             prevp_parent = prevp.parent
1707             assert prevp_parent is not None
1708             if prevp.type == token.COLON and prevp_parent.type in {
1709                 syms.subscript,
1710                 syms.sliceop,
1711             }:
1712                 return NO
1713
1714             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1715                 return NO
1716
1717         elif t == token.NAME or t == token.NUMBER:
1718             return NO
1719
1720     elif p.type == syms.import_from:
1721         if t == token.DOT:
1722             if prev and prev.type == token.DOT:
1723                 return NO
1724
1725         elif t == token.NAME:
1726             if v == "import":
1727                 return SPACE
1728
1729             if prev and prev.type == token.DOT:
1730                 return NO
1731
1732     elif p.type == syms.sliceop:
1733         return NO
1734
1735     return SPACE
1736
1737
1738 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1739     """Return the first leaf that precedes `node`, if any."""
1740     while node:
1741         res = node.prev_sibling
1742         if res:
1743             if isinstance(res, Leaf):
1744                 return res
1745
1746             try:
1747                 return list(res.leaves())[-1]
1748
1749             except IndexError:
1750                 return None
1751
1752         node = node.parent
1753     return None
1754
1755
1756 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1757     """Return the child of `ancestor` that contains `descendant`."""
1758     node: Optional[LN] = descendant
1759     while node and node.parent != ancestor:
1760         node = node.parent
1761     return node
1762
1763
1764 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1765     """Return the priority of the `leaf` delimiter, given a line break after it.
1766
1767     The delimiter priorities returned here are from those delimiters that would
1768     cause a line break after themselves.
1769
1770     Higher numbers are higher priority.
1771     """
1772     if leaf.type == token.COMMA:
1773         return COMMA_PRIORITY
1774
1775     return 0
1776
1777
1778 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1779     """Return the priority of the `leaf` delimiter, given a line before after it.
1780
1781     The delimiter priorities returned here are from those delimiters that would
1782     cause a line break before themselves.
1783
1784     Higher numbers are higher priority.
1785     """
1786     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1787         # * and ** might also be MATH_OPERATORS but in this case they are not.
1788         # Don't treat them as a delimiter.
1789         return 0
1790
1791     if (
1792         leaf.type == token.DOT
1793         and leaf.parent
1794         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1795         and (previous is None or previous.type in CLOSING_BRACKETS)
1796     ):
1797         return DOT_PRIORITY
1798
1799     if (
1800         leaf.type in MATH_OPERATORS
1801         and leaf.parent
1802         and leaf.parent.type not in {syms.factor, syms.star_expr}
1803     ):
1804         return MATH_PRIORITIES[leaf.type]
1805
1806     if leaf.type in COMPARATORS:
1807         return COMPARATOR_PRIORITY
1808
1809     if (
1810         leaf.type == token.STRING
1811         and previous is not None
1812         and previous.type == token.STRING
1813     ):
1814         return STRING_PRIORITY
1815
1816     if leaf.type != token.NAME:
1817         return 0
1818
1819     if (
1820         leaf.value == "for"
1821         and leaf.parent
1822         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1823     ):
1824         return COMPREHENSION_PRIORITY
1825
1826     if (
1827         leaf.value == "if"
1828         and leaf.parent
1829         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1830     ):
1831         return COMPREHENSION_PRIORITY
1832
1833     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1834         return TERNARY_PRIORITY
1835
1836     if leaf.value == "is":
1837         return COMPARATOR_PRIORITY
1838
1839     if (
1840         leaf.value == "in"
1841         and leaf.parent
1842         and leaf.parent.type in {syms.comp_op, syms.comparison}
1843         and not (
1844             previous is not None
1845             and previous.type == token.NAME
1846             and previous.value == "not"
1847         )
1848     ):
1849         return COMPARATOR_PRIORITY
1850
1851     if (
1852         leaf.value == "not"
1853         and leaf.parent
1854         and leaf.parent.type == syms.comp_op
1855         and not (
1856             previous is not None
1857             and previous.type == token.NAME
1858             and previous.value == "is"
1859         )
1860     ):
1861         return COMPARATOR_PRIORITY
1862
1863     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1864         return LOGIC_PRIORITY
1865
1866     return 0
1867
1868
1869 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1870     """Clean the prefix of the `leaf` and generate comments from it, if any.
1871
1872     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1873     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1874     move because it does away with modifying the grammar to include all the
1875     possible places in which comments can be placed.
1876
1877     The sad consequence for us though is that comments don't "belong" anywhere.
1878     This is why this function generates simple parentless Leaf objects for
1879     comments.  We simply don't know what the correct parent should be.
1880
1881     No matter though, we can live without this.  We really only need to
1882     differentiate between inline and standalone comments.  The latter don't
1883     share the line with any code.
1884
1885     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1886     are emitted with a fake STANDALONE_COMMENT token identifier.
1887     """
1888     p = leaf.prefix
1889     if not p:
1890         return
1891
1892     if "#" not in p:
1893         return
1894
1895     consumed = 0
1896     nlines = 0
1897     for index, line in enumerate(p.split("\n")):
1898         consumed += len(line) + 1  # adding the length of the split '\n'
1899         line = line.lstrip()
1900         if not line:
1901             nlines += 1
1902         if not line.startswith("#"):
1903             continue
1904
1905         if index == 0 and leaf.type != token.ENDMARKER:
1906             comment_type = token.COMMENT  # simple trailing comment
1907         else:
1908             comment_type = STANDALONE_COMMENT
1909         comment = make_comment(line)
1910         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1911
1912         if comment in {"# fmt: on", "# yapf: enable"}:
1913             raise FormatOn(consumed)
1914
1915         if comment in {"# fmt: off", "# yapf: disable"}:
1916             if comment_type == STANDALONE_COMMENT:
1917                 raise FormatOff(consumed)
1918
1919             prev = preceding_leaf(leaf)
1920             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1921                 raise FormatOff(consumed)
1922
1923         nlines = 0
1924
1925
1926 def make_comment(content: str) -> str:
1927     """Return a consistently formatted comment from the given `content` string.
1928
1929     All comments (except for "##", "#!", "#:") should have a single space between
1930     the hash sign and the content.
1931
1932     If `content` didn't start with a hash sign, one is provided.
1933     """
1934     content = content.rstrip()
1935     if not content:
1936         return "#"
1937
1938     if content[0] == "#":
1939         content = content[1:]
1940     if content and content[0] not in " !:#":
1941         content = " " + content
1942     return "#" + content
1943
1944
1945 def split_line(
1946     line: Line, line_length: int, inner: bool = False, py36: bool = False
1947 ) -> Iterator[Line]:
1948     """Split a `line` into potentially many lines.
1949
1950     They should fit in the allotted `line_length` but might not be able to.
1951     `inner` signifies that there were a pair of brackets somewhere around the
1952     current `line`, possibly transitively. This means we can fallback to splitting
1953     by delimiters if the LHS/RHS don't yield any results.
1954
1955     If `py36` is True, splitting may generate syntax that is only compatible
1956     with Python 3.6 and later.
1957     """
1958     if isinstance(line, UnformattedLines) or line.is_comment:
1959         yield line
1960         return
1961
1962     line_str = str(line).strip("\n")
1963     if not line.should_explode and is_line_short_enough(
1964         line, line_length=line_length, line_str=line_str
1965     ):
1966         yield line
1967         return
1968
1969     split_funcs: List[SplitFunc]
1970     if line.is_def:
1971         split_funcs = [left_hand_split]
1972     else:
1973
1974         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1975             for omit in generate_trailers_to_omit(line, line_length):
1976                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
1977                 if is_line_short_enough(lines[0], line_length=line_length):
1978                     yield from lines
1979                     return
1980
1981             # All splits failed, best effort split with no omits.
1982             # This mostly happens to multiline strings that are by definition
1983             # reported as not fitting a single line.
1984             yield from right_hand_split(line, py36)
1985
1986         if line.inside_brackets:
1987             split_funcs = [delimiter_split, standalone_comment_split, rhs]
1988         else:
1989             split_funcs = [rhs]
1990     for split_func in split_funcs:
1991         # We are accumulating lines in `result` because we might want to abort
1992         # mission and return the original line in the end, or attempt a different
1993         # split altogether.
1994         result: List[Line] = []
1995         try:
1996             for l in split_func(line, py36):
1997                 if str(l).strip("\n") == line_str:
1998                     raise CannotSplit("Split function returned an unchanged result")
1999
2000                 result.extend(
2001                     split_line(l, line_length=line_length, inner=True, py36=py36)
2002                 )
2003         except CannotSplit as cs:
2004             continue
2005
2006         else:
2007             yield from result
2008             break
2009
2010     else:
2011         yield line
2012
2013
2014 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2015     """Split line into many lines, starting with the first matching bracket pair.
2016
2017     Note: this usually looks weird, only use this for function definitions.
2018     Prefer RHS otherwise.  This is why this function is not symmetrical with
2019     :func:`right_hand_split` which also handles optional parentheses.
2020     """
2021     head = Line(depth=line.depth)
2022     body = Line(depth=line.depth + 1, inside_brackets=True)
2023     tail = Line(depth=line.depth)
2024     tail_leaves: List[Leaf] = []
2025     body_leaves: List[Leaf] = []
2026     head_leaves: List[Leaf] = []
2027     current_leaves = head_leaves
2028     matching_bracket = None
2029     for leaf in line.leaves:
2030         if (
2031             current_leaves is body_leaves
2032             and leaf.type in CLOSING_BRACKETS
2033             and leaf.opening_bracket is matching_bracket
2034         ):
2035             current_leaves = tail_leaves if body_leaves else head_leaves
2036         current_leaves.append(leaf)
2037         if current_leaves is head_leaves:
2038             if leaf.type in OPENING_BRACKETS:
2039                 matching_bracket = leaf
2040                 current_leaves = body_leaves
2041     # Since body is a new indent level, remove spurious leading whitespace.
2042     if body_leaves:
2043         normalize_prefix(body_leaves[0], inside_brackets=True)
2044     # Build the new lines.
2045     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2046         for leaf in leaves:
2047             result.append(leaf, preformatted=True)
2048             for comment_after in line.comments_after(leaf):
2049                 result.append(comment_after, preformatted=True)
2050     bracket_split_succeeded_or_raise(head, body, tail)
2051     for result in (head, body, tail):
2052         if result:
2053             yield result
2054
2055
2056 def right_hand_split(
2057     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2058 ) -> Iterator[Line]:
2059     """Split line into many lines, starting with the last matching bracket pair.
2060
2061     If the split was by optional parentheses, attempt splitting without them, too.
2062     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2063     this split.
2064
2065     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2066     """
2067     head = Line(depth=line.depth)
2068     body = Line(depth=line.depth + 1, inside_brackets=True)
2069     tail = Line(depth=line.depth)
2070     tail_leaves: List[Leaf] = []
2071     body_leaves: List[Leaf] = []
2072     head_leaves: List[Leaf] = []
2073     current_leaves = tail_leaves
2074     opening_bracket = None
2075     closing_bracket = None
2076     for leaf in reversed(line.leaves):
2077         if current_leaves is body_leaves:
2078             if leaf is opening_bracket:
2079                 current_leaves = head_leaves if body_leaves else tail_leaves
2080         current_leaves.append(leaf)
2081         if current_leaves is tail_leaves:
2082             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2083                 opening_bracket = leaf.opening_bracket
2084                 closing_bracket = leaf
2085                 current_leaves = body_leaves
2086     tail_leaves.reverse()
2087     body_leaves.reverse()
2088     head_leaves.reverse()
2089     # Since body is a new indent level, remove spurious leading whitespace.
2090     if body_leaves:
2091         normalize_prefix(body_leaves[0], inside_brackets=True)
2092     if not head_leaves:
2093         # No `head` means the split failed. Either `tail` has all content or
2094         # the matching `opening_bracket` wasn't available on `line` anymore.
2095         raise CannotSplit("No brackets found")
2096
2097     # Build the new lines.
2098     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2099         for leaf in leaves:
2100             result.append(leaf, preformatted=True)
2101             for comment_after in line.comments_after(leaf):
2102                 result.append(comment_after, preformatted=True)
2103     bracket_split_succeeded_or_raise(head, body, tail)
2104     assert opening_bracket and closing_bracket
2105     if (
2106         # the opening bracket is an optional paren
2107         opening_bracket.type == token.LPAR
2108         and not opening_bracket.value
2109         # the closing bracket is an optional paren
2110         and closing_bracket.type == token.RPAR
2111         and not closing_bracket.value
2112         # there are no standalone comments in the body
2113         and not line.contains_standalone_comments(0)
2114         # and it's not an import (optional parens are the only thing we can split
2115         # on in this case; attempting a split without them is a waste of time)
2116         and not line.is_import
2117     ):
2118         omit = {id(closing_bracket), *omit}
2119         if can_omit_invisible_parens(body, line_length):
2120             try:
2121                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2122                 return
2123             except CannotSplit:
2124                 pass
2125
2126     ensure_visible(opening_bracket)
2127     ensure_visible(closing_bracket)
2128     body.should_explode = should_explode(body, opening_bracket)
2129     for result in (head, body, tail):
2130         if result:
2131             yield result
2132
2133
2134 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2135     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2136
2137     Do nothing otherwise.
2138
2139     A left- or right-hand split is based on a pair of brackets. Content before
2140     (and including) the opening bracket is left on one line, content inside the
2141     brackets is put on a separate line, and finally content starting with and
2142     following the closing bracket is put on a separate line.
2143
2144     Those are called `head`, `body`, and `tail`, respectively. If the split
2145     produced the same line (all content in `head`) or ended up with an empty `body`
2146     and the `tail` is just the closing bracket, then it's considered failed.
2147     """
2148     tail_len = len(str(tail).strip())
2149     if not body:
2150         if tail_len == 0:
2151             raise CannotSplit("Splitting brackets produced the same line")
2152
2153         elif tail_len < 3:
2154             raise CannotSplit(
2155                 f"Splitting brackets on an empty body to save "
2156                 f"{tail_len} characters is not worth it"
2157             )
2158
2159
2160 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2161     """Normalize prefix of the first leaf in every line returned by `split_func`.
2162
2163     This is a decorator over relevant split functions.
2164     """
2165
2166     @wraps(split_func)
2167     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2168         for l in split_func(line, py36):
2169             normalize_prefix(l.leaves[0], inside_brackets=True)
2170             yield l
2171
2172     return split_wrapper
2173
2174
2175 @dont_increase_indentation
2176 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2177     """Split according to delimiters of the highest priority.
2178
2179     If `py36` is True, the split will add trailing commas also in function
2180     signatures that contain `*` and `**`.
2181     """
2182     try:
2183         last_leaf = line.leaves[-1]
2184     except IndexError:
2185         raise CannotSplit("Line empty")
2186
2187     bt = line.bracket_tracker
2188     try:
2189         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2190     except ValueError:
2191         raise CannotSplit("No delimiters found")
2192
2193     if delimiter_priority == DOT_PRIORITY:
2194         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2195             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2196
2197     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2198     lowest_depth = sys.maxsize
2199     trailing_comma_safe = True
2200
2201     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2202         """Append `leaf` to current line or to new line if appending impossible."""
2203         nonlocal current_line
2204         try:
2205             current_line.append_safe(leaf, preformatted=True)
2206         except ValueError as ve:
2207             yield current_line
2208
2209             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2210             current_line.append(leaf)
2211
2212     for index, leaf in enumerate(line.leaves):
2213         yield from append_to_line(leaf)
2214
2215         for comment_after in line.comments_after(leaf, index):
2216             yield from append_to_line(comment_after)
2217
2218         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2219         if leaf.bracket_depth == lowest_depth and is_vararg(
2220             leaf, within=VARARGS_PARENTS
2221         ):
2222             trailing_comma_safe = trailing_comma_safe and py36
2223         leaf_priority = bt.delimiters.get(id(leaf))
2224         if leaf_priority == delimiter_priority:
2225             yield current_line
2226
2227             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2228     if current_line:
2229         if (
2230             trailing_comma_safe
2231             and delimiter_priority == COMMA_PRIORITY
2232             and current_line.leaves[-1].type != token.COMMA
2233             and current_line.leaves[-1].type != STANDALONE_COMMENT
2234         ):
2235             current_line.append(Leaf(token.COMMA, ","))
2236         yield current_line
2237
2238
2239 @dont_increase_indentation
2240 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2241     """Split standalone comments from the rest of the line."""
2242     if not line.contains_standalone_comments(0):
2243         raise CannotSplit("Line does not have any standalone comments")
2244
2245     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2246
2247     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2248         """Append `leaf` to current line or to new line if appending impossible."""
2249         nonlocal current_line
2250         try:
2251             current_line.append_safe(leaf, preformatted=True)
2252         except ValueError as ve:
2253             yield current_line
2254
2255             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2256             current_line.append(leaf)
2257
2258     for index, leaf in enumerate(line.leaves):
2259         yield from append_to_line(leaf)
2260
2261         for comment_after in line.comments_after(leaf, index):
2262             yield from append_to_line(comment_after)
2263
2264     if current_line:
2265         yield current_line
2266
2267
2268 def is_import(leaf: Leaf) -> bool:
2269     """Return True if the given leaf starts an import statement."""
2270     p = leaf.parent
2271     t = leaf.type
2272     v = leaf.value
2273     return bool(
2274         t == token.NAME
2275         and (
2276             (v == "import" and p and p.type == syms.import_name)
2277             or (v == "from" and p and p.type == syms.import_from)
2278         )
2279     )
2280
2281
2282 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2283     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2284     else.
2285
2286     Note: don't use backslashes for formatting or you'll lose your voting rights.
2287     """
2288     if not inside_brackets:
2289         spl = leaf.prefix.split("#")
2290         if "\\" not in spl[0]:
2291             nl_count = spl[-1].count("\n")
2292             if len(spl) > 1:
2293                 nl_count -= 1
2294             leaf.prefix = "\n" * nl_count
2295             return
2296
2297     leaf.prefix = ""
2298
2299
2300 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2301     """Make all string prefixes lowercase.
2302
2303     If remove_u_prefix is given, also removes any u prefix from the string.
2304
2305     Note: Mutates its argument.
2306     """
2307     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2308     assert match is not None, f"failed to match string {leaf.value!r}"
2309     orig_prefix = match.group(1)
2310     new_prefix = orig_prefix.lower()
2311     if remove_u_prefix:
2312         new_prefix = new_prefix.replace("u", "")
2313     leaf.value = f"{new_prefix}{match.group(2)}"
2314
2315
2316 def normalize_string_quotes(leaf: Leaf) -> None:
2317     """Prefer double quotes but only if it doesn't cause more escaping.
2318
2319     Adds or removes backslashes as appropriate. Doesn't parse and fix
2320     strings nested in f-strings (yet).
2321
2322     Note: Mutates its argument.
2323     """
2324     value = leaf.value.lstrip("furbFURB")
2325     if value[:3] == '"""':
2326         return
2327
2328     elif value[:3] == "'''":
2329         orig_quote = "'''"
2330         new_quote = '"""'
2331     elif value[0] == '"':
2332         orig_quote = '"'
2333         new_quote = "'"
2334     else:
2335         orig_quote = "'"
2336         new_quote = '"'
2337     first_quote_pos = leaf.value.find(orig_quote)
2338     if first_quote_pos == -1:
2339         return  # There's an internal error
2340
2341     prefix = leaf.value[:first_quote_pos]
2342     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2343     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2344     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2345     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2346     if "r" in prefix.casefold():
2347         if unescaped_new_quote.search(body):
2348             # There's at least one unescaped new_quote in this raw string
2349             # so converting is impossible
2350             return
2351
2352         # Do not introduce or remove backslashes in raw strings
2353         new_body = body
2354     else:
2355         # remove unnecessary quotes
2356         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2357         if body != new_body:
2358             # Consider the string without unnecessary quotes as the original
2359             body = new_body
2360             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2361         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2362         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2363     if new_quote == '"""' and new_body[-1] == '"':
2364         # edge case:
2365         new_body = new_body[:-1] + '\\"'
2366     orig_escape_count = body.count("\\")
2367     new_escape_count = new_body.count("\\")
2368     if new_escape_count > orig_escape_count:
2369         return  # Do not introduce more escaping
2370
2371     if new_escape_count == orig_escape_count and orig_quote == '"':
2372         return  # Prefer double quotes
2373
2374     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2375
2376
2377 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2378     """Make existing optional parentheses invisible or create new ones.
2379
2380     `parens_after` is a set of string leaf values immeditely after which parens
2381     should be put.
2382
2383     Standardizes on visible parentheses for single-element tuples, and keeps
2384     existing visible parentheses for other tuples and generator expressions.
2385     """
2386     try:
2387         list(generate_comments(node))
2388     except FormatOff:
2389         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2390
2391     check_lpar = False
2392     for index, child in enumerate(list(node.children)):
2393         if check_lpar:
2394             if child.type == syms.atom:
2395                 maybe_make_parens_invisible_in_atom(child)
2396             elif is_one_tuple(child):
2397                 # wrap child in visible parentheses
2398                 lpar = Leaf(token.LPAR, "(")
2399                 rpar = Leaf(token.RPAR, ")")
2400                 child.remove()
2401                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2402             elif node.type == syms.import_from:
2403                 # "import from" nodes store parentheses directly as part of
2404                 # the statement
2405                 if child.type == token.LPAR:
2406                     # make parentheses invisible
2407                     child.value = ""  # type: ignore
2408                     node.children[-1].value = ""  # type: ignore
2409                 elif child.type != token.STAR:
2410                     # insert invisible parentheses
2411                     node.insert_child(index, Leaf(token.LPAR, ""))
2412                     node.append_child(Leaf(token.RPAR, ""))
2413                 break
2414
2415             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2416                 # wrap child in invisible parentheses
2417                 lpar = Leaf(token.LPAR, "")
2418                 rpar = Leaf(token.RPAR, "")
2419                 index = child.remove() or 0
2420                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2421
2422         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2423
2424
2425 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2426     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2427     if (
2428         node.type != syms.atom
2429         or is_empty_tuple(node)
2430         or is_one_tuple(node)
2431         or is_yield(node)
2432         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2433     ):
2434         return False
2435
2436     first = node.children[0]
2437     last = node.children[-1]
2438     if first.type == token.LPAR and last.type == token.RPAR:
2439         # make parentheses invisible
2440         first.value = ""  # type: ignore
2441         last.value = ""  # type: ignore
2442         if len(node.children) > 1:
2443             maybe_make_parens_invisible_in_atom(node.children[1])
2444         return True
2445
2446     return False
2447
2448
2449 def is_empty_tuple(node: LN) -> bool:
2450     """Return True if `node` holds an empty tuple."""
2451     return (
2452         node.type == syms.atom
2453         and len(node.children) == 2
2454         and node.children[0].type == token.LPAR
2455         and node.children[1].type == token.RPAR
2456     )
2457
2458
2459 def is_one_tuple(node: LN) -> bool:
2460     """Return True if `node` holds a tuple with one element, with or without parens."""
2461     if node.type == syms.atom:
2462         if len(node.children) != 3:
2463             return False
2464
2465         lpar, gexp, rpar = node.children
2466         if not (
2467             lpar.type == token.LPAR
2468             and gexp.type == syms.testlist_gexp
2469             and rpar.type == token.RPAR
2470         ):
2471             return False
2472
2473         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2474
2475     return (
2476         node.type in IMPLICIT_TUPLE
2477         and len(node.children) == 2
2478         and node.children[1].type == token.COMMA
2479     )
2480
2481
2482 def is_yield(node: LN) -> bool:
2483     """Return True if `node` holds a `yield` or `yield from` expression."""
2484     if node.type == syms.yield_expr:
2485         return True
2486
2487     if node.type == token.NAME and node.value == "yield":  # type: ignore
2488         return True
2489
2490     if node.type != syms.atom:
2491         return False
2492
2493     if len(node.children) != 3:
2494         return False
2495
2496     lpar, expr, rpar = node.children
2497     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2498         return is_yield(expr)
2499
2500     return False
2501
2502
2503 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2504     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2505
2506     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2507     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2508     extended iterable unpacking (PEP 3132) and additional unpacking
2509     generalizations (PEP 448).
2510     """
2511     if leaf.type not in STARS or not leaf.parent:
2512         return False
2513
2514     p = leaf.parent
2515     if p.type == syms.star_expr:
2516         # Star expressions are also used as assignment targets in extended
2517         # iterable unpacking (PEP 3132).  See what its parent is instead.
2518         if not p.parent:
2519             return False
2520
2521         p = p.parent
2522
2523     return p.type in within
2524
2525
2526 def is_multiline_string(leaf: Leaf) -> bool:
2527     """Return True if `leaf` is a multiline string that actually spans many lines."""
2528     value = leaf.value.lstrip("furbFURB")
2529     return value[:3] in {'"""', "'''"} and "\n" in value
2530
2531
2532 def is_stub_suite(node: Node) -> bool:
2533     """Return True if `node` is a suite with a stub body."""
2534     if (
2535         len(node.children) != 4
2536         or node.children[0].type != token.NEWLINE
2537         or node.children[1].type != token.INDENT
2538         or node.children[3].type != token.DEDENT
2539     ):
2540         return False
2541
2542     return is_stub_body(node.children[2])
2543
2544
2545 def is_stub_body(node: LN) -> bool:
2546     """Return True if `node` is a simple statement containing an ellipsis."""
2547     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2548         return False
2549
2550     if len(node.children) != 2:
2551         return False
2552
2553     child = node.children[0]
2554     return (
2555         child.type == syms.atom
2556         and len(child.children) == 3
2557         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2558     )
2559
2560
2561 def max_delimiter_priority_in_atom(node: LN) -> int:
2562     """Return maximum delimiter priority inside `node`.
2563
2564     This is specific to atoms with contents contained in a pair of parentheses.
2565     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2566     """
2567     if node.type != syms.atom:
2568         return 0
2569
2570     first = node.children[0]
2571     last = node.children[-1]
2572     if not (first.type == token.LPAR and last.type == token.RPAR):
2573         return 0
2574
2575     bt = BracketTracker()
2576     for c in node.children[1:-1]:
2577         if isinstance(c, Leaf):
2578             bt.mark(c)
2579         else:
2580             for leaf in c.leaves():
2581                 bt.mark(leaf)
2582     try:
2583         return bt.max_delimiter_priority()
2584
2585     except ValueError:
2586         return 0
2587
2588
2589 def ensure_visible(leaf: Leaf) -> None:
2590     """Make sure parentheses are visible.
2591
2592     They could be invisible as part of some statements (see
2593     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2594     """
2595     if leaf.type == token.LPAR:
2596         leaf.value = "("
2597     elif leaf.type == token.RPAR:
2598         leaf.value = ")"
2599
2600
2601 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2602     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2603     if not (
2604         opening_bracket.parent
2605         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2606         and opening_bracket.value in "[{("
2607     ):
2608         return False
2609
2610     try:
2611         last_leaf = line.leaves[-1]
2612         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2613         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2614     except (IndexError, ValueError):
2615         return False
2616
2617     return max_priority == COMMA_PRIORITY
2618
2619
2620 def is_python36(node: Node) -> bool:
2621     """Return True if the current file is using Python 3.6+ features.
2622
2623     Currently looking for:
2624     - f-strings; and
2625     - trailing commas after * or ** in function signatures and calls.
2626     """
2627     for n in node.pre_order():
2628         if n.type == token.STRING:
2629             value_head = n.value[:2]  # type: ignore
2630             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2631                 return True
2632
2633         elif (
2634             n.type in {syms.typedargslist, syms.arglist}
2635             and n.children
2636             and n.children[-1].type == token.COMMA
2637         ):
2638             for ch in n.children:
2639                 if ch.type in STARS:
2640                     return True
2641
2642                 if ch.type == syms.argument:
2643                     for argch in ch.children:
2644                         if argch.type in STARS:
2645                             return True
2646
2647     return False
2648
2649
2650 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2651     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2652
2653     Brackets can be omitted if the entire trailer up to and including
2654     a preceding closing bracket fits in one line.
2655
2656     Yielded sets are cumulative (contain results of previous yields, too).  First
2657     set is empty.
2658     """
2659
2660     omit: Set[LeafID] = set()
2661     yield omit
2662
2663     length = 4 * line.depth
2664     opening_bracket = None
2665     closing_bracket = None
2666     optional_brackets: Set[LeafID] = set()
2667     inner_brackets: Set[LeafID] = set()
2668     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2669         length += leaf_length
2670         if length > line_length:
2671             break
2672
2673         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2674         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2675             break
2676
2677         optional_brackets.discard(id(leaf))
2678         if opening_bracket:
2679             if leaf is opening_bracket:
2680                 opening_bracket = None
2681             elif leaf.type in CLOSING_BRACKETS:
2682                 inner_brackets.add(id(leaf))
2683         elif leaf.type in CLOSING_BRACKETS:
2684             if not leaf.value:
2685                 optional_brackets.add(id(opening_bracket))
2686                 continue
2687
2688             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2689                 # Empty brackets would fail a split so treat them as "inner"
2690                 # brackets (e.g. only add them to the `omit` set if another
2691                 # pair of brackets was good enough.
2692                 inner_brackets.add(id(leaf))
2693                 continue
2694
2695             opening_bracket = leaf.opening_bracket
2696             if closing_bracket:
2697                 omit.add(id(closing_bracket))
2698                 omit.update(inner_brackets)
2699                 inner_brackets.clear()
2700                 yield omit
2701             closing_bracket = leaf
2702
2703
2704 def get_future_imports(node: Node) -> Set[str]:
2705     """Return a set of __future__ imports in the file."""
2706     imports = set()
2707     for child in node.children:
2708         if child.type != syms.simple_stmt:
2709             break
2710         first_child = child.children[0]
2711         if isinstance(first_child, Leaf):
2712             # Continue looking if we see a docstring; otherwise stop.
2713             if (
2714                 len(child.children) == 2
2715                 and first_child.type == token.STRING
2716                 and child.children[1].type == token.NEWLINE
2717             ):
2718                 continue
2719             else:
2720                 break
2721         elif first_child.type == syms.import_from:
2722             module_name = first_child.children[1]
2723             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2724                 break
2725             for import_from_child in first_child.children[3:]:
2726                 if isinstance(import_from_child, Leaf):
2727                     if import_from_child.type == token.NAME:
2728                         imports.add(import_from_child.value)
2729                 else:
2730                     assert import_from_child.type == syms.import_as_names
2731                     for leaf in import_from_child.children:
2732                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2733                             imports.add(leaf.value)
2734         else:
2735             break
2736     return imports
2737
2738
2739 PYTHON_EXTENSIONS = {".py", ".pyi"}
2740 BLACKLISTED_DIRECTORIES = {
2741     "build",
2742     "buck-out",
2743     "dist",
2744     "_build",
2745     ".git",
2746     ".hg",
2747     ".mypy_cache",
2748     ".tox",
2749     ".venv",
2750 }
2751
2752
2753 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2754     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2755     and have one of the PYTHON_EXTENSIONS.
2756     """
2757     for child in path.iterdir():
2758         if child.is_dir():
2759             if child.name in BLACKLISTED_DIRECTORIES:
2760                 continue
2761
2762             yield from gen_python_files_in_dir(child)
2763
2764         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2765             yield child
2766
2767
2768 @dataclass
2769 class Report:
2770     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2771
2772     check: bool = False
2773     quiet: bool = False
2774     change_count: int = 0
2775     same_count: int = 0
2776     failure_count: int = 0
2777
2778     def done(self, src: Path, changed: Changed) -> None:
2779         """Increment the counter for successful reformatting. Write out a message."""
2780         if changed is Changed.YES:
2781             reformatted = "would reformat" if self.check else "reformatted"
2782             if not self.quiet:
2783                 out(f"{reformatted} {src}")
2784             self.change_count += 1
2785         else:
2786             if not self.quiet:
2787                 if changed is Changed.NO:
2788                     msg = f"{src} already well formatted, good job."
2789                 else:
2790                     msg = f"{src} wasn't modified on disk since last run."
2791                 out(msg, bold=False)
2792             self.same_count += 1
2793
2794     def failed(self, src: Path, message: str) -> None:
2795         """Increment the counter for failed reformatting. Write out a message."""
2796         err(f"error: cannot format {src}: {message}")
2797         self.failure_count += 1
2798
2799     @property
2800     def return_code(self) -> int:
2801         """Return the exit code that the app should use.
2802
2803         This considers the current state of changed files and failures:
2804         - if there were any failures, return 123;
2805         - if any files were changed and --check is being used, return 1;
2806         - otherwise return 0.
2807         """
2808         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2809         # 126 we have special returncodes reserved by the shell.
2810         if self.failure_count:
2811             return 123
2812
2813         elif self.change_count and self.check:
2814             return 1
2815
2816         return 0
2817
2818     def __str__(self) -> str:
2819         """Render a color report of the current state.
2820
2821         Use `click.unstyle` to remove colors.
2822         """
2823         if self.check:
2824             reformatted = "would be reformatted"
2825             unchanged = "would be left unchanged"
2826             failed = "would fail to reformat"
2827         else:
2828             reformatted = "reformatted"
2829             unchanged = "left unchanged"
2830             failed = "failed to reformat"
2831         report = []
2832         if self.change_count:
2833             s = "s" if self.change_count > 1 else ""
2834             report.append(
2835                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2836             )
2837         if self.same_count:
2838             s = "s" if self.same_count > 1 else ""
2839             report.append(f"{self.same_count} file{s} {unchanged}")
2840         if self.failure_count:
2841             s = "s" if self.failure_count > 1 else ""
2842             report.append(
2843                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2844             )
2845         return ", ".join(report) + "."
2846
2847
2848 def assert_equivalent(src: str, dst: str) -> None:
2849     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2850
2851     import ast
2852     import traceback
2853
2854     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2855         """Simple visitor generating strings to compare ASTs by content."""
2856         yield f"{'  ' * depth}{node.__class__.__name__}("
2857
2858         for field in sorted(node._fields):
2859             try:
2860                 value = getattr(node, field)
2861             except AttributeError:
2862                 continue
2863
2864             yield f"{'  ' * (depth+1)}{field}="
2865
2866             if isinstance(value, list):
2867                 for item in value:
2868                     if isinstance(item, ast.AST):
2869                         yield from _v(item, depth + 2)
2870
2871             elif isinstance(value, ast.AST):
2872                 yield from _v(value, depth + 2)
2873
2874             else:
2875                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2876
2877         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2878
2879     try:
2880         src_ast = ast.parse(src)
2881     except Exception as exc:
2882         major, minor = sys.version_info[:2]
2883         raise AssertionError(
2884             f"cannot use --safe with this file; failed to parse source file "
2885             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2886             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2887         )
2888
2889     try:
2890         dst_ast = ast.parse(dst)
2891     except Exception as exc:
2892         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2893         raise AssertionError(
2894             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2895             f"Please report a bug on https://github.com/ambv/black/issues.  "
2896             f"This invalid output might be helpful: {log}"
2897         ) from None
2898
2899     src_ast_str = "\n".join(_v(src_ast))
2900     dst_ast_str = "\n".join(_v(dst_ast))
2901     if src_ast_str != dst_ast_str:
2902         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2903         raise AssertionError(
2904             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2905             f"the source.  "
2906             f"Please report a bug on https://github.com/ambv/black/issues.  "
2907             f"This diff might be helpful: {log}"
2908         ) from None
2909
2910
2911 def assert_stable(
2912     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
2913 ) -> None:
2914     """Raise AssertionError if `dst` reformats differently the second time."""
2915     newdst = format_str(dst, line_length=line_length, mode=mode)
2916     if dst != newdst:
2917         log = dump_to_file(
2918             diff(src, dst, "source", "first pass"),
2919             diff(dst, newdst, "first pass", "second pass"),
2920         )
2921         raise AssertionError(
2922             f"INTERNAL ERROR: Black produced different code on the second pass "
2923             f"of the formatter.  "
2924             f"Please report a bug on https://github.com/ambv/black/issues.  "
2925             f"This diff might be helpful: {log}"
2926         ) from None
2927
2928
2929 def dump_to_file(*output: str) -> str:
2930     """Dump `output` to a temporary file. Return path to the file."""
2931     import tempfile
2932
2933     with tempfile.NamedTemporaryFile(
2934         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2935     ) as f:
2936         for lines in output:
2937             f.write(lines)
2938             if lines and lines[-1] != "\n":
2939                 f.write("\n")
2940     return f.name
2941
2942
2943 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2944     """Return a unified diff string between strings `a` and `b`."""
2945     import difflib
2946
2947     a_lines = [line + "\n" for line in a.split("\n")]
2948     b_lines = [line + "\n" for line in b.split("\n")]
2949     return "".join(
2950         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2951     )
2952
2953
2954 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2955     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2956     err("Aborted!")
2957     for task in tasks:
2958         task.cancel()
2959
2960
2961 def shutdown(loop: BaseEventLoop) -> None:
2962     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2963     try:
2964         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2965         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2966         if not to_cancel:
2967             return
2968
2969         for task in to_cancel:
2970             task.cancel()
2971         loop.run_until_complete(
2972             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2973         )
2974     finally:
2975         # `concurrent.futures.Future` objects cannot be cancelled once they
2976         # are already running. There might be some when the `shutdown()` happened.
2977         # Silence their logger's spew about the event loop being closed.
2978         cf_logger = logging.getLogger("concurrent.futures")
2979         cf_logger.setLevel(logging.CRITICAL)
2980         loop.close()
2981
2982
2983 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2984     """Replace `regex` with `replacement` twice on `original`.
2985
2986     This is used by string normalization to perform replaces on
2987     overlapping matches.
2988     """
2989     return regex.sub(replacement, regex.sub(replacement, original))
2990
2991
2992 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2993     """Like `reversed(enumerate(sequence))` if that were possible."""
2994     index = len(sequence) - 1
2995     for element in reversed(sequence):
2996         yield (index, element)
2997         index -= 1
2998
2999
3000 def enumerate_with_length(
3001     line: Line, reversed: bool = False
3002 ) -> Iterator[Tuple[Index, Leaf, int]]:
3003     """Return an enumeration of leaves with their length.
3004
3005     Stops prematurely on multiline strings and standalone comments.
3006     """
3007     op = cast(
3008         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3009         enumerate_reversed if reversed else enumerate,
3010     )
3011     for index, leaf in op(line.leaves):
3012         length = len(leaf.prefix) + len(leaf.value)
3013         if "\n" in leaf.value:
3014             return  # Multiline strings, we can't continue.
3015
3016         comment: Optional[Leaf]
3017         for comment in line.comments_after(leaf, index):
3018             length += len(comment.value)
3019
3020         yield index, leaf, length
3021
3022
3023 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3024     """Return True if `line` is no longer than `line_length`.
3025
3026     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3027     """
3028     if not line_str:
3029         line_str = str(line).strip("\n")
3030     return (
3031         len(line_str) <= line_length
3032         and "\n" not in line_str  # multiline strings
3033         and not line.contains_standalone_comments()
3034     )
3035
3036
3037 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3038     """Does `line` have a shape safe to reformat without optional parens around it?
3039
3040     Returns True for only a subset of potentially nice looking formattings but
3041     the point is to not return false positives that end up producing lines that
3042     are too long.
3043     """
3044     bt = line.bracket_tracker
3045     if not bt.delimiters:
3046         # Without delimiters the optional parentheses are useless.
3047         return True
3048
3049     max_priority = bt.max_delimiter_priority()
3050     if bt.delimiter_count_with_priority(max_priority) > 1:
3051         # With more than one delimiter of a kind the optional parentheses read better.
3052         return False
3053
3054     if max_priority == DOT_PRIORITY:
3055         # A single stranded method call doesn't require optional parentheses.
3056         return True
3057
3058     assert len(line.leaves) >= 2, "Stranded delimiter"
3059
3060     first = line.leaves[0]
3061     second = line.leaves[1]
3062     penultimate = line.leaves[-2]
3063     last = line.leaves[-1]
3064
3065     # With a single delimiter, omit if the expression starts or ends with
3066     # a bracket.
3067     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3068         remainder = False
3069         length = 4 * line.depth
3070         for _index, leaf, leaf_length in enumerate_with_length(line):
3071             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3072                 remainder = True
3073             if remainder:
3074                 length += leaf_length
3075                 if length > line_length:
3076                     break
3077
3078                 if leaf.type in OPENING_BRACKETS:
3079                     # There are brackets we can further split on.
3080                     remainder = False
3081
3082         else:
3083             # checked the entire string and line length wasn't exceeded
3084             if len(line.leaves) == _index + 1:
3085                 return True
3086
3087         # Note: we are not returning False here because a line might have *both*
3088         # a leading opening bracket and a trailing closing bracket.  If the
3089         # opening bracket doesn't match our rule, maybe the closing will.
3090
3091     if (
3092         last.type == token.RPAR
3093         or last.type == token.RBRACE
3094         or (
3095             # don't use indexing for omitting optional parentheses;
3096             # it looks weird
3097             last.type == token.RSQB
3098             and last.parent
3099             and last.parent.type != syms.trailer
3100         )
3101     ):
3102         if penultimate.type in OPENING_BRACKETS:
3103             # Empty brackets don't help.
3104             return False
3105
3106         if is_multiline_string(first):
3107             # Additional wrapping of a multiline string in this situation is
3108             # unnecessary.
3109             return True
3110
3111         length = 4 * line.depth
3112         seen_other_brackets = False
3113         for _index, leaf, leaf_length in enumerate_with_length(line):
3114             length += leaf_length
3115             if leaf is last.opening_bracket:
3116                 if seen_other_brackets or length <= line_length:
3117                     return True
3118
3119             elif leaf.type in OPENING_BRACKETS:
3120                 # There are brackets we can further split on.
3121                 seen_other_brackets = True
3122
3123     return False
3124
3125
3126 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3127     pyi = bool(mode & FileMode.PYI)
3128     py36 = bool(mode & FileMode.PYTHON36)
3129     return (
3130         CACHE_DIR
3131         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3132     )
3133
3134
3135 def read_cache(line_length: int, mode: FileMode) -> Cache:
3136     """Read the cache if it exists and is well formed.
3137
3138     If it is not well formed, the call to write_cache later should resolve the issue.
3139     """
3140     cache_file = get_cache_file(line_length, mode)
3141     if not cache_file.exists():
3142         return {}
3143
3144     with cache_file.open("rb") as fobj:
3145         try:
3146             cache: Cache = pickle.load(fobj)
3147         except pickle.UnpicklingError:
3148             return {}
3149
3150     return cache
3151
3152
3153 def get_cache_info(path: Path) -> CacheInfo:
3154     """Return the information used to check if a file is already formatted or not."""
3155     stat = path.stat()
3156     return stat.st_mtime, stat.st_size
3157
3158
3159 def filter_cached(
3160     cache: Cache, sources: Iterable[Path]
3161 ) -> Tuple[List[Path], List[Path]]:
3162     """Split a list of paths into two.
3163
3164     The first list contains paths of files that modified on disk or are not in the
3165     cache. The other list contains paths to non-modified files.
3166     """
3167     todo, done = [], []
3168     for src in sources:
3169         src = src.resolve()
3170         if cache.get(src) != get_cache_info(src):
3171             todo.append(src)
3172         else:
3173             done.append(src)
3174     return todo, done
3175
3176
3177 def write_cache(
3178     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3179 ) -> None:
3180     """Update the cache file."""
3181     cache_file = get_cache_file(line_length, mode)
3182     try:
3183         if not CACHE_DIR.exists():
3184             CACHE_DIR.mkdir(parents=True)
3185         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3186         with cache_file.open("wb") as fobj:
3187             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3188     except OSError:
3189         pass
3190
3191
3192 if __name__ == "__main__":
3193     main()