]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Reorder command-line options
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b1"
48 DEFAULT_LINE_LENGTH = 88
49 DEFAULT_EXCLUDES = (
50     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
51 )
52 DEFAULT_INCLUDES = r"\.pyi?$"
53 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
54
55
56 # types
57 FileContent = str
58 Encoding = str
59 Depth = int
60 NodeType = int
61 LeafID = int
62 Priority = int
63 Index = int
64 LN = Union[Leaf, Node]
65 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
66 Timestamp = float
67 FileSize = int
68 CacheInfo = Tuple[Timestamp, FileSize]
69 Cache = Dict[Path, CacheInfo]
70 out = partial(click.secho, bold=True, err=True)
71 err = partial(click.secho, fg="red", err=True)
72
73 pygram.initialize(CACHE_DIR)
74 syms = pygram.python_symbols
75
76
77 class NothingChanged(UserWarning):
78     """Raised by :func:`format_file` when reformatted code is the same as source."""
79
80
81 class CannotSplit(Exception):
82     """A readable split that fits the allotted line length is impossible.
83
84     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
85     :func:`delimiter_split`.
86     """
87
88
89 class FormatError(Exception):
90     """Base exception for `# fmt: on` and `# fmt: off` handling.
91
92     It holds the number of bytes of the prefix consumed before the format
93     control comment appeared.
94     """
95
96     def __init__(self, consumed: int) -> None:
97         super().__init__(consumed)
98         self.consumed = consumed
99
100     def trim_prefix(self, leaf: Leaf) -> None:
101         leaf.prefix = leaf.prefix[self.consumed :]
102
103     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
104         """Returns a new Leaf from the consumed part of the prefix."""
105         unformatted_prefix = leaf.prefix[: self.consumed]
106         return Leaf(token.NEWLINE, unformatted_prefix)
107
108
109 class FormatOn(FormatError):
110     """Found a comment like `# fmt: on` in the file."""
111
112
113 class FormatOff(FormatError):
114     """Found a comment like `# fmt: off` in the file."""
115
116
117 class WriteBack(Enum):
118     NO = 0
119     YES = 1
120     DIFF = 2
121
122
123 class Changed(Enum):
124     NO = 0
125     CACHED = 1
126     YES = 2
127
128
129 class FileMode(Flag):
130     AUTO_DETECT = 0
131     PYTHON36 = 1
132     PYI = 2
133     NO_STRING_NORMALIZATION = 4
134
135
136 @click.command()
137 @click.option(
138     "-l",
139     "--line-length",
140     type=int,
141     default=DEFAULT_LINE_LENGTH,
142     help="How many character per line to allow.",
143     show_default=True,
144 )
145 @click.option(
146     "--py36",
147     is_flag=True,
148     help=(
149         "Allow using Python 3.6-only syntax on all input files.  This will put "
150         "trailing commas in function signatures and calls also after *args and "
151         "**kwargs.  [default: per-file auto-detection]"
152     ),
153 )
154 @click.option(
155     "--pyi",
156     is_flag=True,
157     help=(
158         "Format all input files like typing stubs regardless of file extension "
159         "(useful when piping source on standard input)."
160     ),
161 )
162 @click.option(
163     "-S",
164     "--skip-string-normalization",
165     is_flag=True,
166     help="Don't normalize string quotes or prefixes.",
167 )
168 @click.option(
169     "--check",
170     is_flag=True,
171     help=(
172         "Don't write the files back, just return the status.  Return code 0 "
173         "means nothing would change.  Return code 1 means some files would be "
174         "reformatted.  Return code 123 means there was an internal error."
175     ),
176 )
177 @click.option(
178     "--diff",
179     is_flag=True,
180     help="Don't write the files back, just output a diff for each file on stdout.",
181 )
182 @click.option(
183     "--fast/--safe",
184     is_flag=True,
185     help="If --fast given, skip temporary sanity checks. [default: --safe]",
186 )
187 @click.option(
188     "--include",
189     type=str,
190     default=DEFAULT_INCLUDES,
191     help=(
192         "A regular expression that matches files and directories that should be "
193         "included on recursive searches. On Windows, use forward slashes for "
194         "directories."
195     ),
196     show_default=True,
197 )
198 @click.option(
199     "--exclude",
200     type=str,
201     default=DEFAULT_EXCLUDES,
202     help=(
203         "A regular expression that matches files and directories that should be "
204         "excluded on recursive searches. On Windows, use forward slashes for "
205         "directories."
206     ),
207     show_default=True,
208 )
209 @click.option(
210     "-q",
211     "--quiet",
212     is_flag=True,
213     help=(
214         "Don't emit non-error messages to stderr. Errors are still emitted, "
215         "silence those with 2>/dev/null."
216     ),
217 )
218 @click.version_option(version=__version__)
219 @click.argument(
220     "src",
221     nargs=-1,
222     type=click.Path(
223         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
224     ),
225 )
226 @click.pass_context
227 def main(
228     ctx: click.Context,
229     line_length: int,
230     check: bool,
231     diff: bool,
232     fast: bool,
233     pyi: bool,
234     py36: bool,
235     skip_string_normalization: bool,
236     quiet: bool,
237     include: str,
238     exclude: str,
239     src: List[str],
240 ) -> None:
241     """The uncompromising code formatter."""
242     sources: List[Path] = []
243     try:
244         include_regex = re.compile(include)
245     except re.error:
246         err(f"Invalid regular expression for include given: {include!r}")
247         ctx.exit(2)
248     try:
249         exclude_regex = re.compile(exclude)
250     except re.error:
251         err(f"Invalid regular expression for exclude given: {exclude!r}")
252         ctx.exit(2)
253     for s in src:
254         p = Path(s)
255         if p.is_dir():
256             sources.extend(gen_python_files_in_dir(p, include_regex, exclude_regex))
257         elif p.is_file():
258             # if a file was explicitly given, we don't care about its extension
259             sources.append(p)
260         elif s == "-":
261             sources.append(Path("-"))
262         else:
263             err(f"invalid path: {s}")
264
265     if check and not diff:
266         write_back = WriteBack.NO
267     elif diff:
268         write_back = WriteBack.DIFF
269     else:
270         write_back = WriteBack.YES
271     mode = FileMode.AUTO_DETECT
272     if py36:
273         mode |= FileMode.PYTHON36
274     if pyi:
275         mode |= FileMode.PYI
276     if skip_string_normalization:
277         mode |= FileMode.NO_STRING_NORMALIZATION
278     report = Report(check=check, quiet=quiet)
279     if len(sources) == 0:
280         out("No paths given. Nothing to do 😴")
281         ctx.exit(0)
282         return
283
284     elif len(sources) == 1:
285         reformat_one(
286             src=sources[0],
287             line_length=line_length,
288             fast=fast,
289             write_back=write_back,
290             mode=mode,
291             report=report,
292         )
293     else:
294         loop = asyncio.get_event_loop()
295         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
296         try:
297             loop.run_until_complete(
298                 schedule_formatting(
299                     sources=sources,
300                     line_length=line_length,
301                     fast=fast,
302                     write_back=write_back,
303                     mode=mode,
304                     report=report,
305                     loop=loop,
306                     executor=executor,
307                 )
308             )
309         finally:
310             shutdown(loop)
311         if not quiet:
312             out("All done! ✨ 🍰 ✨")
313             click.echo(str(report))
314     ctx.exit(report.return_code)
315
316
317 def reformat_one(
318     src: Path,
319     line_length: int,
320     fast: bool,
321     write_back: WriteBack,
322     mode: FileMode,
323     report: "Report",
324 ) -> None:
325     """Reformat a single file under `src` without spawning child processes.
326
327     If `quiet` is True, non-error messages are not output. `line_length`,
328     `write_back`, `fast` and `pyi` options are passed to
329     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
330     """
331     try:
332         changed = Changed.NO
333         if not src.is_file() and str(src) == "-":
334             if format_stdin_to_stdout(
335                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
336             ):
337                 changed = Changed.YES
338         else:
339             cache: Cache = {}
340             if write_back != WriteBack.DIFF:
341                 cache = read_cache(line_length, mode)
342                 src = src.resolve()
343                 if src in cache and cache[src] == get_cache_info(src):
344                     changed = Changed.CACHED
345             if changed is not Changed.CACHED and format_file_in_place(
346                 src,
347                 line_length=line_length,
348                 fast=fast,
349                 write_back=write_back,
350                 mode=mode,
351             ):
352                 changed = Changed.YES
353             if write_back == WriteBack.YES and changed is not Changed.NO:
354                 write_cache(cache, [src], line_length, mode)
355         report.done(src, changed)
356     except Exception as exc:
357         report.failed(src, str(exc))
358
359
360 async def schedule_formatting(
361     sources: List[Path],
362     line_length: int,
363     fast: bool,
364     write_back: WriteBack,
365     mode: FileMode,
366     report: "Report",
367     loop: BaseEventLoop,
368     executor: Executor,
369 ) -> None:
370     """Run formatting of `sources` in parallel using the provided `executor`.
371
372     (Use ProcessPoolExecutors for actual parallelism.)
373
374     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
375     :func:`format_file_in_place`.
376     """
377     cache: Cache = {}
378     if write_back != WriteBack.DIFF:
379         cache = read_cache(line_length, mode)
380         sources, cached = filter_cached(cache, sources)
381         for src in cached:
382             report.done(src, Changed.CACHED)
383     cancelled = []
384     formatted = []
385     if sources:
386         lock = None
387         if write_back == WriteBack.DIFF:
388             # For diff output, we need locks to ensure we don't interleave output
389             # from different processes.
390             manager = Manager()
391             lock = manager.Lock()
392         tasks = {
393             loop.run_in_executor(
394                 executor,
395                 format_file_in_place,
396                 src,
397                 line_length,
398                 fast,
399                 write_back,
400                 mode,
401                 lock,
402             ): src
403             for src in sorted(sources)
404         }
405         pending: Iterable[asyncio.Task] = tasks.keys()
406         try:
407             loop.add_signal_handler(signal.SIGINT, cancel, pending)
408             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
409         except NotImplementedError:
410             # There are no good alternatives for these on Windows
411             pass
412         while pending:
413             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
414             for task in done:
415                 src = tasks.pop(task)
416                 if task.cancelled():
417                     cancelled.append(task)
418                 elif task.exception():
419                     report.failed(src, str(task.exception()))
420                 else:
421                     formatted.append(src)
422                     report.done(src, Changed.YES if task.result() else Changed.NO)
423     if cancelled:
424         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
425     if write_back == WriteBack.YES and formatted:
426         write_cache(cache, formatted, line_length, mode)
427
428
429 def format_file_in_place(
430     src: Path,
431     line_length: int,
432     fast: bool,
433     write_back: WriteBack = WriteBack.NO,
434     mode: FileMode = FileMode.AUTO_DETECT,
435     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
436 ) -> bool:
437     """Format file under `src` path. Return True if changed.
438
439     If `write_back` is True, write reformatted code back to stdout.
440     `line_length` and `fast` options are passed to :func:`format_file_contents`.
441     """
442     if src.suffix == ".pyi":
443         mode |= FileMode.PYI
444     with tokenize.open(src) as src_buffer:
445         src_contents = src_buffer.read()
446     try:
447         dst_contents = format_file_contents(
448             src_contents, line_length=line_length, fast=fast, mode=mode
449         )
450     except NothingChanged:
451         return False
452
453     if write_back == write_back.YES:
454         with open(src, "w", encoding=src_buffer.encoding) as f:
455             f.write(dst_contents)
456     elif write_back == write_back.DIFF:
457         src_name = f"{src}  (original)"
458         dst_name = f"{src}  (formatted)"
459         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
460         if lock:
461             lock.acquire()
462         try:
463             sys.stdout.write(diff_contents)
464         finally:
465             if lock:
466                 lock.release()
467     return True
468
469
470 def format_stdin_to_stdout(
471     line_length: int,
472     fast: bool,
473     write_back: WriteBack = WriteBack.NO,
474     mode: FileMode = FileMode.AUTO_DETECT,
475 ) -> bool:
476     """Format file on stdin. Return True if changed.
477
478     If `write_back` is True, write reformatted code back to stdout.
479     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
480     :func:`format_file_contents`.
481     """
482     src = sys.stdin.read()
483     dst = src
484     try:
485         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
486         return True
487
488     except NothingChanged:
489         return False
490
491     finally:
492         if write_back == WriteBack.YES:
493             sys.stdout.write(dst)
494         elif write_back == WriteBack.DIFF:
495             src_name = "<stdin>  (original)"
496             dst_name = "<stdin>  (formatted)"
497             sys.stdout.write(diff(src, dst, src_name, dst_name))
498
499
500 def format_file_contents(
501     src_contents: str,
502     *,
503     line_length: int,
504     fast: bool,
505     mode: FileMode = FileMode.AUTO_DETECT,
506 ) -> FileContent:
507     """Reformat contents a file and return new contents.
508
509     If `fast` is False, additionally confirm that the reformatted code is
510     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
511     `line_length` is passed to :func:`format_str`.
512     """
513     if src_contents.strip() == "":
514         raise NothingChanged
515
516     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
517     if src_contents == dst_contents:
518         raise NothingChanged
519
520     if not fast:
521         assert_equivalent(src_contents, dst_contents)
522         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
523     return dst_contents
524
525
526 def format_str(
527     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
528 ) -> FileContent:
529     """Reformat a string and return new contents.
530
531     `line_length` determines how many characters per line are allowed.
532     """
533     src_node = lib2to3_parse(src_contents)
534     dst_contents = ""
535     future_imports = get_future_imports(src_node)
536     is_pyi = bool(mode & FileMode.PYI)
537     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
538     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
539     lines = LineGenerator(
540         remove_u_prefix=py36 or "unicode_literals" in future_imports,
541         is_pyi=is_pyi,
542         normalize_strings=normalize_strings,
543     )
544     elt = EmptyLineTracker(is_pyi=is_pyi)
545     empty_line = Line()
546     after = 0
547     for current_line in lines.visit(src_node):
548         for _ in range(after):
549             dst_contents += str(empty_line)
550         before, after = elt.maybe_empty_lines(current_line)
551         for _ in range(before):
552             dst_contents += str(empty_line)
553         for line in split_line(current_line, line_length=line_length, py36=py36):
554             dst_contents += str(line)
555     return dst_contents
556
557
558 GRAMMARS = [
559     pygram.python_grammar_no_print_statement_no_exec_statement,
560     pygram.python_grammar_no_print_statement,
561     pygram.python_grammar,
562 ]
563
564
565 def lib2to3_parse(src_txt: str) -> Node:
566     """Given a string with source, return the lib2to3 Node."""
567     grammar = pygram.python_grammar_no_print_statement
568     if src_txt[-1] != "\n":
569         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
570         src_txt += nl
571     for grammar in GRAMMARS:
572         drv = driver.Driver(grammar, pytree.convert)
573         try:
574             result = drv.parse_string(src_txt, True)
575             break
576
577         except ParseError as pe:
578             lineno, column = pe.context[1]
579             lines = src_txt.splitlines()
580             try:
581                 faulty_line = lines[lineno - 1]
582             except IndexError:
583                 faulty_line = "<line number missing in source>"
584             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
585     else:
586         raise exc from None
587
588     if isinstance(result, Leaf):
589         result = Node(syms.file_input, [result])
590     return result
591
592
593 def lib2to3_unparse(node: Node) -> str:
594     """Given a lib2to3 node, return its string representation."""
595     code = str(node)
596     return code
597
598
599 T = TypeVar("T")
600
601
602 class Visitor(Generic[T]):
603     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
604
605     def visit(self, node: LN) -> Iterator[T]:
606         """Main method to visit `node` and its children.
607
608         It tries to find a `visit_*()` method for the given `node.type`, like
609         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
610         If no dedicated `visit_*()` method is found, chooses `visit_default()`
611         instead.
612
613         Then yields objects of type `T` from the selected visitor.
614         """
615         if node.type < 256:
616             name = token.tok_name[node.type]
617         else:
618             name = type_repr(node.type)
619         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
620
621     def visit_default(self, node: LN) -> Iterator[T]:
622         """Default `visit_*()` implementation. Recurses to children of `node`."""
623         if isinstance(node, Node):
624             for child in node.children:
625                 yield from self.visit(child)
626
627
628 @dataclass
629 class DebugVisitor(Visitor[T]):
630     tree_depth: int = 0
631
632     def visit_default(self, node: LN) -> Iterator[T]:
633         indent = " " * (2 * self.tree_depth)
634         if isinstance(node, Node):
635             _type = type_repr(node.type)
636             out(f"{indent}{_type}", fg="yellow")
637             self.tree_depth += 1
638             for child in node.children:
639                 yield from self.visit(child)
640
641             self.tree_depth -= 1
642             out(f"{indent}/{_type}", fg="yellow", bold=False)
643         else:
644             _type = token.tok_name.get(node.type, str(node.type))
645             out(f"{indent}{_type}", fg="blue", nl=False)
646             if node.prefix:
647                 # We don't have to handle prefixes for `Node` objects since
648                 # that delegates to the first child anyway.
649                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
650             out(f" {node.value!r}", fg="blue", bold=False)
651
652     @classmethod
653     def show(cls, code: str) -> None:
654         """Pretty-print the lib2to3 AST of a given string of `code`.
655
656         Convenience method for debugging.
657         """
658         v: DebugVisitor[None] = DebugVisitor()
659         list(v.visit(lib2to3_parse(code)))
660
661
662 KEYWORDS = set(keyword.kwlist)
663 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
664 FLOW_CONTROL = {"return", "raise", "break", "continue"}
665 STATEMENT = {
666     syms.if_stmt,
667     syms.while_stmt,
668     syms.for_stmt,
669     syms.try_stmt,
670     syms.except_clause,
671     syms.with_stmt,
672     syms.funcdef,
673     syms.classdef,
674 }
675 STANDALONE_COMMENT = 153
676 LOGIC_OPERATORS = {"and", "or"}
677 COMPARATORS = {
678     token.LESS,
679     token.GREATER,
680     token.EQEQUAL,
681     token.NOTEQUAL,
682     token.LESSEQUAL,
683     token.GREATEREQUAL,
684 }
685 MATH_OPERATORS = {
686     token.VBAR,
687     token.CIRCUMFLEX,
688     token.AMPER,
689     token.LEFTSHIFT,
690     token.RIGHTSHIFT,
691     token.PLUS,
692     token.MINUS,
693     token.STAR,
694     token.SLASH,
695     token.DOUBLESLASH,
696     token.PERCENT,
697     token.AT,
698     token.TILDE,
699     token.DOUBLESTAR,
700 }
701 STARS = {token.STAR, token.DOUBLESTAR}
702 VARARGS_PARENTS = {
703     syms.arglist,
704     syms.argument,  # double star in arglist
705     syms.trailer,  # single argument to call
706     syms.typedargslist,
707     syms.varargslist,  # lambdas
708 }
709 UNPACKING_PARENTS = {
710     syms.atom,  # single element of a list or set literal
711     syms.dictsetmaker,
712     syms.listmaker,
713     syms.testlist_gexp,
714 }
715 TEST_DESCENDANTS = {
716     syms.test,
717     syms.lambdef,
718     syms.or_test,
719     syms.and_test,
720     syms.not_test,
721     syms.comparison,
722     syms.star_expr,
723     syms.expr,
724     syms.xor_expr,
725     syms.and_expr,
726     syms.shift_expr,
727     syms.arith_expr,
728     syms.trailer,
729     syms.term,
730     syms.power,
731 }
732 ASSIGNMENTS = {
733     "=",
734     "+=",
735     "-=",
736     "*=",
737     "@=",
738     "/=",
739     "%=",
740     "&=",
741     "|=",
742     "^=",
743     "<<=",
744     ">>=",
745     "**=",
746     "//=",
747 }
748 COMPREHENSION_PRIORITY = 20
749 COMMA_PRIORITY = 18
750 TERNARY_PRIORITY = 16
751 LOGIC_PRIORITY = 14
752 STRING_PRIORITY = 12
753 COMPARATOR_PRIORITY = 10
754 MATH_PRIORITIES = {
755     token.VBAR: 9,
756     token.CIRCUMFLEX: 8,
757     token.AMPER: 7,
758     token.LEFTSHIFT: 6,
759     token.RIGHTSHIFT: 6,
760     token.PLUS: 5,
761     token.MINUS: 5,
762     token.STAR: 4,
763     token.SLASH: 4,
764     token.DOUBLESLASH: 4,
765     token.PERCENT: 4,
766     token.AT: 4,
767     token.TILDE: 3,
768     token.DOUBLESTAR: 2,
769 }
770 DOT_PRIORITY = 1
771
772
773 @dataclass
774 class BracketTracker:
775     """Keeps track of brackets on a line."""
776
777     depth: int = 0
778     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
779     delimiters: Dict[LeafID, Priority] = Factory(dict)
780     previous: Optional[Leaf] = None
781     _for_loop_variable: int = 0
782     _lambda_arguments: int = 0
783
784     def mark(self, leaf: Leaf) -> None:
785         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
786
787         All leaves receive an int `bracket_depth` field that stores how deep
788         within brackets a given leaf is. 0 means there are no enclosing brackets
789         that started on this line.
790
791         If a leaf is itself a closing bracket, it receives an `opening_bracket`
792         field that it forms a pair with. This is a one-directional link to
793         avoid reference cycles.
794
795         If a leaf is a delimiter (a token on which Black can split the line if
796         needed) and it's on depth 0, its `id()` is stored in the tracker's
797         `delimiters` field.
798         """
799         if leaf.type == token.COMMENT:
800             return
801
802         self.maybe_decrement_after_for_loop_variable(leaf)
803         self.maybe_decrement_after_lambda_arguments(leaf)
804         if leaf.type in CLOSING_BRACKETS:
805             self.depth -= 1
806             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
807             leaf.opening_bracket = opening_bracket
808         leaf.bracket_depth = self.depth
809         if self.depth == 0:
810             delim = is_split_before_delimiter(leaf, self.previous)
811             if delim and self.previous is not None:
812                 self.delimiters[id(self.previous)] = delim
813             else:
814                 delim = is_split_after_delimiter(leaf, self.previous)
815                 if delim:
816                     self.delimiters[id(leaf)] = delim
817         if leaf.type in OPENING_BRACKETS:
818             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
819             self.depth += 1
820         self.previous = leaf
821         self.maybe_increment_lambda_arguments(leaf)
822         self.maybe_increment_for_loop_variable(leaf)
823
824     def any_open_brackets(self) -> bool:
825         """Return True if there is an yet unmatched open bracket on the line."""
826         return bool(self.bracket_match)
827
828     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
829         """Return the highest priority of a delimiter found on the line.
830
831         Values are consistent with what `is_split_*_delimiter()` return.
832         Raises ValueError on no delimiters.
833         """
834         return max(v for k, v in self.delimiters.items() if k not in exclude)
835
836     def delimiter_count_with_priority(self, priority: int = 0) -> int:
837         """Return the number of delimiters with the given `priority`.
838
839         If no `priority` is passed, defaults to max priority on the line.
840         """
841         if not self.delimiters:
842             return 0
843
844         priority = priority or self.max_delimiter_priority()
845         return sum(1 for p in self.delimiters.values() if p == priority)
846
847     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
848         """In a for loop, or comprehension, the variables are often unpacks.
849
850         To avoid splitting on the comma in this situation, increase the depth of
851         tokens between `for` and `in`.
852         """
853         if leaf.type == token.NAME and leaf.value == "for":
854             self.depth += 1
855             self._for_loop_variable += 1
856             return True
857
858         return False
859
860     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
861         """See `maybe_increment_for_loop_variable` above for explanation."""
862         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
863             self.depth -= 1
864             self._for_loop_variable -= 1
865             return True
866
867         return False
868
869     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
870         """In a lambda expression, there might be more than one argument.
871
872         To avoid splitting on the comma in this situation, increase the depth of
873         tokens between `lambda` and `:`.
874         """
875         if leaf.type == token.NAME and leaf.value == "lambda":
876             self.depth += 1
877             self._lambda_arguments += 1
878             return True
879
880         return False
881
882     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
883         """See `maybe_increment_lambda_arguments` above for explanation."""
884         if self._lambda_arguments and leaf.type == token.COLON:
885             self.depth -= 1
886             self._lambda_arguments -= 1
887             return True
888
889         return False
890
891     def get_open_lsqb(self) -> Optional[Leaf]:
892         """Return the most recent opening square bracket (if any)."""
893         return self.bracket_match.get((self.depth - 1, token.RSQB))
894
895
896 @dataclass
897 class Line:
898     """Holds leaves and comments. Can be printed with `str(line)`."""
899
900     depth: int = 0
901     leaves: List[Leaf] = Factory(list)
902     comments: List[Tuple[Index, Leaf]] = Factory(list)
903     bracket_tracker: BracketTracker = Factory(BracketTracker)
904     inside_brackets: bool = False
905     should_explode: bool = False
906
907     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
908         """Add a new `leaf` to the end of the line.
909
910         Unless `preformatted` is True, the `leaf` will receive a new consistent
911         whitespace prefix and metadata applied by :class:`BracketTracker`.
912         Trailing commas are maybe removed, unpacked for loop variables are
913         demoted from being delimiters.
914
915         Inline comments are put aside.
916         """
917         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
918         if not has_value:
919             return
920
921         if token.COLON == leaf.type and self.is_class_paren_empty:
922             del self.leaves[-2:]
923         if self.leaves and not preformatted:
924             # Note: at this point leaf.prefix should be empty except for
925             # imports, for which we only preserve newlines.
926             leaf.prefix += whitespace(
927                 leaf, complex_subscript=self.is_complex_subscript(leaf)
928             )
929         if self.inside_brackets or not preformatted:
930             self.bracket_tracker.mark(leaf)
931             self.maybe_remove_trailing_comma(leaf)
932         if not self.append_comment(leaf):
933             self.leaves.append(leaf)
934
935     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
936         """Like :func:`append()` but disallow invalid standalone comment structure.
937
938         Raises ValueError when any `leaf` is appended after a standalone comment
939         or when a standalone comment is not the first leaf on the line.
940         """
941         if self.bracket_tracker.depth == 0:
942             if self.is_comment:
943                 raise ValueError("cannot append to standalone comments")
944
945             if self.leaves and leaf.type == STANDALONE_COMMENT:
946                 raise ValueError(
947                     "cannot append standalone comments to a populated line"
948                 )
949
950         self.append(leaf, preformatted=preformatted)
951
952     @property
953     def is_comment(self) -> bool:
954         """Is this line a standalone comment?"""
955         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
956
957     @property
958     def is_decorator(self) -> bool:
959         """Is this line a decorator?"""
960         return bool(self) and self.leaves[0].type == token.AT
961
962     @property
963     def is_import(self) -> bool:
964         """Is this an import line?"""
965         return bool(self) and is_import(self.leaves[0])
966
967     @property
968     def is_class(self) -> bool:
969         """Is this line a class definition?"""
970         return (
971             bool(self)
972             and self.leaves[0].type == token.NAME
973             and self.leaves[0].value == "class"
974         )
975
976     @property
977     def is_stub_class(self) -> bool:
978         """Is this line a class definition with a body consisting only of "..."?"""
979         return self.is_class and self.leaves[-3:] == [
980             Leaf(token.DOT, ".") for _ in range(3)
981         ]
982
983     @property
984     def is_def(self) -> bool:
985         """Is this a function definition? (Also returns True for async defs.)"""
986         try:
987             first_leaf = self.leaves[0]
988         except IndexError:
989             return False
990
991         try:
992             second_leaf: Optional[Leaf] = self.leaves[1]
993         except IndexError:
994             second_leaf = None
995         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
996             first_leaf.type == token.ASYNC
997             and second_leaf is not None
998             and second_leaf.type == token.NAME
999             and second_leaf.value == "def"
1000         )
1001
1002     @property
1003     def is_class_paren_empty(self) -> bool:
1004         """Is this a class with no base classes but using parentheses?
1005
1006         Those are unnecessary and should be removed.
1007         """
1008         return (
1009             bool(self)
1010             and len(self.leaves) == 4
1011             and self.is_class
1012             and self.leaves[2].type == token.LPAR
1013             and self.leaves[2].value == "("
1014             and self.leaves[3].type == token.RPAR
1015             and self.leaves[3].value == ")"
1016         )
1017
1018     @property
1019     def is_triple_quoted_string(self) -> bool:
1020         """Is the line a triple quoted string?"""
1021         return (
1022             bool(self)
1023             and self.leaves[0].type == token.STRING
1024             and self.leaves[0].value.startswith(('"""', "'''"))
1025         )
1026
1027     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1028         """If so, needs to be split before emitting."""
1029         for leaf in self.leaves:
1030             if leaf.type == STANDALONE_COMMENT:
1031                 if leaf.bracket_depth <= depth_limit:
1032                     return True
1033
1034         return False
1035
1036     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1037         """Remove trailing comma if there is one and it's safe."""
1038         if not (
1039             self.leaves
1040             and self.leaves[-1].type == token.COMMA
1041             and closing.type in CLOSING_BRACKETS
1042         ):
1043             return False
1044
1045         if closing.type == token.RBRACE:
1046             self.remove_trailing_comma()
1047             return True
1048
1049         if closing.type == token.RSQB:
1050             comma = self.leaves[-1]
1051             if comma.parent and comma.parent.type == syms.listmaker:
1052                 self.remove_trailing_comma()
1053                 return True
1054
1055         # For parens let's check if it's safe to remove the comma.
1056         # Imports are always safe.
1057         if self.is_import:
1058             self.remove_trailing_comma()
1059             return True
1060
1061         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1062         # change a tuple into a different type by removing the comma.
1063         depth = closing.bracket_depth + 1
1064         commas = 0
1065         opening = closing.opening_bracket
1066         for _opening_index, leaf in enumerate(self.leaves):
1067             if leaf is opening:
1068                 break
1069
1070         else:
1071             return False
1072
1073         for leaf in self.leaves[_opening_index + 1 :]:
1074             if leaf is closing:
1075                 break
1076
1077             bracket_depth = leaf.bracket_depth
1078             if bracket_depth == depth and leaf.type == token.COMMA:
1079                 commas += 1
1080                 if leaf.parent and leaf.parent.type == syms.arglist:
1081                     commas += 1
1082                     break
1083
1084         if commas > 1:
1085             self.remove_trailing_comma()
1086             return True
1087
1088         return False
1089
1090     def append_comment(self, comment: Leaf) -> bool:
1091         """Add an inline or standalone comment to the line."""
1092         if (
1093             comment.type == STANDALONE_COMMENT
1094             and self.bracket_tracker.any_open_brackets()
1095         ):
1096             comment.prefix = ""
1097             return False
1098
1099         if comment.type != token.COMMENT:
1100             return False
1101
1102         after = len(self.leaves) - 1
1103         if after == -1:
1104             comment.type = STANDALONE_COMMENT
1105             comment.prefix = ""
1106             return False
1107
1108         else:
1109             self.comments.append((after, comment))
1110             return True
1111
1112     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1113         """Generate comments that should appear directly after `leaf`.
1114
1115         Provide a non-negative leaf `_index` to speed up the function.
1116         """
1117         if _index == -1:
1118             for _index, _leaf in enumerate(self.leaves):
1119                 if leaf is _leaf:
1120                     break
1121
1122             else:
1123                 return
1124
1125         for index, comment_after in self.comments:
1126             if _index == index:
1127                 yield comment_after
1128
1129     def remove_trailing_comma(self) -> None:
1130         """Remove the trailing comma and moves the comments attached to it."""
1131         comma_index = len(self.leaves) - 1
1132         for i in range(len(self.comments)):
1133             comment_index, comment = self.comments[i]
1134             if comment_index == comma_index:
1135                 self.comments[i] = (comma_index - 1, comment)
1136         self.leaves.pop()
1137
1138     def is_complex_subscript(self, leaf: Leaf) -> bool:
1139         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1140         open_lsqb = (
1141             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1142         )
1143         if open_lsqb is None:
1144             return False
1145
1146         subscript_start = open_lsqb.next_sibling
1147         if (
1148             isinstance(subscript_start, Node)
1149             and subscript_start.type == syms.subscriptlist
1150         ):
1151             subscript_start = child_towards(subscript_start, leaf)
1152         return subscript_start is not None and any(
1153             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1154         )
1155
1156     def __str__(self) -> str:
1157         """Render the line."""
1158         if not self:
1159             return "\n"
1160
1161         indent = "    " * self.depth
1162         leaves = iter(self.leaves)
1163         first = next(leaves)
1164         res = f"{first.prefix}{indent}{first.value}"
1165         for leaf in leaves:
1166             res += str(leaf)
1167         for _, comment in self.comments:
1168             res += str(comment)
1169         return res + "\n"
1170
1171     def __bool__(self) -> bool:
1172         """Return True if the line has leaves or comments."""
1173         return bool(self.leaves or self.comments)
1174
1175
1176 class UnformattedLines(Line):
1177     """Just like :class:`Line` but stores lines which aren't reformatted."""
1178
1179     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1180         """Just add a new `leaf` to the end of the lines.
1181
1182         The `preformatted` argument is ignored.
1183
1184         Keeps track of indentation `depth`, which is useful when the user
1185         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1186         """
1187         try:
1188             list(generate_comments(leaf))
1189         except FormatOn as f_on:
1190             self.leaves.append(f_on.leaf_from_consumed(leaf))
1191             raise
1192
1193         self.leaves.append(leaf)
1194         if leaf.type == token.INDENT:
1195             self.depth += 1
1196         elif leaf.type == token.DEDENT:
1197             self.depth -= 1
1198
1199     def __str__(self) -> str:
1200         """Render unformatted lines from leaves which were added with `append()`.
1201
1202         `depth` is not used for indentation in this case.
1203         """
1204         if not self:
1205             return "\n"
1206
1207         res = ""
1208         for leaf in self.leaves:
1209             res += str(leaf)
1210         return res
1211
1212     def append_comment(self, comment: Leaf) -> bool:
1213         """Not implemented in this class. Raises `NotImplementedError`."""
1214         raise NotImplementedError("Unformatted lines don't store comments separately.")
1215
1216     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1217         """Does nothing and returns False."""
1218         return False
1219
1220     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1221         """Does nothing and returns False."""
1222         return False
1223
1224
1225 @dataclass
1226 class EmptyLineTracker:
1227     """Provides a stateful method that returns the number of potential extra
1228     empty lines needed before and after the currently processed line.
1229
1230     Note: this tracker works on lines that haven't been split yet.  It assumes
1231     the prefix of the first leaf consists of optional newlines.  Those newlines
1232     are consumed by `maybe_empty_lines()` and included in the computation.
1233     """
1234
1235     is_pyi: bool = False
1236     previous_line: Optional[Line] = None
1237     previous_after: int = 0
1238     previous_defs: List[int] = Factory(list)
1239
1240     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1241         """Return the number of extra empty lines before and after the `current_line`.
1242
1243         This is for separating `def`, `async def` and `class` with extra empty
1244         lines (two on module-level).
1245         """
1246         if isinstance(current_line, UnformattedLines):
1247             return 0, 0
1248
1249         before, after = self._maybe_empty_lines(current_line)
1250         before -= self.previous_after
1251         self.previous_after = after
1252         self.previous_line = current_line
1253         return before, after
1254
1255     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1256         max_allowed = 1
1257         if current_line.depth == 0:
1258             max_allowed = 1 if self.is_pyi else 2
1259         if current_line.leaves:
1260             # Consume the first leaf's extra newlines.
1261             first_leaf = current_line.leaves[0]
1262             before = first_leaf.prefix.count("\n")
1263             before = min(before, max_allowed)
1264             first_leaf.prefix = ""
1265         else:
1266             before = 0
1267         depth = current_line.depth
1268         while self.previous_defs and self.previous_defs[-1] >= depth:
1269             self.previous_defs.pop()
1270             if self.is_pyi:
1271                 before = 0 if depth else 1
1272             else:
1273                 before = 1 if depth else 2
1274         is_decorator = current_line.is_decorator
1275         if is_decorator or current_line.is_def or current_line.is_class:
1276             if not is_decorator:
1277                 self.previous_defs.append(depth)
1278             if self.previous_line is None:
1279                 # Don't insert empty lines before the first line in the file.
1280                 return 0, 0
1281
1282             if self.previous_line.is_decorator:
1283                 return 0, 0
1284
1285             if self.previous_line.depth < current_line.depth and (
1286                 self.previous_line.is_class or self.previous_line.is_def
1287             ):
1288                 return 0, 0
1289
1290             if (
1291                 self.previous_line.is_comment
1292                 and self.previous_line.depth == current_line.depth
1293                 and before == 0
1294             ):
1295                 return 0, 0
1296
1297             if self.is_pyi:
1298                 if self.previous_line.depth > current_line.depth:
1299                     newlines = 1
1300                 elif current_line.is_class or self.previous_line.is_class:
1301                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1302                         newlines = 0
1303                     else:
1304                         newlines = 1
1305                 else:
1306                     newlines = 0
1307             else:
1308                 newlines = 2
1309             if current_line.depth and newlines:
1310                 newlines -= 1
1311             return newlines, 0
1312
1313         if (
1314             self.previous_line
1315             and self.previous_line.is_import
1316             and not current_line.is_import
1317             and depth == self.previous_line.depth
1318         ):
1319             return (before or 1), 0
1320
1321         if (
1322             self.previous_line
1323             and self.previous_line.is_class
1324             and current_line.is_triple_quoted_string
1325         ):
1326             return before, 1
1327
1328         return before, 0
1329
1330
1331 @dataclass
1332 class LineGenerator(Visitor[Line]):
1333     """Generates reformatted Line objects.  Empty lines are not emitted.
1334
1335     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1336     in ways that will no longer stringify to valid Python code on the tree.
1337     """
1338
1339     is_pyi: bool = False
1340     normalize_strings: bool = True
1341     current_line: Line = Factory(Line)
1342     remove_u_prefix: bool = False
1343
1344     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1345         """Generate a line.
1346
1347         If the line is empty, only emit if it makes sense.
1348         If the line is too long, split it first and then generate.
1349
1350         If any lines were generated, set up a new current_line.
1351         """
1352         if not self.current_line:
1353             if self.current_line.__class__ == type:
1354                 self.current_line.depth += indent
1355             else:
1356                 self.current_line = type(depth=self.current_line.depth + indent)
1357             return  # Line is empty, don't emit. Creating a new one unnecessary.
1358
1359         complete_line = self.current_line
1360         self.current_line = type(depth=complete_line.depth + indent)
1361         yield complete_line
1362
1363     def visit(self, node: LN) -> Iterator[Line]:
1364         """Main method to visit `node` and its children.
1365
1366         Yields :class:`Line` objects.
1367         """
1368         if isinstance(self.current_line, UnformattedLines):
1369             # File contained `# fmt: off`
1370             yield from self.visit_unformatted(node)
1371
1372         else:
1373             yield from super().visit(node)
1374
1375     def visit_default(self, node: LN) -> Iterator[Line]:
1376         """Default `visit_*()` implementation. Recurses to children of `node`."""
1377         if isinstance(node, Leaf):
1378             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1379             try:
1380                 for comment in generate_comments(node):
1381                     if any_open_brackets:
1382                         # any comment within brackets is subject to splitting
1383                         self.current_line.append(comment)
1384                     elif comment.type == token.COMMENT:
1385                         # regular trailing comment
1386                         self.current_line.append(comment)
1387                         yield from self.line()
1388
1389                     else:
1390                         # regular standalone comment
1391                         yield from self.line()
1392
1393                         self.current_line.append(comment)
1394                         yield from self.line()
1395
1396             except FormatOff as f_off:
1397                 f_off.trim_prefix(node)
1398                 yield from self.line(type=UnformattedLines)
1399                 yield from self.visit(node)
1400
1401             except FormatOn as f_on:
1402                 # This only happens here if somebody says "fmt: on" multiple
1403                 # times in a row.
1404                 f_on.trim_prefix(node)
1405                 yield from self.visit_default(node)
1406
1407             else:
1408                 normalize_prefix(node, inside_brackets=any_open_brackets)
1409                 if self.normalize_strings and node.type == token.STRING:
1410                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1411                     normalize_string_quotes(node)
1412                 if node.type not in WHITESPACE:
1413                     self.current_line.append(node)
1414         yield from super().visit_default(node)
1415
1416     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1417         """Increase indentation level, maybe yield a line."""
1418         # In blib2to3 INDENT never holds comments.
1419         yield from self.line(+1)
1420         yield from self.visit_default(node)
1421
1422     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1423         """Decrease indentation level, maybe yield a line."""
1424         # The current line might still wait for trailing comments.  At DEDENT time
1425         # there won't be any (they would be prefixes on the preceding NEWLINE).
1426         # Emit the line then.
1427         yield from self.line()
1428
1429         # While DEDENT has no value, its prefix may contain standalone comments
1430         # that belong to the current indentation level.  Get 'em.
1431         yield from self.visit_default(node)
1432
1433         # Finally, emit the dedent.
1434         yield from self.line(-1)
1435
1436     def visit_stmt(
1437         self, node: Node, keywords: Set[str], parens: Set[str]
1438     ) -> Iterator[Line]:
1439         """Visit a statement.
1440
1441         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1442         `def`, `with`, `class`, `assert` and assignments.
1443
1444         The relevant Python language `keywords` for a given statement will be
1445         NAME leaves within it. This methods puts those on a separate line.
1446
1447         `parens` holds a set of string leaf values immediately after which
1448         invisible parens should be put.
1449         """
1450         normalize_invisible_parens(node, parens_after=parens)
1451         for child in node.children:
1452             if child.type == token.NAME and child.value in keywords:  # type: ignore
1453                 yield from self.line()
1454
1455             yield from self.visit(child)
1456
1457     def visit_suite(self, node: Node) -> Iterator[Line]:
1458         """Visit a suite."""
1459         if self.is_pyi and is_stub_suite(node):
1460             yield from self.visit(node.children[2])
1461         else:
1462             yield from self.visit_default(node)
1463
1464     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1465         """Visit a statement without nested statements."""
1466         is_suite_like = node.parent and node.parent.type in STATEMENT
1467         if is_suite_like:
1468             if self.is_pyi and is_stub_body(node):
1469                 yield from self.visit_default(node)
1470             else:
1471                 yield from self.line(+1)
1472                 yield from self.visit_default(node)
1473                 yield from self.line(-1)
1474
1475         else:
1476             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1477                 yield from self.line()
1478             yield from self.visit_default(node)
1479
1480     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1481         """Visit `async def`, `async for`, `async with`."""
1482         yield from self.line()
1483
1484         children = iter(node.children)
1485         for child in children:
1486             yield from self.visit(child)
1487
1488             if child.type == token.ASYNC:
1489                 break
1490
1491         internal_stmt = next(children)
1492         for child in internal_stmt.children:
1493             yield from self.visit(child)
1494
1495     def visit_decorators(self, node: Node) -> Iterator[Line]:
1496         """Visit decorators."""
1497         for child in node.children:
1498             yield from self.line()
1499             yield from self.visit(child)
1500
1501     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1502         """Remove a semicolon and put the other statement on a separate line."""
1503         yield from self.line()
1504
1505     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1506         """End of file. Process outstanding comments and end with a newline."""
1507         yield from self.visit_default(leaf)
1508         yield from self.line()
1509
1510     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1511         """Used when file contained a `# fmt: off`."""
1512         if isinstance(node, Node):
1513             for child in node.children:
1514                 yield from self.visit(child)
1515
1516         else:
1517             try:
1518                 self.current_line.append(node)
1519             except FormatOn as f_on:
1520                 f_on.trim_prefix(node)
1521                 yield from self.line()
1522                 yield from self.visit(node)
1523
1524             if node.type == token.ENDMARKER:
1525                 # somebody decided not to put a final `# fmt: on`
1526                 yield from self.line()
1527
1528     def __attrs_post_init__(self) -> None:
1529         """You are in a twisty little maze of passages."""
1530         v = self.visit_stmt
1531         Ø: Set[str] = set()
1532         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1533         self.visit_if_stmt = partial(
1534             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1535         )
1536         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1537         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1538         self.visit_try_stmt = partial(
1539             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1540         )
1541         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1542         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1543         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1544         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1545         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1546         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1547         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1548         self.visit_async_funcdef = self.visit_async_stmt
1549         self.visit_decorated = self.visit_decorators
1550
1551
1552 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1553 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1554 OPENING_BRACKETS = set(BRACKET.keys())
1555 CLOSING_BRACKETS = set(BRACKET.values())
1556 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1557 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1558
1559
1560 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1561     """Return whitespace prefix if needed for the given `leaf`.
1562
1563     `complex_subscript` signals whether the given leaf is part of a subscription
1564     which has non-trivial arguments, like arithmetic expressions or function calls.
1565     """
1566     NO = ""
1567     SPACE = " "
1568     DOUBLESPACE = "  "
1569     t = leaf.type
1570     p = leaf.parent
1571     v = leaf.value
1572     if t in ALWAYS_NO_SPACE:
1573         return NO
1574
1575     if t == token.COMMENT:
1576         return DOUBLESPACE
1577
1578     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1579     if t == token.COLON and p.type not in {
1580         syms.subscript,
1581         syms.subscriptlist,
1582         syms.sliceop,
1583     }:
1584         return NO
1585
1586     prev = leaf.prev_sibling
1587     if not prev:
1588         prevp = preceding_leaf(p)
1589         if not prevp or prevp.type in OPENING_BRACKETS:
1590             return NO
1591
1592         if t == token.COLON:
1593             if prevp.type == token.COLON:
1594                 return NO
1595
1596             elif prevp.type != token.COMMA and not complex_subscript:
1597                 return NO
1598
1599             return SPACE
1600
1601         if prevp.type == token.EQUAL:
1602             if prevp.parent:
1603                 if prevp.parent.type in {
1604                     syms.arglist,
1605                     syms.argument,
1606                     syms.parameters,
1607                     syms.varargslist,
1608                 }:
1609                     return NO
1610
1611                 elif prevp.parent.type == syms.typedargslist:
1612                     # A bit hacky: if the equal sign has whitespace, it means we
1613                     # previously found it's a typed argument.  So, we're using
1614                     # that, too.
1615                     return prevp.prefix
1616
1617         elif prevp.type in STARS:
1618             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1619                 return NO
1620
1621         elif prevp.type == token.COLON:
1622             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1623                 return SPACE if complex_subscript else NO
1624
1625         elif (
1626             prevp.parent
1627             and prevp.parent.type == syms.factor
1628             and prevp.type in MATH_OPERATORS
1629         ):
1630             return NO
1631
1632         elif (
1633             prevp.type == token.RIGHTSHIFT
1634             and prevp.parent
1635             and prevp.parent.type == syms.shift_expr
1636             and prevp.prev_sibling
1637             and prevp.prev_sibling.type == token.NAME
1638             and prevp.prev_sibling.value == "print"  # type: ignore
1639         ):
1640             # Python 2 print chevron
1641             return NO
1642
1643     elif prev.type in OPENING_BRACKETS:
1644         return NO
1645
1646     if p.type in {syms.parameters, syms.arglist}:
1647         # untyped function signatures or calls
1648         if not prev or prev.type != token.COMMA:
1649             return NO
1650
1651     elif p.type == syms.varargslist:
1652         # lambdas
1653         if prev and prev.type != token.COMMA:
1654             return NO
1655
1656     elif p.type == syms.typedargslist:
1657         # typed function signatures
1658         if not prev:
1659             return NO
1660
1661         if t == token.EQUAL:
1662             if prev.type != syms.tname:
1663                 return NO
1664
1665         elif prev.type == token.EQUAL:
1666             # A bit hacky: if the equal sign has whitespace, it means we
1667             # previously found it's a typed argument.  So, we're using that, too.
1668             return prev.prefix
1669
1670         elif prev.type != token.COMMA:
1671             return NO
1672
1673     elif p.type == syms.tname:
1674         # type names
1675         if not prev:
1676             prevp = preceding_leaf(p)
1677             if not prevp or prevp.type != token.COMMA:
1678                 return NO
1679
1680     elif p.type == syms.trailer:
1681         # attributes and calls
1682         if t == token.LPAR or t == token.RPAR:
1683             return NO
1684
1685         if not prev:
1686             if t == token.DOT:
1687                 prevp = preceding_leaf(p)
1688                 if not prevp or prevp.type != token.NUMBER:
1689                     return NO
1690
1691             elif t == token.LSQB:
1692                 return NO
1693
1694         elif prev.type != token.COMMA:
1695             return NO
1696
1697     elif p.type == syms.argument:
1698         # single argument
1699         if t == token.EQUAL:
1700             return NO
1701
1702         if not prev:
1703             prevp = preceding_leaf(p)
1704             if not prevp or prevp.type == token.LPAR:
1705                 return NO
1706
1707         elif prev.type in {token.EQUAL} | STARS:
1708             return NO
1709
1710     elif p.type == syms.decorator:
1711         # decorators
1712         return NO
1713
1714     elif p.type == syms.dotted_name:
1715         if prev:
1716             return NO
1717
1718         prevp = preceding_leaf(p)
1719         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1720             return NO
1721
1722     elif p.type == syms.classdef:
1723         if t == token.LPAR:
1724             return NO
1725
1726         if prev and prev.type == token.LPAR:
1727             return NO
1728
1729     elif p.type in {syms.subscript, syms.sliceop}:
1730         # indexing
1731         if not prev:
1732             assert p.parent is not None, "subscripts are always parented"
1733             if p.parent.type == syms.subscriptlist:
1734                 return SPACE
1735
1736             return NO
1737
1738         elif not complex_subscript:
1739             return NO
1740
1741     elif p.type == syms.atom:
1742         if prev and t == token.DOT:
1743             # dots, but not the first one.
1744             return NO
1745
1746     elif p.type == syms.dictsetmaker:
1747         # dict unpacking
1748         if prev and prev.type == token.DOUBLESTAR:
1749             return NO
1750
1751     elif p.type in {syms.factor, syms.star_expr}:
1752         # unary ops
1753         if not prev:
1754             prevp = preceding_leaf(p)
1755             if not prevp or prevp.type in OPENING_BRACKETS:
1756                 return NO
1757
1758             prevp_parent = prevp.parent
1759             assert prevp_parent is not None
1760             if prevp.type == token.COLON and prevp_parent.type in {
1761                 syms.subscript,
1762                 syms.sliceop,
1763             }:
1764                 return NO
1765
1766             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1767                 return NO
1768
1769         elif t == token.NAME or t == token.NUMBER:
1770             return NO
1771
1772     elif p.type == syms.import_from:
1773         if t == token.DOT:
1774             if prev and prev.type == token.DOT:
1775                 return NO
1776
1777         elif t == token.NAME:
1778             if v == "import":
1779                 return SPACE
1780
1781             if prev and prev.type == token.DOT:
1782                 return NO
1783
1784     elif p.type == syms.sliceop:
1785         return NO
1786
1787     return SPACE
1788
1789
1790 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1791     """Return the first leaf that precedes `node`, if any."""
1792     while node:
1793         res = node.prev_sibling
1794         if res:
1795             if isinstance(res, Leaf):
1796                 return res
1797
1798             try:
1799                 return list(res.leaves())[-1]
1800
1801             except IndexError:
1802                 return None
1803
1804         node = node.parent
1805     return None
1806
1807
1808 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1809     """Return the child of `ancestor` that contains `descendant`."""
1810     node: Optional[LN] = descendant
1811     while node and node.parent != ancestor:
1812         node = node.parent
1813     return node
1814
1815
1816 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1817     """Return the priority of the `leaf` delimiter, given a line break after it.
1818
1819     The delimiter priorities returned here are from those delimiters that would
1820     cause a line break after themselves.
1821
1822     Higher numbers are higher priority.
1823     """
1824     if leaf.type == token.COMMA:
1825         return COMMA_PRIORITY
1826
1827     return 0
1828
1829
1830 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1831     """Return the priority of the `leaf` delimiter, given a line before after it.
1832
1833     The delimiter priorities returned here are from those delimiters that would
1834     cause a line break before themselves.
1835
1836     Higher numbers are higher priority.
1837     """
1838     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1839         # * and ** might also be MATH_OPERATORS but in this case they are not.
1840         # Don't treat them as a delimiter.
1841         return 0
1842
1843     if (
1844         leaf.type == token.DOT
1845         and leaf.parent
1846         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1847         and (previous is None or previous.type in CLOSING_BRACKETS)
1848     ):
1849         return DOT_PRIORITY
1850
1851     if (
1852         leaf.type in MATH_OPERATORS
1853         and leaf.parent
1854         and leaf.parent.type not in {syms.factor, syms.star_expr}
1855     ):
1856         return MATH_PRIORITIES[leaf.type]
1857
1858     if leaf.type in COMPARATORS:
1859         return COMPARATOR_PRIORITY
1860
1861     if (
1862         leaf.type == token.STRING
1863         and previous is not None
1864         and previous.type == token.STRING
1865     ):
1866         return STRING_PRIORITY
1867
1868     if leaf.type != token.NAME:
1869         return 0
1870
1871     if (
1872         leaf.value == "for"
1873         and leaf.parent
1874         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1875     ):
1876         return COMPREHENSION_PRIORITY
1877
1878     if (
1879         leaf.value == "if"
1880         and leaf.parent
1881         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1882     ):
1883         return COMPREHENSION_PRIORITY
1884
1885     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1886         return TERNARY_PRIORITY
1887
1888     if leaf.value == "is":
1889         return COMPARATOR_PRIORITY
1890
1891     if (
1892         leaf.value == "in"
1893         and leaf.parent
1894         and leaf.parent.type in {syms.comp_op, syms.comparison}
1895         and not (
1896             previous is not None
1897             and previous.type == token.NAME
1898             and previous.value == "not"
1899         )
1900     ):
1901         return COMPARATOR_PRIORITY
1902
1903     if (
1904         leaf.value == "not"
1905         and leaf.parent
1906         and leaf.parent.type == syms.comp_op
1907         and not (
1908             previous is not None
1909             and previous.type == token.NAME
1910             and previous.value == "is"
1911         )
1912     ):
1913         return COMPARATOR_PRIORITY
1914
1915     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1916         return LOGIC_PRIORITY
1917
1918     return 0
1919
1920
1921 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1922     """Clean the prefix of the `leaf` and generate comments from it, if any.
1923
1924     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1925     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1926     move because it does away with modifying the grammar to include all the
1927     possible places in which comments can be placed.
1928
1929     The sad consequence for us though is that comments don't "belong" anywhere.
1930     This is why this function generates simple parentless Leaf objects for
1931     comments.  We simply don't know what the correct parent should be.
1932
1933     No matter though, we can live without this.  We really only need to
1934     differentiate between inline and standalone comments.  The latter don't
1935     share the line with any code.
1936
1937     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1938     are emitted with a fake STANDALONE_COMMENT token identifier.
1939     """
1940     p = leaf.prefix
1941     if not p:
1942         return
1943
1944     if "#" not in p:
1945         return
1946
1947     consumed = 0
1948     nlines = 0
1949     for index, line in enumerate(p.split("\n")):
1950         consumed += len(line) + 1  # adding the length of the split '\n'
1951         line = line.lstrip()
1952         if not line:
1953             nlines += 1
1954         if not line.startswith("#"):
1955             continue
1956
1957         if index == 0 and leaf.type != token.ENDMARKER:
1958             comment_type = token.COMMENT  # simple trailing comment
1959         else:
1960             comment_type = STANDALONE_COMMENT
1961         comment = make_comment(line)
1962         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1963
1964         if comment in {"# fmt: on", "# yapf: enable"}:
1965             raise FormatOn(consumed)
1966
1967         if comment in {"# fmt: off", "# yapf: disable"}:
1968             if comment_type == STANDALONE_COMMENT:
1969                 raise FormatOff(consumed)
1970
1971             prev = preceding_leaf(leaf)
1972             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1973                 raise FormatOff(consumed)
1974
1975         nlines = 0
1976
1977
1978 def make_comment(content: str) -> str:
1979     """Return a consistently formatted comment from the given `content` string.
1980
1981     All comments (except for "##", "#!", "#:") should have a single space between
1982     the hash sign and the content.
1983
1984     If `content` didn't start with a hash sign, one is provided.
1985     """
1986     content = content.rstrip()
1987     if not content:
1988         return "#"
1989
1990     if content[0] == "#":
1991         content = content[1:]
1992     if content and content[0] not in " !:#":
1993         content = " " + content
1994     return "#" + content
1995
1996
1997 def split_line(
1998     line: Line, line_length: int, inner: bool = False, py36: bool = False
1999 ) -> Iterator[Line]:
2000     """Split a `line` into potentially many lines.
2001
2002     They should fit in the allotted `line_length` but might not be able to.
2003     `inner` signifies that there were a pair of brackets somewhere around the
2004     current `line`, possibly transitively. This means we can fallback to splitting
2005     by delimiters if the LHS/RHS don't yield any results.
2006
2007     If `py36` is True, splitting may generate syntax that is only compatible
2008     with Python 3.6 and later.
2009     """
2010     if isinstance(line, UnformattedLines) or line.is_comment:
2011         yield line
2012         return
2013
2014     line_str = str(line).strip("\n")
2015     if not line.should_explode and is_line_short_enough(
2016         line, line_length=line_length, line_str=line_str
2017     ):
2018         yield line
2019         return
2020
2021     split_funcs: List[SplitFunc]
2022     if line.is_def:
2023         split_funcs = [left_hand_split]
2024     else:
2025
2026         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2027             for omit in generate_trailers_to_omit(line, line_length):
2028                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2029                 if is_line_short_enough(lines[0], line_length=line_length):
2030                     yield from lines
2031                     return
2032
2033             # All splits failed, best effort split with no omits.
2034             # This mostly happens to multiline strings that are by definition
2035             # reported as not fitting a single line.
2036             yield from right_hand_split(line, py36)
2037
2038         if line.inside_brackets:
2039             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2040         else:
2041             split_funcs = [rhs]
2042     for split_func in split_funcs:
2043         # We are accumulating lines in `result` because we might want to abort
2044         # mission and return the original line in the end, or attempt a different
2045         # split altogether.
2046         result: List[Line] = []
2047         try:
2048             for l in split_func(line, py36):
2049                 if str(l).strip("\n") == line_str:
2050                     raise CannotSplit("Split function returned an unchanged result")
2051
2052                 result.extend(
2053                     split_line(l, line_length=line_length, inner=True, py36=py36)
2054                 )
2055         except CannotSplit as cs:
2056             continue
2057
2058         else:
2059             yield from result
2060             break
2061
2062     else:
2063         yield line
2064
2065
2066 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2067     """Split line into many lines, starting with the first matching bracket pair.
2068
2069     Note: this usually looks weird, only use this for function definitions.
2070     Prefer RHS otherwise.  This is why this function is not symmetrical with
2071     :func:`right_hand_split` which also handles optional parentheses.
2072     """
2073     head = Line(depth=line.depth)
2074     body = Line(depth=line.depth + 1, inside_brackets=True)
2075     tail = Line(depth=line.depth)
2076     tail_leaves: List[Leaf] = []
2077     body_leaves: List[Leaf] = []
2078     head_leaves: List[Leaf] = []
2079     current_leaves = head_leaves
2080     matching_bracket = None
2081     for leaf in line.leaves:
2082         if (
2083             current_leaves is body_leaves
2084             and leaf.type in CLOSING_BRACKETS
2085             and leaf.opening_bracket is matching_bracket
2086         ):
2087             current_leaves = tail_leaves if body_leaves else head_leaves
2088         current_leaves.append(leaf)
2089         if current_leaves is head_leaves:
2090             if leaf.type in OPENING_BRACKETS:
2091                 matching_bracket = leaf
2092                 current_leaves = body_leaves
2093     # Since body is a new indent level, remove spurious leading whitespace.
2094     if body_leaves:
2095         normalize_prefix(body_leaves[0], inside_brackets=True)
2096     # Build the new lines.
2097     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2098         for leaf in leaves:
2099             result.append(leaf, preformatted=True)
2100             for comment_after in line.comments_after(leaf):
2101                 result.append(comment_after, preformatted=True)
2102     bracket_split_succeeded_or_raise(head, body, tail)
2103     for result in (head, body, tail):
2104         if result:
2105             yield result
2106
2107
2108 def right_hand_split(
2109     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2110 ) -> Iterator[Line]:
2111     """Split line into many lines, starting with the last matching bracket pair.
2112
2113     If the split was by optional parentheses, attempt splitting without them, too.
2114     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2115     this split.
2116
2117     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2118     """
2119     head = Line(depth=line.depth)
2120     body = Line(depth=line.depth + 1, inside_brackets=True)
2121     tail = Line(depth=line.depth)
2122     tail_leaves: List[Leaf] = []
2123     body_leaves: List[Leaf] = []
2124     head_leaves: List[Leaf] = []
2125     current_leaves = tail_leaves
2126     opening_bracket = None
2127     closing_bracket = None
2128     for leaf in reversed(line.leaves):
2129         if current_leaves is body_leaves:
2130             if leaf is opening_bracket:
2131                 current_leaves = head_leaves if body_leaves else tail_leaves
2132         current_leaves.append(leaf)
2133         if current_leaves is tail_leaves:
2134             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2135                 opening_bracket = leaf.opening_bracket
2136                 closing_bracket = leaf
2137                 current_leaves = body_leaves
2138     tail_leaves.reverse()
2139     body_leaves.reverse()
2140     head_leaves.reverse()
2141     # Since body is a new indent level, remove spurious leading whitespace.
2142     if body_leaves:
2143         normalize_prefix(body_leaves[0], inside_brackets=True)
2144     if not head_leaves:
2145         # No `head` means the split failed. Either `tail` has all content or
2146         # the matching `opening_bracket` wasn't available on `line` anymore.
2147         raise CannotSplit("No brackets found")
2148
2149     # Build the new lines.
2150     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2151         for leaf in leaves:
2152             result.append(leaf, preformatted=True)
2153             for comment_after in line.comments_after(leaf):
2154                 result.append(comment_after, preformatted=True)
2155     bracket_split_succeeded_or_raise(head, body, tail)
2156     assert opening_bracket and closing_bracket
2157     if (
2158         # the opening bracket is an optional paren
2159         opening_bracket.type == token.LPAR
2160         and not opening_bracket.value
2161         # the closing bracket is an optional paren
2162         and closing_bracket.type == token.RPAR
2163         and not closing_bracket.value
2164         # there are no standalone comments in the body
2165         and not line.contains_standalone_comments(0)
2166         # and it's not an import (optional parens are the only thing we can split
2167         # on in this case; attempting a split without them is a waste of time)
2168         and not line.is_import
2169     ):
2170         omit = {id(closing_bracket), *omit}
2171         if can_omit_invisible_parens(body, line_length):
2172             try:
2173                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2174                 return
2175             except CannotSplit:
2176                 pass
2177
2178     ensure_visible(opening_bracket)
2179     ensure_visible(closing_bracket)
2180     body.should_explode = should_explode(body, opening_bracket)
2181     for result in (head, body, tail):
2182         if result:
2183             yield result
2184
2185
2186 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2187     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2188
2189     Do nothing otherwise.
2190
2191     A left- or right-hand split is based on a pair of brackets. Content before
2192     (and including) the opening bracket is left on one line, content inside the
2193     brackets is put on a separate line, and finally content starting with and
2194     following the closing bracket is put on a separate line.
2195
2196     Those are called `head`, `body`, and `tail`, respectively. If the split
2197     produced the same line (all content in `head`) or ended up with an empty `body`
2198     and the `tail` is just the closing bracket, then it's considered failed.
2199     """
2200     tail_len = len(str(tail).strip())
2201     if not body:
2202         if tail_len == 0:
2203             raise CannotSplit("Splitting brackets produced the same line")
2204
2205         elif tail_len < 3:
2206             raise CannotSplit(
2207                 f"Splitting brackets on an empty body to save "
2208                 f"{tail_len} characters is not worth it"
2209             )
2210
2211
2212 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2213     """Normalize prefix of the first leaf in every line returned by `split_func`.
2214
2215     This is a decorator over relevant split functions.
2216     """
2217
2218     @wraps(split_func)
2219     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2220         for l in split_func(line, py36):
2221             normalize_prefix(l.leaves[0], inside_brackets=True)
2222             yield l
2223
2224     return split_wrapper
2225
2226
2227 @dont_increase_indentation
2228 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2229     """Split according to delimiters of the highest priority.
2230
2231     If `py36` is True, the split will add trailing commas also in function
2232     signatures that contain `*` and `**`.
2233     """
2234     try:
2235         last_leaf = line.leaves[-1]
2236     except IndexError:
2237         raise CannotSplit("Line empty")
2238
2239     bt = line.bracket_tracker
2240     try:
2241         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2242     except ValueError:
2243         raise CannotSplit("No delimiters found")
2244
2245     if delimiter_priority == DOT_PRIORITY:
2246         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2247             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2248
2249     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2250     lowest_depth = sys.maxsize
2251     trailing_comma_safe = True
2252
2253     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2254         """Append `leaf` to current line or to new line if appending impossible."""
2255         nonlocal current_line
2256         try:
2257             current_line.append_safe(leaf, preformatted=True)
2258         except ValueError as ve:
2259             yield current_line
2260
2261             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2262             current_line.append(leaf)
2263
2264     for index, leaf in enumerate(line.leaves):
2265         yield from append_to_line(leaf)
2266
2267         for comment_after in line.comments_after(leaf, index):
2268             yield from append_to_line(comment_after)
2269
2270         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2271         if leaf.bracket_depth == lowest_depth and is_vararg(
2272             leaf, within=VARARGS_PARENTS
2273         ):
2274             trailing_comma_safe = trailing_comma_safe and py36
2275         leaf_priority = bt.delimiters.get(id(leaf))
2276         if leaf_priority == delimiter_priority:
2277             yield current_line
2278
2279             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2280     if current_line:
2281         if (
2282             trailing_comma_safe
2283             and delimiter_priority == COMMA_PRIORITY
2284             and current_line.leaves[-1].type != token.COMMA
2285             and current_line.leaves[-1].type != STANDALONE_COMMENT
2286         ):
2287             current_line.append(Leaf(token.COMMA, ","))
2288         yield current_line
2289
2290
2291 @dont_increase_indentation
2292 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2293     """Split standalone comments from the rest of the line."""
2294     if not line.contains_standalone_comments(0):
2295         raise CannotSplit("Line does not have any standalone comments")
2296
2297     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2298
2299     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2300         """Append `leaf` to current line or to new line if appending impossible."""
2301         nonlocal current_line
2302         try:
2303             current_line.append_safe(leaf, preformatted=True)
2304         except ValueError as ve:
2305             yield current_line
2306
2307             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2308             current_line.append(leaf)
2309
2310     for index, leaf in enumerate(line.leaves):
2311         yield from append_to_line(leaf)
2312
2313         for comment_after in line.comments_after(leaf, index):
2314             yield from append_to_line(comment_after)
2315
2316     if current_line:
2317         yield current_line
2318
2319
2320 def is_import(leaf: Leaf) -> bool:
2321     """Return True if the given leaf starts an import statement."""
2322     p = leaf.parent
2323     t = leaf.type
2324     v = leaf.value
2325     return bool(
2326         t == token.NAME
2327         and (
2328             (v == "import" and p and p.type == syms.import_name)
2329             or (v == "from" and p and p.type == syms.import_from)
2330         )
2331     )
2332
2333
2334 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2335     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2336     else.
2337
2338     Note: don't use backslashes for formatting or you'll lose your voting rights.
2339     """
2340     if not inside_brackets:
2341         spl = leaf.prefix.split("#")
2342         if "\\" not in spl[0]:
2343             nl_count = spl[-1].count("\n")
2344             if len(spl) > 1:
2345                 nl_count -= 1
2346             leaf.prefix = "\n" * nl_count
2347             return
2348
2349     leaf.prefix = ""
2350
2351
2352 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2353     """Make all string prefixes lowercase.
2354
2355     If remove_u_prefix is given, also removes any u prefix from the string.
2356
2357     Note: Mutates its argument.
2358     """
2359     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2360     assert match is not None, f"failed to match string {leaf.value!r}"
2361     orig_prefix = match.group(1)
2362     new_prefix = orig_prefix.lower()
2363     if remove_u_prefix:
2364         new_prefix = new_prefix.replace("u", "")
2365     leaf.value = f"{new_prefix}{match.group(2)}"
2366
2367
2368 def normalize_string_quotes(leaf: Leaf) -> None:
2369     """Prefer double quotes but only if it doesn't cause more escaping.
2370
2371     Adds or removes backslashes as appropriate. Doesn't parse and fix
2372     strings nested in f-strings (yet).
2373
2374     Note: Mutates its argument.
2375     """
2376     value = leaf.value.lstrip("furbFURB")
2377     if value[:3] == '"""':
2378         return
2379
2380     elif value[:3] == "'''":
2381         orig_quote = "'''"
2382         new_quote = '"""'
2383     elif value[0] == '"':
2384         orig_quote = '"'
2385         new_quote = "'"
2386     else:
2387         orig_quote = "'"
2388         new_quote = '"'
2389     first_quote_pos = leaf.value.find(orig_quote)
2390     if first_quote_pos == -1:
2391         return  # There's an internal error
2392
2393     prefix = leaf.value[:first_quote_pos]
2394     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2395     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2396     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2397     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2398     if "r" in prefix.casefold():
2399         if unescaped_new_quote.search(body):
2400             # There's at least one unescaped new_quote in this raw string
2401             # so converting is impossible
2402             return
2403
2404         # Do not introduce or remove backslashes in raw strings
2405         new_body = body
2406     else:
2407         # remove unnecessary quotes
2408         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2409         if body != new_body:
2410             # Consider the string without unnecessary quotes as the original
2411             body = new_body
2412             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2413         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2414         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2415     if new_quote == '"""' and new_body[-1] == '"':
2416         # edge case:
2417         new_body = new_body[:-1] + '\\"'
2418     orig_escape_count = body.count("\\")
2419     new_escape_count = new_body.count("\\")
2420     if new_escape_count > orig_escape_count:
2421         return  # Do not introduce more escaping
2422
2423     if new_escape_count == orig_escape_count and orig_quote == '"':
2424         return  # Prefer double quotes
2425
2426     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2427
2428
2429 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2430     """Make existing optional parentheses invisible or create new ones.
2431
2432     `parens_after` is a set of string leaf values immeditely after which parens
2433     should be put.
2434
2435     Standardizes on visible parentheses for single-element tuples, and keeps
2436     existing visible parentheses for other tuples and generator expressions.
2437     """
2438     try:
2439         list(generate_comments(node))
2440     except FormatOff:
2441         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2442
2443     check_lpar = False
2444     for index, child in enumerate(list(node.children)):
2445         if check_lpar:
2446             if child.type == syms.atom:
2447                 maybe_make_parens_invisible_in_atom(child)
2448             elif is_one_tuple(child):
2449                 # wrap child in visible parentheses
2450                 lpar = Leaf(token.LPAR, "(")
2451                 rpar = Leaf(token.RPAR, ")")
2452                 child.remove()
2453                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2454             elif node.type == syms.import_from:
2455                 # "import from" nodes store parentheses directly as part of
2456                 # the statement
2457                 if child.type == token.LPAR:
2458                     # make parentheses invisible
2459                     child.value = ""  # type: ignore
2460                     node.children[-1].value = ""  # type: ignore
2461                 elif child.type != token.STAR:
2462                     # insert invisible parentheses
2463                     node.insert_child(index, Leaf(token.LPAR, ""))
2464                     node.append_child(Leaf(token.RPAR, ""))
2465                 break
2466
2467             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2468                 # wrap child in invisible parentheses
2469                 lpar = Leaf(token.LPAR, "")
2470                 rpar = Leaf(token.RPAR, "")
2471                 index = child.remove() or 0
2472                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2473
2474         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2475
2476
2477 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2478     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2479     if (
2480         node.type != syms.atom
2481         or is_empty_tuple(node)
2482         or is_one_tuple(node)
2483         or is_yield(node)
2484         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2485     ):
2486         return False
2487
2488     first = node.children[0]
2489     last = node.children[-1]
2490     if first.type == token.LPAR and last.type == token.RPAR:
2491         # make parentheses invisible
2492         first.value = ""  # type: ignore
2493         last.value = ""  # type: ignore
2494         if len(node.children) > 1:
2495             maybe_make_parens_invisible_in_atom(node.children[1])
2496         return True
2497
2498     return False
2499
2500
2501 def is_empty_tuple(node: LN) -> bool:
2502     """Return True if `node` holds an empty tuple."""
2503     return (
2504         node.type == syms.atom
2505         and len(node.children) == 2
2506         and node.children[0].type == token.LPAR
2507         and node.children[1].type == token.RPAR
2508     )
2509
2510
2511 def is_one_tuple(node: LN) -> bool:
2512     """Return True if `node` holds a tuple with one element, with or without parens."""
2513     if node.type == syms.atom:
2514         if len(node.children) != 3:
2515             return False
2516
2517         lpar, gexp, rpar = node.children
2518         if not (
2519             lpar.type == token.LPAR
2520             and gexp.type == syms.testlist_gexp
2521             and rpar.type == token.RPAR
2522         ):
2523             return False
2524
2525         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2526
2527     return (
2528         node.type in IMPLICIT_TUPLE
2529         and len(node.children) == 2
2530         and node.children[1].type == token.COMMA
2531     )
2532
2533
2534 def is_yield(node: LN) -> bool:
2535     """Return True if `node` holds a `yield` or `yield from` expression."""
2536     if node.type == syms.yield_expr:
2537         return True
2538
2539     if node.type == token.NAME and node.value == "yield":  # type: ignore
2540         return True
2541
2542     if node.type != syms.atom:
2543         return False
2544
2545     if len(node.children) != 3:
2546         return False
2547
2548     lpar, expr, rpar = node.children
2549     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2550         return is_yield(expr)
2551
2552     return False
2553
2554
2555 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2556     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2557
2558     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2559     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2560     extended iterable unpacking (PEP 3132) and additional unpacking
2561     generalizations (PEP 448).
2562     """
2563     if leaf.type not in STARS or not leaf.parent:
2564         return False
2565
2566     p = leaf.parent
2567     if p.type == syms.star_expr:
2568         # Star expressions are also used as assignment targets in extended
2569         # iterable unpacking (PEP 3132).  See what its parent is instead.
2570         if not p.parent:
2571             return False
2572
2573         p = p.parent
2574
2575     return p.type in within
2576
2577
2578 def is_multiline_string(leaf: Leaf) -> bool:
2579     """Return True if `leaf` is a multiline string that actually spans many lines."""
2580     value = leaf.value.lstrip("furbFURB")
2581     return value[:3] in {'"""', "'''"} and "\n" in value
2582
2583
2584 def is_stub_suite(node: Node) -> bool:
2585     """Return True if `node` is a suite with a stub body."""
2586     if (
2587         len(node.children) != 4
2588         or node.children[0].type != token.NEWLINE
2589         or node.children[1].type != token.INDENT
2590         or node.children[3].type != token.DEDENT
2591     ):
2592         return False
2593
2594     return is_stub_body(node.children[2])
2595
2596
2597 def is_stub_body(node: LN) -> bool:
2598     """Return True if `node` is a simple statement containing an ellipsis."""
2599     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2600         return False
2601
2602     if len(node.children) != 2:
2603         return False
2604
2605     child = node.children[0]
2606     return (
2607         child.type == syms.atom
2608         and len(child.children) == 3
2609         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2610     )
2611
2612
2613 def max_delimiter_priority_in_atom(node: LN) -> int:
2614     """Return maximum delimiter priority inside `node`.
2615
2616     This is specific to atoms with contents contained in a pair of parentheses.
2617     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2618     """
2619     if node.type != syms.atom:
2620         return 0
2621
2622     first = node.children[0]
2623     last = node.children[-1]
2624     if not (first.type == token.LPAR and last.type == token.RPAR):
2625         return 0
2626
2627     bt = BracketTracker()
2628     for c in node.children[1:-1]:
2629         if isinstance(c, Leaf):
2630             bt.mark(c)
2631         else:
2632             for leaf in c.leaves():
2633                 bt.mark(leaf)
2634     try:
2635         return bt.max_delimiter_priority()
2636
2637     except ValueError:
2638         return 0
2639
2640
2641 def ensure_visible(leaf: Leaf) -> None:
2642     """Make sure parentheses are visible.
2643
2644     They could be invisible as part of some statements (see
2645     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2646     """
2647     if leaf.type == token.LPAR:
2648         leaf.value = "("
2649     elif leaf.type == token.RPAR:
2650         leaf.value = ")"
2651
2652
2653 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2654     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2655     if not (
2656         opening_bracket.parent
2657         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2658         and opening_bracket.value in "[{("
2659     ):
2660         return False
2661
2662     try:
2663         last_leaf = line.leaves[-1]
2664         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2665         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2666     except (IndexError, ValueError):
2667         return False
2668
2669     return max_priority == COMMA_PRIORITY
2670
2671
2672 def is_python36(node: Node) -> bool:
2673     """Return True if the current file is using Python 3.6+ features.
2674
2675     Currently looking for:
2676     - f-strings; and
2677     - trailing commas after * or ** in function signatures and calls.
2678     """
2679     for n in node.pre_order():
2680         if n.type == token.STRING:
2681             value_head = n.value[:2]  # type: ignore
2682             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2683                 return True
2684
2685         elif (
2686             n.type in {syms.typedargslist, syms.arglist}
2687             and n.children
2688             and n.children[-1].type == token.COMMA
2689         ):
2690             for ch in n.children:
2691                 if ch.type in STARS:
2692                     return True
2693
2694                 if ch.type == syms.argument:
2695                     for argch in ch.children:
2696                         if argch.type in STARS:
2697                             return True
2698
2699     return False
2700
2701
2702 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2703     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2704
2705     Brackets can be omitted if the entire trailer up to and including
2706     a preceding closing bracket fits in one line.
2707
2708     Yielded sets are cumulative (contain results of previous yields, too).  First
2709     set is empty.
2710     """
2711
2712     omit: Set[LeafID] = set()
2713     yield omit
2714
2715     length = 4 * line.depth
2716     opening_bracket = None
2717     closing_bracket = None
2718     optional_brackets: Set[LeafID] = set()
2719     inner_brackets: Set[LeafID] = set()
2720     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2721         length += leaf_length
2722         if length > line_length:
2723             break
2724
2725         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2726         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2727             break
2728
2729         optional_brackets.discard(id(leaf))
2730         if opening_bracket:
2731             if leaf is opening_bracket:
2732                 opening_bracket = None
2733             elif leaf.type in CLOSING_BRACKETS:
2734                 inner_brackets.add(id(leaf))
2735         elif leaf.type in CLOSING_BRACKETS:
2736             if not leaf.value:
2737                 optional_brackets.add(id(opening_bracket))
2738                 continue
2739
2740             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2741                 # Empty brackets would fail a split so treat them as "inner"
2742                 # brackets (e.g. only add them to the `omit` set if another
2743                 # pair of brackets was good enough.
2744                 inner_brackets.add(id(leaf))
2745                 continue
2746
2747             opening_bracket = leaf.opening_bracket
2748             if closing_bracket:
2749                 omit.add(id(closing_bracket))
2750                 omit.update(inner_brackets)
2751                 inner_brackets.clear()
2752                 yield omit
2753             closing_bracket = leaf
2754
2755
2756 def get_future_imports(node: Node) -> Set[str]:
2757     """Return a set of __future__ imports in the file."""
2758     imports = set()
2759     for child in node.children:
2760         if child.type != syms.simple_stmt:
2761             break
2762         first_child = child.children[0]
2763         if isinstance(first_child, Leaf):
2764             # Continue looking if we see a docstring; otherwise stop.
2765             if (
2766                 len(child.children) == 2
2767                 and first_child.type == token.STRING
2768                 and child.children[1].type == token.NEWLINE
2769             ):
2770                 continue
2771             else:
2772                 break
2773         elif first_child.type == syms.import_from:
2774             module_name = first_child.children[1]
2775             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2776                 break
2777             for import_from_child in first_child.children[3:]:
2778                 if isinstance(import_from_child, Leaf):
2779                     if import_from_child.type == token.NAME:
2780                         imports.add(import_from_child.value)
2781                 else:
2782                     assert import_from_child.type == syms.import_as_names
2783                     for leaf in import_from_child.children:
2784                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2785                             imports.add(leaf.value)
2786         else:
2787             break
2788     return imports
2789
2790
2791 def gen_python_files_in_dir(
2792     path: Path, include: Pattern[str], exclude: Pattern[str]
2793 ) -> Iterator[Path]:
2794     """Generate all files under `path` whose paths are not excluded by the
2795     `exclude` regex, but are included by the `include` regex.
2796     """
2797
2798     for child in path.iterdir():
2799         searchable_path = str(child.as_posix())
2800         if Path(child.parts[0]).is_dir():
2801             searchable_path = "/" + searchable_path
2802         if child.is_dir():
2803             searchable_path = searchable_path + "/"
2804             exclude_match = exclude.search(searchable_path)
2805             if exclude_match and len(exclude_match.group()) > 0:
2806                 continue
2807
2808             yield from gen_python_files_in_dir(child, include, exclude)
2809
2810         else:
2811             include_match = include.search(searchable_path)
2812             exclude_match = exclude.search(searchable_path)
2813             if (
2814                 child.is_file()
2815                 and include_match
2816                 and len(include_match.group()) > 0
2817                 and (not exclude_match or len(exclude_match.group()) == 0)
2818             ):
2819                 yield child
2820
2821
2822 @dataclass
2823 class Report:
2824     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2825
2826     check: bool = False
2827     quiet: bool = False
2828     change_count: int = 0
2829     same_count: int = 0
2830     failure_count: int = 0
2831
2832     def done(self, src: Path, changed: Changed) -> None:
2833         """Increment the counter for successful reformatting. Write out a message."""
2834         if changed is Changed.YES:
2835             reformatted = "would reformat" if self.check else "reformatted"
2836             if not self.quiet:
2837                 out(f"{reformatted} {src}")
2838             self.change_count += 1
2839         else:
2840             if not self.quiet:
2841                 if changed is Changed.NO:
2842                     msg = f"{src} already well formatted, good job."
2843                 else:
2844                     msg = f"{src} wasn't modified on disk since last run."
2845                 out(msg, bold=False)
2846             self.same_count += 1
2847
2848     def failed(self, src: Path, message: str) -> None:
2849         """Increment the counter for failed reformatting. Write out a message."""
2850         err(f"error: cannot format {src}: {message}")
2851         self.failure_count += 1
2852
2853     @property
2854     def return_code(self) -> int:
2855         """Return the exit code that the app should use.
2856
2857         This considers the current state of changed files and failures:
2858         - if there were any failures, return 123;
2859         - if any files were changed and --check is being used, return 1;
2860         - otherwise return 0.
2861         """
2862         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2863         # 126 we have special returncodes reserved by the shell.
2864         if self.failure_count:
2865             return 123
2866
2867         elif self.change_count and self.check:
2868             return 1
2869
2870         return 0
2871
2872     def __str__(self) -> str:
2873         """Render a color report of the current state.
2874
2875         Use `click.unstyle` to remove colors.
2876         """
2877         if self.check:
2878             reformatted = "would be reformatted"
2879             unchanged = "would be left unchanged"
2880             failed = "would fail to reformat"
2881         else:
2882             reformatted = "reformatted"
2883             unchanged = "left unchanged"
2884             failed = "failed to reformat"
2885         report = []
2886         if self.change_count:
2887             s = "s" if self.change_count > 1 else ""
2888             report.append(
2889                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2890             )
2891         if self.same_count:
2892             s = "s" if self.same_count > 1 else ""
2893             report.append(f"{self.same_count} file{s} {unchanged}")
2894         if self.failure_count:
2895             s = "s" if self.failure_count > 1 else ""
2896             report.append(
2897                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2898             )
2899         return ", ".join(report) + "."
2900
2901
2902 def assert_equivalent(src: str, dst: str) -> None:
2903     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2904
2905     import ast
2906     import traceback
2907
2908     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2909         """Simple visitor generating strings to compare ASTs by content."""
2910         yield f"{'  ' * depth}{node.__class__.__name__}("
2911
2912         for field in sorted(node._fields):
2913             try:
2914                 value = getattr(node, field)
2915             except AttributeError:
2916                 continue
2917
2918             yield f"{'  ' * (depth+1)}{field}="
2919
2920             if isinstance(value, list):
2921                 for item in value:
2922                     if isinstance(item, ast.AST):
2923                         yield from _v(item, depth + 2)
2924
2925             elif isinstance(value, ast.AST):
2926                 yield from _v(value, depth + 2)
2927
2928             else:
2929                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2930
2931         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2932
2933     try:
2934         src_ast = ast.parse(src)
2935     except Exception as exc:
2936         major, minor = sys.version_info[:2]
2937         raise AssertionError(
2938             f"cannot use --safe with this file; failed to parse source file "
2939             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2940             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2941         )
2942
2943     try:
2944         dst_ast = ast.parse(dst)
2945     except Exception as exc:
2946         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2947         raise AssertionError(
2948             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2949             f"Please report a bug on https://github.com/ambv/black/issues.  "
2950             f"This invalid output might be helpful: {log}"
2951         ) from None
2952
2953     src_ast_str = "\n".join(_v(src_ast))
2954     dst_ast_str = "\n".join(_v(dst_ast))
2955     if src_ast_str != dst_ast_str:
2956         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2957         raise AssertionError(
2958             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2959             f"the source.  "
2960             f"Please report a bug on https://github.com/ambv/black/issues.  "
2961             f"This diff might be helpful: {log}"
2962         ) from None
2963
2964
2965 def assert_stable(
2966     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
2967 ) -> None:
2968     """Raise AssertionError if `dst` reformats differently the second time."""
2969     newdst = format_str(dst, line_length=line_length, mode=mode)
2970     if dst != newdst:
2971         log = dump_to_file(
2972             diff(src, dst, "source", "first pass"),
2973             diff(dst, newdst, "first pass", "second pass"),
2974         )
2975         raise AssertionError(
2976             f"INTERNAL ERROR: Black produced different code on the second pass "
2977             f"of the formatter.  "
2978             f"Please report a bug on https://github.com/ambv/black/issues.  "
2979             f"This diff might be helpful: {log}"
2980         ) from None
2981
2982
2983 def dump_to_file(*output: str) -> str:
2984     """Dump `output` to a temporary file. Return path to the file."""
2985     import tempfile
2986
2987     with tempfile.NamedTemporaryFile(
2988         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2989     ) as f:
2990         for lines in output:
2991             f.write(lines)
2992             if lines and lines[-1] != "\n":
2993                 f.write("\n")
2994     return f.name
2995
2996
2997 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2998     """Return a unified diff string between strings `a` and `b`."""
2999     import difflib
3000
3001     a_lines = [line + "\n" for line in a.split("\n")]
3002     b_lines = [line + "\n" for line in b.split("\n")]
3003     return "".join(
3004         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3005     )
3006
3007
3008 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3009     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3010     err("Aborted!")
3011     for task in tasks:
3012         task.cancel()
3013
3014
3015 def shutdown(loop: BaseEventLoop) -> None:
3016     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3017     try:
3018         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3019         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3020         if not to_cancel:
3021             return
3022
3023         for task in to_cancel:
3024             task.cancel()
3025         loop.run_until_complete(
3026             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3027         )
3028     finally:
3029         # `concurrent.futures.Future` objects cannot be cancelled once they
3030         # are already running. There might be some when the `shutdown()` happened.
3031         # Silence their logger's spew about the event loop being closed.
3032         cf_logger = logging.getLogger("concurrent.futures")
3033         cf_logger.setLevel(logging.CRITICAL)
3034         loop.close()
3035
3036
3037 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3038     """Replace `regex` with `replacement` twice on `original`.
3039
3040     This is used by string normalization to perform replaces on
3041     overlapping matches.
3042     """
3043     return regex.sub(replacement, regex.sub(replacement, original))
3044
3045
3046 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3047     """Like `reversed(enumerate(sequence))` if that were possible."""
3048     index = len(sequence) - 1
3049     for element in reversed(sequence):
3050         yield (index, element)
3051         index -= 1
3052
3053
3054 def enumerate_with_length(
3055     line: Line, reversed: bool = False
3056 ) -> Iterator[Tuple[Index, Leaf, int]]:
3057     """Return an enumeration of leaves with their length.
3058
3059     Stops prematurely on multiline strings and standalone comments.
3060     """
3061     op = cast(
3062         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3063         enumerate_reversed if reversed else enumerate,
3064     )
3065     for index, leaf in op(line.leaves):
3066         length = len(leaf.prefix) + len(leaf.value)
3067         if "\n" in leaf.value:
3068             return  # Multiline strings, we can't continue.
3069
3070         comment: Optional[Leaf]
3071         for comment in line.comments_after(leaf, index):
3072             length += len(comment.value)
3073
3074         yield index, leaf, length
3075
3076
3077 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3078     """Return True if `line` is no longer than `line_length`.
3079
3080     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3081     """
3082     if not line_str:
3083         line_str = str(line).strip("\n")
3084     return (
3085         len(line_str) <= line_length
3086         and "\n" not in line_str  # multiline strings
3087         and not line.contains_standalone_comments()
3088     )
3089
3090
3091 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3092     """Does `line` have a shape safe to reformat without optional parens around it?
3093
3094     Returns True for only a subset of potentially nice looking formattings but
3095     the point is to not return false positives that end up producing lines that
3096     are too long.
3097     """
3098     bt = line.bracket_tracker
3099     if not bt.delimiters:
3100         # Without delimiters the optional parentheses are useless.
3101         return True
3102
3103     max_priority = bt.max_delimiter_priority()
3104     if bt.delimiter_count_with_priority(max_priority) > 1:
3105         # With more than one delimiter of a kind the optional parentheses read better.
3106         return False
3107
3108     if max_priority == DOT_PRIORITY:
3109         # A single stranded method call doesn't require optional parentheses.
3110         return True
3111
3112     assert len(line.leaves) >= 2, "Stranded delimiter"
3113
3114     first = line.leaves[0]
3115     second = line.leaves[1]
3116     penultimate = line.leaves[-2]
3117     last = line.leaves[-1]
3118
3119     # With a single delimiter, omit if the expression starts or ends with
3120     # a bracket.
3121     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3122         remainder = False
3123         length = 4 * line.depth
3124         for _index, leaf, leaf_length in enumerate_with_length(line):
3125             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3126                 remainder = True
3127             if remainder:
3128                 length += leaf_length
3129                 if length > line_length:
3130                     break
3131
3132                 if leaf.type in OPENING_BRACKETS:
3133                     # There are brackets we can further split on.
3134                     remainder = False
3135
3136         else:
3137             # checked the entire string and line length wasn't exceeded
3138             if len(line.leaves) == _index + 1:
3139                 return True
3140
3141         # Note: we are not returning False here because a line might have *both*
3142         # a leading opening bracket and a trailing closing bracket.  If the
3143         # opening bracket doesn't match our rule, maybe the closing will.
3144
3145     if (
3146         last.type == token.RPAR
3147         or last.type == token.RBRACE
3148         or (
3149             # don't use indexing for omitting optional parentheses;
3150             # it looks weird
3151             last.type == token.RSQB
3152             and last.parent
3153             and last.parent.type != syms.trailer
3154         )
3155     ):
3156         if penultimate.type in OPENING_BRACKETS:
3157             # Empty brackets don't help.
3158             return False
3159
3160         if is_multiline_string(first):
3161             # Additional wrapping of a multiline string in this situation is
3162             # unnecessary.
3163             return True
3164
3165         length = 4 * line.depth
3166         seen_other_brackets = False
3167         for _index, leaf, leaf_length in enumerate_with_length(line):
3168             length += leaf_length
3169             if leaf is last.opening_bracket:
3170                 if seen_other_brackets or length <= line_length:
3171                     return True
3172
3173             elif leaf.type in OPENING_BRACKETS:
3174                 # There are brackets we can further split on.
3175                 seen_other_brackets = True
3176
3177     return False
3178
3179
3180 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3181     pyi = bool(mode & FileMode.PYI)
3182     py36 = bool(mode & FileMode.PYTHON36)
3183     return (
3184         CACHE_DIR
3185         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3186     )
3187
3188
3189 def read_cache(line_length: int, mode: FileMode) -> Cache:
3190     """Read the cache if it exists and is well formed.
3191
3192     If it is not well formed, the call to write_cache later should resolve the issue.
3193     """
3194     cache_file = get_cache_file(line_length, mode)
3195     if not cache_file.exists():
3196         return {}
3197
3198     with cache_file.open("rb") as fobj:
3199         try:
3200             cache: Cache = pickle.load(fobj)
3201         except pickle.UnpicklingError:
3202             return {}
3203
3204     return cache
3205
3206
3207 def get_cache_info(path: Path) -> CacheInfo:
3208     """Return the information used to check if a file is already formatted or not."""
3209     stat = path.stat()
3210     return stat.st_mtime, stat.st_size
3211
3212
3213 def filter_cached(
3214     cache: Cache, sources: Iterable[Path]
3215 ) -> Tuple[List[Path], List[Path]]:
3216     """Split a list of paths into two.
3217
3218     The first list contains paths of files that modified on disk or are not in the
3219     cache. The other list contains paths to non-modified files.
3220     """
3221     todo, done = [], []
3222     for src in sources:
3223         src = src.resolve()
3224         if cache.get(src) != get_cache_info(src):
3225             todo.append(src)
3226         else:
3227             done.append(src)
3228     return todo, done
3229
3230
3231 def write_cache(
3232     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3233 ) -> None:
3234     """Update the cache file."""
3235     cache_file = get_cache_file(line_length, mode)
3236     try:
3237         if not CACHE_DIR.exists():
3238             CACHE_DIR.mkdir(parents=True)
3239         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3240         with cache_file.open("wb") as fobj:
3241             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3242     except OSError:
3243         pass
3244
3245
3246 if __name__ == "__main__":
3247     main()