]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Introduce "project root" as a concept
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b1"
48 DEFAULT_LINE_LENGTH = 88
49 DEFAULT_EXCLUDES = (
50     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
51 )
52 DEFAULT_INCLUDES = r"\.pyi?$"
53 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
54
55
56 # types
57 FileContent = str
58 Encoding = str
59 Depth = int
60 NodeType = int
61 LeafID = int
62 Priority = int
63 Index = int
64 LN = Union[Leaf, Node]
65 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
66 Timestamp = float
67 FileSize = int
68 CacheInfo = Tuple[Timestamp, FileSize]
69 Cache = Dict[Path, CacheInfo]
70 out = partial(click.secho, bold=True, err=True)
71 err = partial(click.secho, fg="red", err=True)
72
73 pygram.initialize(CACHE_DIR)
74 syms = pygram.python_symbols
75
76
77 class NothingChanged(UserWarning):
78     """Raised by :func:`format_file` when reformatted code is the same as source."""
79
80
81 class CannotSplit(Exception):
82     """A readable split that fits the allotted line length is impossible.
83
84     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
85     :func:`delimiter_split`.
86     """
87
88
89 class FormatError(Exception):
90     """Base exception for `# fmt: on` and `# fmt: off` handling.
91
92     It holds the number of bytes of the prefix consumed before the format
93     control comment appeared.
94     """
95
96     def __init__(self, consumed: int) -> None:
97         super().__init__(consumed)
98         self.consumed = consumed
99
100     def trim_prefix(self, leaf: Leaf) -> None:
101         leaf.prefix = leaf.prefix[self.consumed :]
102
103     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
104         """Returns a new Leaf from the consumed part of the prefix."""
105         unformatted_prefix = leaf.prefix[: self.consumed]
106         return Leaf(token.NEWLINE, unformatted_prefix)
107
108
109 class FormatOn(FormatError):
110     """Found a comment like `# fmt: on` in the file."""
111
112
113 class FormatOff(FormatError):
114     """Found a comment like `# fmt: off` in the file."""
115
116
117 class WriteBack(Enum):
118     NO = 0
119     YES = 1
120     DIFF = 2
121
122
123 class Changed(Enum):
124     NO = 0
125     CACHED = 1
126     YES = 2
127
128
129 class FileMode(Flag):
130     AUTO_DETECT = 0
131     PYTHON36 = 1
132     PYI = 2
133     NO_STRING_NORMALIZATION = 4
134
135
136 @click.command()
137 @click.option(
138     "-l",
139     "--line-length",
140     type=int,
141     default=DEFAULT_LINE_LENGTH,
142     help="How many character per line to allow.",
143     show_default=True,
144 )
145 @click.option(
146     "--py36",
147     is_flag=True,
148     help=(
149         "Allow using Python 3.6-only syntax on all input files.  This will put "
150         "trailing commas in function signatures and calls also after *args and "
151         "**kwargs.  [default: per-file auto-detection]"
152     ),
153 )
154 @click.option(
155     "--pyi",
156     is_flag=True,
157     help=(
158         "Format all input files like typing stubs regardless of file extension "
159         "(useful when piping source on standard input)."
160     ),
161 )
162 @click.option(
163     "-S",
164     "--skip-string-normalization",
165     is_flag=True,
166     help="Don't normalize string quotes or prefixes.",
167 )
168 @click.option(
169     "--check",
170     is_flag=True,
171     help=(
172         "Don't write the files back, just return the status.  Return code 0 "
173         "means nothing would change.  Return code 1 means some files would be "
174         "reformatted.  Return code 123 means there was an internal error."
175     ),
176 )
177 @click.option(
178     "--diff",
179     is_flag=True,
180     help="Don't write the files back, just output a diff for each file on stdout.",
181 )
182 @click.option(
183     "--fast/--safe",
184     is_flag=True,
185     help="If --fast given, skip temporary sanity checks. [default: --safe]",
186 )
187 @click.option(
188     "--include",
189     type=str,
190     default=DEFAULT_INCLUDES,
191     help=(
192         "A regular expression that matches files and directories that should be "
193         "included on recursive searches.  An empty value means all files are "
194         "included regardless of the name.  Use forward slashes for directories on "
195         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
196         "later."
197     ),
198     show_default=True,
199 )
200 @click.option(
201     "--exclude",
202     type=str,
203     default=DEFAULT_EXCLUDES,
204     help=(
205         "A regular expression that matches files and directories that should be "
206         "excluded on recursive searches.  An empty value means no paths are excluded. "
207         "Use forward slashes for directories on all platforms (Windows, too).  "
208         "Exclusions are calculated first, inclusions later."
209     ),
210     show_default=True,
211 )
212 @click.option(
213     "-q",
214     "--quiet",
215     is_flag=True,
216     help=(
217         "Don't emit non-error messages to stderr. Errors are still emitted, "
218         "silence those with 2>/dev/null."
219     ),
220 )
221 @click.version_option(version=__version__)
222 @click.argument(
223     "src",
224     nargs=-1,
225     type=click.Path(
226         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
227     ),
228 )
229 @click.pass_context
230 def main(
231     ctx: click.Context,
232     line_length: int,
233     check: bool,
234     diff: bool,
235     fast: bool,
236     pyi: bool,
237     py36: bool,
238     skip_string_normalization: bool,
239     quiet: bool,
240     include: str,
241     exclude: str,
242     src: List[str],
243 ) -> None:
244     """The uncompromising code formatter."""
245     sources: List[Path] = []
246     try:
247         include_regex = re.compile(include)
248     except re.error:
249         err(f"Invalid regular expression for include given: {include!r}")
250         ctx.exit(2)
251     try:
252         exclude_regex = re.compile(exclude)
253     except re.error:
254         err(f"Invalid regular expression for exclude given: {exclude!r}")
255         ctx.exit(2)
256     root = find_project_root(src)
257     for s in src:
258         p = Path(s)
259         if p.is_dir():
260             sources.extend(
261                 gen_python_files_in_dir(p, root, include_regex, exclude_regex)
262             )
263         elif p.is_file():
264             # if a file was explicitly given, we don't care about its extension
265             sources.append(p)
266         elif s == "-":
267             sources.append(Path("-"))
268         else:
269             err(f"invalid path: {s}")
270
271     if check and not diff:
272         write_back = WriteBack.NO
273     elif diff:
274         write_back = WriteBack.DIFF
275     else:
276         write_back = WriteBack.YES
277     mode = FileMode.AUTO_DETECT
278     if py36:
279         mode |= FileMode.PYTHON36
280     if pyi:
281         mode |= FileMode.PYI
282     if skip_string_normalization:
283         mode |= FileMode.NO_STRING_NORMALIZATION
284     report = Report(check=check, quiet=quiet)
285     if len(sources) == 0:
286         out("No paths given. Nothing to do 😴")
287         ctx.exit(0)
288         return
289
290     elif len(sources) == 1:
291         reformat_one(
292             src=sources[0],
293             line_length=line_length,
294             fast=fast,
295             write_back=write_back,
296             mode=mode,
297             report=report,
298         )
299     else:
300         loop = asyncio.get_event_loop()
301         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
302         try:
303             loop.run_until_complete(
304                 schedule_formatting(
305                     sources=sources,
306                     line_length=line_length,
307                     fast=fast,
308                     write_back=write_back,
309                     mode=mode,
310                     report=report,
311                     loop=loop,
312                     executor=executor,
313                 )
314             )
315         finally:
316             shutdown(loop)
317         if not quiet:
318             out("All done! ✨ 🍰 ✨")
319             click.echo(str(report))
320     ctx.exit(report.return_code)
321
322
323 def reformat_one(
324     src: Path,
325     line_length: int,
326     fast: bool,
327     write_back: WriteBack,
328     mode: FileMode,
329     report: "Report",
330 ) -> None:
331     """Reformat a single file under `src` without spawning child processes.
332
333     If `quiet` is True, non-error messages are not output. `line_length`,
334     `write_back`, `fast` and `pyi` options are passed to
335     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
336     """
337     try:
338         changed = Changed.NO
339         if not src.is_file() and str(src) == "-":
340             if format_stdin_to_stdout(
341                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
342             ):
343                 changed = Changed.YES
344         else:
345             cache: Cache = {}
346             if write_back != WriteBack.DIFF:
347                 cache = read_cache(line_length, mode)
348                 res_src = src.resolve()
349                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
350                     changed = Changed.CACHED
351             if changed is not Changed.CACHED and format_file_in_place(
352                 src,
353                 line_length=line_length,
354                 fast=fast,
355                 write_back=write_back,
356                 mode=mode,
357             ):
358                 changed = Changed.YES
359             if write_back == WriteBack.YES and changed is not Changed.NO:
360                 write_cache(cache, [src], line_length, mode)
361         report.done(src, changed)
362     except Exception as exc:
363         report.failed(src, str(exc))
364
365
366 async def schedule_formatting(
367     sources: List[Path],
368     line_length: int,
369     fast: bool,
370     write_back: WriteBack,
371     mode: FileMode,
372     report: "Report",
373     loop: BaseEventLoop,
374     executor: Executor,
375 ) -> None:
376     """Run formatting of `sources` in parallel using the provided `executor`.
377
378     (Use ProcessPoolExecutors for actual parallelism.)
379
380     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
381     :func:`format_file_in_place`.
382     """
383     cache: Cache = {}
384     if write_back != WriteBack.DIFF:
385         cache = read_cache(line_length, mode)
386         sources, cached = filter_cached(cache, sources)
387         for src in cached:
388             report.done(src, Changed.CACHED)
389     cancelled = []
390     formatted = []
391     if sources:
392         lock = None
393         if write_back == WriteBack.DIFF:
394             # For diff output, we need locks to ensure we don't interleave output
395             # from different processes.
396             manager = Manager()
397             lock = manager.Lock()
398         tasks = {
399             loop.run_in_executor(
400                 executor,
401                 format_file_in_place,
402                 src,
403                 line_length,
404                 fast,
405                 write_back,
406                 mode,
407                 lock,
408             ): src
409             for src in sorted(sources)
410         }
411         pending: Iterable[asyncio.Task] = tasks.keys()
412         try:
413             loop.add_signal_handler(signal.SIGINT, cancel, pending)
414             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
415         except NotImplementedError:
416             # There are no good alternatives for these on Windows
417             pass
418         while pending:
419             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
420             for task in done:
421                 src = tasks.pop(task)
422                 if task.cancelled():
423                     cancelled.append(task)
424                 elif task.exception():
425                     report.failed(src, str(task.exception()))
426                 else:
427                     formatted.append(src)
428                     report.done(src, Changed.YES if task.result() else Changed.NO)
429     if cancelled:
430         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
431     if write_back == WriteBack.YES and formatted:
432         write_cache(cache, formatted, line_length, mode)
433
434
435 def format_file_in_place(
436     src: Path,
437     line_length: int,
438     fast: bool,
439     write_back: WriteBack = WriteBack.NO,
440     mode: FileMode = FileMode.AUTO_DETECT,
441     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
442 ) -> bool:
443     """Format file under `src` path. Return True if changed.
444
445     If `write_back` is True, write reformatted code back to stdout.
446     `line_length` and `fast` options are passed to :func:`format_file_contents`.
447     """
448     if src.suffix == ".pyi":
449         mode |= FileMode.PYI
450     with tokenize.open(src) as src_buffer:
451         src_contents = src_buffer.read()
452     try:
453         dst_contents = format_file_contents(
454             src_contents, line_length=line_length, fast=fast, mode=mode
455         )
456     except NothingChanged:
457         return False
458
459     if write_back == write_back.YES:
460         with open(src, "w", encoding=src_buffer.encoding) as f:
461             f.write(dst_contents)
462     elif write_back == write_back.DIFF:
463         src_name = f"{src}  (original)"
464         dst_name = f"{src}  (formatted)"
465         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
466         if lock:
467             lock.acquire()
468         try:
469             sys.stdout.write(diff_contents)
470         finally:
471             if lock:
472                 lock.release()
473     return True
474
475
476 def format_stdin_to_stdout(
477     line_length: int,
478     fast: bool,
479     write_back: WriteBack = WriteBack.NO,
480     mode: FileMode = FileMode.AUTO_DETECT,
481 ) -> bool:
482     """Format file on stdin. Return True if changed.
483
484     If `write_back` is True, write reformatted code back to stdout.
485     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
486     :func:`format_file_contents`.
487     """
488     src = sys.stdin.read()
489     dst = src
490     try:
491         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
492         return True
493
494     except NothingChanged:
495         return False
496
497     finally:
498         if write_back == WriteBack.YES:
499             sys.stdout.write(dst)
500         elif write_back == WriteBack.DIFF:
501             src_name = "<stdin>  (original)"
502             dst_name = "<stdin>  (formatted)"
503             sys.stdout.write(diff(src, dst, src_name, dst_name))
504
505
506 def format_file_contents(
507     src_contents: str,
508     *,
509     line_length: int,
510     fast: bool,
511     mode: FileMode = FileMode.AUTO_DETECT,
512 ) -> FileContent:
513     """Reformat contents a file and return new contents.
514
515     If `fast` is False, additionally confirm that the reformatted code is
516     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
517     `line_length` is passed to :func:`format_str`.
518     """
519     if src_contents.strip() == "":
520         raise NothingChanged
521
522     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
523     if src_contents == dst_contents:
524         raise NothingChanged
525
526     if not fast:
527         assert_equivalent(src_contents, dst_contents)
528         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
529     return dst_contents
530
531
532 def format_str(
533     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
534 ) -> FileContent:
535     """Reformat a string and return new contents.
536
537     `line_length` determines how many characters per line are allowed.
538     """
539     src_node = lib2to3_parse(src_contents)
540     dst_contents = ""
541     future_imports = get_future_imports(src_node)
542     is_pyi = bool(mode & FileMode.PYI)
543     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
544     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
545     lines = LineGenerator(
546         remove_u_prefix=py36 or "unicode_literals" in future_imports,
547         is_pyi=is_pyi,
548         normalize_strings=normalize_strings,
549     )
550     elt = EmptyLineTracker(is_pyi=is_pyi)
551     empty_line = Line()
552     after = 0
553     for current_line in lines.visit(src_node):
554         for _ in range(after):
555             dst_contents += str(empty_line)
556         before, after = elt.maybe_empty_lines(current_line)
557         for _ in range(before):
558             dst_contents += str(empty_line)
559         for line in split_line(current_line, line_length=line_length, py36=py36):
560             dst_contents += str(line)
561     return dst_contents
562
563
564 GRAMMARS = [
565     pygram.python_grammar_no_print_statement_no_exec_statement,
566     pygram.python_grammar_no_print_statement,
567     pygram.python_grammar,
568 ]
569
570
571 def lib2to3_parse(src_txt: str) -> Node:
572     """Given a string with source, return the lib2to3 Node."""
573     grammar = pygram.python_grammar_no_print_statement
574     if src_txt[-1] != "\n":
575         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
576         src_txt += nl
577     for grammar in GRAMMARS:
578         drv = driver.Driver(grammar, pytree.convert)
579         try:
580             result = drv.parse_string(src_txt, True)
581             break
582
583         except ParseError as pe:
584             lineno, column = pe.context[1]
585             lines = src_txt.splitlines()
586             try:
587                 faulty_line = lines[lineno - 1]
588             except IndexError:
589                 faulty_line = "<line number missing in source>"
590             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
591     else:
592         raise exc from None
593
594     if isinstance(result, Leaf):
595         result = Node(syms.file_input, [result])
596     return result
597
598
599 def lib2to3_unparse(node: Node) -> str:
600     """Given a lib2to3 node, return its string representation."""
601     code = str(node)
602     return code
603
604
605 T = TypeVar("T")
606
607
608 class Visitor(Generic[T]):
609     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
610
611     def visit(self, node: LN) -> Iterator[T]:
612         """Main method to visit `node` and its children.
613
614         It tries to find a `visit_*()` method for the given `node.type`, like
615         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
616         If no dedicated `visit_*()` method is found, chooses `visit_default()`
617         instead.
618
619         Then yields objects of type `T` from the selected visitor.
620         """
621         if node.type < 256:
622             name = token.tok_name[node.type]
623         else:
624             name = type_repr(node.type)
625         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
626
627     def visit_default(self, node: LN) -> Iterator[T]:
628         """Default `visit_*()` implementation. Recurses to children of `node`."""
629         if isinstance(node, Node):
630             for child in node.children:
631                 yield from self.visit(child)
632
633
634 @dataclass
635 class DebugVisitor(Visitor[T]):
636     tree_depth: int = 0
637
638     def visit_default(self, node: LN) -> Iterator[T]:
639         indent = " " * (2 * self.tree_depth)
640         if isinstance(node, Node):
641             _type = type_repr(node.type)
642             out(f"{indent}{_type}", fg="yellow")
643             self.tree_depth += 1
644             for child in node.children:
645                 yield from self.visit(child)
646
647             self.tree_depth -= 1
648             out(f"{indent}/{_type}", fg="yellow", bold=False)
649         else:
650             _type = token.tok_name.get(node.type, str(node.type))
651             out(f"{indent}{_type}", fg="blue", nl=False)
652             if node.prefix:
653                 # We don't have to handle prefixes for `Node` objects since
654                 # that delegates to the first child anyway.
655                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
656             out(f" {node.value!r}", fg="blue", bold=False)
657
658     @classmethod
659     def show(cls, code: str) -> None:
660         """Pretty-print the lib2to3 AST of a given string of `code`.
661
662         Convenience method for debugging.
663         """
664         v: DebugVisitor[None] = DebugVisitor()
665         list(v.visit(lib2to3_parse(code)))
666
667
668 KEYWORDS = set(keyword.kwlist)
669 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
670 FLOW_CONTROL = {"return", "raise", "break", "continue"}
671 STATEMENT = {
672     syms.if_stmt,
673     syms.while_stmt,
674     syms.for_stmt,
675     syms.try_stmt,
676     syms.except_clause,
677     syms.with_stmt,
678     syms.funcdef,
679     syms.classdef,
680 }
681 STANDALONE_COMMENT = 153
682 LOGIC_OPERATORS = {"and", "or"}
683 COMPARATORS = {
684     token.LESS,
685     token.GREATER,
686     token.EQEQUAL,
687     token.NOTEQUAL,
688     token.LESSEQUAL,
689     token.GREATEREQUAL,
690 }
691 MATH_OPERATORS = {
692     token.VBAR,
693     token.CIRCUMFLEX,
694     token.AMPER,
695     token.LEFTSHIFT,
696     token.RIGHTSHIFT,
697     token.PLUS,
698     token.MINUS,
699     token.STAR,
700     token.SLASH,
701     token.DOUBLESLASH,
702     token.PERCENT,
703     token.AT,
704     token.TILDE,
705     token.DOUBLESTAR,
706 }
707 STARS = {token.STAR, token.DOUBLESTAR}
708 VARARGS_PARENTS = {
709     syms.arglist,
710     syms.argument,  # double star in arglist
711     syms.trailer,  # single argument to call
712     syms.typedargslist,
713     syms.varargslist,  # lambdas
714 }
715 UNPACKING_PARENTS = {
716     syms.atom,  # single element of a list or set literal
717     syms.dictsetmaker,
718     syms.listmaker,
719     syms.testlist_gexp,
720 }
721 TEST_DESCENDANTS = {
722     syms.test,
723     syms.lambdef,
724     syms.or_test,
725     syms.and_test,
726     syms.not_test,
727     syms.comparison,
728     syms.star_expr,
729     syms.expr,
730     syms.xor_expr,
731     syms.and_expr,
732     syms.shift_expr,
733     syms.arith_expr,
734     syms.trailer,
735     syms.term,
736     syms.power,
737 }
738 ASSIGNMENTS = {
739     "=",
740     "+=",
741     "-=",
742     "*=",
743     "@=",
744     "/=",
745     "%=",
746     "&=",
747     "|=",
748     "^=",
749     "<<=",
750     ">>=",
751     "**=",
752     "//=",
753 }
754 COMPREHENSION_PRIORITY = 20
755 COMMA_PRIORITY = 18
756 TERNARY_PRIORITY = 16
757 LOGIC_PRIORITY = 14
758 STRING_PRIORITY = 12
759 COMPARATOR_PRIORITY = 10
760 MATH_PRIORITIES = {
761     token.VBAR: 9,
762     token.CIRCUMFLEX: 8,
763     token.AMPER: 7,
764     token.LEFTSHIFT: 6,
765     token.RIGHTSHIFT: 6,
766     token.PLUS: 5,
767     token.MINUS: 5,
768     token.STAR: 4,
769     token.SLASH: 4,
770     token.DOUBLESLASH: 4,
771     token.PERCENT: 4,
772     token.AT: 4,
773     token.TILDE: 3,
774     token.DOUBLESTAR: 2,
775 }
776 DOT_PRIORITY = 1
777
778
779 @dataclass
780 class BracketTracker:
781     """Keeps track of brackets on a line."""
782
783     depth: int = 0
784     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
785     delimiters: Dict[LeafID, Priority] = Factory(dict)
786     previous: Optional[Leaf] = None
787     _for_loop_variable: int = 0
788     _lambda_arguments: int = 0
789
790     def mark(self, leaf: Leaf) -> None:
791         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
792
793         All leaves receive an int `bracket_depth` field that stores how deep
794         within brackets a given leaf is. 0 means there are no enclosing brackets
795         that started on this line.
796
797         If a leaf is itself a closing bracket, it receives an `opening_bracket`
798         field that it forms a pair with. This is a one-directional link to
799         avoid reference cycles.
800
801         If a leaf is a delimiter (a token on which Black can split the line if
802         needed) and it's on depth 0, its `id()` is stored in the tracker's
803         `delimiters` field.
804         """
805         if leaf.type == token.COMMENT:
806             return
807
808         self.maybe_decrement_after_for_loop_variable(leaf)
809         self.maybe_decrement_after_lambda_arguments(leaf)
810         if leaf.type in CLOSING_BRACKETS:
811             self.depth -= 1
812             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
813             leaf.opening_bracket = opening_bracket
814         leaf.bracket_depth = self.depth
815         if self.depth == 0:
816             delim = is_split_before_delimiter(leaf, self.previous)
817             if delim and self.previous is not None:
818                 self.delimiters[id(self.previous)] = delim
819             else:
820                 delim = is_split_after_delimiter(leaf, self.previous)
821                 if delim:
822                     self.delimiters[id(leaf)] = delim
823         if leaf.type in OPENING_BRACKETS:
824             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
825             self.depth += 1
826         self.previous = leaf
827         self.maybe_increment_lambda_arguments(leaf)
828         self.maybe_increment_for_loop_variable(leaf)
829
830     def any_open_brackets(self) -> bool:
831         """Return True if there is an yet unmatched open bracket on the line."""
832         return bool(self.bracket_match)
833
834     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
835         """Return the highest priority of a delimiter found on the line.
836
837         Values are consistent with what `is_split_*_delimiter()` return.
838         Raises ValueError on no delimiters.
839         """
840         return max(v for k, v in self.delimiters.items() if k not in exclude)
841
842     def delimiter_count_with_priority(self, priority: int = 0) -> int:
843         """Return the number of delimiters with the given `priority`.
844
845         If no `priority` is passed, defaults to max priority on the line.
846         """
847         if not self.delimiters:
848             return 0
849
850         priority = priority or self.max_delimiter_priority()
851         return sum(1 for p in self.delimiters.values() if p == priority)
852
853     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
854         """In a for loop, or comprehension, the variables are often unpacks.
855
856         To avoid splitting on the comma in this situation, increase the depth of
857         tokens between `for` and `in`.
858         """
859         if leaf.type == token.NAME and leaf.value == "for":
860             self.depth += 1
861             self._for_loop_variable += 1
862             return True
863
864         return False
865
866     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
867         """See `maybe_increment_for_loop_variable` above for explanation."""
868         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
869             self.depth -= 1
870             self._for_loop_variable -= 1
871             return True
872
873         return False
874
875     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
876         """In a lambda expression, there might be more than one argument.
877
878         To avoid splitting on the comma in this situation, increase the depth of
879         tokens between `lambda` and `:`.
880         """
881         if leaf.type == token.NAME and leaf.value == "lambda":
882             self.depth += 1
883             self._lambda_arguments += 1
884             return True
885
886         return False
887
888     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
889         """See `maybe_increment_lambda_arguments` above for explanation."""
890         if self._lambda_arguments and leaf.type == token.COLON:
891             self.depth -= 1
892             self._lambda_arguments -= 1
893             return True
894
895         return False
896
897     def get_open_lsqb(self) -> Optional[Leaf]:
898         """Return the most recent opening square bracket (if any)."""
899         return self.bracket_match.get((self.depth - 1, token.RSQB))
900
901
902 @dataclass
903 class Line:
904     """Holds leaves and comments. Can be printed with `str(line)`."""
905
906     depth: int = 0
907     leaves: List[Leaf] = Factory(list)
908     comments: List[Tuple[Index, Leaf]] = Factory(list)
909     bracket_tracker: BracketTracker = Factory(BracketTracker)
910     inside_brackets: bool = False
911     should_explode: bool = False
912
913     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
914         """Add a new `leaf` to the end of the line.
915
916         Unless `preformatted` is True, the `leaf` will receive a new consistent
917         whitespace prefix and metadata applied by :class:`BracketTracker`.
918         Trailing commas are maybe removed, unpacked for loop variables are
919         demoted from being delimiters.
920
921         Inline comments are put aside.
922         """
923         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
924         if not has_value:
925             return
926
927         if token.COLON == leaf.type and self.is_class_paren_empty:
928             del self.leaves[-2:]
929         if self.leaves and not preformatted:
930             # Note: at this point leaf.prefix should be empty except for
931             # imports, for which we only preserve newlines.
932             leaf.prefix += whitespace(
933                 leaf, complex_subscript=self.is_complex_subscript(leaf)
934             )
935         if self.inside_brackets or not preformatted:
936             self.bracket_tracker.mark(leaf)
937             self.maybe_remove_trailing_comma(leaf)
938         if not self.append_comment(leaf):
939             self.leaves.append(leaf)
940
941     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
942         """Like :func:`append()` but disallow invalid standalone comment structure.
943
944         Raises ValueError when any `leaf` is appended after a standalone comment
945         or when a standalone comment is not the first leaf on the line.
946         """
947         if self.bracket_tracker.depth == 0:
948             if self.is_comment:
949                 raise ValueError("cannot append to standalone comments")
950
951             if self.leaves and leaf.type == STANDALONE_COMMENT:
952                 raise ValueError(
953                     "cannot append standalone comments to a populated line"
954                 )
955
956         self.append(leaf, preformatted=preformatted)
957
958     @property
959     def is_comment(self) -> bool:
960         """Is this line a standalone comment?"""
961         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
962
963     @property
964     def is_decorator(self) -> bool:
965         """Is this line a decorator?"""
966         return bool(self) and self.leaves[0].type == token.AT
967
968     @property
969     def is_import(self) -> bool:
970         """Is this an import line?"""
971         return bool(self) and is_import(self.leaves[0])
972
973     @property
974     def is_class(self) -> bool:
975         """Is this line a class definition?"""
976         return (
977             bool(self)
978             and self.leaves[0].type == token.NAME
979             and self.leaves[0].value == "class"
980         )
981
982     @property
983     def is_stub_class(self) -> bool:
984         """Is this line a class definition with a body consisting only of "..."?"""
985         return self.is_class and self.leaves[-3:] == [
986             Leaf(token.DOT, ".") for _ in range(3)
987         ]
988
989     @property
990     def is_def(self) -> bool:
991         """Is this a function definition? (Also returns True for async defs.)"""
992         try:
993             first_leaf = self.leaves[0]
994         except IndexError:
995             return False
996
997         try:
998             second_leaf: Optional[Leaf] = self.leaves[1]
999         except IndexError:
1000             second_leaf = None
1001         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1002             first_leaf.type == token.ASYNC
1003             and second_leaf is not None
1004             and second_leaf.type == token.NAME
1005             and second_leaf.value == "def"
1006         )
1007
1008     @property
1009     def is_class_paren_empty(self) -> bool:
1010         """Is this a class with no base classes but using parentheses?
1011
1012         Those are unnecessary and should be removed.
1013         """
1014         return (
1015             bool(self)
1016             and len(self.leaves) == 4
1017             and self.is_class
1018             and self.leaves[2].type == token.LPAR
1019             and self.leaves[2].value == "("
1020             and self.leaves[3].type == token.RPAR
1021             and self.leaves[3].value == ")"
1022         )
1023
1024     @property
1025     def is_triple_quoted_string(self) -> bool:
1026         """Is the line a triple quoted string?"""
1027         return (
1028             bool(self)
1029             and self.leaves[0].type == token.STRING
1030             and self.leaves[0].value.startswith(('"""', "'''"))
1031         )
1032
1033     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1034         """If so, needs to be split before emitting."""
1035         for leaf in self.leaves:
1036             if leaf.type == STANDALONE_COMMENT:
1037                 if leaf.bracket_depth <= depth_limit:
1038                     return True
1039
1040         return False
1041
1042     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1043         """Remove trailing comma if there is one and it's safe."""
1044         if not (
1045             self.leaves
1046             and self.leaves[-1].type == token.COMMA
1047             and closing.type in CLOSING_BRACKETS
1048         ):
1049             return False
1050
1051         if closing.type == token.RBRACE:
1052             self.remove_trailing_comma()
1053             return True
1054
1055         if closing.type == token.RSQB:
1056             comma = self.leaves[-1]
1057             if comma.parent and comma.parent.type == syms.listmaker:
1058                 self.remove_trailing_comma()
1059                 return True
1060
1061         # For parens let's check if it's safe to remove the comma.
1062         # Imports are always safe.
1063         if self.is_import:
1064             self.remove_trailing_comma()
1065             return True
1066
1067         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1068         # change a tuple into a different type by removing the comma.
1069         depth = closing.bracket_depth + 1
1070         commas = 0
1071         opening = closing.opening_bracket
1072         for _opening_index, leaf in enumerate(self.leaves):
1073             if leaf is opening:
1074                 break
1075
1076         else:
1077             return False
1078
1079         for leaf in self.leaves[_opening_index + 1 :]:
1080             if leaf is closing:
1081                 break
1082
1083             bracket_depth = leaf.bracket_depth
1084             if bracket_depth == depth and leaf.type == token.COMMA:
1085                 commas += 1
1086                 if leaf.parent and leaf.parent.type == syms.arglist:
1087                     commas += 1
1088                     break
1089
1090         if commas > 1:
1091             self.remove_trailing_comma()
1092             return True
1093
1094         return False
1095
1096     def append_comment(self, comment: Leaf) -> bool:
1097         """Add an inline or standalone comment to the line."""
1098         if (
1099             comment.type == STANDALONE_COMMENT
1100             and self.bracket_tracker.any_open_brackets()
1101         ):
1102             comment.prefix = ""
1103             return False
1104
1105         if comment.type != token.COMMENT:
1106             return False
1107
1108         after = len(self.leaves) - 1
1109         if after == -1:
1110             comment.type = STANDALONE_COMMENT
1111             comment.prefix = ""
1112             return False
1113
1114         else:
1115             self.comments.append((after, comment))
1116             return True
1117
1118     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1119         """Generate comments that should appear directly after `leaf`.
1120
1121         Provide a non-negative leaf `_index` to speed up the function.
1122         """
1123         if _index == -1:
1124             for _index, _leaf in enumerate(self.leaves):
1125                 if leaf is _leaf:
1126                     break
1127
1128             else:
1129                 return
1130
1131         for index, comment_after in self.comments:
1132             if _index == index:
1133                 yield comment_after
1134
1135     def remove_trailing_comma(self) -> None:
1136         """Remove the trailing comma and moves the comments attached to it."""
1137         comma_index = len(self.leaves) - 1
1138         for i in range(len(self.comments)):
1139             comment_index, comment = self.comments[i]
1140             if comment_index == comma_index:
1141                 self.comments[i] = (comma_index - 1, comment)
1142         self.leaves.pop()
1143
1144     def is_complex_subscript(self, leaf: Leaf) -> bool:
1145         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1146         open_lsqb = (
1147             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1148         )
1149         if open_lsqb is None:
1150             return False
1151
1152         subscript_start = open_lsqb.next_sibling
1153         if (
1154             isinstance(subscript_start, Node)
1155             and subscript_start.type == syms.subscriptlist
1156         ):
1157             subscript_start = child_towards(subscript_start, leaf)
1158         return subscript_start is not None and any(
1159             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1160         )
1161
1162     def __str__(self) -> str:
1163         """Render the line."""
1164         if not self:
1165             return "\n"
1166
1167         indent = "    " * self.depth
1168         leaves = iter(self.leaves)
1169         first = next(leaves)
1170         res = f"{first.prefix}{indent}{first.value}"
1171         for leaf in leaves:
1172             res += str(leaf)
1173         for _, comment in self.comments:
1174             res += str(comment)
1175         return res + "\n"
1176
1177     def __bool__(self) -> bool:
1178         """Return True if the line has leaves or comments."""
1179         return bool(self.leaves or self.comments)
1180
1181
1182 class UnformattedLines(Line):
1183     """Just like :class:`Line` but stores lines which aren't reformatted."""
1184
1185     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1186         """Just add a new `leaf` to the end of the lines.
1187
1188         The `preformatted` argument is ignored.
1189
1190         Keeps track of indentation `depth`, which is useful when the user
1191         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1192         """
1193         try:
1194             list(generate_comments(leaf))
1195         except FormatOn as f_on:
1196             self.leaves.append(f_on.leaf_from_consumed(leaf))
1197             raise
1198
1199         self.leaves.append(leaf)
1200         if leaf.type == token.INDENT:
1201             self.depth += 1
1202         elif leaf.type == token.DEDENT:
1203             self.depth -= 1
1204
1205     def __str__(self) -> str:
1206         """Render unformatted lines from leaves which were added with `append()`.
1207
1208         `depth` is not used for indentation in this case.
1209         """
1210         if not self:
1211             return "\n"
1212
1213         res = ""
1214         for leaf in self.leaves:
1215             res += str(leaf)
1216         return res
1217
1218     def append_comment(self, comment: Leaf) -> bool:
1219         """Not implemented in this class. Raises `NotImplementedError`."""
1220         raise NotImplementedError("Unformatted lines don't store comments separately.")
1221
1222     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1223         """Does nothing and returns False."""
1224         return False
1225
1226     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1227         """Does nothing and returns False."""
1228         return False
1229
1230
1231 @dataclass
1232 class EmptyLineTracker:
1233     """Provides a stateful method that returns the number of potential extra
1234     empty lines needed before and after the currently processed line.
1235
1236     Note: this tracker works on lines that haven't been split yet.  It assumes
1237     the prefix of the first leaf consists of optional newlines.  Those newlines
1238     are consumed by `maybe_empty_lines()` and included in the computation.
1239     """
1240
1241     is_pyi: bool = False
1242     previous_line: Optional[Line] = None
1243     previous_after: int = 0
1244     previous_defs: List[int] = Factory(list)
1245
1246     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1247         """Return the number of extra empty lines before and after the `current_line`.
1248
1249         This is for separating `def`, `async def` and `class` with extra empty
1250         lines (two on module-level).
1251         """
1252         if isinstance(current_line, UnformattedLines):
1253             return 0, 0
1254
1255         before, after = self._maybe_empty_lines(current_line)
1256         before -= self.previous_after
1257         self.previous_after = after
1258         self.previous_line = current_line
1259         return before, after
1260
1261     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1262         max_allowed = 1
1263         if current_line.depth == 0:
1264             max_allowed = 1 if self.is_pyi else 2
1265         if current_line.leaves:
1266             # Consume the first leaf's extra newlines.
1267             first_leaf = current_line.leaves[0]
1268             before = first_leaf.prefix.count("\n")
1269             before = min(before, max_allowed)
1270             first_leaf.prefix = ""
1271         else:
1272             before = 0
1273         depth = current_line.depth
1274         while self.previous_defs and self.previous_defs[-1] >= depth:
1275             self.previous_defs.pop()
1276             if self.is_pyi:
1277                 before = 0 if depth else 1
1278             else:
1279                 before = 1 if depth else 2
1280         is_decorator = current_line.is_decorator
1281         if is_decorator or current_line.is_def or current_line.is_class:
1282             if not is_decorator:
1283                 self.previous_defs.append(depth)
1284             if self.previous_line is None:
1285                 # Don't insert empty lines before the first line in the file.
1286                 return 0, 0
1287
1288             if self.previous_line.is_decorator:
1289                 return 0, 0
1290
1291             if self.previous_line.depth < current_line.depth and (
1292                 self.previous_line.is_class or self.previous_line.is_def
1293             ):
1294                 return 0, 0
1295
1296             if (
1297                 self.previous_line.is_comment
1298                 and self.previous_line.depth == current_line.depth
1299                 and before == 0
1300             ):
1301                 return 0, 0
1302
1303             if self.is_pyi:
1304                 if self.previous_line.depth > current_line.depth:
1305                     newlines = 1
1306                 elif current_line.is_class or self.previous_line.is_class:
1307                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1308                         newlines = 0
1309                     else:
1310                         newlines = 1
1311                 else:
1312                     newlines = 0
1313             else:
1314                 newlines = 2
1315             if current_line.depth and newlines:
1316                 newlines -= 1
1317             return newlines, 0
1318
1319         if (
1320             self.previous_line
1321             and self.previous_line.is_import
1322             and not current_line.is_import
1323             and depth == self.previous_line.depth
1324         ):
1325             return (before or 1), 0
1326
1327         if (
1328             self.previous_line
1329             and self.previous_line.is_class
1330             and current_line.is_triple_quoted_string
1331         ):
1332             return before, 1
1333
1334         return before, 0
1335
1336
1337 @dataclass
1338 class LineGenerator(Visitor[Line]):
1339     """Generates reformatted Line objects.  Empty lines are not emitted.
1340
1341     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1342     in ways that will no longer stringify to valid Python code on the tree.
1343     """
1344
1345     is_pyi: bool = False
1346     normalize_strings: bool = True
1347     current_line: Line = Factory(Line)
1348     remove_u_prefix: bool = False
1349
1350     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1351         """Generate a line.
1352
1353         If the line is empty, only emit if it makes sense.
1354         If the line is too long, split it first and then generate.
1355
1356         If any lines were generated, set up a new current_line.
1357         """
1358         if not self.current_line:
1359             if self.current_line.__class__ == type:
1360                 self.current_line.depth += indent
1361             else:
1362                 self.current_line = type(depth=self.current_line.depth + indent)
1363             return  # Line is empty, don't emit. Creating a new one unnecessary.
1364
1365         complete_line = self.current_line
1366         self.current_line = type(depth=complete_line.depth + indent)
1367         yield complete_line
1368
1369     def visit(self, node: LN) -> Iterator[Line]:
1370         """Main method to visit `node` and its children.
1371
1372         Yields :class:`Line` objects.
1373         """
1374         if isinstance(self.current_line, UnformattedLines):
1375             # File contained `# fmt: off`
1376             yield from self.visit_unformatted(node)
1377
1378         else:
1379             yield from super().visit(node)
1380
1381     def visit_default(self, node: LN) -> Iterator[Line]:
1382         """Default `visit_*()` implementation. Recurses to children of `node`."""
1383         if isinstance(node, Leaf):
1384             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1385             try:
1386                 for comment in generate_comments(node):
1387                     if any_open_brackets:
1388                         # any comment within brackets is subject to splitting
1389                         self.current_line.append(comment)
1390                     elif comment.type == token.COMMENT:
1391                         # regular trailing comment
1392                         self.current_line.append(comment)
1393                         yield from self.line()
1394
1395                     else:
1396                         # regular standalone comment
1397                         yield from self.line()
1398
1399                         self.current_line.append(comment)
1400                         yield from self.line()
1401
1402             except FormatOff as f_off:
1403                 f_off.trim_prefix(node)
1404                 yield from self.line(type=UnformattedLines)
1405                 yield from self.visit(node)
1406
1407             except FormatOn as f_on:
1408                 # This only happens here if somebody says "fmt: on" multiple
1409                 # times in a row.
1410                 f_on.trim_prefix(node)
1411                 yield from self.visit_default(node)
1412
1413             else:
1414                 normalize_prefix(node, inside_brackets=any_open_brackets)
1415                 if self.normalize_strings and node.type == token.STRING:
1416                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1417                     normalize_string_quotes(node)
1418                 if node.type not in WHITESPACE:
1419                     self.current_line.append(node)
1420         yield from super().visit_default(node)
1421
1422     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1423         """Increase indentation level, maybe yield a line."""
1424         # In blib2to3 INDENT never holds comments.
1425         yield from self.line(+1)
1426         yield from self.visit_default(node)
1427
1428     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1429         """Decrease indentation level, maybe yield a line."""
1430         # The current line might still wait for trailing comments.  At DEDENT time
1431         # there won't be any (they would be prefixes on the preceding NEWLINE).
1432         # Emit the line then.
1433         yield from self.line()
1434
1435         # While DEDENT has no value, its prefix may contain standalone comments
1436         # that belong to the current indentation level.  Get 'em.
1437         yield from self.visit_default(node)
1438
1439         # Finally, emit the dedent.
1440         yield from self.line(-1)
1441
1442     def visit_stmt(
1443         self, node: Node, keywords: Set[str], parens: Set[str]
1444     ) -> Iterator[Line]:
1445         """Visit a statement.
1446
1447         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1448         `def`, `with`, `class`, `assert` and assignments.
1449
1450         The relevant Python language `keywords` for a given statement will be
1451         NAME leaves within it. This methods puts those on a separate line.
1452
1453         `parens` holds a set of string leaf values immediately after which
1454         invisible parens should be put.
1455         """
1456         normalize_invisible_parens(node, parens_after=parens)
1457         for child in node.children:
1458             if child.type == token.NAME and child.value in keywords:  # type: ignore
1459                 yield from self.line()
1460
1461             yield from self.visit(child)
1462
1463     def visit_suite(self, node: Node) -> Iterator[Line]:
1464         """Visit a suite."""
1465         if self.is_pyi and is_stub_suite(node):
1466             yield from self.visit(node.children[2])
1467         else:
1468             yield from self.visit_default(node)
1469
1470     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1471         """Visit a statement without nested statements."""
1472         is_suite_like = node.parent and node.parent.type in STATEMENT
1473         if is_suite_like:
1474             if self.is_pyi and is_stub_body(node):
1475                 yield from self.visit_default(node)
1476             else:
1477                 yield from self.line(+1)
1478                 yield from self.visit_default(node)
1479                 yield from self.line(-1)
1480
1481         else:
1482             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1483                 yield from self.line()
1484             yield from self.visit_default(node)
1485
1486     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1487         """Visit `async def`, `async for`, `async with`."""
1488         yield from self.line()
1489
1490         children = iter(node.children)
1491         for child in children:
1492             yield from self.visit(child)
1493
1494             if child.type == token.ASYNC:
1495                 break
1496
1497         internal_stmt = next(children)
1498         for child in internal_stmt.children:
1499             yield from self.visit(child)
1500
1501     def visit_decorators(self, node: Node) -> Iterator[Line]:
1502         """Visit decorators."""
1503         for child in node.children:
1504             yield from self.line()
1505             yield from self.visit(child)
1506
1507     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1508         """Remove a semicolon and put the other statement on a separate line."""
1509         yield from self.line()
1510
1511     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1512         """End of file. Process outstanding comments and end with a newline."""
1513         yield from self.visit_default(leaf)
1514         yield from self.line()
1515
1516     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1517         """Used when file contained a `# fmt: off`."""
1518         if isinstance(node, Node):
1519             for child in node.children:
1520                 yield from self.visit(child)
1521
1522         else:
1523             try:
1524                 self.current_line.append(node)
1525             except FormatOn as f_on:
1526                 f_on.trim_prefix(node)
1527                 yield from self.line()
1528                 yield from self.visit(node)
1529
1530             if node.type == token.ENDMARKER:
1531                 # somebody decided not to put a final `# fmt: on`
1532                 yield from self.line()
1533
1534     def __attrs_post_init__(self) -> None:
1535         """You are in a twisty little maze of passages."""
1536         v = self.visit_stmt
1537         Ø: Set[str] = set()
1538         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1539         self.visit_if_stmt = partial(
1540             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1541         )
1542         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1543         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1544         self.visit_try_stmt = partial(
1545             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1546         )
1547         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1548         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1549         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1550         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1551         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1552         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1553         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1554         self.visit_async_funcdef = self.visit_async_stmt
1555         self.visit_decorated = self.visit_decorators
1556
1557
1558 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1559 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1560 OPENING_BRACKETS = set(BRACKET.keys())
1561 CLOSING_BRACKETS = set(BRACKET.values())
1562 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1563 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1564
1565
1566 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1567     """Return whitespace prefix if needed for the given `leaf`.
1568
1569     `complex_subscript` signals whether the given leaf is part of a subscription
1570     which has non-trivial arguments, like arithmetic expressions or function calls.
1571     """
1572     NO = ""
1573     SPACE = " "
1574     DOUBLESPACE = "  "
1575     t = leaf.type
1576     p = leaf.parent
1577     v = leaf.value
1578     if t in ALWAYS_NO_SPACE:
1579         return NO
1580
1581     if t == token.COMMENT:
1582         return DOUBLESPACE
1583
1584     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1585     if t == token.COLON and p.type not in {
1586         syms.subscript,
1587         syms.subscriptlist,
1588         syms.sliceop,
1589     }:
1590         return NO
1591
1592     prev = leaf.prev_sibling
1593     if not prev:
1594         prevp = preceding_leaf(p)
1595         if not prevp or prevp.type in OPENING_BRACKETS:
1596             return NO
1597
1598         if t == token.COLON:
1599             if prevp.type == token.COLON:
1600                 return NO
1601
1602             elif prevp.type != token.COMMA and not complex_subscript:
1603                 return NO
1604
1605             return SPACE
1606
1607         if prevp.type == token.EQUAL:
1608             if prevp.parent:
1609                 if prevp.parent.type in {
1610                     syms.arglist,
1611                     syms.argument,
1612                     syms.parameters,
1613                     syms.varargslist,
1614                 }:
1615                     return NO
1616
1617                 elif prevp.parent.type == syms.typedargslist:
1618                     # A bit hacky: if the equal sign has whitespace, it means we
1619                     # previously found it's a typed argument.  So, we're using
1620                     # that, too.
1621                     return prevp.prefix
1622
1623         elif prevp.type in STARS:
1624             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1625                 return NO
1626
1627         elif prevp.type == token.COLON:
1628             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1629                 return SPACE if complex_subscript else NO
1630
1631         elif (
1632             prevp.parent
1633             and prevp.parent.type == syms.factor
1634             and prevp.type in MATH_OPERATORS
1635         ):
1636             return NO
1637
1638         elif (
1639             prevp.type == token.RIGHTSHIFT
1640             and prevp.parent
1641             and prevp.parent.type == syms.shift_expr
1642             and prevp.prev_sibling
1643             and prevp.prev_sibling.type == token.NAME
1644             and prevp.prev_sibling.value == "print"  # type: ignore
1645         ):
1646             # Python 2 print chevron
1647             return NO
1648
1649     elif prev.type in OPENING_BRACKETS:
1650         return NO
1651
1652     if p.type in {syms.parameters, syms.arglist}:
1653         # untyped function signatures or calls
1654         if not prev or prev.type != token.COMMA:
1655             return NO
1656
1657     elif p.type == syms.varargslist:
1658         # lambdas
1659         if prev and prev.type != token.COMMA:
1660             return NO
1661
1662     elif p.type == syms.typedargslist:
1663         # typed function signatures
1664         if not prev:
1665             return NO
1666
1667         if t == token.EQUAL:
1668             if prev.type != syms.tname:
1669                 return NO
1670
1671         elif prev.type == token.EQUAL:
1672             # A bit hacky: if the equal sign has whitespace, it means we
1673             # previously found it's a typed argument.  So, we're using that, too.
1674             return prev.prefix
1675
1676         elif prev.type != token.COMMA:
1677             return NO
1678
1679     elif p.type == syms.tname:
1680         # type names
1681         if not prev:
1682             prevp = preceding_leaf(p)
1683             if not prevp or prevp.type != token.COMMA:
1684                 return NO
1685
1686     elif p.type == syms.trailer:
1687         # attributes and calls
1688         if t == token.LPAR or t == token.RPAR:
1689             return NO
1690
1691         if not prev:
1692             if t == token.DOT:
1693                 prevp = preceding_leaf(p)
1694                 if not prevp or prevp.type != token.NUMBER:
1695                     return NO
1696
1697             elif t == token.LSQB:
1698                 return NO
1699
1700         elif prev.type != token.COMMA:
1701             return NO
1702
1703     elif p.type == syms.argument:
1704         # single argument
1705         if t == token.EQUAL:
1706             return NO
1707
1708         if not prev:
1709             prevp = preceding_leaf(p)
1710             if not prevp or prevp.type == token.LPAR:
1711                 return NO
1712
1713         elif prev.type in {token.EQUAL} | STARS:
1714             return NO
1715
1716     elif p.type == syms.decorator:
1717         # decorators
1718         return NO
1719
1720     elif p.type == syms.dotted_name:
1721         if prev:
1722             return NO
1723
1724         prevp = preceding_leaf(p)
1725         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1726             return NO
1727
1728     elif p.type == syms.classdef:
1729         if t == token.LPAR:
1730             return NO
1731
1732         if prev and prev.type == token.LPAR:
1733             return NO
1734
1735     elif p.type in {syms.subscript, syms.sliceop}:
1736         # indexing
1737         if not prev:
1738             assert p.parent is not None, "subscripts are always parented"
1739             if p.parent.type == syms.subscriptlist:
1740                 return SPACE
1741
1742             return NO
1743
1744         elif not complex_subscript:
1745             return NO
1746
1747     elif p.type == syms.atom:
1748         if prev and t == token.DOT:
1749             # dots, but not the first one.
1750             return NO
1751
1752     elif p.type == syms.dictsetmaker:
1753         # dict unpacking
1754         if prev and prev.type == token.DOUBLESTAR:
1755             return NO
1756
1757     elif p.type in {syms.factor, syms.star_expr}:
1758         # unary ops
1759         if not prev:
1760             prevp = preceding_leaf(p)
1761             if not prevp or prevp.type in OPENING_BRACKETS:
1762                 return NO
1763
1764             prevp_parent = prevp.parent
1765             assert prevp_parent is not None
1766             if prevp.type == token.COLON and prevp_parent.type in {
1767                 syms.subscript,
1768                 syms.sliceop,
1769             }:
1770                 return NO
1771
1772             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1773                 return NO
1774
1775         elif t == token.NAME or t == token.NUMBER:
1776             return NO
1777
1778     elif p.type == syms.import_from:
1779         if t == token.DOT:
1780             if prev and prev.type == token.DOT:
1781                 return NO
1782
1783         elif t == token.NAME:
1784             if v == "import":
1785                 return SPACE
1786
1787             if prev and prev.type == token.DOT:
1788                 return NO
1789
1790     elif p.type == syms.sliceop:
1791         return NO
1792
1793     return SPACE
1794
1795
1796 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1797     """Return the first leaf that precedes `node`, if any."""
1798     while node:
1799         res = node.prev_sibling
1800         if res:
1801             if isinstance(res, Leaf):
1802                 return res
1803
1804             try:
1805                 return list(res.leaves())[-1]
1806
1807             except IndexError:
1808                 return None
1809
1810         node = node.parent
1811     return None
1812
1813
1814 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1815     """Return the child of `ancestor` that contains `descendant`."""
1816     node: Optional[LN] = descendant
1817     while node and node.parent != ancestor:
1818         node = node.parent
1819     return node
1820
1821
1822 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1823     """Return the priority of the `leaf` delimiter, given a line break after it.
1824
1825     The delimiter priorities returned here are from those delimiters that would
1826     cause a line break after themselves.
1827
1828     Higher numbers are higher priority.
1829     """
1830     if leaf.type == token.COMMA:
1831         return COMMA_PRIORITY
1832
1833     return 0
1834
1835
1836 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1837     """Return the priority of the `leaf` delimiter, given a line before after it.
1838
1839     The delimiter priorities returned here are from those delimiters that would
1840     cause a line break before themselves.
1841
1842     Higher numbers are higher priority.
1843     """
1844     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1845         # * and ** might also be MATH_OPERATORS but in this case they are not.
1846         # Don't treat them as a delimiter.
1847         return 0
1848
1849     if (
1850         leaf.type == token.DOT
1851         and leaf.parent
1852         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1853         and (previous is None or previous.type in CLOSING_BRACKETS)
1854     ):
1855         return DOT_PRIORITY
1856
1857     if (
1858         leaf.type in MATH_OPERATORS
1859         and leaf.parent
1860         and leaf.parent.type not in {syms.factor, syms.star_expr}
1861     ):
1862         return MATH_PRIORITIES[leaf.type]
1863
1864     if leaf.type in COMPARATORS:
1865         return COMPARATOR_PRIORITY
1866
1867     if (
1868         leaf.type == token.STRING
1869         and previous is not None
1870         and previous.type == token.STRING
1871     ):
1872         return STRING_PRIORITY
1873
1874     if leaf.type != token.NAME:
1875         return 0
1876
1877     if (
1878         leaf.value == "for"
1879         and leaf.parent
1880         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1881     ):
1882         return COMPREHENSION_PRIORITY
1883
1884     if (
1885         leaf.value == "if"
1886         and leaf.parent
1887         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1888     ):
1889         return COMPREHENSION_PRIORITY
1890
1891     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1892         return TERNARY_PRIORITY
1893
1894     if leaf.value == "is":
1895         return COMPARATOR_PRIORITY
1896
1897     if (
1898         leaf.value == "in"
1899         and leaf.parent
1900         and leaf.parent.type in {syms.comp_op, syms.comparison}
1901         and not (
1902             previous is not None
1903             and previous.type == token.NAME
1904             and previous.value == "not"
1905         )
1906     ):
1907         return COMPARATOR_PRIORITY
1908
1909     if (
1910         leaf.value == "not"
1911         and leaf.parent
1912         and leaf.parent.type == syms.comp_op
1913         and not (
1914             previous is not None
1915             and previous.type == token.NAME
1916             and previous.value == "is"
1917         )
1918     ):
1919         return COMPARATOR_PRIORITY
1920
1921     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1922         return LOGIC_PRIORITY
1923
1924     return 0
1925
1926
1927 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1928     """Clean the prefix of the `leaf` and generate comments from it, if any.
1929
1930     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1931     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1932     move because it does away with modifying the grammar to include all the
1933     possible places in which comments can be placed.
1934
1935     The sad consequence for us though is that comments don't "belong" anywhere.
1936     This is why this function generates simple parentless Leaf objects for
1937     comments.  We simply don't know what the correct parent should be.
1938
1939     No matter though, we can live without this.  We really only need to
1940     differentiate between inline and standalone comments.  The latter don't
1941     share the line with any code.
1942
1943     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1944     are emitted with a fake STANDALONE_COMMENT token identifier.
1945     """
1946     p = leaf.prefix
1947     if not p:
1948         return
1949
1950     if "#" not in p:
1951         return
1952
1953     consumed = 0
1954     nlines = 0
1955     for index, line in enumerate(p.split("\n")):
1956         consumed += len(line) + 1  # adding the length of the split '\n'
1957         line = line.lstrip()
1958         if not line:
1959             nlines += 1
1960         if not line.startswith("#"):
1961             continue
1962
1963         if index == 0 and leaf.type != token.ENDMARKER:
1964             comment_type = token.COMMENT  # simple trailing comment
1965         else:
1966             comment_type = STANDALONE_COMMENT
1967         comment = make_comment(line)
1968         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1969
1970         if comment in {"# fmt: on", "# yapf: enable"}:
1971             raise FormatOn(consumed)
1972
1973         if comment in {"# fmt: off", "# yapf: disable"}:
1974             if comment_type == STANDALONE_COMMENT:
1975                 raise FormatOff(consumed)
1976
1977             prev = preceding_leaf(leaf)
1978             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1979                 raise FormatOff(consumed)
1980
1981         nlines = 0
1982
1983
1984 def make_comment(content: str) -> str:
1985     """Return a consistently formatted comment from the given `content` string.
1986
1987     All comments (except for "##", "#!", "#:") should have a single space between
1988     the hash sign and the content.
1989
1990     If `content` didn't start with a hash sign, one is provided.
1991     """
1992     content = content.rstrip()
1993     if not content:
1994         return "#"
1995
1996     if content[0] == "#":
1997         content = content[1:]
1998     if content and content[0] not in " !:#":
1999         content = " " + content
2000     return "#" + content
2001
2002
2003 def split_line(
2004     line: Line, line_length: int, inner: bool = False, py36: bool = False
2005 ) -> Iterator[Line]:
2006     """Split a `line` into potentially many lines.
2007
2008     They should fit in the allotted `line_length` but might not be able to.
2009     `inner` signifies that there were a pair of brackets somewhere around the
2010     current `line`, possibly transitively. This means we can fallback to splitting
2011     by delimiters if the LHS/RHS don't yield any results.
2012
2013     If `py36` is True, splitting may generate syntax that is only compatible
2014     with Python 3.6 and later.
2015     """
2016     if isinstance(line, UnformattedLines) or line.is_comment:
2017         yield line
2018         return
2019
2020     line_str = str(line).strip("\n")
2021     if not line.should_explode and is_line_short_enough(
2022         line, line_length=line_length, line_str=line_str
2023     ):
2024         yield line
2025         return
2026
2027     split_funcs: List[SplitFunc]
2028     if line.is_def:
2029         split_funcs = [left_hand_split]
2030     else:
2031
2032         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2033             for omit in generate_trailers_to_omit(line, line_length):
2034                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2035                 if is_line_short_enough(lines[0], line_length=line_length):
2036                     yield from lines
2037                     return
2038
2039             # All splits failed, best effort split with no omits.
2040             # This mostly happens to multiline strings that are by definition
2041             # reported as not fitting a single line.
2042             yield from right_hand_split(line, py36)
2043
2044         if line.inside_brackets:
2045             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2046         else:
2047             split_funcs = [rhs]
2048     for split_func in split_funcs:
2049         # We are accumulating lines in `result` because we might want to abort
2050         # mission and return the original line in the end, or attempt a different
2051         # split altogether.
2052         result: List[Line] = []
2053         try:
2054             for l in split_func(line, py36):
2055                 if str(l).strip("\n") == line_str:
2056                     raise CannotSplit("Split function returned an unchanged result")
2057
2058                 result.extend(
2059                     split_line(l, line_length=line_length, inner=True, py36=py36)
2060                 )
2061         except CannotSplit as cs:
2062             continue
2063
2064         else:
2065             yield from result
2066             break
2067
2068     else:
2069         yield line
2070
2071
2072 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2073     """Split line into many lines, starting with the first matching bracket pair.
2074
2075     Note: this usually looks weird, only use this for function definitions.
2076     Prefer RHS otherwise.  This is why this function is not symmetrical with
2077     :func:`right_hand_split` which also handles optional parentheses.
2078     """
2079     head = Line(depth=line.depth)
2080     body = Line(depth=line.depth + 1, inside_brackets=True)
2081     tail = Line(depth=line.depth)
2082     tail_leaves: List[Leaf] = []
2083     body_leaves: List[Leaf] = []
2084     head_leaves: List[Leaf] = []
2085     current_leaves = head_leaves
2086     matching_bracket = None
2087     for leaf in line.leaves:
2088         if (
2089             current_leaves is body_leaves
2090             and leaf.type in CLOSING_BRACKETS
2091             and leaf.opening_bracket is matching_bracket
2092         ):
2093             current_leaves = tail_leaves if body_leaves else head_leaves
2094         current_leaves.append(leaf)
2095         if current_leaves is head_leaves:
2096             if leaf.type in OPENING_BRACKETS:
2097                 matching_bracket = leaf
2098                 current_leaves = body_leaves
2099     # Since body is a new indent level, remove spurious leading whitespace.
2100     if body_leaves:
2101         normalize_prefix(body_leaves[0], inside_brackets=True)
2102     # Build the new lines.
2103     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2104         for leaf in leaves:
2105             result.append(leaf, preformatted=True)
2106             for comment_after in line.comments_after(leaf):
2107                 result.append(comment_after, preformatted=True)
2108     bracket_split_succeeded_or_raise(head, body, tail)
2109     for result in (head, body, tail):
2110         if result:
2111             yield result
2112
2113
2114 def right_hand_split(
2115     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2116 ) -> Iterator[Line]:
2117     """Split line into many lines, starting with the last matching bracket pair.
2118
2119     If the split was by optional parentheses, attempt splitting without them, too.
2120     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2121     this split.
2122
2123     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2124     """
2125     head = Line(depth=line.depth)
2126     body = Line(depth=line.depth + 1, inside_brackets=True)
2127     tail = Line(depth=line.depth)
2128     tail_leaves: List[Leaf] = []
2129     body_leaves: List[Leaf] = []
2130     head_leaves: List[Leaf] = []
2131     current_leaves = tail_leaves
2132     opening_bracket = None
2133     closing_bracket = None
2134     for leaf in reversed(line.leaves):
2135         if current_leaves is body_leaves:
2136             if leaf is opening_bracket:
2137                 current_leaves = head_leaves if body_leaves else tail_leaves
2138         current_leaves.append(leaf)
2139         if current_leaves is tail_leaves:
2140             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2141                 opening_bracket = leaf.opening_bracket
2142                 closing_bracket = leaf
2143                 current_leaves = body_leaves
2144     tail_leaves.reverse()
2145     body_leaves.reverse()
2146     head_leaves.reverse()
2147     # Since body is a new indent level, remove spurious leading whitespace.
2148     if body_leaves:
2149         normalize_prefix(body_leaves[0], inside_brackets=True)
2150     if not head_leaves:
2151         # No `head` means the split failed. Either `tail` has all content or
2152         # the matching `opening_bracket` wasn't available on `line` anymore.
2153         raise CannotSplit("No brackets found")
2154
2155     # Build the new lines.
2156     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2157         for leaf in leaves:
2158             result.append(leaf, preformatted=True)
2159             for comment_after in line.comments_after(leaf):
2160                 result.append(comment_after, preformatted=True)
2161     bracket_split_succeeded_or_raise(head, body, tail)
2162     assert opening_bracket and closing_bracket
2163     if (
2164         # the opening bracket is an optional paren
2165         opening_bracket.type == token.LPAR
2166         and not opening_bracket.value
2167         # the closing bracket is an optional paren
2168         and closing_bracket.type == token.RPAR
2169         and not closing_bracket.value
2170         # there are no standalone comments in the body
2171         and not line.contains_standalone_comments(0)
2172         # and it's not an import (optional parens are the only thing we can split
2173         # on in this case; attempting a split without them is a waste of time)
2174         and not line.is_import
2175     ):
2176         omit = {id(closing_bracket), *omit}
2177         if can_omit_invisible_parens(body, line_length):
2178             try:
2179                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2180                 return
2181             except CannotSplit:
2182                 pass
2183
2184     ensure_visible(opening_bracket)
2185     ensure_visible(closing_bracket)
2186     body.should_explode = should_explode(body, opening_bracket)
2187     for result in (head, body, tail):
2188         if result:
2189             yield result
2190
2191
2192 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2193     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2194
2195     Do nothing otherwise.
2196
2197     A left- or right-hand split is based on a pair of brackets. Content before
2198     (and including) the opening bracket is left on one line, content inside the
2199     brackets is put on a separate line, and finally content starting with and
2200     following the closing bracket is put on a separate line.
2201
2202     Those are called `head`, `body`, and `tail`, respectively. If the split
2203     produced the same line (all content in `head`) or ended up with an empty `body`
2204     and the `tail` is just the closing bracket, then it's considered failed.
2205     """
2206     tail_len = len(str(tail).strip())
2207     if not body:
2208         if tail_len == 0:
2209             raise CannotSplit("Splitting brackets produced the same line")
2210
2211         elif tail_len < 3:
2212             raise CannotSplit(
2213                 f"Splitting brackets on an empty body to save "
2214                 f"{tail_len} characters is not worth it"
2215             )
2216
2217
2218 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2219     """Normalize prefix of the first leaf in every line returned by `split_func`.
2220
2221     This is a decorator over relevant split functions.
2222     """
2223
2224     @wraps(split_func)
2225     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2226         for l in split_func(line, py36):
2227             normalize_prefix(l.leaves[0], inside_brackets=True)
2228             yield l
2229
2230     return split_wrapper
2231
2232
2233 @dont_increase_indentation
2234 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2235     """Split according to delimiters of the highest priority.
2236
2237     If `py36` is True, the split will add trailing commas also in function
2238     signatures that contain `*` and `**`.
2239     """
2240     try:
2241         last_leaf = line.leaves[-1]
2242     except IndexError:
2243         raise CannotSplit("Line empty")
2244
2245     bt = line.bracket_tracker
2246     try:
2247         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2248     except ValueError:
2249         raise CannotSplit("No delimiters found")
2250
2251     if delimiter_priority == DOT_PRIORITY:
2252         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2253             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2254
2255     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2256     lowest_depth = sys.maxsize
2257     trailing_comma_safe = True
2258
2259     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2260         """Append `leaf` to current line or to new line if appending impossible."""
2261         nonlocal current_line
2262         try:
2263             current_line.append_safe(leaf, preformatted=True)
2264         except ValueError as ve:
2265             yield current_line
2266
2267             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2268             current_line.append(leaf)
2269
2270     for index, leaf in enumerate(line.leaves):
2271         yield from append_to_line(leaf)
2272
2273         for comment_after in line.comments_after(leaf, index):
2274             yield from append_to_line(comment_after)
2275
2276         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2277         if leaf.bracket_depth == lowest_depth and is_vararg(
2278             leaf, within=VARARGS_PARENTS
2279         ):
2280             trailing_comma_safe = trailing_comma_safe and py36
2281         leaf_priority = bt.delimiters.get(id(leaf))
2282         if leaf_priority == delimiter_priority:
2283             yield current_line
2284
2285             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2286     if current_line:
2287         if (
2288             trailing_comma_safe
2289             and delimiter_priority == COMMA_PRIORITY
2290             and current_line.leaves[-1].type != token.COMMA
2291             and current_line.leaves[-1].type != STANDALONE_COMMENT
2292         ):
2293             current_line.append(Leaf(token.COMMA, ","))
2294         yield current_line
2295
2296
2297 @dont_increase_indentation
2298 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2299     """Split standalone comments from the rest of the line."""
2300     if not line.contains_standalone_comments(0):
2301         raise CannotSplit("Line does not have any standalone comments")
2302
2303     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2304
2305     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2306         """Append `leaf` to current line or to new line if appending impossible."""
2307         nonlocal current_line
2308         try:
2309             current_line.append_safe(leaf, preformatted=True)
2310         except ValueError as ve:
2311             yield current_line
2312
2313             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2314             current_line.append(leaf)
2315
2316     for index, leaf in enumerate(line.leaves):
2317         yield from append_to_line(leaf)
2318
2319         for comment_after in line.comments_after(leaf, index):
2320             yield from append_to_line(comment_after)
2321
2322     if current_line:
2323         yield current_line
2324
2325
2326 def is_import(leaf: Leaf) -> bool:
2327     """Return True if the given leaf starts an import statement."""
2328     p = leaf.parent
2329     t = leaf.type
2330     v = leaf.value
2331     return bool(
2332         t == token.NAME
2333         and (
2334             (v == "import" and p and p.type == syms.import_name)
2335             or (v == "from" and p and p.type == syms.import_from)
2336         )
2337     )
2338
2339
2340 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2341     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2342     else.
2343
2344     Note: don't use backslashes for formatting or you'll lose your voting rights.
2345     """
2346     if not inside_brackets:
2347         spl = leaf.prefix.split("#")
2348         if "\\" not in spl[0]:
2349             nl_count = spl[-1].count("\n")
2350             if len(spl) > 1:
2351                 nl_count -= 1
2352             leaf.prefix = "\n" * nl_count
2353             return
2354
2355     leaf.prefix = ""
2356
2357
2358 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2359     """Make all string prefixes lowercase.
2360
2361     If remove_u_prefix is given, also removes any u prefix from the string.
2362
2363     Note: Mutates its argument.
2364     """
2365     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2366     assert match is not None, f"failed to match string {leaf.value!r}"
2367     orig_prefix = match.group(1)
2368     new_prefix = orig_prefix.lower()
2369     if remove_u_prefix:
2370         new_prefix = new_prefix.replace("u", "")
2371     leaf.value = f"{new_prefix}{match.group(2)}"
2372
2373
2374 def normalize_string_quotes(leaf: Leaf) -> None:
2375     """Prefer double quotes but only if it doesn't cause more escaping.
2376
2377     Adds or removes backslashes as appropriate. Doesn't parse and fix
2378     strings nested in f-strings (yet).
2379
2380     Note: Mutates its argument.
2381     """
2382     value = leaf.value.lstrip("furbFURB")
2383     if value[:3] == '"""':
2384         return
2385
2386     elif value[:3] == "'''":
2387         orig_quote = "'''"
2388         new_quote = '"""'
2389     elif value[0] == '"':
2390         orig_quote = '"'
2391         new_quote = "'"
2392     else:
2393         orig_quote = "'"
2394         new_quote = '"'
2395     first_quote_pos = leaf.value.find(orig_quote)
2396     if first_quote_pos == -1:
2397         return  # There's an internal error
2398
2399     prefix = leaf.value[:first_quote_pos]
2400     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2401     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2402     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2403     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2404     if "r" in prefix.casefold():
2405         if unescaped_new_quote.search(body):
2406             # There's at least one unescaped new_quote in this raw string
2407             # so converting is impossible
2408             return
2409
2410         # Do not introduce or remove backslashes in raw strings
2411         new_body = body
2412     else:
2413         # remove unnecessary quotes
2414         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2415         if body != new_body:
2416             # Consider the string without unnecessary quotes as the original
2417             body = new_body
2418             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2419         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2420         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2421     if new_quote == '"""' and new_body[-1] == '"':
2422         # edge case:
2423         new_body = new_body[:-1] + '\\"'
2424     orig_escape_count = body.count("\\")
2425     new_escape_count = new_body.count("\\")
2426     if new_escape_count > orig_escape_count:
2427         return  # Do not introduce more escaping
2428
2429     if new_escape_count == orig_escape_count and orig_quote == '"':
2430         return  # Prefer double quotes
2431
2432     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2433
2434
2435 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2436     """Make existing optional parentheses invisible or create new ones.
2437
2438     `parens_after` is a set of string leaf values immeditely after which parens
2439     should be put.
2440
2441     Standardizes on visible parentheses for single-element tuples, and keeps
2442     existing visible parentheses for other tuples and generator expressions.
2443     """
2444     try:
2445         list(generate_comments(node))
2446     except FormatOff:
2447         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2448
2449     check_lpar = False
2450     for index, child in enumerate(list(node.children)):
2451         if check_lpar:
2452             if child.type == syms.atom:
2453                 maybe_make_parens_invisible_in_atom(child)
2454             elif is_one_tuple(child):
2455                 # wrap child in visible parentheses
2456                 lpar = Leaf(token.LPAR, "(")
2457                 rpar = Leaf(token.RPAR, ")")
2458                 child.remove()
2459                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2460             elif node.type == syms.import_from:
2461                 # "import from" nodes store parentheses directly as part of
2462                 # the statement
2463                 if child.type == token.LPAR:
2464                     # make parentheses invisible
2465                     child.value = ""  # type: ignore
2466                     node.children[-1].value = ""  # type: ignore
2467                 elif child.type != token.STAR:
2468                     # insert invisible parentheses
2469                     node.insert_child(index, Leaf(token.LPAR, ""))
2470                     node.append_child(Leaf(token.RPAR, ""))
2471                 break
2472
2473             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2474                 # wrap child in invisible parentheses
2475                 lpar = Leaf(token.LPAR, "")
2476                 rpar = Leaf(token.RPAR, "")
2477                 index = child.remove() or 0
2478                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2479
2480         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2481
2482
2483 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2484     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2485     if (
2486         node.type != syms.atom
2487         or is_empty_tuple(node)
2488         or is_one_tuple(node)
2489         or is_yield(node)
2490         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2491     ):
2492         return False
2493
2494     first = node.children[0]
2495     last = node.children[-1]
2496     if first.type == token.LPAR and last.type == token.RPAR:
2497         # make parentheses invisible
2498         first.value = ""  # type: ignore
2499         last.value = ""  # type: ignore
2500         if len(node.children) > 1:
2501             maybe_make_parens_invisible_in_atom(node.children[1])
2502         return True
2503
2504     return False
2505
2506
2507 def is_empty_tuple(node: LN) -> bool:
2508     """Return True if `node` holds an empty tuple."""
2509     return (
2510         node.type == syms.atom
2511         and len(node.children) == 2
2512         and node.children[0].type == token.LPAR
2513         and node.children[1].type == token.RPAR
2514     )
2515
2516
2517 def is_one_tuple(node: LN) -> bool:
2518     """Return True if `node` holds a tuple with one element, with or without parens."""
2519     if node.type == syms.atom:
2520         if len(node.children) != 3:
2521             return False
2522
2523         lpar, gexp, rpar = node.children
2524         if not (
2525             lpar.type == token.LPAR
2526             and gexp.type == syms.testlist_gexp
2527             and rpar.type == token.RPAR
2528         ):
2529             return False
2530
2531         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2532
2533     return (
2534         node.type in IMPLICIT_TUPLE
2535         and len(node.children) == 2
2536         and node.children[1].type == token.COMMA
2537     )
2538
2539
2540 def is_yield(node: LN) -> bool:
2541     """Return True if `node` holds a `yield` or `yield from` expression."""
2542     if node.type == syms.yield_expr:
2543         return True
2544
2545     if node.type == token.NAME and node.value == "yield":  # type: ignore
2546         return True
2547
2548     if node.type != syms.atom:
2549         return False
2550
2551     if len(node.children) != 3:
2552         return False
2553
2554     lpar, expr, rpar = node.children
2555     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2556         return is_yield(expr)
2557
2558     return False
2559
2560
2561 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2562     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2563
2564     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2565     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2566     extended iterable unpacking (PEP 3132) and additional unpacking
2567     generalizations (PEP 448).
2568     """
2569     if leaf.type not in STARS or not leaf.parent:
2570         return False
2571
2572     p = leaf.parent
2573     if p.type == syms.star_expr:
2574         # Star expressions are also used as assignment targets in extended
2575         # iterable unpacking (PEP 3132).  See what its parent is instead.
2576         if not p.parent:
2577             return False
2578
2579         p = p.parent
2580
2581     return p.type in within
2582
2583
2584 def is_multiline_string(leaf: Leaf) -> bool:
2585     """Return True if `leaf` is a multiline string that actually spans many lines."""
2586     value = leaf.value.lstrip("furbFURB")
2587     return value[:3] in {'"""', "'''"} and "\n" in value
2588
2589
2590 def is_stub_suite(node: Node) -> bool:
2591     """Return True if `node` is a suite with a stub body."""
2592     if (
2593         len(node.children) != 4
2594         or node.children[0].type != token.NEWLINE
2595         or node.children[1].type != token.INDENT
2596         or node.children[3].type != token.DEDENT
2597     ):
2598         return False
2599
2600     return is_stub_body(node.children[2])
2601
2602
2603 def is_stub_body(node: LN) -> bool:
2604     """Return True if `node` is a simple statement containing an ellipsis."""
2605     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2606         return False
2607
2608     if len(node.children) != 2:
2609         return False
2610
2611     child = node.children[0]
2612     return (
2613         child.type == syms.atom
2614         and len(child.children) == 3
2615         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2616     )
2617
2618
2619 def max_delimiter_priority_in_atom(node: LN) -> int:
2620     """Return maximum delimiter priority inside `node`.
2621
2622     This is specific to atoms with contents contained in a pair of parentheses.
2623     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2624     """
2625     if node.type != syms.atom:
2626         return 0
2627
2628     first = node.children[0]
2629     last = node.children[-1]
2630     if not (first.type == token.LPAR and last.type == token.RPAR):
2631         return 0
2632
2633     bt = BracketTracker()
2634     for c in node.children[1:-1]:
2635         if isinstance(c, Leaf):
2636             bt.mark(c)
2637         else:
2638             for leaf in c.leaves():
2639                 bt.mark(leaf)
2640     try:
2641         return bt.max_delimiter_priority()
2642
2643     except ValueError:
2644         return 0
2645
2646
2647 def ensure_visible(leaf: Leaf) -> None:
2648     """Make sure parentheses are visible.
2649
2650     They could be invisible as part of some statements (see
2651     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2652     """
2653     if leaf.type == token.LPAR:
2654         leaf.value = "("
2655     elif leaf.type == token.RPAR:
2656         leaf.value = ")"
2657
2658
2659 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2660     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2661     if not (
2662         opening_bracket.parent
2663         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2664         and opening_bracket.value in "[{("
2665     ):
2666         return False
2667
2668     try:
2669         last_leaf = line.leaves[-1]
2670         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2671         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2672     except (IndexError, ValueError):
2673         return False
2674
2675     return max_priority == COMMA_PRIORITY
2676
2677
2678 def is_python36(node: Node) -> bool:
2679     """Return True if the current file is using Python 3.6+ features.
2680
2681     Currently looking for:
2682     - f-strings; and
2683     - trailing commas after * or ** in function signatures and calls.
2684     """
2685     for n in node.pre_order():
2686         if n.type == token.STRING:
2687             value_head = n.value[:2]  # type: ignore
2688             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2689                 return True
2690
2691         elif (
2692             n.type in {syms.typedargslist, syms.arglist}
2693             and n.children
2694             and n.children[-1].type == token.COMMA
2695         ):
2696             for ch in n.children:
2697                 if ch.type in STARS:
2698                     return True
2699
2700                 if ch.type == syms.argument:
2701                     for argch in ch.children:
2702                         if argch.type in STARS:
2703                             return True
2704
2705     return False
2706
2707
2708 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2709     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2710
2711     Brackets can be omitted if the entire trailer up to and including
2712     a preceding closing bracket fits in one line.
2713
2714     Yielded sets are cumulative (contain results of previous yields, too).  First
2715     set is empty.
2716     """
2717
2718     omit: Set[LeafID] = set()
2719     yield omit
2720
2721     length = 4 * line.depth
2722     opening_bracket = None
2723     closing_bracket = None
2724     optional_brackets: Set[LeafID] = set()
2725     inner_brackets: Set[LeafID] = set()
2726     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2727         length += leaf_length
2728         if length > line_length:
2729             break
2730
2731         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2732         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2733             break
2734
2735         optional_brackets.discard(id(leaf))
2736         if opening_bracket:
2737             if leaf is opening_bracket:
2738                 opening_bracket = None
2739             elif leaf.type in CLOSING_BRACKETS:
2740                 inner_brackets.add(id(leaf))
2741         elif leaf.type in CLOSING_BRACKETS:
2742             if not leaf.value:
2743                 optional_brackets.add(id(opening_bracket))
2744                 continue
2745
2746             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2747                 # Empty brackets would fail a split so treat them as "inner"
2748                 # brackets (e.g. only add them to the `omit` set if another
2749                 # pair of brackets was good enough.
2750                 inner_brackets.add(id(leaf))
2751                 continue
2752
2753             opening_bracket = leaf.opening_bracket
2754             if closing_bracket:
2755                 omit.add(id(closing_bracket))
2756                 omit.update(inner_brackets)
2757                 inner_brackets.clear()
2758                 yield omit
2759             closing_bracket = leaf
2760
2761
2762 def get_future_imports(node: Node) -> Set[str]:
2763     """Return a set of __future__ imports in the file."""
2764     imports = set()
2765     for child in node.children:
2766         if child.type != syms.simple_stmt:
2767             break
2768         first_child = child.children[0]
2769         if isinstance(first_child, Leaf):
2770             # Continue looking if we see a docstring; otherwise stop.
2771             if (
2772                 len(child.children) == 2
2773                 and first_child.type == token.STRING
2774                 and child.children[1].type == token.NEWLINE
2775             ):
2776                 continue
2777             else:
2778                 break
2779         elif first_child.type == syms.import_from:
2780             module_name = first_child.children[1]
2781             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2782                 break
2783             for import_from_child in first_child.children[3:]:
2784                 if isinstance(import_from_child, Leaf):
2785                     if import_from_child.type == token.NAME:
2786                         imports.add(import_from_child.value)
2787                 else:
2788                     assert import_from_child.type == syms.import_as_names
2789                     for leaf in import_from_child.children:
2790                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2791                             imports.add(leaf.value)
2792         else:
2793             break
2794     return imports
2795
2796
2797 def gen_python_files_in_dir(
2798     path: Path, root: Path, include: Pattern[str], exclude: Pattern[str]
2799 ) -> Iterator[Path]:
2800     """Generate all files under `path` whose paths are not excluded by the
2801     `exclude` regex, but are included by the `include` regex.
2802     """
2803     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2804     for child in path.iterdir():
2805         normalized_path = child.resolve().relative_to(root).as_posix()
2806         if child.is_dir():
2807             normalized_path += "/"
2808         exclude_match = exclude.search(normalized_path)
2809         if exclude_match and exclude_match.group(0):
2810             continue
2811
2812         if child.is_dir():
2813             yield from gen_python_files_in_dir(child, root, include, exclude)
2814
2815         elif child.is_file():
2816             include_match = include.search(normalized_path)
2817             if include_match:
2818                 yield child
2819
2820
2821 def find_project_root(srcs: List[str]) -> Path:
2822     """Return a directory containing .git, .hg, or pyproject.toml.
2823
2824     That directory can be one of the directories passed in `srcs` or their
2825     common parent.
2826
2827     If no directory in the tree contains a marker that would specify it's the
2828     project root, the root of the file system is returned.
2829     """
2830     if not srcs:
2831         return Path("/").resolve()
2832
2833     common_base = min(Path(src).resolve() for src in srcs)
2834     if common_base.is_dir():
2835         # Append a fake file so `parents` below returns `common_base_dir`, too.
2836         common_base /= "fake-file"
2837     for directory in common_base.parents:
2838         if (directory / ".git").is_dir():
2839             return directory
2840
2841         if (directory / ".hg").is_dir():
2842             return directory
2843
2844         if (directory / "pyproject.toml").is_file():
2845             return directory
2846
2847     return directory
2848
2849
2850 @dataclass
2851 class Report:
2852     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2853
2854     check: bool = False
2855     quiet: bool = False
2856     change_count: int = 0
2857     same_count: int = 0
2858     failure_count: int = 0
2859
2860     def done(self, src: Path, changed: Changed) -> None:
2861         """Increment the counter for successful reformatting. Write out a message."""
2862         if changed is Changed.YES:
2863             reformatted = "would reformat" if self.check else "reformatted"
2864             if not self.quiet:
2865                 out(f"{reformatted} {src}")
2866             self.change_count += 1
2867         else:
2868             if not self.quiet:
2869                 if changed is Changed.NO:
2870                     msg = f"{src} already well formatted, good job."
2871                 else:
2872                     msg = f"{src} wasn't modified on disk since last run."
2873                 out(msg, bold=False)
2874             self.same_count += 1
2875
2876     def failed(self, src: Path, message: str) -> None:
2877         """Increment the counter for failed reformatting. Write out a message."""
2878         err(f"error: cannot format {src}: {message}")
2879         self.failure_count += 1
2880
2881     @property
2882     def return_code(self) -> int:
2883         """Return the exit code that the app should use.
2884
2885         This considers the current state of changed files and failures:
2886         - if there were any failures, return 123;
2887         - if any files were changed and --check is being used, return 1;
2888         - otherwise return 0.
2889         """
2890         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2891         # 126 we have special returncodes reserved by the shell.
2892         if self.failure_count:
2893             return 123
2894
2895         elif self.change_count and self.check:
2896             return 1
2897
2898         return 0
2899
2900     def __str__(self) -> str:
2901         """Render a color report of the current state.
2902
2903         Use `click.unstyle` to remove colors.
2904         """
2905         if self.check:
2906             reformatted = "would be reformatted"
2907             unchanged = "would be left unchanged"
2908             failed = "would fail to reformat"
2909         else:
2910             reformatted = "reformatted"
2911             unchanged = "left unchanged"
2912             failed = "failed to reformat"
2913         report = []
2914         if self.change_count:
2915             s = "s" if self.change_count > 1 else ""
2916             report.append(
2917                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2918             )
2919         if self.same_count:
2920             s = "s" if self.same_count > 1 else ""
2921             report.append(f"{self.same_count} file{s} {unchanged}")
2922         if self.failure_count:
2923             s = "s" if self.failure_count > 1 else ""
2924             report.append(
2925                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2926             )
2927         return ", ".join(report) + "."
2928
2929
2930 def assert_equivalent(src: str, dst: str) -> None:
2931     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2932
2933     import ast
2934     import traceback
2935
2936     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2937         """Simple visitor generating strings to compare ASTs by content."""
2938         yield f"{'  ' * depth}{node.__class__.__name__}("
2939
2940         for field in sorted(node._fields):
2941             try:
2942                 value = getattr(node, field)
2943             except AttributeError:
2944                 continue
2945
2946             yield f"{'  ' * (depth+1)}{field}="
2947
2948             if isinstance(value, list):
2949                 for item in value:
2950                     if isinstance(item, ast.AST):
2951                         yield from _v(item, depth + 2)
2952
2953             elif isinstance(value, ast.AST):
2954                 yield from _v(value, depth + 2)
2955
2956             else:
2957                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2958
2959         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2960
2961     try:
2962         src_ast = ast.parse(src)
2963     except Exception as exc:
2964         major, minor = sys.version_info[:2]
2965         raise AssertionError(
2966             f"cannot use --safe with this file; failed to parse source file "
2967             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2968             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2969         )
2970
2971     try:
2972         dst_ast = ast.parse(dst)
2973     except Exception as exc:
2974         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2975         raise AssertionError(
2976             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2977             f"Please report a bug on https://github.com/ambv/black/issues.  "
2978             f"This invalid output might be helpful: {log}"
2979         ) from None
2980
2981     src_ast_str = "\n".join(_v(src_ast))
2982     dst_ast_str = "\n".join(_v(dst_ast))
2983     if src_ast_str != dst_ast_str:
2984         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2985         raise AssertionError(
2986             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2987             f"the source.  "
2988             f"Please report a bug on https://github.com/ambv/black/issues.  "
2989             f"This diff might be helpful: {log}"
2990         ) from None
2991
2992
2993 def assert_stable(
2994     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
2995 ) -> None:
2996     """Raise AssertionError if `dst` reformats differently the second time."""
2997     newdst = format_str(dst, line_length=line_length, mode=mode)
2998     if dst != newdst:
2999         log = dump_to_file(
3000             diff(src, dst, "source", "first pass"),
3001             diff(dst, newdst, "first pass", "second pass"),
3002         )
3003         raise AssertionError(
3004             f"INTERNAL ERROR: Black produced different code on the second pass "
3005             f"of the formatter.  "
3006             f"Please report a bug on https://github.com/ambv/black/issues.  "
3007             f"This diff might be helpful: {log}"
3008         ) from None
3009
3010
3011 def dump_to_file(*output: str) -> str:
3012     """Dump `output` to a temporary file. Return path to the file."""
3013     import tempfile
3014
3015     with tempfile.NamedTemporaryFile(
3016         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3017     ) as f:
3018         for lines in output:
3019             f.write(lines)
3020             if lines and lines[-1] != "\n":
3021                 f.write("\n")
3022     return f.name
3023
3024
3025 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3026     """Return a unified diff string between strings `a` and `b`."""
3027     import difflib
3028
3029     a_lines = [line + "\n" for line in a.split("\n")]
3030     b_lines = [line + "\n" for line in b.split("\n")]
3031     return "".join(
3032         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3033     )
3034
3035
3036 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3037     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3038     err("Aborted!")
3039     for task in tasks:
3040         task.cancel()
3041
3042
3043 def shutdown(loop: BaseEventLoop) -> None:
3044     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3045     try:
3046         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3047         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3048         if not to_cancel:
3049             return
3050
3051         for task in to_cancel:
3052             task.cancel()
3053         loop.run_until_complete(
3054             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3055         )
3056     finally:
3057         # `concurrent.futures.Future` objects cannot be cancelled once they
3058         # are already running. There might be some when the `shutdown()` happened.
3059         # Silence their logger's spew about the event loop being closed.
3060         cf_logger = logging.getLogger("concurrent.futures")
3061         cf_logger.setLevel(logging.CRITICAL)
3062         loop.close()
3063
3064
3065 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3066     """Replace `regex` with `replacement` twice on `original`.
3067
3068     This is used by string normalization to perform replaces on
3069     overlapping matches.
3070     """
3071     return regex.sub(replacement, regex.sub(replacement, original))
3072
3073
3074 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3075     """Like `reversed(enumerate(sequence))` if that were possible."""
3076     index = len(sequence) - 1
3077     for element in reversed(sequence):
3078         yield (index, element)
3079         index -= 1
3080
3081
3082 def enumerate_with_length(
3083     line: Line, reversed: bool = False
3084 ) -> Iterator[Tuple[Index, Leaf, int]]:
3085     """Return an enumeration of leaves with their length.
3086
3087     Stops prematurely on multiline strings and standalone comments.
3088     """
3089     op = cast(
3090         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3091         enumerate_reversed if reversed else enumerate,
3092     )
3093     for index, leaf in op(line.leaves):
3094         length = len(leaf.prefix) + len(leaf.value)
3095         if "\n" in leaf.value:
3096             return  # Multiline strings, we can't continue.
3097
3098         comment: Optional[Leaf]
3099         for comment in line.comments_after(leaf, index):
3100             length += len(comment.value)
3101
3102         yield index, leaf, length
3103
3104
3105 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3106     """Return True if `line` is no longer than `line_length`.
3107
3108     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3109     """
3110     if not line_str:
3111         line_str = str(line).strip("\n")
3112     return (
3113         len(line_str) <= line_length
3114         and "\n" not in line_str  # multiline strings
3115         and not line.contains_standalone_comments()
3116     )
3117
3118
3119 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3120     """Does `line` have a shape safe to reformat without optional parens around it?
3121
3122     Returns True for only a subset of potentially nice looking formattings but
3123     the point is to not return false positives that end up producing lines that
3124     are too long.
3125     """
3126     bt = line.bracket_tracker
3127     if not bt.delimiters:
3128         # Without delimiters the optional parentheses are useless.
3129         return True
3130
3131     max_priority = bt.max_delimiter_priority()
3132     if bt.delimiter_count_with_priority(max_priority) > 1:
3133         # With more than one delimiter of a kind the optional parentheses read better.
3134         return False
3135
3136     if max_priority == DOT_PRIORITY:
3137         # A single stranded method call doesn't require optional parentheses.
3138         return True
3139
3140     assert len(line.leaves) >= 2, "Stranded delimiter"
3141
3142     first = line.leaves[0]
3143     second = line.leaves[1]
3144     penultimate = line.leaves[-2]
3145     last = line.leaves[-1]
3146
3147     # With a single delimiter, omit if the expression starts or ends with
3148     # a bracket.
3149     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3150         remainder = False
3151         length = 4 * line.depth
3152         for _index, leaf, leaf_length in enumerate_with_length(line):
3153             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3154                 remainder = True
3155             if remainder:
3156                 length += leaf_length
3157                 if length > line_length:
3158                     break
3159
3160                 if leaf.type in OPENING_BRACKETS:
3161                     # There are brackets we can further split on.
3162                     remainder = False
3163
3164         else:
3165             # checked the entire string and line length wasn't exceeded
3166             if len(line.leaves) == _index + 1:
3167                 return True
3168
3169         # Note: we are not returning False here because a line might have *both*
3170         # a leading opening bracket and a trailing closing bracket.  If the
3171         # opening bracket doesn't match our rule, maybe the closing will.
3172
3173     if (
3174         last.type == token.RPAR
3175         or last.type == token.RBRACE
3176         or (
3177             # don't use indexing for omitting optional parentheses;
3178             # it looks weird
3179             last.type == token.RSQB
3180             and last.parent
3181             and last.parent.type != syms.trailer
3182         )
3183     ):
3184         if penultimate.type in OPENING_BRACKETS:
3185             # Empty brackets don't help.
3186             return False
3187
3188         if is_multiline_string(first):
3189             # Additional wrapping of a multiline string in this situation is
3190             # unnecessary.
3191             return True
3192
3193         length = 4 * line.depth
3194         seen_other_brackets = False
3195         for _index, leaf, leaf_length in enumerate_with_length(line):
3196             length += leaf_length
3197             if leaf is last.opening_bracket:
3198                 if seen_other_brackets or length <= line_length:
3199                     return True
3200
3201             elif leaf.type in OPENING_BRACKETS:
3202                 # There are brackets we can further split on.
3203                 seen_other_brackets = True
3204
3205     return False
3206
3207
3208 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3209     pyi = bool(mode & FileMode.PYI)
3210     py36 = bool(mode & FileMode.PYTHON36)
3211     return (
3212         CACHE_DIR
3213         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3214     )
3215
3216
3217 def read_cache(line_length: int, mode: FileMode) -> Cache:
3218     """Read the cache if it exists and is well formed.
3219
3220     If it is not well formed, the call to write_cache later should resolve the issue.
3221     """
3222     cache_file = get_cache_file(line_length, mode)
3223     if not cache_file.exists():
3224         return {}
3225
3226     with cache_file.open("rb") as fobj:
3227         try:
3228             cache: Cache = pickle.load(fobj)
3229         except pickle.UnpicklingError:
3230             return {}
3231
3232     return cache
3233
3234
3235 def get_cache_info(path: Path) -> CacheInfo:
3236     """Return the information used to check if a file is already formatted or not."""
3237     stat = path.stat()
3238     return stat.st_mtime, stat.st_size
3239
3240
3241 def filter_cached(
3242     cache: Cache, sources: Iterable[Path]
3243 ) -> Tuple[List[Path], List[Path]]:
3244     """Split a list of paths into two.
3245
3246     The first list contains paths of files that modified on disk or are not in the
3247     cache. The other list contains paths to non-modified files.
3248     """
3249     todo, done = [], []
3250     for src in sources:
3251         src = src.resolve()
3252         if cache.get(src) != get_cache_info(src):
3253             todo.append(src)
3254         else:
3255             done.append(src)
3256     return todo, done
3257
3258
3259 def write_cache(
3260     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3261 ) -> None:
3262     """Update the cache file."""
3263     cache_file = get_cache_file(line_length, mode)
3264     try:
3265         if not CACHE_DIR.exists():
3266             CACHE_DIR.mkdir(parents=True)
3267         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3268         with cache_file.open("wb") as fobj:
3269             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3270     except OSError:
3271         pass
3272
3273
3274 if __name__ == "__main__":
3275     main()