]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[trivial] Simplify `mode` and `write_back` calculation in main()
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b1"
48 DEFAULT_LINE_LENGTH = 88
49 DEFAULT_EXCLUDES = (
50     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
51 )
52 DEFAULT_INCLUDES = r"\.pyi?$"
53 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
54
55
56 # types
57 FileContent = str
58 Encoding = str
59 Depth = int
60 NodeType = int
61 LeafID = int
62 Priority = int
63 Index = int
64 LN = Union[Leaf, Node]
65 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
66 Timestamp = float
67 FileSize = int
68 CacheInfo = Tuple[Timestamp, FileSize]
69 Cache = Dict[Path, CacheInfo]
70 out = partial(click.secho, bold=True, err=True)
71 err = partial(click.secho, fg="red", err=True)
72
73 pygram.initialize(CACHE_DIR)
74 syms = pygram.python_symbols
75
76
77 class NothingChanged(UserWarning):
78     """Raised by :func:`format_file` when reformatted code is the same as source."""
79
80
81 class CannotSplit(Exception):
82     """A readable split that fits the allotted line length is impossible.
83
84     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
85     :func:`delimiter_split`.
86     """
87
88
89 class FormatError(Exception):
90     """Base exception for `# fmt: on` and `# fmt: off` handling.
91
92     It holds the number of bytes of the prefix consumed before the format
93     control comment appeared.
94     """
95
96     def __init__(self, consumed: int) -> None:
97         super().__init__(consumed)
98         self.consumed = consumed
99
100     def trim_prefix(self, leaf: Leaf) -> None:
101         leaf.prefix = leaf.prefix[self.consumed :]
102
103     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
104         """Returns a new Leaf from the consumed part of the prefix."""
105         unformatted_prefix = leaf.prefix[: self.consumed]
106         return Leaf(token.NEWLINE, unformatted_prefix)
107
108
109 class FormatOn(FormatError):
110     """Found a comment like `# fmt: on` in the file."""
111
112
113 class FormatOff(FormatError):
114     """Found a comment like `# fmt: off` in the file."""
115
116
117 class WriteBack(Enum):
118     NO = 0
119     YES = 1
120     DIFF = 2
121
122     @classmethod
123     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
124         if check and not diff:
125             return cls.NO
126
127         return cls.DIFF if diff else cls.YES
128
129
130 class Changed(Enum):
131     NO = 0
132     CACHED = 1
133     YES = 2
134
135
136 class FileMode(Flag):
137     AUTO_DETECT = 0
138     PYTHON36 = 1
139     PYI = 2
140     NO_STRING_NORMALIZATION = 4
141
142     @classmethod
143     def from_configuration(
144         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
145     ) -> "FileMode":
146         mode = cls.AUTO_DETECT
147         if py36:
148             mode |= cls.PYTHON36
149         if pyi:
150             mode |= cls.PYI
151         if skip_string_normalization:
152             mode |= cls.NO_STRING_NORMALIZATION
153         return mode
154
155
156 @click.command()
157 @click.option(
158     "-l",
159     "--line-length",
160     type=int,
161     default=DEFAULT_LINE_LENGTH,
162     help="How many character per line to allow.",
163     show_default=True,
164 )
165 @click.option(
166     "--py36",
167     is_flag=True,
168     help=(
169         "Allow using Python 3.6-only syntax on all input files.  This will put "
170         "trailing commas in function signatures and calls also after *args and "
171         "**kwargs.  [default: per-file auto-detection]"
172     ),
173 )
174 @click.option(
175     "--pyi",
176     is_flag=True,
177     help=(
178         "Format all input files like typing stubs regardless of file extension "
179         "(useful when piping source on standard input)."
180     ),
181 )
182 @click.option(
183     "-S",
184     "--skip-string-normalization",
185     is_flag=True,
186     help="Don't normalize string quotes or prefixes.",
187 )
188 @click.option(
189     "--check",
190     is_flag=True,
191     help=(
192         "Don't write the files back, just return the status.  Return code 0 "
193         "means nothing would change.  Return code 1 means some files would be "
194         "reformatted.  Return code 123 means there was an internal error."
195     ),
196 )
197 @click.option(
198     "--diff",
199     is_flag=True,
200     help="Don't write the files back, just output a diff for each file on stdout.",
201 )
202 @click.option(
203     "--fast/--safe",
204     is_flag=True,
205     help="If --fast given, skip temporary sanity checks. [default: --safe]",
206 )
207 @click.option(
208     "--include",
209     type=str,
210     default=DEFAULT_INCLUDES,
211     help=(
212         "A regular expression that matches files and directories that should be "
213         "included on recursive searches.  An empty value means all files are "
214         "included regardless of the name.  Use forward slashes for directories on "
215         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
216         "later."
217     ),
218     show_default=True,
219 )
220 @click.option(
221     "--exclude",
222     type=str,
223     default=DEFAULT_EXCLUDES,
224     help=(
225         "A regular expression that matches files and directories that should be "
226         "excluded on recursive searches.  An empty value means no paths are excluded. "
227         "Use forward slashes for directories on all platforms (Windows, too).  "
228         "Exclusions are calculated first, inclusions later."
229     ),
230     show_default=True,
231 )
232 @click.option(
233     "-q",
234     "--quiet",
235     is_flag=True,
236     help=(
237         "Don't emit non-error messages to stderr. Errors are still emitted, "
238         "silence those with 2>/dev/null."
239     ),
240 )
241 @click.version_option(version=__version__)
242 @click.argument(
243     "src",
244     nargs=-1,
245     type=click.Path(
246         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
247     ),
248 )
249 @click.pass_context
250 def main(
251     ctx: click.Context,
252     line_length: int,
253     check: bool,
254     diff: bool,
255     fast: bool,
256     pyi: bool,
257     py36: bool,
258     skip_string_normalization: bool,
259     quiet: bool,
260     include: str,
261     exclude: str,
262     src: List[str],
263 ) -> None:
264     """The uncompromising code formatter."""
265     write_back = WriteBack.from_configuration(check=check, diff=diff)
266     mode = FileMode.from_configuration(
267         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
268     )
269     report = Report(check=check, quiet=quiet)
270     sources: List[Path] = []
271     try:
272         include_regex = re.compile(include)
273     except re.error:
274         err(f"Invalid regular expression for include given: {include!r}")
275         ctx.exit(2)
276     try:
277         exclude_regex = re.compile(exclude)
278     except re.error:
279         err(f"Invalid regular expression for exclude given: {exclude!r}")
280         ctx.exit(2)
281     root = find_project_root(src)
282     for s in src:
283         p = Path(s)
284         if p.is_dir():
285             sources.extend(
286                 gen_python_files_in_dir(p, root, include_regex, exclude_regex)
287             )
288         elif p.is_file() or s == "-":
289             # if a file was explicitly given, we don't care about its extension
290             sources.append(p)
291         else:
292             err(f"invalid path: {s}")
293     if len(sources) == 0:
294         out("No paths given. Nothing to do 😴")
295         ctx.exit(0)
296         return
297
298     elif len(sources) == 1:
299         reformat_one(
300             src=sources[0],
301             line_length=line_length,
302             fast=fast,
303             write_back=write_back,
304             mode=mode,
305             report=report,
306         )
307     else:
308         loop = asyncio.get_event_loop()
309         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
310         try:
311             loop.run_until_complete(
312                 schedule_formatting(
313                     sources=sources,
314                     line_length=line_length,
315                     fast=fast,
316                     write_back=write_back,
317                     mode=mode,
318                     report=report,
319                     loop=loop,
320                     executor=executor,
321                 )
322             )
323         finally:
324             shutdown(loop)
325         if not quiet:
326             out("All done! ✨ 🍰 ✨")
327             click.echo(str(report))
328     ctx.exit(report.return_code)
329
330
331 def reformat_one(
332     src: Path,
333     line_length: int,
334     fast: bool,
335     write_back: WriteBack,
336     mode: FileMode,
337     report: "Report",
338 ) -> None:
339     """Reformat a single file under `src` without spawning child processes.
340
341     If `quiet` is True, non-error messages are not output. `line_length`,
342     `write_back`, `fast` and `pyi` options are passed to
343     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
344     """
345     try:
346         changed = Changed.NO
347         if not src.is_file() and str(src) == "-":
348             if format_stdin_to_stdout(
349                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
350             ):
351                 changed = Changed.YES
352         else:
353             cache: Cache = {}
354             if write_back != WriteBack.DIFF:
355                 cache = read_cache(line_length, mode)
356                 res_src = src.resolve()
357                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
358                     changed = Changed.CACHED
359             if changed is not Changed.CACHED and format_file_in_place(
360                 src,
361                 line_length=line_length,
362                 fast=fast,
363                 write_back=write_back,
364                 mode=mode,
365             ):
366                 changed = Changed.YES
367             if write_back == WriteBack.YES and changed is not Changed.NO:
368                 write_cache(cache, [src], line_length, mode)
369         report.done(src, changed)
370     except Exception as exc:
371         report.failed(src, str(exc))
372
373
374 async def schedule_formatting(
375     sources: List[Path],
376     line_length: int,
377     fast: bool,
378     write_back: WriteBack,
379     mode: FileMode,
380     report: "Report",
381     loop: BaseEventLoop,
382     executor: Executor,
383 ) -> None:
384     """Run formatting of `sources` in parallel using the provided `executor`.
385
386     (Use ProcessPoolExecutors for actual parallelism.)
387
388     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
389     :func:`format_file_in_place`.
390     """
391     cache: Cache = {}
392     if write_back != WriteBack.DIFF:
393         cache = read_cache(line_length, mode)
394         sources, cached = filter_cached(cache, sources)
395         for src in cached:
396             report.done(src, Changed.CACHED)
397     cancelled = []
398     formatted = []
399     if sources:
400         lock = None
401         if write_back == WriteBack.DIFF:
402             # For diff output, we need locks to ensure we don't interleave output
403             # from different processes.
404             manager = Manager()
405             lock = manager.Lock()
406         tasks = {
407             loop.run_in_executor(
408                 executor,
409                 format_file_in_place,
410                 src,
411                 line_length,
412                 fast,
413                 write_back,
414                 mode,
415                 lock,
416             ): src
417             for src in sorted(sources)
418         }
419         pending: Iterable[asyncio.Task] = tasks.keys()
420         try:
421             loop.add_signal_handler(signal.SIGINT, cancel, pending)
422             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
423         except NotImplementedError:
424             # There are no good alternatives for these on Windows
425             pass
426         while pending:
427             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
428             for task in done:
429                 src = tasks.pop(task)
430                 if task.cancelled():
431                     cancelled.append(task)
432                 elif task.exception():
433                     report.failed(src, str(task.exception()))
434                 else:
435                     formatted.append(src)
436                     report.done(src, Changed.YES if task.result() else Changed.NO)
437     if cancelled:
438         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
439     if write_back == WriteBack.YES and formatted:
440         write_cache(cache, formatted, line_length, mode)
441
442
443 def format_file_in_place(
444     src: Path,
445     line_length: int,
446     fast: bool,
447     write_back: WriteBack = WriteBack.NO,
448     mode: FileMode = FileMode.AUTO_DETECT,
449     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
450 ) -> bool:
451     """Format file under `src` path. Return True if changed.
452
453     If `write_back` is True, write reformatted code back to stdout.
454     `line_length` and `fast` options are passed to :func:`format_file_contents`.
455     """
456     if src.suffix == ".pyi":
457         mode |= FileMode.PYI
458     with tokenize.open(src) as src_buffer:
459         src_contents = src_buffer.read()
460     try:
461         dst_contents = format_file_contents(
462             src_contents, line_length=line_length, fast=fast, mode=mode
463         )
464     except NothingChanged:
465         return False
466
467     if write_back == write_back.YES:
468         with open(src, "w", encoding=src_buffer.encoding) as f:
469             f.write(dst_contents)
470     elif write_back == write_back.DIFF:
471         src_name = f"{src}  (original)"
472         dst_name = f"{src}  (formatted)"
473         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
474         if lock:
475             lock.acquire()
476         try:
477             sys.stdout.write(diff_contents)
478         finally:
479             if lock:
480                 lock.release()
481     return True
482
483
484 def format_stdin_to_stdout(
485     line_length: int,
486     fast: bool,
487     write_back: WriteBack = WriteBack.NO,
488     mode: FileMode = FileMode.AUTO_DETECT,
489 ) -> bool:
490     """Format file on stdin. Return True if changed.
491
492     If `write_back` is True, write reformatted code back to stdout.
493     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
494     :func:`format_file_contents`.
495     """
496     src = sys.stdin.read()
497     dst = src
498     try:
499         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
500         return True
501
502     except NothingChanged:
503         return False
504
505     finally:
506         if write_back == WriteBack.YES:
507             sys.stdout.write(dst)
508         elif write_back == WriteBack.DIFF:
509             src_name = "<stdin>  (original)"
510             dst_name = "<stdin>  (formatted)"
511             sys.stdout.write(diff(src, dst, src_name, dst_name))
512
513
514 def format_file_contents(
515     src_contents: str,
516     *,
517     line_length: int,
518     fast: bool,
519     mode: FileMode = FileMode.AUTO_DETECT,
520 ) -> FileContent:
521     """Reformat contents a file and return new contents.
522
523     If `fast` is False, additionally confirm that the reformatted code is
524     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
525     `line_length` is passed to :func:`format_str`.
526     """
527     if src_contents.strip() == "":
528         raise NothingChanged
529
530     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
531     if src_contents == dst_contents:
532         raise NothingChanged
533
534     if not fast:
535         assert_equivalent(src_contents, dst_contents)
536         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
537     return dst_contents
538
539
540 def format_str(
541     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
542 ) -> FileContent:
543     """Reformat a string and return new contents.
544
545     `line_length` determines how many characters per line are allowed.
546     """
547     src_node = lib2to3_parse(src_contents)
548     dst_contents = ""
549     future_imports = get_future_imports(src_node)
550     is_pyi = bool(mode & FileMode.PYI)
551     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
552     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
553     lines = LineGenerator(
554         remove_u_prefix=py36 or "unicode_literals" in future_imports,
555         is_pyi=is_pyi,
556         normalize_strings=normalize_strings,
557     )
558     elt = EmptyLineTracker(is_pyi=is_pyi)
559     empty_line = Line()
560     after = 0
561     for current_line in lines.visit(src_node):
562         for _ in range(after):
563             dst_contents += str(empty_line)
564         before, after = elt.maybe_empty_lines(current_line)
565         for _ in range(before):
566             dst_contents += str(empty_line)
567         for line in split_line(current_line, line_length=line_length, py36=py36):
568             dst_contents += str(line)
569     return dst_contents
570
571
572 GRAMMARS = [
573     pygram.python_grammar_no_print_statement_no_exec_statement,
574     pygram.python_grammar_no_print_statement,
575     pygram.python_grammar,
576 ]
577
578
579 def lib2to3_parse(src_txt: str) -> Node:
580     """Given a string with source, return the lib2to3 Node."""
581     grammar = pygram.python_grammar_no_print_statement
582     if src_txt[-1] != "\n":
583         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
584         src_txt += nl
585     for grammar in GRAMMARS:
586         drv = driver.Driver(grammar, pytree.convert)
587         try:
588             result = drv.parse_string(src_txt, True)
589             break
590
591         except ParseError as pe:
592             lineno, column = pe.context[1]
593             lines = src_txt.splitlines()
594             try:
595                 faulty_line = lines[lineno - 1]
596             except IndexError:
597                 faulty_line = "<line number missing in source>"
598             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
599     else:
600         raise exc from None
601
602     if isinstance(result, Leaf):
603         result = Node(syms.file_input, [result])
604     return result
605
606
607 def lib2to3_unparse(node: Node) -> str:
608     """Given a lib2to3 node, return its string representation."""
609     code = str(node)
610     return code
611
612
613 T = TypeVar("T")
614
615
616 class Visitor(Generic[T]):
617     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
618
619     def visit(self, node: LN) -> Iterator[T]:
620         """Main method to visit `node` and its children.
621
622         It tries to find a `visit_*()` method for the given `node.type`, like
623         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
624         If no dedicated `visit_*()` method is found, chooses `visit_default()`
625         instead.
626
627         Then yields objects of type `T` from the selected visitor.
628         """
629         if node.type < 256:
630             name = token.tok_name[node.type]
631         else:
632             name = type_repr(node.type)
633         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
634
635     def visit_default(self, node: LN) -> Iterator[T]:
636         """Default `visit_*()` implementation. Recurses to children of `node`."""
637         if isinstance(node, Node):
638             for child in node.children:
639                 yield from self.visit(child)
640
641
642 @dataclass
643 class DebugVisitor(Visitor[T]):
644     tree_depth: int = 0
645
646     def visit_default(self, node: LN) -> Iterator[T]:
647         indent = " " * (2 * self.tree_depth)
648         if isinstance(node, Node):
649             _type = type_repr(node.type)
650             out(f"{indent}{_type}", fg="yellow")
651             self.tree_depth += 1
652             for child in node.children:
653                 yield from self.visit(child)
654
655             self.tree_depth -= 1
656             out(f"{indent}/{_type}", fg="yellow", bold=False)
657         else:
658             _type = token.tok_name.get(node.type, str(node.type))
659             out(f"{indent}{_type}", fg="blue", nl=False)
660             if node.prefix:
661                 # We don't have to handle prefixes for `Node` objects since
662                 # that delegates to the first child anyway.
663                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
664             out(f" {node.value!r}", fg="blue", bold=False)
665
666     @classmethod
667     def show(cls, code: str) -> None:
668         """Pretty-print the lib2to3 AST of a given string of `code`.
669
670         Convenience method for debugging.
671         """
672         v: DebugVisitor[None] = DebugVisitor()
673         list(v.visit(lib2to3_parse(code)))
674
675
676 KEYWORDS = set(keyword.kwlist)
677 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
678 FLOW_CONTROL = {"return", "raise", "break", "continue"}
679 STATEMENT = {
680     syms.if_stmt,
681     syms.while_stmt,
682     syms.for_stmt,
683     syms.try_stmt,
684     syms.except_clause,
685     syms.with_stmt,
686     syms.funcdef,
687     syms.classdef,
688 }
689 STANDALONE_COMMENT = 153
690 LOGIC_OPERATORS = {"and", "or"}
691 COMPARATORS = {
692     token.LESS,
693     token.GREATER,
694     token.EQEQUAL,
695     token.NOTEQUAL,
696     token.LESSEQUAL,
697     token.GREATEREQUAL,
698 }
699 MATH_OPERATORS = {
700     token.VBAR,
701     token.CIRCUMFLEX,
702     token.AMPER,
703     token.LEFTSHIFT,
704     token.RIGHTSHIFT,
705     token.PLUS,
706     token.MINUS,
707     token.STAR,
708     token.SLASH,
709     token.DOUBLESLASH,
710     token.PERCENT,
711     token.AT,
712     token.TILDE,
713     token.DOUBLESTAR,
714 }
715 STARS = {token.STAR, token.DOUBLESTAR}
716 VARARGS_PARENTS = {
717     syms.arglist,
718     syms.argument,  # double star in arglist
719     syms.trailer,  # single argument to call
720     syms.typedargslist,
721     syms.varargslist,  # lambdas
722 }
723 UNPACKING_PARENTS = {
724     syms.atom,  # single element of a list or set literal
725     syms.dictsetmaker,
726     syms.listmaker,
727     syms.testlist_gexp,
728 }
729 TEST_DESCENDANTS = {
730     syms.test,
731     syms.lambdef,
732     syms.or_test,
733     syms.and_test,
734     syms.not_test,
735     syms.comparison,
736     syms.star_expr,
737     syms.expr,
738     syms.xor_expr,
739     syms.and_expr,
740     syms.shift_expr,
741     syms.arith_expr,
742     syms.trailer,
743     syms.term,
744     syms.power,
745 }
746 ASSIGNMENTS = {
747     "=",
748     "+=",
749     "-=",
750     "*=",
751     "@=",
752     "/=",
753     "%=",
754     "&=",
755     "|=",
756     "^=",
757     "<<=",
758     ">>=",
759     "**=",
760     "//=",
761 }
762 COMPREHENSION_PRIORITY = 20
763 COMMA_PRIORITY = 18
764 TERNARY_PRIORITY = 16
765 LOGIC_PRIORITY = 14
766 STRING_PRIORITY = 12
767 COMPARATOR_PRIORITY = 10
768 MATH_PRIORITIES = {
769     token.VBAR: 9,
770     token.CIRCUMFLEX: 8,
771     token.AMPER: 7,
772     token.LEFTSHIFT: 6,
773     token.RIGHTSHIFT: 6,
774     token.PLUS: 5,
775     token.MINUS: 5,
776     token.STAR: 4,
777     token.SLASH: 4,
778     token.DOUBLESLASH: 4,
779     token.PERCENT: 4,
780     token.AT: 4,
781     token.TILDE: 3,
782     token.DOUBLESTAR: 2,
783 }
784 DOT_PRIORITY = 1
785
786
787 @dataclass
788 class BracketTracker:
789     """Keeps track of brackets on a line."""
790
791     depth: int = 0
792     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
793     delimiters: Dict[LeafID, Priority] = Factory(dict)
794     previous: Optional[Leaf] = None
795     _for_loop_variable: int = 0
796     _lambda_arguments: int = 0
797
798     def mark(self, leaf: Leaf) -> None:
799         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
800
801         All leaves receive an int `bracket_depth` field that stores how deep
802         within brackets a given leaf is. 0 means there are no enclosing brackets
803         that started on this line.
804
805         If a leaf is itself a closing bracket, it receives an `opening_bracket`
806         field that it forms a pair with. This is a one-directional link to
807         avoid reference cycles.
808
809         If a leaf is a delimiter (a token on which Black can split the line if
810         needed) and it's on depth 0, its `id()` is stored in the tracker's
811         `delimiters` field.
812         """
813         if leaf.type == token.COMMENT:
814             return
815
816         self.maybe_decrement_after_for_loop_variable(leaf)
817         self.maybe_decrement_after_lambda_arguments(leaf)
818         if leaf.type in CLOSING_BRACKETS:
819             self.depth -= 1
820             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
821             leaf.opening_bracket = opening_bracket
822         leaf.bracket_depth = self.depth
823         if self.depth == 0:
824             delim = is_split_before_delimiter(leaf, self.previous)
825             if delim and self.previous is not None:
826                 self.delimiters[id(self.previous)] = delim
827             else:
828                 delim = is_split_after_delimiter(leaf, self.previous)
829                 if delim:
830                     self.delimiters[id(leaf)] = delim
831         if leaf.type in OPENING_BRACKETS:
832             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
833             self.depth += 1
834         self.previous = leaf
835         self.maybe_increment_lambda_arguments(leaf)
836         self.maybe_increment_for_loop_variable(leaf)
837
838     def any_open_brackets(self) -> bool:
839         """Return True if there is an yet unmatched open bracket on the line."""
840         return bool(self.bracket_match)
841
842     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
843         """Return the highest priority of a delimiter found on the line.
844
845         Values are consistent with what `is_split_*_delimiter()` return.
846         Raises ValueError on no delimiters.
847         """
848         return max(v for k, v in self.delimiters.items() if k not in exclude)
849
850     def delimiter_count_with_priority(self, priority: int = 0) -> int:
851         """Return the number of delimiters with the given `priority`.
852
853         If no `priority` is passed, defaults to max priority on the line.
854         """
855         if not self.delimiters:
856             return 0
857
858         priority = priority or self.max_delimiter_priority()
859         return sum(1 for p in self.delimiters.values() if p == priority)
860
861     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
862         """In a for loop, or comprehension, the variables are often unpacks.
863
864         To avoid splitting on the comma in this situation, increase the depth of
865         tokens between `for` and `in`.
866         """
867         if leaf.type == token.NAME and leaf.value == "for":
868             self.depth += 1
869             self._for_loop_variable += 1
870             return True
871
872         return False
873
874     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
875         """See `maybe_increment_for_loop_variable` above for explanation."""
876         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
877             self.depth -= 1
878             self._for_loop_variable -= 1
879             return True
880
881         return False
882
883     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
884         """In a lambda expression, there might be more than one argument.
885
886         To avoid splitting on the comma in this situation, increase the depth of
887         tokens between `lambda` and `:`.
888         """
889         if leaf.type == token.NAME and leaf.value == "lambda":
890             self.depth += 1
891             self._lambda_arguments += 1
892             return True
893
894         return False
895
896     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
897         """See `maybe_increment_lambda_arguments` above for explanation."""
898         if self._lambda_arguments and leaf.type == token.COLON:
899             self.depth -= 1
900             self._lambda_arguments -= 1
901             return True
902
903         return False
904
905     def get_open_lsqb(self) -> Optional[Leaf]:
906         """Return the most recent opening square bracket (if any)."""
907         return self.bracket_match.get((self.depth - 1, token.RSQB))
908
909
910 @dataclass
911 class Line:
912     """Holds leaves and comments. Can be printed with `str(line)`."""
913
914     depth: int = 0
915     leaves: List[Leaf] = Factory(list)
916     comments: List[Tuple[Index, Leaf]] = Factory(list)
917     bracket_tracker: BracketTracker = Factory(BracketTracker)
918     inside_brackets: bool = False
919     should_explode: bool = False
920
921     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
922         """Add a new `leaf` to the end of the line.
923
924         Unless `preformatted` is True, the `leaf` will receive a new consistent
925         whitespace prefix and metadata applied by :class:`BracketTracker`.
926         Trailing commas are maybe removed, unpacked for loop variables are
927         demoted from being delimiters.
928
929         Inline comments are put aside.
930         """
931         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
932         if not has_value:
933             return
934
935         if token.COLON == leaf.type and self.is_class_paren_empty:
936             del self.leaves[-2:]
937         if self.leaves and not preformatted:
938             # Note: at this point leaf.prefix should be empty except for
939             # imports, for which we only preserve newlines.
940             leaf.prefix += whitespace(
941                 leaf, complex_subscript=self.is_complex_subscript(leaf)
942             )
943         if self.inside_brackets or not preformatted:
944             self.bracket_tracker.mark(leaf)
945             self.maybe_remove_trailing_comma(leaf)
946         if not self.append_comment(leaf):
947             self.leaves.append(leaf)
948
949     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
950         """Like :func:`append()` but disallow invalid standalone comment structure.
951
952         Raises ValueError when any `leaf` is appended after a standalone comment
953         or when a standalone comment is not the first leaf on the line.
954         """
955         if self.bracket_tracker.depth == 0:
956             if self.is_comment:
957                 raise ValueError("cannot append to standalone comments")
958
959             if self.leaves and leaf.type == STANDALONE_COMMENT:
960                 raise ValueError(
961                     "cannot append standalone comments to a populated line"
962                 )
963
964         self.append(leaf, preformatted=preformatted)
965
966     @property
967     def is_comment(self) -> bool:
968         """Is this line a standalone comment?"""
969         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
970
971     @property
972     def is_decorator(self) -> bool:
973         """Is this line a decorator?"""
974         return bool(self) and self.leaves[0].type == token.AT
975
976     @property
977     def is_import(self) -> bool:
978         """Is this an import line?"""
979         return bool(self) and is_import(self.leaves[0])
980
981     @property
982     def is_class(self) -> bool:
983         """Is this line a class definition?"""
984         return (
985             bool(self)
986             and self.leaves[0].type == token.NAME
987             and self.leaves[0].value == "class"
988         )
989
990     @property
991     def is_stub_class(self) -> bool:
992         """Is this line a class definition with a body consisting only of "..."?"""
993         return self.is_class and self.leaves[-3:] == [
994             Leaf(token.DOT, ".") for _ in range(3)
995         ]
996
997     @property
998     def is_def(self) -> bool:
999         """Is this a function definition? (Also returns True for async defs.)"""
1000         try:
1001             first_leaf = self.leaves[0]
1002         except IndexError:
1003             return False
1004
1005         try:
1006             second_leaf: Optional[Leaf] = self.leaves[1]
1007         except IndexError:
1008             second_leaf = None
1009         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1010             first_leaf.type == token.ASYNC
1011             and second_leaf is not None
1012             and second_leaf.type == token.NAME
1013             and second_leaf.value == "def"
1014         )
1015
1016     @property
1017     def is_class_paren_empty(self) -> bool:
1018         """Is this a class with no base classes but using parentheses?
1019
1020         Those are unnecessary and should be removed.
1021         """
1022         return (
1023             bool(self)
1024             and len(self.leaves) == 4
1025             and self.is_class
1026             and self.leaves[2].type == token.LPAR
1027             and self.leaves[2].value == "("
1028             and self.leaves[3].type == token.RPAR
1029             and self.leaves[3].value == ")"
1030         )
1031
1032     @property
1033     def is_triple_quoted_string(self) -> bool:
1034         """Is the line a triple quoted string?"""
1035         return (
1036             bool(self)
1037             and self.leaves[0].type == token.STRING
1038             and self.leaves[0].value.startswith(('"""', "'''"))
1039         )
1040
1041     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1042         """If so, needs to be split before emitting."""
1043         for leaf in self.leaves:
1044             if leaf.type == STANDALONE_COMMENT:
1045                 if leaf.bracket_depth <= depth_limit:
1046                     return True
1047
1048         return False
1049
1050     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1051         """Remove trailing comma if there is one and it's safe."""
1052         if not (
1053             self.leaves
1054             and self.leaves[-1].type == token.COMMA
1055             and closing.type in CLOSING_BRACKETS
1056         ):
1057             return False
1058
1059         if closing.type == token.RBRACE:
1060             self.remove_trailing_comma()
1061             return True
1062
1063         if closing.type == token.RSQB:
1064             comma = self.leaves[-1]
1065             if comma.parent and comma.parent.type == syms.listmaker:
1066                 self.remove_trailing_comma()
1067                 return True
1068
1069         # For parens let's check if it's safe to remove the comma.
1070         # Imports are always safe.
1071         if self.is_import:
1072             self.remove_trailing_comma()
1073             return True
1074
1075         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1076         # change a tuple into a different type by removing the comma.
1077         depth = closing.bracket_depth + 1
1078         commas = 0
1079         opening = closing.opening_bracket
1080         for _opening_index, leaf in enumerate(self.leaves):
1081             if leaf is opening:
1082                 break
1083
1084         else:
1085             return False
1086
1087         for leaf in self.leaves[_opening_index + 1 :]:
1088             if leaf is closing:
1089                 break
1090
1091             bracket_depth = leaf.bracket_depth
1092             if bracket_depth == depth and leaf.type == token.COMMA:
1093                 commas += 1
1094                 if leaf.parent and leaf.parent.type == syms.arglist:
1095                     commas += 1
1096                     break
1097
1098         if commas > 1:
1099             self.remove_trailing_comma()
1100             return True
1101
1102         return False
1103
1104     def append_comment(self, comment: Leaf) -> bool:
1105         """Add an inline or standalone comment to the line."""
1106         if (
1107             comment.type == STANDALONE_COMMENT
1108             and self.bracket_tracker.any_open_brackets()
1109         ):
1110             comment.prefix = ""
1111             return False
1112
1113         if comment.type != token.COMMENT:
1114             return False
1115
1116         after = len(self.leaves) - 1
1117         if after == -1:
1118             comment.type = STANDALONE_COMMENT
1119             comment.prefix = ""
1120             return False
1121
1122         else:
1123             self.comments.append((after, comment))
1124             return True
1125
1126     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1127         """Generate comments that should appear directly after `leaf`.
1128
1129         Provide a non-negative leaf `_index` to speed up the function.
1130         """
1131         if _index == -1:
1132             for _index, _leaf in enumerate(self.leaves):
1133                 if leaf is _leaf:
1134                     break
1135
1136             else:
1137                 return
1138
1139         for index, comment_after in self.comments:
1140             if _index == index:
1141                 yield comment_after
1142
1143     def remove_trailing_comma(self) -> None:
1144         """Remove the trailing comma and moves the comments attached to it."""
1145         comma_index = len(self.leaves) - 1
1146         for i in range(len(self.comments)):
1147             comment_index, comment = self.comments[i]
1148             if comment_index == comma_index:
1149                 self.comments[i] = (comma_index - 1, comment)
1150         self.leaves.pop()
1151
1152     def is_complex_subscript(self, leaf: Leaf) -> bool:
1153         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1154         open_lsqb = (
1155             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1156         )
1157         if open_lsqb is None:
1158             return False
1159
1160         subscript_start = open_lsqb.next_sibling
1161         if (
1162             isinstance(subscript_start, Node)
1163             and subscript_start.type == syms.subscriptlist
1164         ):
1165             subscript_start = child_towards(subscript_start, leaf)
1166         return subscript_start is not None and any(
1167             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1168         )
1169
1170     def __str__(self) -> str:
1171         """Render the line."""
1172         if not self:
1173             return "\n"
1174
1175         indent = "    " * self.depth
1176         leaves = iter(self.leaves)
1177         first = next(leaves)
1178         res = f"{first.prefix}{indent}{first.value}"
1179         for leaf in leaves:
1180             res += str(leaf)
1181         for _, comment in self.comments:
1182             res += str(comment)
1183         return res + "\n"
1184
1185     def __bool__(self) -> bool:
1186         """Return True if the line has leaves or comments."""
1187         return bool(self.leaves or self.comments)
1188
1189
1190 class UnformattedLines(Line):
1191     """Just like :class:`Line` but stores lines which aren't reformatted."""
1192
1193     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1194         """Just add a new `leaf` to the end of the lines.
1195
1196         The `preformatted` argument is ignored.
1197
1198         Keeps track of indentation `depth`, which is useful when the user
1199         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1200         """
1201         try:
1202             list(generate_comments(leaf))
1203         except FormatOn as f_on:
1204             self.leaves.append(f_on.leaf_from_consumed(leaf))
1205             raise
1206
1207         self.leaves.append(leaf)
1208         if leaf.type == token.INDENT:
1209             self.depth += 1
1210         elif leaf.type == token.DEDENT:
1211             self.depth -= 1
1212
1213     def __str__(self) -> str:
1214         """Render unformatted lines from leaves which were added with `append()`.
1215
1216         `depth` is not used for indentation in this case.
1217         """
1218         if not self:
1219             return "\n"
1220
1221         res = ""
1222         for leaf in self.leaves:
1223             res += str(leaf)
1224         return res
1225
1226     def append_comment(self, comment: Leaf) -> bool:
1227         """Not implemented in this class. Raises `NotImplementedError`."""
1228         raise NotImplementedError("Unformatted lines don't store comments separately.")
1229
1230     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1231         """Does nothing and returns False."""
1232         return False
1233
1234     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1235         """Does nothing and returns False."""
1236         return False
1237
1238
1239 @dataclass
1240 class EmptyLineTracker:
1241     """Provides a stateful method that returns the number of potential extra
1242     empty lines needed before and after the currently processed line.
1243
1244     Note: this tracker works on lines that haven't been split yet.  It assumes
1245     the prefix of the first leaf consists of optional newlines.  Those newlines
1246     are consumed by `maybe_empty_lines()` and included in the computation.
1247     """
1248
1249     is_pyi: bool = False
1250     previous_line: Optional[Line] = None
1251     previous_after: int = 0
1252     previous_defs: List[int] = Factory(list)
1253
1254     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1255         """Return the number of extra empty lines before and after the `current_line`.
1256
1257         This is for separating `def`, `async def` and `class` with extra empty
1258         lines (two on module-level).
1259         """
1260         if isinstance(current_line, UnformattedLines):
1261             return 0, 0
1262
1263         before, after = self._maybe_empty_lines(current_line)
1264         before -= self.previous_after
1265         self.previous_after = after
1266         self.previous_line = current_line
1267         return before, after
1268
1269     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1270         max_allowed = 1
1271         if current_line.depth == 0:
1272             max_allowed = 1 if self.is_pyi else 2
1273         if current_line.leaves:
1274             # Consume the first leaf's extra newlines.
1275             first_leaf = current_line.leaves[0]
1276             before = first_leaf.prefix.count("\n")
1277             before = min(before, max_allowed)
1278             first_leaf.prefix = ""
1279         else:
1280             before = 0
1281         depth = current_line.depth
1282         while self.previous_defs and self.previous_defs[-1] >= depth:
1283             self.previous_defs.pop()
1284             if self.is_pyi:
1285                 before = 0 if depth else 1
1286             else:
1287                 before = 1 if depth else 2
1288         is_decorator = current_line.is_decorator
1289         if is_decorator or current_line.is_def or current_line.is_class:
1290             if not is_decorator:
1291                 self.previous_defs.append(depth)
1292             if self.previous_line is None:
1293                 # Don't insert empty lines before the first line in the file.
1294                 return 0, 0
1295
1296             if self.previous_line.is_decorator:
1297                 return 0, 0
1298
1299             if self.previous_line.depth < current_line.depth and (
1300                 self.previous_line.is_class or self.previous_line.is_def
1301             ):
1302                 return 0, 0
1303
1304             if (
1305                 self.previous_line.is_comment
1306                 and self.previous_line.depth == current_line.depth
1307                 and before == 0
1308             ):
1309                 return 0, 0
1310
1311             if self.is_pyi:
1312                 if self.previous_line.depth > current_line.depth:
1313                     newlines = 1
1314                 elif current_line.is_class or self.previous_line.is_class:
1315                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1316                         newlines = 0
1317                     else:
1318                         newlines = 1
1319                 else:
1320                     newlines = 0
1321             else:
1322                 newlines = 2
1323             if current_line.depth and newlines:
1324                 newlines -= 1
1325             return newlines, 0
1326
1327         if (
1328             self.previous_line
1329             and self.previous_line.is_import
1330             and not current_line.is_import
1331             and depth == self.previous_line.depth
1332         ):
1333             return (before or 1), 0
1334
1335         if (
1336             self.previous_line
1337             and self.previous_line.is_class
1338             and current_line.is_triple_quoted_string
1339         ):
1340             return before, 1
1341
1342         return before, 0
1343
1344
1345 @dataclass
1346 class LineGenerator(Visitor[Line]):
1347     """Generates reformatted Line objects.  Empty lines are not emitted.
1348
1349     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1350     in ways that will no longer stringify to valid Python code on the tree.
1351     """
1352
1353     is_pyi: bool = False
1354     normalize_strings: bool = True
1355     current_line: Line = Factory(Line)
1356     remove_u_prefix: bool = False
1357
1358     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1359         """Generate a line.
1360
1361         If the line is empty, only emit if it makes sense.
1362         If the line is too long, split it first and then generate.
1363
1364         If any lines were generated, set up a new current_line.
1365         """
1366         if not self.current_line:
1367             if self.current_line.__class__ == type:
1368                 self.current_line.depth += indent
1369             else:
1370                 self.current_line = type(depth=self.current_line.depth + indent)
1371             return  # Line is empty, don't emit. Creating a new one unnecessary.
1372
1373         complete_line = self.current_line
1374         self.current_line = type(depth=complete_line.depth + indent)
1375         yield complete_line
1376
1377     def visit(self, node: LN) -> Iterator[Line]:
1378         """Main method to visit `node` and its children.
1379
1380         Yields :class:`Line` objects.
1381         """
1382         if isinstance(self.current_line, UnformattedLines):
1383             # File contained `# fmt: off`
1384             yield from self.visit_unformatted(node)
1385
1386         else:
1387             yield from super().visit(node)
1388
1389     def visit_default(self, node: LN) -> Iterator[Line]:
1390         """Default `visit_*()` implementation. Recurses to children of `node`."""
1391         if isinstance(node, Leaf):
1392             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1393             try:
1394                 for comment in generate_comments(node):
1395                     if any_open_brackets:
1396                         # any comment within brackets is subject to splitting
1397                         self.current_line.append(comment)
1398                     elif comment.type == token.COMMENT:
1399                         # regular trailing comment
1400                         self.current_line.append(comment)
1401                         yield from self.line()
1402
1403                     else:
1404                         # regular standalone comment
1405                         yield from self.line()
1406
1407                         self.current_line.append(comment)
1408                         yield from self.line()
1409
1410             except FormatOff as f_off:
1411                 f_off.trim_prefix(node)
1412                 yield from self.line(type=UnformattedLines)
1413                 yield from self.visit(node)
1414
1415             except FormatOn as f_on:
1416                 # This only happens here if somebody says "fmt: on" multiple
1417                 # times in a row.
1418                 f_on.trim_prefix(node)
1419                 yield from self.visit_default(node)
1420
1421             else:
1422                 normalize_prefix(node, inside_brackets=any_open_brackets)
1423                 if self.normalize_strings and node.type == token.STRING:
1424                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1425                     normalize_string_quotes(node)
1426                 if node.type not in WHITESPACE:
1427                     self.current_line.append(node)
1428         yield from super().visit_default(node)
1429
1430     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1431         """Increase indentation level, maybe yield a line."""
1432         # In blib2to3 INDENT never holds comments.
1433         yield from self.line(+1)
1434         yield from self.visit_default(node)
1435
1436     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1437         """Decrease indentation level, maybe yield a line."""
1438         # The current line might still wait for trailing comments.  At DEDENT time
1439         # there won't be any (they would be prefixes on the preceding NEWLINE).
1440         # Emit the line then.
1441         yield from self.line()
1442
1443         # While DEDENT has no value, its prefix may contain standalone comments
1444         # that belong to the current indentation level.  Get 'em.
1445         yield from self.visit_default(node)
1446
1447         # Finally, emit the dedent.
1448         yield from self.line(-1)
1449
1450     def visit_stmt(
1451         self, node: Node, keywords: Set[str], parens: Set[str]
1452     ) -> Iterator[Line]:
1453         """Visit a statement.
1454
1455         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1456         `def`, `with`, `class`, `assert` and assignments.
1457
1458         The relevant Python language `keywords` for a given statement will be
1459         NAME leaves within it. This methods puts those on a separate line.
1460
1461         `parens` holds a set of string leaf values immediately after which
1462         invisible parens should be put.
1463         """
1464         normalize_invisible_parens(node, parens_after=parens)
1465         for child in node.children:
1466             if child.type == token.NAME and child.value in keywords:  # type: ignore
1467                 yield from self.line()
1468
1469             yield from self.visit(child)
1470
1471     def visit_suite(self, node: Node) -> Iterator[Line]:
1472         """Visit a suite."""
1473         if self.is_pyi and is_stub_suite(node):
1474             yield from self.visit(node.children[2])
1475         else:
1476             yield from self.visit_default(node)
1477
1478     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1479         """Visit a statement without nested statements."""
1480         is_suite_like = node.parent and node.parent.type in STATEMENT
1481         if is_suite_like:
1482             if self.is_pyi and is_stub_body(node):
1483                 yield from self.visit_default(node)
1484             else:
1485                 yield from self.line(+1)
1486                 yield from self.visit_default(node)
1487                 yield from self.line(-1)
1488
1489         else:
1490             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1491                 yield from self.line()
1492             yield from self.visit_default(node)
1493
1494     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1495         """Visit `async def`, `async for`, `async with`."""
1496         yield from self.line()
1497
1498         children = iter(node.children)
1499         for child in children:
1500             yield from self.visit(child)
1501
1502             if child.type == token.ASYNC:
1503                 break
1504
1505         internal_stmt = next(children)
1506         for child in internal_stmt.children:
1507             yield from self.visit(child)
1508
1509     def visit_decorators(self, node: Node) -> Iterator[Line]:
1510         """Visit decorators."""
1511         for child in node.children:
1512             yield from self.line()
1513             yield from self.visit(child)
1514
1515     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1516         """Remove a semicolon and put the other statement on a separate line."""
1517         yield from self.line()
1518
1519     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1520         """End of file. Process outstanding comments and end with a newline."""
1521         yield from self.visit_default(leaf)
1522         yield from self.line()
1523
1524     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1525         """Used when file contained a `# fmt: off`."""
1526         if isinstance(node, Node):
1527             for child in node.children:
1528                 yield from self.visit(child)
1529
1530         else:
1531             try:
1532                 self.current_line.append(node)
1533             except FormatOn as f_on:
1534                 f_on.trim_prefix(node)
1535                 yield from self.line()
1536                 yield from self.visit(node)
1537
1538             if node.type == token.ENDMARKER:
1539                 # somebody decided not to put a final `# fmt: on`
1540                 yield from self.line()
1541
1542     def __attrs_post_init__(self) -> None:
1543         """You are in a twisty little maze of passages."""
1544         v = self.visit_stmt
1545         Ø: Set[str] = set()
1546         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1547         self.visit_if_stmt = partial(
1548             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1549         )
1550         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1551         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1552         self.visit_try_stmt = partial(
1553             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1554         )
1555         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1556         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1557         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1558         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1559         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1560         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1561         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1562         self.visit_async_funcdef = self.visit_async_stmt
1563         self.visit_decorated = self.visit_decorators
1564
1565
1566 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1567 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1568 OPENING_BRACKETS = set(BRACKET.keys())
1569 CLOSING_BRACKETS = set(BRACKET.values())
1570 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1571 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1572
1573
1574 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1575     """Return whitespace prefix if needed for the given `leaf`.
1576
1577     `complex_subscript` signals whether the given leaf is part of a subscription
1578     which has non-trivial arguments, like arithmetic expressions or function calls.
1579     """
1580     NO = ""
1581     SPACE = " "
1582     DOUBLESPACE = "  "
1583     t = leaf.type
1584     p = leaf.parent
1585     v = leaf.value
1586     if t in ALWAYS_NO_SPACE:
1587         return NO
1588
1589     if t == token.COMMENT:
1590         return DOUBLESPACE
1591
1592     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1593     if t == token.COLON and p.type not in {
1594         syms.subscript,
1595         syms.subscriptlist,
1596         syms.sliceop,
1597     }:
1598         return NO
1599
1600     prev = leaf.prev_sibling
1601     if not prev:
1602         prevp = preceding_leaf(p)
1603         if not prevp or prevp.type in OPENING_BRACKETS:
1604             return NO
1605
1606         if t == token.COLON:
1607             if prevp.type == token.COLON:
1608                 return NO
1609
1610             elif prevp.type != token.COMMA and not complex_subscript:
1611                 return NO
1612
1613             return SPACE
1614
1615         if prevp.type == token.EQUAL:
1616             if prevp.parent:
1617                 if prevp.parent.type in {
1618                     syms.arglist,
1619                     syms.argument,
1620                     syms.parameters,
1621                     syms.varargslist,
1622                 }:
1623                     return NO
1624
1625                 elif prevp.parent.type == syms.typedargslist:
1626                     # A bit hacky: if the equal sign has whitespace, it means we
1627                     # previously found it's a typed argument.  So, we're using
1628                     # that, too.
1629                     return prevp.prefix
1630
1631         elif prevp.type in STARS:
1632             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1633                 return NO
1634
1635         elif prevp.type == token.COLON:
1636             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1637                 return SPACE if complex_subscript else NO
1638
1639         elif (
1640             prevp.parent
1641             and prevp.parent.type == syms.factor
1642             and prevp.type in MATH_OPERATORS
1643         ):
1644             return NO
1645
1646         elif (
1647             prevp.type == token.RIGHTSHIFT
1648             and prevp.parent
1649             and prevp.parent.type == syms.shift_expr
1650             and prevp.prev_sibling
1651             and prevp.prev_sibling.type == token.NAME
1652             and prevp.prev_sibling.value == "print"  # type: ignore
1653         ):
1654             # Python 2 print chevron
1655             return NO
1656
1657     elif prev.type in OPENING_BRACKETS:
1658         return NO
1659
1660     if p.type in {syms.parameters, syms.arglist}:
1661         # untyped function signatures or calls
1662         if not prev or prev.type != token.COMMA:
1663             return NO
1664
1665     elif p.type == syms.varargslist:
1666         # lambdas
1667         if prev and prev.type != token.COMMA:
1668             return NO
1669
1670     elif p.type == syms.typedargslist:
1671         # typed function signatures
1672         if not prev:
1673             return NO
1674
1675         if t == token.EQUAL:
1676             if prev.type != syms.tname:
1677                 return NO
1678
1679         elif prev.type == token.EQUAL:
1680             # A bit hacky: if the equal sign has whitespace, it means we
1681             # previously found it's a typed argument.  So, we're using that, too.
1682             return prev.prefix
1683
1684         elif prev.type != token.COMMA:
1685             return NO
1686
1687     elif p.type == syms.tname:
1688         # type names
1689         if not prev:
1690             prevp = preceding_leaf(p)
1691             if not prevp or prevp.type != token.COMMA:
1692                 return NO
1693
1694     elif p.type == syms.trailer:
1695         # attributes and calls
1696         if t == token.LPAR or t == token.RPAR:
1697             return NO
1698
1699         if not prev:
1700             if t == token.DOT:
1701                 prevp = preceding_leaf(p)
1702                 if not prevp or prevp.type != token.NUMBER:
1703                     return NO
1704
1705             elif t == token.LSQB:
1706                 return NO
1707
1708         elif prev.type != token.COMMA:
1709             return NO
1710
1711     elif p.type == syms.argument:
1712         # single argument
1713         if t == token.EQUAL:
1714             return NO
1715
1716         if not prev:
1717             prevp = preceding_leaf(p)
1718             if not prevp or prevp.type == token.LPAR:
1719                 return NO
1720
1721         elif prev.type in {token.EQUAL} | STARS:
1722             return NO
1723
1724     elif p.type == syms.decorator:
1725         # decorators
1726         return NO
1727
1728     elif p.type == syms.dotted_name:
1729         if prev:
1730             return NO
1731
1732         prevp = preceding_leaf(p)
1733         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1734             return NO
1735
1736     elif p.type == syms.classdef:
1737         if t == token.LPAR:
1738             return NO
1739
1740         if prev and prev.type == token.LPAR:
1741             return NO
1742
1743     elif p.type in {syms.subscript, syms.sliceop}:
1744         # indexing
1745         if not prev:
1746             assert p.parent is not None, "subscripts are always parented"
1747             if p.parent.type == syms.subscriptlist:
1748                 return SPACE
1749
1750             return NO
1751
1752         elif not complex_subscript:
1753             return NO
1754
1755     elif p.type == syms.atom:
1756         if prev and t == token.DOT:
1757             # dots, but not the first one.
1758             return NO
1759
1760     elif p.type == syms.dictsetmaker:
1761         # dict unpacking
1762         if prev and prev.type == token.DOUBLESTAR:
1763             return NO
1764
1765     elif p.type in {syms.factor, syms.star_expr}:
1766         # unary ops
1767         if not prev:
1768             prevp = preceding_leaf(p)
1769             if not prevp or prevp.type in OPENING_BRACKETS:
1770                 return NO
1771
1772             prevp_parent = prevp.parent
1773             assert prevp_parent is not None
1774             if prevp.type == token.COLON and prevp_parent.type in {
1775                 syms.subscript,
1776                 syms.sliceop,
1777             }:
1778                 return NO
1779
1780             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1781                 return NO
1782
1783         elif t == token.NAME or t == token.NUMBER:
1784             return NO
1785
1786     elif p.type == syms.import_from:
1787         if t == token.DOT:
1788             if prev and prev.type == token.DOT:
1789                 return NO
1790
1791         elif t == token.NAME:
1792             if v == "import":
1793                 return SPACE
1794
1795             if prev and prev.type == token.DOT:
1796                 return NO
1797
1798     elif p.type == syms.sliceop:
1799         return NO
1800
1801     return SPACE
1802
1803
1804 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1805     """Return the first leaf that precedes `node`, if any."""
1806     while node:
1807         res = node.prev_sibling
1808         if res:
1809             if isinstance(res, Leaf):
1810                 return res
1811
1812             try:
1813                 return list(res.leaves())[-1]
1814
1815             except IndexError:
1816                 return None
1817
1818         node = node.parent
1819     return None
1820
1821
1822 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1823     """Return the child of `ancestor` that contains `descendant`."""
1824     node: Optional[LN] = descendant
1825     while node and node.parent != ancestor:
1826         node = node.parent
1827     return node
1828
1829
1830 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1831     """Return the priority of the `leaf` delimiter, given a line break after it.
1832
1833     The delimiter priorities returned here are from those delimiters that would
1834     cause a line break after themselves.
1835
1836     Higher numbers are higher priority.
1837     """
1838     if leaf.type == token.COMMA:
1839         return COMMA_PRIORITY
1840
1841     return 0
1842
1843
1844 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1845     """Return the priority of the `leaf` delimiter, given a line before after it.
1846
1847     The delimiter priorities returned here are from those delimiters that would
1848     cause a line break before themselves.
1849
1850     Higher numbers are higher priority.
1851     """
1852     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1853         # * and ** might also be MATH_OPERATORS but in this case they are not.
1854         # Don't treat them as a delimiter.
1855         return 0
1856
1857     if (
1858         leaf.type == token.DOT
1859         and leaf.parent
1860         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1861         and (previous is None or previous.type in CLOSING_BRACKETS)
1862     ):
1863         return DOT_PRIORITY
1864
1865     if (
1866         leaf.type in MATH_OPERATORS
1867         and leaf.parent
1868         and leaf.parent.type not in {syms.factor, syms.star_expr}
1869     ):
1870         return MATH_PRIORITIES[leaf.type]
1871
1872     if leaf.type in COMPARATORS:
1873         return COMPARATOR_PRIORITY
1874
1875     if (
1876         leaf.type == token.STRING
1877         and previous is not None
1878         and previous.type == token.STRING
1879     ):
1880         return STRING_PRIORITY
1881
1882     if leaf.type != token.NAME:
1883         return 0
1884
1885     if (
1886         leaf.value == "for"
1887         and leaf.parent
1888         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1889     ):
1890         return COMPREHENSION_PRIORITY
1891
1892     if (
1893         leaf.value == "if"
1894         and leaf.parent
1895         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1896     ):
1897         return COMPREHENSION_PRIORITY
1898
1899     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1900         return TERNARY_PRIORITY
1901
1902     if leaf.value == "is":
1903         return COMPARATOR_PRIORITY
1904
1905     if (
1906         leaf.value == "in"
1907         and leaf.parent
1908         and leaf.parent.type in {syms.comp_op, syms.comparison}
1909         and not (
1910             previous is not None
1911             and previous.type == token.NAME
1912             and previous.value == "not"
1913         )
1914     ):
1915         return COMPARATOR_PRIORITY
1916
1917     if (
1918         leaf.value == "not"
1919         and leaf.parent
1920         and leaf.parent.type == syms.comp_op
1921         and not (
1922             previous is not None
1923             and previous.type == token.NAME
1924             and previous.value == "is"
1925         )
1926     ):
1927         return COMPARATOR_PRIORITY
1928
1929     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1930         return LOGIC_PRIORITY
1931
1932     return 0
1933
1934
1935 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1936     """Clean the prefix of the `leaf` and generate comments from it, if any.
1937
1938     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1939     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1940     move because it does away with modifying the grammar to include all the
1941     possible places in which comments can be placed.
1942
1943     The sad consequence for us though is that comments don't "belong" anywhere.
1944     This is why this function generates simple parentless Leaf objects for
1945     comments.  We simply don't know what the correct parent should be.
1946
1947     No matter though, we can live without this.  We really only need to
1948     differentiate between inline and standalone comments.  The latter don't
1949     share the line with any code.
1950
1951     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1952     are emitted with a fake STANDALONE_COMMENT token identifier.
1953     """
1954     p = leaf.prefix
1955     if not p:
1956         return
1957
1958     if "#" not in p:
1959         return
1960
1961     consumed = 0
1962     nlines = 0
1963     for index, line in enumerate(p.split("\n")):
1964         consumed += len(line) + 1  # adding the length of the split '\n'
1965         line = line.lstrip()
1966         if not line:
1967             nlines += 1
1968         if not line.startswith("#"):
1969             continue
1970
1971         if index == 0 and leaf.type != token.ENDMARKER:
1972             comment_type = token.COMMENT  # simple trailing comment
1973         else:
1974             comment_type = STANDALONE_COMMENT
1975         comment = make_comment(line)
1976         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1977
1978         if comment in {"# fmt: on", "# yapf: enable"}:
1979             raise FormatOn(consumed)
1980
1981         if comment in {"# fmt: off", "# yapf: disable"}:
1982             if comment_type == STANDALONE_COMMENT:
1983                 raise FormatOff(consumed)
1984
1985             prev = preceding_leaf(leaf)
1986             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1987                 raise FormatOff(consumed)
1988
1989         nlines = 0
1990
1991
1992 def make_comment(content: str) -> str:
1993     """Return a consistently formatted comment from the given `content` string.
1994
1995     All comments (except for "##", "#!", "#:") should have a single space between
1996     the hash sign and the content.
1997
1998     If `content` didn't start with a hash sign, one is provided.
1999     """
2000     content = content.rstrip()
2001     if not content:
2002         return "#"
2003
2004     if content[0] == "#":
2005         content = content[1:]
2006     if content and content[0] not in " !:#":
2007         content = " " + content
2008     return "#" + content
2009
2010
2011 def split_line(
2012     line: Line, line_length: int, inner: bool = False, py36: bool = False
2013 ) -> Iterator[Line]:
2014     """Split a `line` into potentially many lines.
2015
2016     They should fit in the allotted `line_length` but might not be able to.
2017     `inner` signifies that there were a pair of brackets somewhere around the
2018     current `line`, possibly transitively. This means we can fallback to splitting
2019     by delimiters if the LHS/RHS don't yield any results.
2020
2021     If `py36` is True, splitting may generate syntax that is only compatible
2022     with Python 3.6 and later.
2023     """
2024     if isinstance(line, UnformattedLines) or line.is_comment:
2025         yield line
2026         return
2027
2028     line_str = str(line).strip("\n")
2029     if not line.should_explode and is_line_short_enough(
2030         line, line_length=line_length, line_str=line_str
2031     ):
2032         yield line
2033         return
2034
2035     split_funcs: List[SplitFunc]
2036     if line.is_def:
2037         split_funcs = [left_hand_split]
2038     else:
2039
2040         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2041             for omit in generate_trailers_to_omit(line, line_length):
2042                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2043                 if is_line_short_enough(lines[0], line_length=line_length):
2044                     yield from lines
2045                     return
2046
2047             # All splits failed, best effort split with no omits.
2048             # This mostly happens to multiline strings that are by definition
2049             # reported as not fitting a single line.
2050             yield from right_hand_split(line, py36)
2051
2052         if line.inside_brackets:
2053             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2054         else:
2055             split_funcs = [rhs]
2056     for split_func in split_funcs:
2057         # We are accumulating lines in `result` because we might want to abort
2058         # mission and return the original line in the end, or attempt a different
2059         # split altogether.
2060         result: List[Line] = []
2061         try:
2062             for l in split_func(line, py36):
2063                 if str(l).strip("\n") == line_str:
2064                     raise CannotSplit("Split function returned an unchanged result")
2065
2066                 result.extend(
2067                     split_line(l, line_length=line_length, inner=True, py36=py36)
2068                 )
2069         except CannotSplit as cs:
2070             continue
2071
2072         else:
2073             yield from result
2074             break
2075
2076     else:
2077         yield line
2078
2079
2080 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2081     """Split line into many lines, starting with the first matching bracket pair.
2082
2083     Note: this usually looks weird, only use this for function definitions.
2084     Prefer RHS otherwise.  This is why this function is not symmetrical with
2085     :func:`right_hand_split` which also handles optional parentheses.
2086     """
2087     head = Line(depth=line.depth)
2088     body = Line(depth=line.depth + 1, inside_brackets=True)
2089     tail = Line(depth=line.depth)
2090     tail_leaves: List[Leaf] = []
2091     body_leaves: List[Leaf] = []
2092     head_leaves: List[Leaf] = []
2093     current_leaves = head_leaves
2094     matching_bracket = None
2095     for leaf in line.leaves:
2096         if (
2097             current_leaves is body_leaves
2098             and leaf.type in CLOSING_BRACKETS
2099             and leaf.opening_bracket is matching_bracket
2100         ):
2101             current_leaves = tail_leaves if body_leaves else head_leaves
2102         current_leaves.append(leaf)
2103         if current_leaves is head_leaves:
2104             if leaf.type in OPENING_BRACKETS:
2105                 matching_bracket = leaf
2106                 current_leaves = body_leaves
2107     # Since body is a new indent level, remove spurious leading whitespace.
2108     if body_leaves:
2109         normalize_prefix(body_leaves[0], inside_brackets=True)
2110     # Build the new lines.
2111     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2112         for leaf in leaves:
2113             result.append(leaf, preformatted=True)
2114             for comment_after in line.comments_after(leaf):
2115                 result.append(comment_after, preformatted=True)
2116     bracket_split_succeeded_or_raise(head, body, tail)
2117     for result in (head, body, tail):
2118         if result:
2119             yield result
2120
2121
2122 def right_hand_split(
2123     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2124 ) -> Iterator[Line]:
2125     """Split line into many lines, starting with the last matching bracket pair.
2126
2127     If the split was by optional parentheses, attempt splitting without them, too.
2128     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2129     this split.
2130
2131     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2132     """
2133     head = Line(depth=line.depth)
2134     body = Line(depth=line.depth + 1, inside_brackets=True)
2135     tail = Line(depth=line.depth)
2136     tail_leaves: List[Leaf] = []
2137     body_leaves: List[Leaf] = []
2138     head_leaves: List[Leaf] = []
2139     current_leaves = tail_leaves
2140     opening_bracket = None
2141     closing_bracket = None
2142     for leaf in reversed(line.leaves):
2143         if current_leaves is body_leaves:
2144             if leaf is opening_bracket:
2145                 current_leaves = head_leaves if body_leaves else tail_leaves
2146         current_leaves.append(leaf)
2147         if current_leaves is tail_leaves:
2148             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2149                 opening_bracket = leaf.opening_bracket
2150                 closing_bracket = leaf
2151                 current_leaves = body_leaves
2152     tail_leaves.reverse()
2153     body_leaves.reverse()
2154     head_leaves.reverse()
2155     # Since body is a new indent level, remove spurious leading whitespace.
2156     if body_leaves:
2157         normalize_prefix(body_leaves[0], inside_brackets=True)
2158     if not head_leaves:
2159         # No `head` means the split failed. Either `tail` has all content or
2160         # the matching `opening_bracket` wasn't available on `line` anymore.
2161         raise CannotSplit("No brackets found")
2162
2163     # Build the new lines.
2164     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2165         for leaf in leaves:
2166             result.append(leaf, preformatted=True)
2167             for comment_after in line.comments_after(leaf):
2168                 result.append(comment_after, preformatted=True)
2169     bracket_split_succeeded_or_raise(head, body, tail)
2170     assert opening_bracket and closing_bracket
2171     if (
2172         # the opening bracket is an optional paren
2173         opening_bracket.type == token.LPAR
2174         and not opening_bracket.value
2175         # the closing bracket is an optional paren
2176         and closing_bracket.type == token.RPAR
2177         and not closing_bracket.value
2178         # there are no standalone comments in the body
2179         and not line.contains_standalone_comments(0)
2180         # and it's not an import (optional parens are the only thing we can split
2181         # on in this case; attempting a split without them is a waste of time)
2182         and not line.is_import
2183     ):
2184         omit = {id(closing_bracket), *omit}
2185         if can_omit_invisible_parens(body, line_length):
2186             try:
2187                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2188                 return
2189             except CannotSplit:
2190                 pass
2191
2192     ensure_visible(opening_bracket)
2193     ensure_visible(closing_bracket)
2194     body.should_explode = should_explode(body, opening_bracket)
2195     for result in (head, body, tail):
2196         if result:
2197             yield result
2198
2199
2200 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2201     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2202
2203     Do nothing otherwise.
2204
2205     A left- or right-hand split is based on a pair of brackets. Content before
2206     (and including) the opening bracket is left on one line, content inside the
2207     brackets is put on a separate line, and finally content starting with and
2208     following the closing bracket is put on a separate line.
2209
2210     Those are called `head`, `body`, and `tail`, respectively. If the split
2211     produced the same line (all content in `head`) or ended up with an empty `body`
2212     and the `tail` is just the closing bracket, then it's considered failed.
2213     """
2214     tail_len = len(str(tail).strip())
2215     if not body:
2216         if tail_len == 0:
2217             raise CannotSplit("Splitting brackets produced the same line")
2218
2219         elif tail_len < 3:
2220             raise CannotSplit(
2221                 f"Splitting brackets on an empty body to save "
2222                 f"{tail_len} characters is not worth it"
2223             )
2224
2225
2226 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2227     """Normalize prefix of the first leaf in every line returned by `split_func`.
2228
2229     This is a decorator over relevant split functions.
2230     """
2231
2232     @wraps(split_func)
2233     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2234         for l in split_func(line, py36):
2235             normalize_prefix(l.leaves[0], inside_brackets=True)
2236             yield l
2237
2238     return split_wrapper
2239
2240
2241 @dont_increase_indentation
2242 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2243     """Split according to delimiters of the highest priority.
2244
2245     If `py36` is True, the split will add trailing commas also in function
2246     signatures that contain `*` and `**`.
2247     """
2248     try:
2249         last_leaf = line.leaves[-1]
2250     except IndexError:
2251         raise CannotSplit("Line empty")
2252
2253     bt = line.bracket_tracker
2254     try:
2255         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2256     except ValueError:
2257         raise CannotSplit("No delimiters found")
2258
2259     if delimiter_priority == DOT_PRIORITY:
2260         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2261             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2262
2263     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2264     lowest_depth = sys.maxsize
2265     trailing_comma_safe = True
2266
2267     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2268         """Append `leaf` to current line or to new line if appending impossible."""
2269         nonlocal current_line
2270         try:
2271             current_line.append_safe(leaf, preformatted=True)
2272         except ValueError as ve:
2273             yield current_line
2274
2275             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2276             current_line.append(leaf)
2277
2278     for index, leaf in enumerate(line.leaves):
2279         yield from append_to_line(leaf)
2280
2281         for comment_after in line.comments_after(leaf, index):
2282             yield from append_to_line(comment_after)
2283
2284         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2285         if leaf.bracket_depth == lowest_depth and is_vararg(
2286             leaf, within=VARARGS_PARENTS
2287         ):
2288             trailing_comma_safe = trailing_comma_safe and py36
2289         leaf_priority = bt.delimiters.get(id(leaf))
2290         if leaf_priority == delimiter_priority:
2291             yield current_line
2292
2293             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2294     if current_line:
2295         if (
2296             trailing_comma_safe
2297             and delimiter_priority == COMMA_PRIORITY
2298             and current_line.leaves[-1].type != token.COMMA
2299             and current_line.leaves[-1].type != STANDALONE_COMMENT
2300         ):
2301             current_line.append(Leaf(token.COMMA, ","))
2302         yield current_line
2303
2304
2305 @dont_increase_indentation
2306 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2307     """Split standalone comments from the rest of the line."""
2308     if not line.contains_standalone_comments(0):
2309         raise CannotSplit("Line does not have any standalone comments")
2310
2311     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2312
2313     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2314         """Append `leaf` to current line or to new line if appending impossible."""
2315         nonlocal current_line
2316         try:
2317             current_line.append_safe(leaf, preformatted=True)
2318         except ValueError as ve:
2319             yield current_line
2320
2321             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2322             current_line.append(leaf)
2323
2324     for index, leaf in enumerate(line.leaves):
2325         yield from append_to_line(leaf)
2326
2327         for comment_after in line.comments_after(leaf, index):
2328             yield from append_to_line(comment_after)
2329
2330     if current_line:
2331         yield current_line
2332
2333
2334 def is_import(leaf: Leaf) -> bool:
2335     """Return True if the given leaf starts an import statement."""
2336     p = leaf.parent
2337     t = leaf.type
2338     v = leaf.value
2339     return bool(
2340         t == token.NAME
2341         and (
2342             (v == "import" and p and p.type == syms.import_name)
2343             or (v == "from" and p and p.type == syms.import_from)
2344         )
2345     )
2346
2347
2348 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2349     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2350     else.
2351
2352     Note: don't use backslashes for formatting or you'll lose your voting rights.
2353     """
2354     if not inside_brackets:
2355         spl = leaf.prefix.split("#")
2356         if "\\" not in spl[0]:
2357             nl_count = spl[-1].count("\n")
2358             if len(spl) > 1:
2359                 nl_count -= 1
2360             leaf.prefix = "\n" * nl_count
2361             return
2362
2363     leaf.prefix = ""
2364
2365
2366 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2367     """Make all string prefixes lowercase.
2368
2369     If remove_u_prefix is given, also removes any u prefix from the string.
2370
2371     Note: Mutates its argument.
2372     """
2373     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2374     assert match is not None, f"failed to match string {leaf.value!r}"
2375     orig_prefix = match.group(1)
2376     new_prefix = orig_prefix.lower()
2377     if remove_u_prefix:
2378         new_prefix = new_prefix.replace("u", "")
2379     leaf.value = f"{new_prefix}{match.group(2)}"
2380
2381
2382 def normalize_string_quotes(leaf: Leaf) -> None:
2383     """Prefer double quotes but only if it doesn't cause more escaping.
2384
2385     Adds or removes backslashes as appropriate. Doesn't parse and fix
2386     strings nested in f-strings (yet).
2387
2388     Note: Mutates its argument.
2389     """
2390     value = leaf.value.lstrip("furbFURB")
2391     if value[:3] == '"""':
2392         return
2393
2394     elif value[:3] == "'''":
2395         orig_quote = "'''"
2396         new_quote = '"""'
2397     elif value[0] == '"':
2398         orig_quote = '"'
2399         new_quote = "'"
2400     else:
2401         orig_quote = "'"
2402         new_quote = '"'
2403     first_quote_pos = leaf.value.find(orig_quote)
2404     if first_quote_pos == -1:
2405         return  # There's an internal error
2406
2407     prefix = leaf.value[:first_quote_pos]
2408     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2409     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2410     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2411     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2412     if "r" in prefix.casefold():
2413         if unescaped_new_quote.search(body):
2414             # There's at least one unescaped new_quote in this raw string
2415             # so converting is impossible
2416             return
2417
2418         # Do not introduce or remove backslashes in raw strings
2419         new_body = body
2420     else:
2421         # remove unnecessary quotes
2422         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2423         if body != new_body:
2424             # Consider the string without unnecessary quotes as the original
2425             body = new_body
2426             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2427         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2428         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2429     if new_quote == '"""' and new_body[-1] == '"':
2430         # edge case:
2431         new_body = new_body[:-1] + '\\"'
2432     orig_escape_count = body.count("\\")
2433     new_escape_count = new_body.count("\\")
2434     if new_escape_count > orig_escape_count:
2435         return  # Do not introduce more escaping
2436
2437     if new_escape_count == orig_escape_count and orig_quote == '"':
2438         return  # Prefer double quotes
2439
2440     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2441
2442
2443 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2444     """Make existing optional parentheses invisible or create new ones.
2445
2446     `parens_after` is a set of string leaf values immeditely after which parens
2447     should be put.
2448
2449     Standardizes on visible parentheses for single-element tuples, and keeps
2450     existing visible parentheses for other tuples and generator expressions.
2451     """
2452     try:
2453         list(generate_comments(node))
2454     except FormatOff:
2455         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2456
2457     check_lpar = False
2458     for index, child in enumerate(list(node.children)):
2459         if check_lpar:
2460             if child.type == syms.atom:
2461                 maybe_make_parens_invisible_in_atom(child)
2462             elif is_one_tuple(child):
2463                 # wrap child in visible parentheses
2464                 lpar = Leaf(token.LPAR, "(")
2465                 rpar = Leaf(token.RPAR, ")")
2466                 child.remove()
2467                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2468             elif node.type == syms.import_from:
2469                 # "import from" nodes store parentheses directly as part of
2470                 # the statement
2471                 if child.type == token.LPAR:
2472                     # make parentheses invisible
2473                     child.value = ""  # type: ignore
2474                     node.children[-1].value = ""  # type: ignore
2475                 elif child.type != token.STAR:
2476                     # insert invisible parentheses
2477                     node.insert_child(index, Leaf(token.LPAR, ""))
2478                     node.append_child(Leaf(token.RPAR, ""))
2479                 break
2480
2481             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2482                 # wrap child in invisible parentheses
2483                 lpar = Leaf(token.LPAR, "")
2484                 rpar = Leaf(token.RPAR, "")
2485                 index = child.remove() or 0
2486                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2487
2488         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2489
2490
2491 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2492     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2493     if (
2494         node.type != syms.atom
2495         or is_empty_tuple(node)
2496         or is_one_tuple(node)
2497         or is_yield(node)
2498         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2499     ):
2500         return False
2501
2502     first = node.children[0]
2503     last = node.children[-1]
2504     if first.type == token.LPAR and last.type == token.RPAR:
2505         # make parentheses invisible
2506         first.value = ""  # type: ignore
2507         last.value = ""  # type: ignore
2508         if len(node.children) > 1:
2509             maybe_make_parens_invisible_in_atom(node.children[1])
2510         return True
2511
2512     return False
2513
2514
2515 def is_empty_tuple(node: LN) -> bool:
2516     """Return True if `node` holds an empty tuple."""
2517     return (
2518         node.type == syms.atom
2519         and len(node.children) == 2
2520         and node.children[0].type == token.LPAR
2521         and node.children[1].type == token.RPAR
2522     )
2523
2524
2525 def is_one_tuple(node: LN) -> bool:
2526     """Return True if `node` holds a tuple with one element, with or without parens."""
2527     if node.type == syms.atom:
2528         if len(node.children) != 3:
2529             return False
2530
2531         lpar, gexp, rpar = node.children
2532         if not (
2533             lpar.type == token.LPAR
2534             and gexp.type == syms.testlist_gexp
2535             and rpar.type == token.RPAR
2536         ):
2537             return False
2538
2539         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2540
2541     return (
2542         node.type in IMPLICIT_TUPLE
2543         and len(node.children) == 2
2544         and node.children[1].type == token.COMMA
2545     )
2546
2547
2548 def is_yield(node: LN) -> bool:
2549     """Return True if `node` holds a `yield` or `yield from` expression."""
2550     if node.type == syms.yield_expr:
2551         return True
2552
2553     if node.type == token.NAME and node.value == "yield":  # type: ignore
2554         return True
2555
2556     if node.type != syms.atom:
2557         return False
2558
2559     if len(node.children) != 3:
2560         return False
2561
2562     lpar, expr, rpar = node.children
2563     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2564         return is_yield(expr)
2565
2566     return False
2567
2568
2569 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2570     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2571
2572     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2573     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2574     extended iterable unpacking (PEP 3132) and additional unpacking
2575     generalizations (PEP 448).
2576     """
2577     if leaf.type not in STARS or not leaf.parent:
2578         return False
2579
2580     p = leaf.parent
2581     if p.type == syms.star_expr:
2582         # Star expressions are also used as assignment targets in extended
2583         # iterable unpacking (PEP 3132).  See what its parent is instead.
2584         if not p.parent:
2585             return False
2586
2587         p = p.parent
2588
2589     return p.type in within
2590
2591
2592 def is_multiline_string(leaf: Leaf) -> bool:
2593     """Return True if `leaf` is a multiline string that actually spans many lines."""
2594     value = leaf.value.lstrip("furbFURB")
2595     return value[:3] in {'"""', "'''"} and "\n" in value
2596
2597
2598 def is_stub_suite(node: Node) -> bool:
2599     """Return True if `node` is a suite with a stub body."""
2600     if (
2601         len(node.children) != 4
2602         or node.children[0].type != token.NEWLINE
2603         or node.children[1].type != token.INDENT
2604         or node.children[3].type != token.DEDENT
2605     ):
2606         return False
2607
2608     return is_stub_body(node.children[2])
2609
2610
2611 def is_stub_body(node: LN) -> bool:
2612     """Return True if `node` is a simple statement containing an ellipsis."""
2613     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2614         return False
2615
2616     if len(node.children) != 2:
2617         return False
2618
2619     child = node.children[0]
2620     return (
2621         child.type == syms.atom
2622         and len(child.children) == 3
2623         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2624     )
2625
2626
2627 def max_delimiter_priority_in_atom(node: LN) -> int:
2628     """Return maximum delimiter priority inside `node`.
2629
2630     This is specific to atoms with contents contained in a pair of parentheses.
2631     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2632     """
2633     if node.type != syms.atom:
2634         return 0
2635
2636     first = node.children[0]
2637     last = node.children[-1]
2638     if not (first.type == token.LPAR and last.type == token.RPAR):
2639         return 0
2640
2641     bt = BracketTracker()
2642     for c in node.children[1:-1]:
2643         if isinstance(c, Leaf):
2644             bt.mark(c)
2645         else:
2646             for leaf in c.leaves():
2647                 bt.mark(leaf)
2648     try:
2649         return bt.max_delimiter_priority()
2650
2651     except ValueError:
2652         return 0
2653
2654
2655 def ensure_visible(leaf: Leaf) -> None:
2656     """Make sure parentheses are visible.
2657
2658     They could be invisible as part of some statements (see
2659     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2660     """
2661     if leaf.type == token.LPAR:
2662         leaf.value = "("
2663     elif leaf.type == token.RPAR:
2664         leaf.value = ")"
2665
2666
2667 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2668     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2669     if not (
2670         opening_bracket.parent
2671         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2672         and opening_bracket.value in "[{("
2673     ):
2674         return False
2675
2676     try:
2677         last_leaf = line.leaves[-1]
2678         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2679         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2680     except (IndexError, ValueError):
2681         return False
2682
2683     return max_priority == COMMA_PRIORITY
2684
2685
2686 def is_python36(node: Node) -> bool:
2687     """Return True if the current file is using Python 3.6+ features.
2688
2689     Currently looking for:
2690     - f-strings; and
2691     - trailing commas after * or ** in function signatures and calls.
2692     """
2693     for n in node.pre_order():
2694         if n.type == token.STRING:
2695             value_head = n.value[:2]  # type: ignore
2696             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2697                 return True
2698
2699         elif (
2700             n.type in {syms.typedargslist, syms.arglist}
2701             and n.children
2702             and n.children[-1].type == token.COMMA
2703         ):
2704             for ch in n.children:
2705                 if ch.type in STARS:
2706                     return True
2707
2708                 if ch.type == syms.argument:
2709                     for argch in ch.children:
2710                         if argch.type in STARS:
2711                             return True
2712
2713     return False
2714
2715
2716 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2717     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2718
2719     Brackets can be omitted if the entire trailer up to and including
2720     a preceding closing bracket fits in one line.
2721
2722     Yielded sets are cumulative (contain results of previous yields, too).  First
2723     set is empty.
2724     """
2725
2726     omit: Set[LeafID] = set()
2727     yield omit
2728
2729     length = 4 * line.depth
2730     opening_bracket = None
2731     closing_bracket = None
2732     optional_brackets: Set[LeafID] = set()
2733     inner_brackets: Set[LeafID] = set()
2734     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2735         length += leaf_length
2736         if length > line_length:
2737             break
2738
2739         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2740         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2741             break
2742
2743         optional_brackets.discard(id(leaf))
2744         if opening_bracket:
2745             if leaf is opening_bracket:
2746                 opening_bracket = None
2747             elif leaf.type in CLOSING_BRACKETS:
2748                 inner_brackets.add(id(leaf))
2749         elif leaf.type in CLOSING_BRACKETS:
2750             if not leaf.value:
2751                 optional_brackets.add(id(opening_bracket))
2752                 continue
2753
2754             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2755                 # Empty brackets would fail a split so treat them as "inner"
2756                 # brackets (e.g. only add them to the `omit` set if another
2757                 # pair of brackets was good enough.
2758                 inner_brackets.add(id(leaf))
2759                 continue
2760
2761             opening_bracket = leaf.opening_bracket
2762             if closing_bracket:
2763                 omit.add(id(closing_bracket))
2764                 omit.update(inner_brackets)
2765                 inner_brackets.clear()
2766                 yield omit
2767             closing_bracket = leaf
2768
2769
2770 def get_future_imports(node: Node) -> Set[str]:
2771     """Return a set of __future__ imports in the file."""
2772     imports = set()
2773     for child in node.children:
2774         if child.type != syms.simple_stmt:
2775             break
2776         first_child = child.children[0]
2777         if isinstance(first_child, Leaf):
2778             # Continue looking if we see a docstring; otherwise stop.
2779             if (
2780                 len(child.children) == 2
2781                 and first_child.type == token.STRING
2782                 and child.children[1].type == token.NEWLINE
2783             ):
2784                 continue
2785             else:
2786                 break
2787         elif first_child.type == syms.import_from:
2788             module_name = first_child.children[1]
2789             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2790                 break
2791             for import_from_child in first_child.children[3:]:
2792                 if isinstance(import_from_child, Leaf):
2793                     if import_from_child.type == token.NAME:
2794                         imports.add(import_from_child.value)
2795                 else:
2796                     assert import_from_child.type == syms.import_as_names
2797                     for leaf in import_from_child.children:
2798                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2799                             imports.add(leaf.value)
2800         else:
2801             break
2802     return imports
2803
2804
2805 def gen_python_files_in_dir(
2806     path: Path, root: Path, include: Pattern[str], exclude: Pattern[str]
2807 ) -> Iterator[Path]:
2808     """Generate all files under `path` whose paths are not excluded by the
2809     `exclude` regex, but are included by the `include` regex.
2810     """
2811     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2812     for child in path.iterdir():
2813         normalized_path = child.resolve().relative_to(root).as_posix()
2814         if child.is_dir():
2815             normalized_path += "/"
2816         exclude_match = exclude.search(normalized_path)
2817         if exclude_match and exclude_match.group(0):
2818             continue
2819
2820         if child.is_dir():
2821             yield from gen_python_files_in_dir(child, root, include, exclude)
2822
2823         elif child.is_file():
2824             include_match = include.search(normalized_path)
2825             if include_match:
2826                 yield child
2827
2828
2829 def find_project_root(srcs: List[str]) -> Path:
2830     """Return a directory containing .git, .hg, or pyproject.toml.
2831
2832     That directory can be one of the directories passed in `srcs` or their
2833     common parent.
2834
2835     If no directory in the tree contains a marker that would specify it's the
2836     project root, the root of the file system is returned.
2837     """
2838     if not srcs:
2839         return Path("/").resolve()
2840
2841     common_base = min(Path(src).resolve() for src in srcs)
2842     if common_base.is_dir():
2843         # Append a fake file so `parents` below returns `common_base_dir`, too.
2844         common_base /= "fake-file"
2845     for directory in common_base.parents:
2846         if (directory / ".git").is_dir():
2847             return directory
2848
2849         if (directory / ".hg").is_dir():
2850             return directory
2851
2852         if (directory / "pyproject.toml").is_file():
2853             return directory
2854
2855     return directory
2856
2857
2858 @dataclass
2859 class Report:
2860     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2861
2862     check: bool = False
2863     quiet: bool = False
2864     change_count: int = 0
2865     same_count: int = 0
2866     failure_count: int = 0
2867
2868     def done(self, src: Path, changed: Changed) -> None:
2869         """Increment the counter for successful reformatting. Write out a message."""
2870         if changed is Changed.YES:
2871             reformatted = "would reformat" if self.check else "reformatted"
2872             if not self.quiet:
2873                 out(f"{reformatted} {src}")
2874             self.change_count += 1
2875         else:
2876             if not self.quiet:
2877                 if changed is Changed.NO:
2878                     msg = f"{src} already well formatted, good job."
2879                 else:
2880                     msg = f"{src} wasn't modified on disk since last run."
2881                 out(msg, bold=False)
2882             self.same_count += 1
2883
2884     def failed(self, src: Path, message: str) -> None:
2885         """Increment the counter for failed reformatting. Write out a message."""
2886         err(f"error: cannot format {src}: {message}")
2887         self.failure_count += 1
2888
2889     @property
2890     def return_code(self) -> int:
2891         """Return the exit code that the app should use.
2892
2893         This considers the current state of changed files and failures:
2894         - if there were any failures, return 123;
2895         - if any files were changed and --check is being used, return 1;
2896         - otherwise return 0.
2897         """
2898         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2899         # 126 we have special returncodes reserved by the shell.
2900         if self.failure_count:
2901             return 123
2902
2903         elif self.change_count and self.check:
2904             return 1
2905
2906         return 0
2907
2908     def __str__(self) -> str:
2909         """Render a color report of the current state.
2910
2911         Use `click.unstyle` to remove colors.
2912         """
2913         if self.check:
2914             reformatted = "would be reformatted"
2915             unchanged = "would be left unchanged"
2916             failed = "would fail to reformat"
2917         else:
2918             reformatted = "reformatted"
2919             unchanged = "left unchanged"
2920             failed = "failed to reformat"
2921         report = []
2922         if self.change_count:
2923             s = "s" if self.change_count > 1 else ""
2924             report.append(
2925                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2926             )
2927         if self.same_count:
2928             s = "s" if self.same_count > 1 else ""
2929             report.append(f"{self.same_count} file{s} {unchanged}")
2930         if self.failure_count:
2931             s = "s" if self.failure_count > 1 else ""
2932             report.append(
2933                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2934             )
2935         return ", ".join(report) + "."
2936
2937
2938 def assert_equivalent(src: str, dst: str) -> None:
2939     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2940
2941     import ast
2942     import traceback
2943
2944     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2945         """Simple visitor generating strings to compare ASTs by content."""
2946         yield f"{'  ' * depth}{node.__class__.__name__}("
2947
2948         for field in sorted(node._fields):
2949             try:
2950                 value = getattr(node, field)
2951             except AttributeError:
2952                 continue
2953
2954             yield f"{'  ' * (depth+1)}{field}="
2955
2956             if isinstance(value, list):
2957                 for item in value:
2958                     if isinstance(item, ast.AST):
2959                         yield from _v(item, depth + 2)
2960
2961             elif isinstance(value, ast.AST):
2962                 yield from _v(value, depth + 2)
2963
2964             else:
2965                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2966
2967         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2968
2969     try:
2970         src_ast = ast.parse(src)
2971     except Exception as exc:
2972         major, minor = sys.version_info[:2]
2973         raise AssertionError(
2974             f"cannot use --safe with this file; failed to parse source file "
2975             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2976             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2977         )
2978
2979     try:
2980         dst_ast = ast.parse(dst)
2981     except Exception as exc:
2982         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2983         raise AssertionError(
2984             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2985             f"Please report a bug on https://github.com/ambv/black/issues.  "
2986             f"This invalid output might be helpful: {log}"
2987         ) from None
2988
2989     src_ast_str = "\n".join(_v(src_ast))
2990     dst_ast_str = "\n".join(_v(dst_ast))
2991     if src_ast_str != dst_ast_str:
2992         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2993         raise AssertionError(
2994             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2995             f"the source.  "
2996             f"Please report a bug on https://github.com/ambv/black/issues.  "
2997             f"This diff might be helpful: {log}"
2998         ) from None
2999
3000
3001 def assert_stable(
3002     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3003 ) -> None:
3004     """Raise AssertionError if `dst` reformats differently the second time."""
3005     newdst = format_str(dst, line_length=line_length, mode=mode)
3006     if dst != newdst:
3007         log = dump_to_file(
3008             diff(src, dst, "source", "first pass"),
3009             diff(dst, newdst, "first pass", "second pass"),
3010         )
3011         raise AssertionError(
3012             f"INTERNAL ERROR: Black produced different code on the second pass "
3013             f"of the formatter.  "
3014             f"Please report a bug on https://github.com/ambv/black/issues.  "
3015             f"This diff might be helpful: {log}"
3016         ) from None
3017
3018
3019 def dump_to_file(*output: str) -> str:
3020     """Dump `output` to a temporary file. Return path to the file."""
3021     import tempfile
3022
3023     with tempfile.NamedTemporaryFile(
3024         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3025     ) as f:
3026         for lines in output:
3027             f.write(lines)
3028             if lines and lines[-1] != "\n":
3029                 f.write("\n")
3030     return f.name
3031
3032
3033 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3034     """Return a unified diff string between strings `a` and `b`."""
3035     import difflib
3036
3037     a_lines = [line + "\n" for line in a.split("\n")]
3038     b_lines = [line + "\n" for line in b.split("\n")]
3039     return "".join(
3040         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3041     )
3042
3043
3044 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3045     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3046     err("Aborted!")
3047     for task in tasks:
3048         task.cancel()
3049
3050
3051 def shutdown(loop: BaseEventLoop) -> None:
3052     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3053     try:
3054         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3055         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3056         if not to_cancel:
3057             return
3058
3059         for task in to_cancel:
3060             task.cancel()
3061         loop.run_until_complete(
3062             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3063         )
3064     finally:
3065         # `concurrent.futures.Future` objects cannot be cancelled once they
3066         # are already running. There might be some when the `shutdown()` happened.
3067         # Silence their logger's spew about the event loop being closed.
3068         cf_logger = logging.getLogger("concurrent.futures")
3069         cf_logger.setLevel(logging.CRITICAL)
3070         loop.close()
3071
3072
3073 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3074     """Replace `regex` with `replacement` twice on `original`.
3075
3076     This is used by string normalization to perform replaces on
3077     overlapping matches.
3078     """
3079     return regex.sub(replacement, regex.sub(replacement, original))
3080
3081
3082 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3083     """Like `reversed(enumerate(sequence))` if that were possible."""
3084     index = len(sequence) - 1
3085     for element in reversed(sequence):
3086         yield (index, element)
3087         index -= 1
3088
3089
3090 def enumerate_with_length(
3091     line: Line, reversed: bool = False
3092 ) -> Iterator[Tuple[Index, Leaf, int]]:
3093     """Return an enumeration of leaves with their length.
3094
3095     Stops prematurely on multiline strings and standalone comments.
3096     """
3097     op = cast(
3098         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3099         enumerate_reversed if reversed else enumerate,
3100     )
3101     for index, leaf in op(line.leaves):
3102         length = len(leaf.prefix) + len(leaf.value)
3103         if "\n" in leaf.value:
3104             return  # Multiline strings, we can't continue.
3105
3106         comment: Optional[Leaf]
3107         for comment in line.comments_after(leaf, index):
3108             length += len(comment.value)
3109
3110         yield index, leaf, length
3111
3112
3113 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3114     """Return True if `line` is no longer than `line_length`.
3115
3116     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3117     """
3118     if not line_str:
3119         line_str = str(line).strip("\n")
3120     return (
3121         len(line_str) <= line_length
3122         and "\n" not in line_str  # multiline strings
3123         and not line.contains_standalone_comments()
3124     )
3125
3126
3127 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3128     """Does `line` have a shape safe to reformat without optional parens around it?
3129
3130     Returns True for only a subset of potentially nice looking formattings but
3131     the point is to not return false positives that end up producing lines that
3132     are too long.
3133     """
3134     bt = line.bracket_tracker
3135     if not bt.delimiters:
3136         # Without delimiters the optional parentheses are useless.
3137         return True
3138
3139     max_priority = bt.max_delimiter_priority()
3140     if bt.delimiter_count_with_priority(max_priority) > 1:
3141         # With more than one delimiter of a kind the optional parentheses read better.
3142         return False
3143
3144     if max_priority == DOT_PRIORITY:
3145         # A single stranded method call doesn't require optional parentheses.
3146         return True
3147
3148     assert len(line.leaves) >= 2, "Stranded delimiter"
3149
3150     first = line.leaves[0]
3151     second = line.leaves[1]
3152     penultimate = line.leaves[-2]
3153     last = line.leaves[-1]
3154
3155     # With a single delimiter, omit if the expression starts or ends with
3156     # a bracket.
3157     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3158         remainder = False
3159         length = 4 * line.depth
3160         for _index, leaf, leaf_length in enumerate_with_length(line):
3161             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3162                 remainder = True
3163             if remainder:
3164                 length += leaf_length
3165                 if length > line_length:
3166                     break
3167
3168                 if leaf.type in OPENING_BRACKETS:
3169                     # There are brackets we can further split on.
3170                     remainder = False
3171
3172         else:
3173             # checked the entire string and line length wasn't exceeded
3174             if len(line.leaves) == _index + 1:
3175                 return True
3176
3177         # Note: we are not returning False here because a line might have *both*
3178         # a leading opening bracket and a trailing closing bracket.  If the
3179         # opening bracket doesn't match our rule, maybe the closing will.
3180
3181     if (
3182         last.type == token.RPAR
3183         or last.type == token.RBRACE
3184         or (
3185             # don't use indexing for omitting optional parentheses;
3186             # it looks weird
3187             last.type == token.RSQB
3188             and last.parent
3189             and last.parent.type != syms.trailer
3190         )
3191     ):
3192         if penultimate.type in OPENING_BRACKETS:
3193             # Empty brackets don't help.
3194             return False
3195
3196         if is_multiline_string(first):
3197             # Additional wrapping of a multiline string in this situation is
3198             # unnecessary.
3199             return True
3200
3201         length = 4 * line.depth
3202         seen_other_brackets = False
3203         for _index, leaf, leaf_length in enumerate_with_length(line):
3204             length += leaf_length
3205             if leaf is last.opening_bracket:
3206                 if seen_other_brackets or length <= line_length:
3207                     return True
3208
3209             elif leaf.type in OPENING_BRACKETS:
3210                 # There are brackets we can further split on.
3211                 seen_other_brackets = True
3212
3213     return False
3214
3215
3216 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3217     pyi = bool(mode & FileMode.PYI)
3218     py36 = bool(mode & FileMode.PYTHON36)
3219     return (
3220         CACHE_DIR
3221         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3222     )
3223
3224
3225 def read_cache(line_length: int, mode: FileMode) -> Cache:
3226     """Read the cache if it exists and is well formed.
3227
3228     If it is not well formed, the call to write_cache later should resolve the issue.
3229     """
3230     cache_file = get_cache_file(line_length, mode)
3231     if not cache_file.exists():
3232         return {}
3233
3234     with cache_file.open("rb") as fobj:
3235         try:
3236             cache: Cache = pickle.load(fobj)
3237         except pickle.UnpicklingError:
3238             return {}
3239
3240     return cache
3241
3242
3243 def get_cache_info(path: Path) -> CacheInfo:
3244     """Return the information used to check if a file is already formatted or not."""
3245     stat = path.stat()
3246     return stat.st_mtime, stat.st_size
3247
3248
3249 def filter_cached(
3250     cache: Cache, sources: Iterable[Path]
3251 ) -> Tuple[List[Path], List[Path]]:
3252     """Split a list of paths into two.
3253
3254     The first list contains paths of files that modified on disk or are not in the
3255     cache. The other list contains paths to non-modified files.
3256     """
3257     todo, done = [], []
3258     for src in sources:
3259         src = src.resolve()
3260         if cache.get(src) != get_cache_info(src):
3261             todo.append(src)
3262         else:
3263             done.append(src)
3264     return todo, done
3265
3266
3267 def write_cache(
3268     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3269 ) -> None:
3270     """Update the cache file."""
3271     cache_file = get_cache_file(line_length, mode)
3272     try:
3273         if not CACHE_DIR.exists():
3274             CACHE_DIR.mkdir(parents=True)
3275         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3276         with cache_file.open("wb") as fobj:
3277             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3278     except OSError:
3279         pass
3280
3281
3282 if __name__ == "__main__":
3283     main()