]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

36a180da702a3a003276dc3a8e1f25b96b4abd01
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generator,
24     Generic,
25     Iterable,
26     Iterator,
27     List,
28     Optional,
29     Pattern,
30     Sequence,
31     Set,
32     Tuple,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41 import toml
42
43 # lib2to3 fork
44 from blib2to3.pytree import Node, Leaf, type_repr
45 from blib2to3 import pygram, pytree
46 from blib2to3.pgen2 import driver, token
47 from blib2to3.pgen2.parse import ParseError
48
49
50 __version__ = "18.6b4"
51 DEFAULT_LINE_LENGTH = 88
52 DEFAULT_EXCLUDES = (
53     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
54 )
55 DEFAULT_INCLUDES = r"\.pyi?$"
56 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
57
58
59 # types
60 FileContent = str
61 Encoding = str
62 NewLine = str
63 Depth = int
64 NodeType = int
65 LeafID = int
66 Priority = int
67 Index = int
68 LN = Union[Leaf, Node]
69 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
70 Timestamp = float
71 FileSize = int
72 CacheInfo = Tuple[Timestamp, FileSize]
73 Cache = Dict[Path, CacheInfo]
74 out = partial(click.secho, bold=True, err=True)
75 err = partial(click.secho, fg="red", err=True)
76
77 pygram.initialize(CACHE_DIR)
78 syms = pygram.python_symbols
79
80
81 class NothingChanged(UserWarning):
82     """Raised by :func:`format_file` when reformatted code is the same as source."""
83
84
85 class CannotSplit(Exception):
86     """A readable split that fits the allotted line length is impossible.
87
88     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
89     :func:`delimiter_split`.
90     """
91
92
93 class WriteBack(Enum):
94     NO = 0
95     YES = 1
96     DIFF = 2
97
98     @classmethod
99     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
100         if check and not diff:
101             return cls.NO
102
103         return cls.DIFF if diff else cls.YES
104
105
106 class Changed(Enum):
107     NO = 0
108     CACHED = 1
109     YES = 2
110
111
112 class FileMode(Flag):
113     AUTO_DETECT = 0
114     PYTHON36 = 1
115     PYI = 2
116     NO_STRING_NORMALIZATION = 4
117
118     @classmethod
119     def from_configuration(
120         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
121     ) -> "FileMode":
122         mode = cls.AUTO_DETECT
123         if py36:
124             mode |= cls.PYTHON36
125         if pyi:
126             mode |= cls.PYI
127         if skip_string_normalization:
128             mode |= cls.NO_STRING_NORMALIZATION
129         return mode
130
131
132 def read_pyproject_toml(
133     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
134 ) -> Optional[str]:
135     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
136
137     Returns the path to a successfully found and read configuration file, None
138     otherwise.
139     """
140     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
141     if not value:
142         root = find_project_root(ctx.params.get("src", ()))
143         path = root / "pyproject.toml"
144         if path.is_file():
145             value = str(path)
146         else:
147             return None
148
149     try:
150         pyproject_toml = toml.load(value)
151         config = pyproject_toml.get("tool", {}).get("black", {})
152     except (toml.TomlDecodeError, OSError) as e:
153         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
154
155     if not config:
156         return None
157
158     if ctx.default_map is None:
159         ctx.default_map = {}
160     ctx.default_map.update(  # type: ignore  # bad types in .pyi
161         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
162     )
163     return value
164
165
166 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
167 @click.option(
168     "-l",
169     "--line-length",
170     type=int,
171     default=DEFAULT_LINE_LENGTH,
172     help="How many character per line to allow.",
173     show_default=True,
174 )
175 @click.option(
176     "--py36",
177     is_flag=True,
178     help=(
179         "Allow using Python 3.6-only syntax on all input files.  This will put "
180         "trailing commas in function signatures and calls also after *args and "
181         "**kwargs.  [default: per-file auto-detection]"
182     ),
183 )
184 @click.option(
185     "--pyi",
186     is_flag=True,
187     help=(
188         "Format all input files like typing stubs regardless of file extension "
189         "(useful when piping source on standard input)."
190     ),
191 )
192 @click.option(
193     "-S",
194     "--skip-string-normalization",
195     is_flag=True,
196     help="Don't normalize string quotes or prefixes.",
197 )
198 @click.option(
199     "--check",
200     is_flag=True,
201     help=(
202         "Don't write the files back, just return the status.  Return code 0 "
203         "means nothing would change.  Return code 1 means some files would be "
204         "reformatted.  Return code 123 means there was an internal error."
205     ),
206 )
207 @click.option(
208     "--diff",
209     is_flag=True,
210     help="Don't write the files back, just output a diff for each file on stdout.",
211 )
212 @click.option(
213     "--fast/--safe",
214     is_flag=True,
215     help="If --fast given, skip temporary sanity checks. [default: --safe]",
216 )
217 @click.option(
218     "--include",
219     type=str,
220     default=DEFAULT_INCLUDES,
221     help=(
222         "A regular expression that matches files and directories that should be "
223         "included on recursive searches.  An empty value means all files are "
224         "included regardless of the name.  Use forward slashes for directories on "
225         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
226         "later."
227     ),
228     show_default=True,
229 )
230 @click.option(
231     "--exclude",
232     type=str,
233     default=DEFAULT_EXCLUDES,
234     help=(
235         "A regular expression that matches files and directories that should be "
236         "excluded on recursive searches.  An empty value means no paths are excluded. "
237         "Use forward slashes for directories on all platforms (Windows, too).  "
238         "Exclusions are calculated first, inclusions later."
239     ),
240     show_default=True,
241 )
242 @click.option(
243     "-q",
244     "--quiet",
245     is_flag=True,
246     help=(
247         "Don't emit non-error messages to stderr. Errors are still emitted, "
248         "silence those with 2>/dev/null."
249     ),
250 )
251 @click.option(
252     "-v",
253     "--verbose",
254     is_flag=True,
255     help=(
256         "Also emit messages to stderr about files that were not changed or were "
257         "ignored due to --exclude=."
258     ),
259 )
260 @click.version_option(version=__version__)
261 @click.argument(
262     "src",
263     nargs=-1,
264     type=click.Path(
265         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
266     ),
267     is_eager=True,
268 )
269 @click.option(
270     "--config",
271     type=click.Path(
272         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
273     ),
274     is_eager=True,
275     callback=read_pyproject_toml,
276     help="Read configuration from PATH.",
277 )
278 @click.pass_context
279 def main(
280     ctx: click.Context,
281     line_length: int,
282     check: bool,
283     diff: bool,
284     fast: bool,
285     pyi: bool,
286     py36: bool,
287     skip_string_normalization: bool,
288     quiet: bool,
289     verbose: bool,
290     include: str,
291     exclude: str,
292     src: Tuple[str],
293     config: Optional[str],
294 ) -> None:
295     """The uncompromising code formatter."""
296     write_back = WriteBack.from_configuration(check=check, diff=diff)
297     mode = FileMode.from_configuration(
298         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
299     )
300     if config and verbose:
301         out(f"Using configuration from {config}.", bold=False, fg="blue")
302     try:
303         include_regex = re_compile_maybe_verbose(include)
304     except re.error:
305         err(f"Invalid regular expression for include given: {include!r}")
306         ctx.exit(2)
307     try:
308         exclude_regex = re_compile_maybe_verbose(exclude)
309     except re.error:
310         err(f"Invalid regular expression for exclude given: {exclude!r}")
311         ctx.exit(2)
312     report = Report(check=check, quiet=quiet, verbose=verbose)
313     root = find_project_root(src)
314     sources: Set[Path] = set()
315     for s in src:
316         p = Path(s)
317         if p.is_dir():
318             sources.update(
319                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
320             )
321         elif p.is_file() or s == "-":
322             # if a file was explicitly given, we don't care about its extension
323             sources.add(p)
324         else:
325             err(f"invalid path: {s}")
326     if len(sources) == 0:
327         if verbose or not quiet:
328             out("No paths given. Nothing to do 😴")
329         ctx.exit(0)
330
331     if len(sources) == 1:
332         reformat_one(
333             src=sources.pop(),
334             line_length=line_length,
335             fast=fast,
336             write_back=write_back,
337             mode=mode,
338             report=report,
339         )
340     else:
341         loop = asyncio.get_event_loop()
342         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
343         try:
344             loop.run_until_complete(
345                 schedule_formatting(
346                     sources=sources,
347                     line_length=line_length,
348                     fast=fast,
349                     write_back=write_back,
350                     mode=mode,
351                     report=report,
352                     loop=loop,
353                     executor=executor,
354                 )
355             )
356         finally:
357             shutdown(loop)
358     if verbose or not quiet:
359         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
360         out(f"All done! {bang}")
361         click.secho(str(report), err=True)
362     ctx.exit(report.return_code)
363
364
365 def reformat_one(
366     src: Path,
367     line_length: int,
368     fast: bool,
369     write_back: WriteBack,
370     mode: FileMode,
371     report: "Report",
372 ) -> None:
373     """Reformat a single file under `src` without spawning child processes.
374
375     If `quiet` is True, non-error messages are not output. `line_length`,
376     `write_back`, `fast` and `pyi` options are passed to
377     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
378     """
379     try:
380         changed = Changed.NO
381         if not src.is_file() and str(src) == "-":
382             if format_stdin_to_stdout(
383                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
384             ):
385                 changed = Changed.YES
386         else:
387             cache: Cache = {}
388             if write_back != WriteBack.DIFF:
389                 cache = read_cache(line_length, mode)
390                 res_src = src.resolve()
391                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
392                     changed = Changed.CACHED
393             if changed is not Changed.CACHED and format_file_in_place(
394                 src,
395                 line_length=line_length,
396                 fast=fast,
397                 write_back=write_back,
398                 mode=mode,
399             ):
400                 changed = Changed.YES
401             if write_back == WriteBack.YES and changed is not Changed.NO:
402                 write_cache(cache, [src], line_length, mode)
403         report.done(src, changed)
404     except Exception as exc:
405         report.failed(src, str(exc))
406
407
408 async def schedule_formatting(
409     sources: Set[Path],
410     line_length: int,
411     fast: bool,
412     write_back: WriteBack,
413     mode: FileMode,
414     report: "Report",
415     loop: BaseEventLoop,
416     executor: Executor,
417 ) -> None:
418     """Run formatting of `sources` in parallel using the provided `executor`.
419
420     (Use ProcessPoolExecutors for actual parallelism.)
421
422     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
423     :func:`format_file_in_place`.
424     """
425     cache: Cache = {}
426     if write_back != WriteBack.DIFF:
427         cache = read_cache(line_length, mode)
428         sources, cached = filter_cached(cache, sources)
429         for src in sorted(cached):
430             report.done(src, Changed.CACHED)
431     cancelled = []
432     formatted = []
433     if sources:
434         lock = None
435         if write_back == WriteBack.DIFF:
436             # For diff output, we need locks to ensure we don't interleave output
437             # from different processes.
438             manager = Manager()
439             lock = manager.Lock()
440         tasks = {
441             loop.run_in_executor(
442                 executor,
443                 format_file_in_place,
444                 src,
445                 line_length,
446                 fast,
447                 write_back,
448                 mode,
449                 lock,
450             ): src
451             for src in sorted(sources)
452         }
453         pending: Iterable[asyncio.Task] = tasks.keys()
454         try:
455             loop.add_signal_handler(signal.SIGINT, cancel, pending)
456             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
457         except NotImplementedError:
458             # There are no good alternatives for these on Windows
459             pass
460         while pending:
461             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
462             for task in done:
463                 src = tasks.pop(task)
464                 if task.cancelled():
465                     cancelled.append(task)
466                 elif task.exception():
467                     report.failed(src, str(task.exception()))
468                 else:
469                     formatted.append(src)
470                     report.done(src, Changed.YES if task.result() else Changed.NO)
471     if cancelled:
472         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
473     if write_back == WriteBack.YES and formatted:
474         write_cache(cache, formatted, line_length, mode)
475
476
477 def format_file_in_place(
478     src: Path,
479     line_length: int,
480     fast: bool,
481     write_back: WriteBack = WriteBack.NO,
482     mode: FileMode = FileMode.AUTO_DETECT,
483     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
484 ) -> bool:
485     """Format file under `src` path. Return True if changed.
486
487     If `write_back` is True, write reformatted code back to stdout.
488     `line_length` and `fast` options are passed to :func:`format_file_contents`.
489     """
490     if src.suffix == ".pyi":
491         mode |= FileMode.PYI
492
493     then = datetime.utcfromtimestamp(src.stat().st_mtime)
494     with open(src, "rb") as buf:
495         src_contents, encoding, newline = decode_bytes(buf.read())
496     try:
497         dst_contents = format_file_contents(
498             src_contents, line_length=line_length, fast=fast, mode=mode
499         )
500     except NothingChanged:
501         return False
502
503     if write_back == write_back.YES:
504         with open(src, "w", encoding=encoding, newline=newline) as f:
505             f.write(dst_contents)
506     elif write_back == write_back.DIFF:
507         now = datetime.utcnow()
508         src_name = f"{src}\t{then} +0000"
509         dst_name = f"{src}\t{now} +0000"
510         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
511         if lock:
512             lock.acquire()
513         try:
514             f = io.TextIOWrapper(
515                 sys.stdout.buffer,
516                 encoding=encoding,
517                 newline=newline,
518                 write_through=True,
519             )
520             f.write(diff_contents)
521             f.detach()
522         finally:
523             if lock:
524                 lock.release()
525     return True
526
527
528 def format_stdin_to_stdout(
529     line_length: int,
530     fast: bool,
531     write_back: WriteBack = WriteBack.NO,
532     mode: FileMode = FileMode.AUTO_DETECT,
533 ) -> bool:
534     """Format file on stdin. Return True if changed.
535
536     If `write_back` is True, write reformatted code back to stdout.
537     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
538     :func:`format_file_contents`.
539     """
540     then = datetime.utcnow()
541     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
542     dst = src
543     try:
544         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
545         return True
546
547     except NothingChanged:
548         return False
549
550     finally:
551         f = io.TextIOWrapper(
552             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
553         )
554         if write_back == WriteBack.YES:
555             f.write(dst)
556         elif write_back == WriteBack.DIFF:
557             now = datetime.utcnow()
558             src_name = f"STDIN\t{then} +0000"
559             dst_name = f"STDOUT\t{now} +0000"
560             f.write(diff(src, dst, src_name, dst_name))
561         f.detach()
562
563
564 def format_file_contents(
565     src_contents: str,
566     *,
567     line_length: int,
568     fast: bool,
569     mode: FileMode = FileMode.AUTO_DETECT,
570 ) -> FileContent:
571     """Reformat contents a file and return new contents.
572
573     If `fast` is False, additionally confirm that the reformatted code is
574     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
575     `line_length` is passed to :func:`format_str`.
576     """
577     if src_contents.strip() == "":
578         raise NothingChanged
579
580     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
581     if src_contents == dst_contents:
582         raise NothingChanged
583
584     if not fast:
585         assert_equivalent(src_contents, dst_contents)
586         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
587     return dst_contents
588
589
590 def format_str(
591     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
592 ) -> FileContent:
593     """Reformat a string and return new contents.
594
595     `line_length` determines how many characters per line are allowed.
596     """
597     src_node = lib2to3_parse(src_contents)
598     dst_contents = ""
599     future_imports = get_future_imports(src_node)
600     is_pyi = bool(mode & FileMode.PYI)
601     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
602     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
603     normalize_fmt_off(src_node)
604     lines = LineGenerator(
605         remove_u_prefix=py36 or "unicode_literals" in future_imports,
606         is_pyi=is_pyi,
607         normalize_strings=normalize_strings,
608     )
609     elt = EmptyLineTracker(is_pyi=is_pyi)
610     empty_line = Line()
611     after = 0
612     for current_line in lines.visit(src_node):
613         for _ in range(after):
614             dst_contents += str(empty_line)
615         before, after = elt.maybe_empty_lines(current_line)
616         for _ in range(before):
617             dst_contents += str(empty_line)
618         for line in split_line(current_line, line_length=line_length, py36=py36):
619             dst_contents += str(line)
620     return dst_contents
621
622
623 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
624     """Return a tuple of (decoded_contents, encoding, newline).
625
626     `newline` is either CRLF or LF but `decoded_contents` is decoded with
627     universal newlines (i.e. only contains LF).
628     """
629     srcbuf = io.BytesIO(src)
630     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
631     if not lines:
632         return "", encoding, "\n"
633
634     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
635     srcbuf.seek(0)
636     with io.TextIOWrapper(srcbuf, encoding) as tiow:
637         return tiow.read(), encoding, newline
638
639
640 GRAMMARS = [
641     pygram.python_grammar_no_print_statement_no_exec_statement,
642     pygram.python_grammar_no_print_statement,
643     pygram.python_grammar,
644 ]
645
646
647 def lib2to3_parse(src_txt: str) -> Node:
648     """Given a string with source, return the lib2to3 Node."""
649     grammar = pygram.python_grammar_no_print_statement
650     if src_txt[-1:] != "\n":
651         src_txt += "\n"
652     for grammar in GRAMMARS:
653         drv = driver.Driver(grammar, pytree.convert)
654         try:
655             result = drv.parse_string(src_txt, True)
656             break
657
658         except ParseError as pe:
659             lineno, column = pe.context[1]
660             lines = src_txt.splitlines()
661             try:
662                 faulty_line = lines[lineno - 1]
663             except IndexError:
664                 faulty_line = "<line number missing in source>"
665             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
666     else:
667         raise exc from None
668
669     if isinstance(result, Leaf):
670         result = Node(syms.file_input, [result])
671     return result
672
673
674 def lib2to3_unparse(node: Node) -> str:
675     """Given a lib2to3 node, return its string representation."""
676     code = str(node)
677     return code
678
679
680 T = TypeVar("T")
681
682
683 class Visitor(Generic[T]):
684     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
685
686     def visit(self, node: LN) -> Iterator[T]:
687         """Main method to visit `node` and its children.
688
689         It tries to find a `visit_*()` method for the given `node.type`, like
690         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
691         If no dedicated `visit_*()` method is found, chooses `visit_default()`
692         instead.
693
694         Then yields objects of type `T` from the selected visitor.
695         """
696         if node.type < 256:
697             name = token.tok_name[node.type]
698         else:
699             name = type_repr(node.type)
700         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
701
702     def visit_default(self, node: LN) -> Iterator[T]:
703         """Default `visit_*()` implementation. Recurses to children of `node`."""
704         if isinstance(node, Node):
705             for child in node.children:
706                 yield from self.visit(child)
707
708
709 @dataclass
710 class DebugVisitor(Visitor[T]):
711     tree_depth: int = 0
712
713     def visit_default(self, node: LN) -> Iterator[T]:
714         indent = " " * (2 * self.tree_depth)
715         if isinstance(node, Node):
716             _type = type_repr(node.type)
717             out(f"{indent}{_type}", fg="yellow")
718             self.tree_depth += 1
719             for child in node.children:
720                 yield from self.visit(child)
721
722             self.tree_depth -= 1
723             out(f"{indent}/{_type}", fg="yellow", bold=False)
724         else:
725             _type = token.tok_name.get(node.type, str(node.type))
726             out(f"{indent}{_type}", fg="blue", nl=False)
727             if node.prefix:
728                 # We don't have to handle prefixes for `Node` objects since
729                 # that delegates to the first child anyway.
730                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
731             out(f" {node.value!r}", fg="blue", bold=False)
732
733     @classmethod
734     def show(cls, code: Union[str, Leaf, Node]) -> None:
735         """Pretty-print the lib2to3 AST of a given string of `code`.
736
737         Convenience method for debugging.
738         """
739         v: DebugVisitor[None] = DebugVisitor()
740         if isinstance(code, str):
741             code = lib2to3_parse(code)
742         list(v.visit(code))
743
744
745 KEYWORDS = set(keyword.kwlist)
746 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
747 FLOW_CONTROL = {"return", "raise", "break", "continue"}
748 STATEMENT = {
749     syms.if_stmt,
750     syms.while_stmt,
751     syms.for_stmt,
752     syms.try_stmt,
753     syms.except_clause,
754     syms.with_stmt,
755     syms.funcdef,
756     syms.classdef,
757 }
758 STANDALONE_COMMENT = 153
759 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
760 LOGIC_OPERATORS = {"and", "or"}
761 COMPARATORS = {
762     token.LESS,
763     token.GREATER,
764     token.EQEQUAL,
765     token.NOTEQUAL,
766     token.LESSEQUAL,
767     token.GREATEREQUAL,
768 }
769 MATH_OPERATORS = {
770     token.VBAR,
771     token.CIRCUMFLEX,
772     token.AMPER,
773     token.LEFTSHIFT,
774     token.RIGHTSHIFT,
775     token.PLUS,
776     token.MINUS,
777     token.STAR,
778     token.SLASH,
779     token.DOUBLESLASH,
780     token.PERCENT,
781     token.AT,
782     token.TILDE,
783     token.DOUBLESTAR,
784 }
785 STARS = {token.STAR, token.DOUBLESTAR}
786 VARARGS_PARENTS = {
787     syms.arglist,
788     syms.argument,  # double star in arglist
789     syms.trailer,  # single argument to call
790     syms.typedargslist,
791     syms.varargslist,  # lambdas
792 }
793 UNPACKING_PARENTS = {
794     syms.atom,  # single element of a list or set literal
795     syms.dictsetmaker,
796     syms.listmaker,
797     syms.testlist_gexp,
798     syms.testlist_star_expr,
799 }
800 SURROUNDED_BY_BRACKETS = {
801     syms.typedargslist,
802     syms.arglist,
803     syms.subscriptlist,
804     syms.vfplist,
805     syms.import_as_names,
806     syms.yield_expr,
807     syms.testlist_gexp,
808     syms.testlist_star_expr,
809     syms.listmaker,
810     syms.dictsetmaker,
811 }
812 TEST_DESCENDANTS = {
813     syms.test,
814     syms.lambdef,
815     syms.or_test,
816     syms.and_test,
817     syms.not_test,
818     syms.comparison,
819     syms.star_expr,
820     syms.expr,
821     syms.xor_expr,
822     syms.and_expr,
823     syms.shift_expr,
824     syms.arith_expr,
825     syms.trailer,
826     syms.term,
827     syms.power,
828 }
829 ASSIGNMENTS = {
830     "=",
831     "+=",
832     "-=",
833     "*=",
834     "@=",
835     "/=",
836     "%=",
837     "&=",
838     "|=",
839     "^=",
840     "<<=",
841     ">>=",
842     "**=",
843     "//=",
844 }
845 COMPREHENSION_PRIORITY = 20
846 COMMA_PRIORITY = 18
847 TERNARY_PRIORITY = 16
848 LOGIC_PRIORITY = 14
849 STRING_PRIORITY = 12
850 COMPARATOR_PRIORITY = 10
851 MATH_PRIORITIES = {
852     token.VBAR: 9,
853     token.CIRCUMFLEX: 8,
854     token.AMPER: 7,
855     token.LEFTSHIFT: 6,
856     token.RIGHTSHIFT: 6,
857     token.PLUS: 5,
858     token.MINUS: 5,
859     token.STAR: 4,
860     token.SLASH: 4,
861     token.DOUBLESLASH: 4,
862     token.PERCENT: 4,
863     token.AT: 4,
864     token.TILDE: 3,
865     token.DOUBLESTAR: 2,
866 }
867 DOT_PRIORITY = 1
868
869
870 @dataclass
871 class BracketTracker:
872     """Keeps track of brackets on a line."""
873
874     depth: int = 0
875     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
876     delimiters: Dict[LeafID, Priority] = Factory(dict)
877     previous: Optional[Leaf] = None
878     _for_loop_variable: int = 0
879     _lambda_arguments: int = 0
880
881     def mark(self, leaf: Leaf) -> None:
882         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
883
884         All leaves receive an int `bracket_depth` field that stores how deep
885         within brackets a given leaf is. 0 means there are no enclosing brackets
886         that started on this line.
887
888         If a leaf is itself a closing bracket, it receives an `opening_bracket`
889         field that it forms a pair with. This is a one-directional link to
890         avoid reference cycles.
891
892         If a leaf is a delimiter (a token on which Black can split the line if
893         needed) and it's on depth 0, its `id()` is stored in the tracker's
894         `delimiters` field.
895         """
896         if leaf.type == token.COMMENT:
897             return
898
899         self.maybe_decrement_after_for_loop_variable(leaf)
900         self.maybe_decrement_after_lambda_arguments(leaf)
901         if leaf.type in CLOSING_BRACKETS:
902             self.depth -= 1
903             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
904             leaf.opening_bracket = opening_bracket
905         leaf.bracket_depth = self.depth
906         if self.depth == 0:
907             delim = is_split_before_delimiter(leaf, self.previous)
908             if delim and self.previous is not None:
909                 self.delimiters[id(self.previous)] = delim
910             else:
911                 delim = is_split_after_delimiter(leaf, self.previous)
912                 if delim:
913                     self.delimiters[id(leaf)] = delim
914         if leaf.type in OPENING_BRACKETS:
915             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
916             self.depth += 1
917         self.previous = leaf
918         self.maybe_increment_lambda_arguments(leaf)
919         self.maybe_increment_for_loop_variable(leaf)
920
921     def any_open_brackets(self) -> bool:
922         """Return True if there is an yet unmatched open bracket on the line."""
923         return bool(self.bracket_match)
924
925     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
926         """Return the highest priority of a delimiter found on the line.
927
928         Values are consistent with what `is_split_*_delimiter()` return.
929         Raises ValueError on no delimiters.
930         """
931         return max(v for k, v in self.delimiters.items() if k not in exclude)
932
933     def delimiter_count_with_priority(self, priority: int = 0) -> int:
934         """Return the number of delimiters with the given `priority`.
935
936         If no `priority` is passed, defaults to max priority on the line.
937         """
938         if not self.delimiters:
939             return 0
940
941         priority = priority or self.max_delimiter_priority()
942         return sum(1 for p in self.delimiters.values() if p == priority)
943
944     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
945         """In a for loop, or comprehension, the variables are often unpacks.
946
947         To avoid splitting on the comma in this situation, increase the depth of
948         tokens between `for` and `in`.
949         """
950         if leaf.type == token.NAME and leaf.value == "for":
951             self.depth += 1
952             self._for_loop_variable += 1
953             return True
954
955         return False
956
957     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
958         """See `maybe_increment_for_loop_variable` above for explanation."""
959         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
960             self.depth -= 1
961             self._for_loop_variable -= 1
962             return True
963
964         return False
965
966     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
967         """In a lambda expression, there might be more than one argument.
968
969         To avoid splitting on the comma in this situation, increase the depth of
970         tokens between `lambda` and `:`.
971         """
972         if leaf.type == token.NAME and leaf.value == "lambda":
973             self.depth += 1
974             self._lambda_arguments += 1
975             return True
976
977         return False
978
979     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
980         """See `maybe_increment_lambda_arguments` above for explanation."""
981         if self._lambda_arguments and leaf.type == token.COLON:
982             self.depth -= 1
983             self._lambda_arguments -= 1
984             return True
985
986         return False
987
988     def get_open_lsqb(self) -> Optional[Leaf]:
989         """Return the most recent opening square bracket (if any)."""
990         return self.bracket_match.get((self.depth - 1, token.RSQB))
991
992
993 @dataclass
994 class Line:
995     """Holds leaves and comments. Can be printed with `str(line)`."""
996
997     depth: int = 0
998     leaves: List[Leaf] = Factory(list)
999     comments: List[Tuple[Index, Leaf]] = Factory(list)
1000     bracket_tracker: BracketTracker = Factory(BracketTracker)
1001     inside_brackets: bool = False
1002     should_explode: bool = False
1003
1004     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1005         """Add a new `leaf` to the end of the line.
1006
1007         Unless `preformatted` is True, the `leaf` will receive a new consistent
1008         whitespace prefix and metadata applied by :class:`BracketTracker`.
1009         Trailing commas are maybe removed, unpacked for loop variables are
1010         demoted from being delimiters.
1011
1012         Inline comments are put aside.
1013         """
1014         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1015         if not has_value:
1016             return
1017
1018         if token.COLON == leaf.type and self.is_class_paren_empty:
1019             del self.leaves[-2:]
1020         if self.leaves and not preformatted:
1021             # Note: at this point leaf.prefix should be empty except for
1022             # imports, for which we only preserve newlines.
1023             leaf.prefix += whitespace(
1024                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1025             )
1026         if self.inside_brackets or not preformatted:
1027             self.bracket_tracker.mark(leaf)
1028             self.maybe_remove_trailing_comma(leaf)
1029         if not self.append_comment(leaf):
1030             self.leaves.append(leaf)
1031
1032     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1033         """Like :func:`append()` but disallow invalid standalone comment structure.
1034
1035         Raises ValueError when any `leaf` is appended after a standalone comment
1036         or when a standalone comment is not the first leaf on the line.
1037         """
1038         if self.bracket_tracker.depth == 0:
1039             if self.is_comment:
1040                 raise ValueError("cannot append to standalone comments")
1041
1042             if self.leaves and leaf.type == STANDALONE_COMMENT:
1043                 raise ValueError(
1044                     "cannot append standalone comments to a populated line"
1045                 )
1046
1047         self.append(leaf, preformatted=preformatted)
1048
1049     @property
1050     def is_comment(self) -> bool:
1051         """Is this line a standalone comment?"""
1052         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1053
1054     @property
1055     def is_decorator(self) -> bool:
1056         """Is this line a decorator?"""
1057         return bool(self) and self.leaves[0].type == token.AT
1058
1059     @property
1060     def is_import(self) -> bool:
1061         """Is this an import line?"""
1062         return bool(self) and is_import(self.leaves[0])
1063
1064     @property
1065     def is_class(self) -> bool:
1066         """Is this line a class definition?"""
1067         return (
1068             bool(self)
1069             and self.leaves[0].type == token.NAME
1070             and self.leaves[0].value == "class"
1071         )
1072
1073     @property
1074     def is_stub_class(self) -> bool:
1075         """Is this line a class definition with a body consisting only of "..."?"""
1076         return self.is_class and self.leaves[-3:] == [
1077             Leaf(token.DOT, ".") for _ in range(3)
1078         ]
1079
1080     @property
1081     def is_def(self) -> bool:
1082         """Is this a function definition? (Also returns True for async defs.)"""
1083         try:
1084             first_leaf = self.leaves[0]
1085         except IndexError:
1086             return False
1087
1088         try:
1089             second_leaf: Optional[Leaf] = self.leaves[1]
1090         except IndexError:
1091             second_leaf = None
1092         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1093             first_leaf.type == token.ASYNC
1094             and second_leaf is not None
1095             and second_leaf.type == token.NAME
1096             and second_leaf.value == "def"
1097         )
1098
1099     @property
1100     def is_class_paren_empty(self) -> bool:
1101         """Is this a class with no base classes but using parentheses?
1102
1103         Those are unnecessary and should be removed.
1104         """
1105         return (
1106             bool(self)
1107             and len(self.leaves) == 4
1108             and self.is_class
1109             and self.leaves[2].type == token.LPAR
1110             and self.leaves[2].value == "("
1111             and self.leaves[3].type == token.RPAR
1112             and self.leaves[3].value == ")"
1113         )
1114
1115     @property
1116     def is_triple_quoted_string(self) -> bool:
1117         """Is the line a triple quoted string?"""
1118         return (
1119             bool(self)
1120             and self.leaves[0].type == token.STRING
1121             and self.leaves[0].value.startswith(('"""', "'''"))
1122         )
1123
1124     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1125         """If so, needs to be split before emitting."""
1126         for leaf in self.leaves:
1127             if leaf.type == STANDALONE_COMMENT:
1128                 if leaf.bracket_depth <= depth_limit:
1129                     return True
1130
1131         return False
1132
1133     def contains_multiline_strings(self) -> bool:
1134         for leaf in self.leaves:
1135             if is_multiline_string(leaf):
1136                 return True
1137
1138         return False
1139
1140     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1141         """Remove trailing comma if there is one and it's safe."""
1142         if not (
1143             self.leaves
1144             and self.leaves[-1].type == token.COMMA
1145             and closing.type in CLOSING_BRACKETS
1146         ):
1147             return False
1148
1149         if closing.type == token.RBRACE:
1150             self.remove_trailing_comma()
1151             return True
1152
1153         if closing.type == token.RSQB:
1154             comma = self.leaves[-1]
1155             if comma.parent and comma.parent.type == syms.listmaker:
1156                 self.remove_trailing_comma()
1157                 return True
1158
1159         # For parens let's check if it's safe to remove the comma.
1160         # Imports are always safe.
1161         if self.is_import:
1162             self.remove_trailing_comma()
1163             return True
1164
1165         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1166         # change a tuple into a different type by removing the comma.
1167         depth = closing.bracket_depth + 1
1168         commas = 0
1169         opening = closing.opening_bracket
1170         for _opening_index, leaf in enumerate(self.leaves):
1171             if leaf is opening:
1172                 break
1173
1174         else:
1175             return False
1176
1177         for leaf in self.leaves[_opening_index + 1 :]:
1178             if leaf is closing:
1179                 break
1180
1181             bracket_depth = leaf.bracket_depth
1182             if bracket_depth == depth and leaf.type == token.COMMA:
1183                 commas += 1
1184                 if leaf.parent and leaf.parent.type == syms.arglist:
1185                     commas += 1
1186                     break
1187
1188         if commas > 1:
1189             self.remove_trailing_comma()
1190             return True
1191
1192         return False
1193
1194     def append_comment(self, comment: Leaf) -> bool:
1195         """Add an inline or standalone comment to the line."""
1196         if (
1197             comment.type == STANDALONE_COMMENT
1198             and self.bracket_tracker.any_open_brackets()
1199         ):
1200             comment.prefix = ""
1201             return False
1202
1203         if comment.type != token.COMMENT:
1204             return False
1205
1206         after = len(self.leaves) - 1
1207         if after == -1:
1208             comment.type = STANDALONE_COMMENT
1209             comment.prefix = ""
1210             return False
1211
1212         else:
1213             self.comments.append((after, comment))
1214             return True
1215
1216     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1217         """Generate comments that should appear directly after `leaf`.
1218
1219         Provide a non-negative leaf `_index` to speed up the function.
1220         """
1221         if not self.comments:
1222             return
1223
1224         if _index == -1:
1225             for _index, _leaf in enumerate(self.leaves):
1226                 if leaf is _leaf:
1227                     break
1228
1229             else:
1230                 return
1231
1232         for index, comment_after in self.comments:
1233             if _index == index:
1234                 yield comment_after
1235
1236     def remove_trailing_comma(self) -> None:
1237         """Remove the trailing comma and moves the comments attached to it."""
1238         comma_index = len(self.leaves) - 1
1239         for i in range(len(self.comments)):
1240             comment_index, comment = self.comments[i]
1241             if comment_index == comma_index:
1242                 self.comments[i] = (comma_index - 1, comment)
1243         self.leaves.pop()
1244
1245     def is_complex_subscript(self, leaf: Leaf) -> bool:
1246         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1247         open_lsqb = self.bracket_tracker.get_open_lsqb()
1248         if open_lsqb is None:
1249             return False
1250
1251         subscript_start = open_lsqb.next_sibling
1252
1253         if isinstance(subscript_start, Node):
1254             if subscript_start.type == syms.listmaker:
1255                 return False
1256
1257             if subscript_start.type == syms.subscriptlist:
1258                 subscript_start = child_towards(subscript_start, leaf)
1259         return subscript_start is not None and any(
1260             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1261         )
1262
1263     def __str__(self) -> str:
1264         """Render the line."""
1265         if not self:
1266             return "\n"
1267
1268         indent = "    " * self.depth
1269         leaves = iter(self.leaves)
1270         first = next(leaves)
1271         res = f"{first.prefix}{indent}{first.value}"
1272         for leaf in leaves:
1273             res += str(leaf)
1274         for _, comment in self.comments:
1275             res += str(comment)
1276         return res + "\n"
1277
1278     def __bool__(self) -> bool:
1279         """Return True if the line has leaves or comments."""
1280         return bool(self.leaves or self.comments)
1281
1282
1283 @dataclass
1284 class EmptyLineTracker:
1285     """Provides a stateful method that returns the number of potential extra
1286     empty lines needed before and after the currently processed line.
1287
1288     Note: this tracker works on lines that haven't been split yet.  It assumes
1289     the prefix of the first leaf consists of optional newlines.  Those newlines
1290     are consumed by `maybe_empty_lines()` and included in the computation.
1291     """
1292
1293     is_pyi: bool = False
1294     previous_line: Optional[Line] = None
1295     previous_after: int = 0
1296     previous_defs: List[int] = Factory(list)
1297
1298     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1299         """Return the number of extra empty lines before and after the `current_line`.
1300
1301         This is for separating `def`, `async def` and `class` with extra empty
1302         lines (two on module-level).
1303         """
1304         before, after = self._maybe_empty_lines(current_line)
1305         before -= self.previous_after
1306         self.previous_after = after
1307         self.previous_line = current_line
1308         return before, after
1309
1310     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1311         max_allowed = 1
1312         if current_line.depth == 0:
1313             max_allowed = 1 if self.is_pyi else 2
1314         if current_line.leaves:
1315             # Consume the first leaf's extra newlines.
1316             first_leaf = current_line.leaves[0]
1317             before = first_leaf.prefix.count("\n")
1318             before = min(before, max_allowed)
1319             first_leaf.prefix = ""
1320         else:
1321             before = 0
1322         depth = current_line.depth
1323         while self.previous_defs and self.previous_defs[-1] >= depth:
1324             self.previous_defs.pop()
1325             if self.is_pyi:
1326                 before = 0 if depth else 1
1327             else:
1328                 before = 1 if depth else 2
1329         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1330             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1331
1332         if (
1333             self.previous_line
1334             and self.previous_line.is_import
1335             and not current_line.is_import
1336             and depth == self.previous_line.depth
1337         ):
1338             return (before or 1), 0
1339
1340         if (
1341             self.previous_line
1342             and self.previous_line.is_class
1343             and current_line.is_triple_quoted_string
1344         ):
1345             return before, 1
1346
1347         return before, 0
1348
1349     def _maybe_empty_lines_for_class_or_def(
1350         self, current_line: Line, before: int
1351     ) -> Tuple[int, int]:
1352         if not current_line.is_decorator:
1353             self.previous_defs.append(current_line.depth)
1354         if self.previous_line is None:
1355             # Don't insert empty lines before the first line in the file.
1356             return 0, 0
1357
1358         if self.previous_line.is_decorator:
1359             return 0, 0
1360
1361         if self.previous_line.depth < current_line.depth and (
1362             self.previous_line.is_class or self.previous_line.is_def
1363         ):
1364             return 0, 0
1365
1366         if (
1367             self.previous_line.is_comment
1368             and self.previous_line.depth == current_line.depth
1369             and before == 0
1370         ):
1371             return 0, 0
1372
1373         if self.is_pyi:
1374             if self.previous_line.depth > current_line.depth:
1375                 newlines = 1
1376             elif current_line.is_class or self.previous_line.is_class:
1377                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1378                     # No blank line between classes with an emty body
1379                     newlines = 0
1380                 else:
1381                     newlines = 1
1382             elif current_line.is_def and not self.previous_line.is_def:
1383                 # Blank line between a block of functions and a block of non-functions
1384                 newlines = 1
1385             else:
1386                 newlines = 0
1387         else:
1388             newlines = 2
1389         if current_line.depth and newlines:
1390             newlines -= 1
1391         return newlines, 0
1392
1393
1394 @dataclass
1395 class LineGenerator(Visitor[Line]):
1396     """Generates reformatted Line objects.  Empty lines are not emitted.
1397
1398     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1399     in ways that will no longer stringify to valid Python code on the tree.
1400     """
1401
1402     is_pyi: bool = False
1403     normalize_strings: bool = True
1404     current_line: Line = Factory(Line)
1405     remove_u_prefix: bool = False
1406
1407     def line(self, indent: int = 0) -> Iterator[Line]:
1408         """Generate a line.
1409
1410         If the line is empty, only emit if it makes sense.
1411         If the line is too long, split it first and then generate.
1412
1413         If any lines were generated, set up a new current_line.
1414         """
1415         if not self.current_line:
1416             self.current_line.depth += indent
1417             return  # Line is empty, don't emit. Creating a new one unnecessary.
1418
1419         complete_line = self.current_line
1420         self.current_line = Line(depth=complete_line.depth + indent)
1421         yield complete_line
1422
1423     def visit_default(self, node: LN) -> Iterator[Line]:
1424         """Default `visit_*()` implementation. Recurses to children of `node`."""
1425         if isinstance(node, Leaf):
1426             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1427             for comment in generate_comments(node):
1428                 if any_open_brackets:
1429                     # any comment within brackets is subject to splitting
1430                     self.current_line.append(comment)
1431                 elif comment.type == token.COMMENT:
1432                     # regular trailing comment
1433                     self.current_line.append(comment)
1434                     yield from self.line()
1435
1436                 else:
1437                     # regular standalone comment
1438                     yield from self.line()
1439
1440                     self.current_line.append(comment)
1441                     yield from self.line()
1442
1443             normalize_prefix(node, inside_brackets=any_open_brackets)
1444             if self.normalize_strings and node.type == token.STRING:
1445                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1446                 normalize_string_quotes(node)
1447             if node.type not in WHITESPACE:
1448                 self.current_line.append(node)
1449         yield from super().visit_default(node)
1450
1451     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1452         """Increase indentation level, maybe yield a line."""
1453         # In blib2to3 INDENT never holds comments.
1454         yield from self.line(+1)
1455         yield from self.visit_default(node)
1456
1457     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1458         """Decrease indentation level, maybe yield a line."""
1459         # The current line might still wait for trailing comments.  At DEDENT time
1460         # there won't be any (they would be prefixes on the preceding NEWLINE).
1461         # Emit the line then.
1462         yield from self.line()
1463
1464         # While DEDENT has no value, its prefix may contain standalone comments
1465         # that belong to the current indentation level.  Get 'em.
1466         yield from self.visit_default(node)
1467
1468         # Finally, emit the dedent.
1469         yield from self.line(-1)
1470
1471     def visit_stmt(
1472         self, node: Node, keywords: Set[str], parens: Set[str]
1473     ) -> Iterator[Line]:
1474         """Visit a statement.
1475
1476         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1477         `def`, `with`, `class`, `assert` and assignments.
1478
1479         The relevant Python language `keywords` for a given statement will be
1480         NAME leaves within it. This methods puts those on a separate line.
1481
1482         `parens` holds a set of string leaf values immediately after which
1483         invisible parens should be put.
1484         """
1485         normalize_invisible_parens(node, parens_after=parens)
1486         for child in node.children:
1487             if child.type == token.NAME and child.value in keywords:  # type: ignore
1488                 yield from self.line()
1489
1490             yield from self.visit(child)
1491
1492     def visit_suite(self, node: Node) -> Iterator[Line]:
1493         """Visit a suite."""
1494         if self.is_pyi and is_stub_suite(node):
1495             yield from self.visit(node.children[2])
1496         else:
1497             yield from self.visit_default(node)
1498
1499     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1500         """Visit a statement without nested statements."""
1501         is_suite_like = node.parent and node.parent.type in STATEMENT
1502         if is_suite_like:
1503             if self.is_pyi and is_stub_body(node):
1504                 yield from self.visit_default(node)
1505             else:
1506                 yield from self.line(+1)
1507                 yield from self.visit_default(node)
1508                 yield from self.line(-1)
1509
1510         else:
1511             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1512                 yield from self.line()
1513             yield from self.visit_default(node)
1514
1515     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1516         """Visit `async def`, `async for`, `async with`."""
1517         yield from self.line()
1518
1519         children = iter(node.children)
1520         for child in children:
1521             yield from self.visit(child)
1522
1523             if child.type == token.ASYNC:
1524                 break
1525
1526         internal_stmt = next(children)
1527         for child in internal_stmt.children:
1528             yield from self.visit(child)
1529
1530     def visit_decorators(self, node: Node) -> Iterator[Line]:
1531         """Visit decorators."""
1532         for child in node.children:
1533             yield from self.line()
1534             yield from self.visit(child)
1535
1536     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1537         """Remove a semicolon and put the other statement on a separate line."""
1538         yield from self.line()
1539
1540     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1541         """End of file. Process outstanding comments and end with a newline."""
1542         yield from self.visit_default(leaf)
1543         yield from self.line()
1544
1545     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1546         if not self.current_line.bracket_tracker.any_open_brackets():
1547             yield from self.line()
1548         yield from self.visit_default(leaf)
1549
1550     def __attrs_post_init__(self) -> None:
1551         """You are in a twisty little maze of passages."""
1552         v = self.visit_stmt
1553         Ø: Set[str] = set()
1554         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1555         self.visit_if_stmt = partial(
1556             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1557         )
1558         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1559         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1560         self.visit_try_stmt = partial(
1561             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1562         )
1563         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1564         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1565         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1566         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1567         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1568         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1569         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1570         self.visit_async_funcdef = self.visit_async_stmt
1571         self.visit_decorated = self.visit_decorators
1572
1573
1574 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1575 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1576 OPENING_BRACKETS = set(BRACKET.keys())
1577 CLOSING_BRACKETS = set(BRACKET.values())
1578 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1579 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1580
1581
1582 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1583     """Return whitespace prefix if needed for the given `leaf`.
1584
1585     `complex_subscript` signals whether the given leaf is part of a subscription
1586     which has non-trivial arguments, like arithmetic expressions or function calls.
1587     """
1588     NO = ""
1589     SPACE = " "
1590     DOUBLESPACE = "  "
1591     t = leaf.type
1592     p = leaf.parent
1593     v = leaf.value
1594     if t in ALWAYS_NO_SPACE:
1595         return NO
1596
1597     if t == token.COMMENT:
1598         return DOUBLESPACE
1599
1600     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1601     if t == token.COLON and p.type not in {
1602         syms.subscript,
1603         syms.subscriptlist,
1604         syms.sliceop,
1605     }:
1606         return NO
1607
1608     prev = leaf.prev_sibling
1609     if not prev:
1610         prevp = preceding_leaf(p)
1611         if not prevp or prevp.type in OPENING_BRACKETS:
1612             return NO
1613
1614         if t == token.COLON:
1615             if prevp.type == token.COLON:
1616                 return NO
1617
1618             elif prevp.type != token.COMMA and not complex_subscript:
1619                 return NO
1620
1621             return SPACE
1622
1623         if prevp.type == token.EQUAL:
1624             if prevp.parent:
1625                 if prevp.parent.type in {
1626                     syms.arglist,
1627                     syms.argument,
1628                     syms.parameters,
1629                     syms.varargslist,
1630                 }:
1631                     return NO
1632
1633                 elif prevp.parent.type == syms.typedargslist:
1634                     # A bit hacky: if the equal sign has whitespace, it means we
1635                     # previously found it's a typed argument.  So, we're using
1636                     # that, too.
1637                     return prevp.prefix
1638
1639         elif prevp.type in STARS:
1640             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1641                 return NO
1642
1643         elif prevp.type == token.COLON:
1644             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1645                 return SPACE if complex_subscript else NO
1646
1647         elif (
1648             prevp.parent
1649             and prevp.parent.type == syms.factor
1650             and prevp.type in MATH_OPERATORS
1651         ):
1652             return NO
1653
1654         elif (
1655             prevp.type == token.RIGHTSHIFT
1656             and prevp.parent
1657             and prevp.parent.type == syms.shift_expr
1658             and prevp.prev_sibling
1659             and prevp.prev_sibling.type == token.NAME
1660             and prevp.prev_sibling.value == "print"  # type: ignore
1661         ):
1662             # Python 2 print chevron
1663             return NO
1664
1665     elif prev.type in OPENING_BRACKETS:
1666         return NO
1667
1668     if p.type in {syms.parameters, syms.arglist}:
1669         # untyped function signatures or calls
1670         if not prev or prev.type != token.COMMA:
1671             return NO
1672
1673     elif p.type == syms.varargslist:
1674         # lambdas
1675         if prev and prev.type != token.COMMA:
1676             return NO
1677
1678     elif p.type == syms.typedargslist:
1679         # typed function signatures
1680         if not prev:
1681             return NO
1682
1683         if t == token.EQUAL:
1684             if prev.type != syms.tname:
1685                 return NO
1686
1687         elif prev.type == token.EQUAL:
1688             # A bit hacky: if the equal sign has whitespace, it means we
1689             # previously found it's a typed argument.  So, we're using that, too.
1690             return prev.prefix
1691
1692         elif prev.type != token.COMMA:
1693             return NO
1694
1695     elif p.type == syms.tname:
1696         # type names
1697         if not prev:
1698             prevp = preceding_leaf(p)
1699             if not prevp or prevp.type != token.COMMA:
1700                 return NO
1701
1702     elif p.type == syms.trailer:
1703         # attributes and calls
1704         if t == token.LPAR or t == token.RPAR:
1705             return NO
1706
1707         if not prev:
1708             if t == token.DOT:
1709                 prevp = preceding_leaf(p)
1710                 if not prevp or prevp.type != token.NUMBER:
1711                     return NO
1712
1713             elif t == token.LSQB:
1714                 return NO
1715
1716         elif prev.type != token.COMMA:
1717             return NO
1718
1719     elif p.type == syms.argument:
1720         # single argument
1721         if t == token.EQUAL:
1722             return NO
1723
1724         if not prev:
1725             prevp = preceding_leaf(p)
1726             if not prevp or prevp.type == token.LPAR:
1727                 return NO
1728
1729         elif prev.type in {token.EQUAL} | STARS:
1730             return NO
1731
1732     elif p.type == syms.decorator:
1733         # decorators
1734         return NO
1735
1736     elif p.type == syms.dotted_name:
1737         if prev:
1738             return NO
1739
1740         prevp = preceding_leaf(p)
1741         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1742             return NO
1743
1744     elif p.type == syms.classdef:
1745         if t == token.LPAR:
1746             return NO
1747
1748         if prev and prev.type == token.LPAR:
1749             return NO
1750
1751     elif p.type in {syms.subscript, syms.sliceop}:
1752         # indexing
1753         if not prev:
1754             assert p.parent is not None, "subscripts are always parented"
1755             if p.parent.type == syms.subscriptlist:
1756                 return SPACE
1757
1758             return NO
1759
1760         elif not complex_subscript:
1761             return NO
1762
1763     elif p.type == syms.atom:
1764         if prev and t == token.DOT:
1765             # dots, but not the first one.
1766             return NO
1767
1768     elif p.type == syms.dictsetmaker:
1769         # dict unpacking
1770         if prev and prev.type == token.DOUBLESTAR:
1771             return NO
1772
1773     elif p.type in {syms.factor, syms.star_expr}:
1774         # unary ops
1775         if not prev:
1776             prevp = preceding_leaf(p)
1777             if not prevp or prevp.type in OPENING_BRACKETS:
1778                 return NO
1779
1780             prevp_parent = prevp.parent
1781             assert prevp_parent is not None
1782             if prevp.type == token.COLON and prevp_parent.type in {
1783                 syms.subscript,
1784                 syms.sliceop,
1785             }:
1786                 return NO
1787
1788             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1789                 return NO
1790
1791         elif t in {token.NAME, token.NUMBER, token.STRING}:
1792             return NO
1793
1794     elif p.type == syms.import_from:
1795         if t == token.DOT:
1796             if prev and prev.type == token.DOT:
1797                 return NO
1798
1799         elif t == token.NAME:
1800             if v == "import":
1801                 return SPACE
1802
1803             if prev and prev.type == token.DOT:
1804                 return NO
1805
1806     elif p.type == syms.sliceop:
1807         return NO
1808
1809     return SPACE
1810
1811
1812 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1813     """Return the first leaf that precedes `node`, if any."""
1814     while node:
1815         res = node.prev_sibling
1816         if res:
1817             if isinstance(res, Leaf):
1818                 return res
1819
1820             try:
1821                 return list(res.leaves())[-1]
1822
1823             except IndexError:
1824                 return None
1825
1826         node = node.parent
1827     return None
1828
1829
1830 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1831     """Return the child of `ancestor` that contains `descendant`."""
1832     node: Optional[LN] = descendant
1833     while node and node.parent != ancestor:
1834         node = node.parent
1835     return node
1836
1837
1838 def container_of(leaf: Leaf) -> LN:
1839     """Return `leaf` or one of its ancestors that is the topmost container of it.
1840
1841     By "container" we mean a node where `leaf` is the very first child.
1842     """
1843     same_prefix = leaf.prefix
1844     container: LN = leaf
1845     while container:
1846         parent = container.parent
1847         if parent is None:
1848             break
1849
1850         if parent.children[0].prefix != same_prefix:
1851             break
1852
1853         if parent.type == syms.file_input:
1854             break
1855
1856         if parent.type in SURROUNDED_BY_BRACKETS:
1857             break
1858
1859         container = parent
1860     return container
1861
1862
1863 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1864     """Return the priority of the `leaf` delimiter, given a line break after it.
1865
1866     The delimiter priorities returned here are from those delimiters that would
1867     cause a line break after themselves.
1868
1869     Higher numbers are higher priority.
1870     """
1871     if leaf.type == token.COMMA:
1872         return COMMA_PRIORITY
1873
1874     return 0
1875
1876
1877 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1878     """Return the priority of the `leaf` delimiter, given a line before after it.
1879
1880     The delimiter priorities returned here are from those delimiters that would
1881     cause a line break before themselves.
1882
1883     Higher numbers are higher priority.
1884     """
1885     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1886         # * and ** might also be MATH_OPERATORS but in this case they are not.
1887         # Don't treat them as a delimiter.
1888         return 0
1889
1890     if (
1891         leaf.type == token.DOT
1892         and leaf.parent
1893         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1894         and (previous is None or previous.type in CLOSING_BRACKETS)
1895     ):
1896         return DOT_PRIORITY
1897
1898     if (
1899         leaf.type in MATH_OPERATORS
1900         and leaf.parent
1901         and leaf.parent.type not in {syms.factor, syms.star_expr}
1902     ):
1903         return MATH_PRIORITIES[leaf.type]
1904
1905     if leaf.type in COMPARATORS:
1906         return COMPARATOR_PRIORITY
1907
1908     if (
1909         leaf.type == token.STRING
1910         and previous is not None
1911         and previous.type == token.STRING
1912     ):
1913         return STRING_PRIORITY
1914
1915     if leaf.type != token.NAME:
1916         return 0
1917
1918     if (
1919         leaf.value == "for"
1920         and leaf.parent
1921         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1922     ):
1923         return COMPREHENSION_PRIORITY
1924
1925     if (
1926         leaf.value == "if"
1927         and leaf.parent
1928         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1929     ):
1930         return COMPREHENSION_PRIORITY
1931
1932     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1933         return TERNARY_PRIORITY
1934
1935     if leaf.value == "is":
1936         return COMPARATOR_PRIORITY
1937
1938     if (
1939         leaf.value == "in"
1940         and leaf.parent
1941         and leaf.parent.type in {syms.comp_op, syms.comparison}
1942         and not (
1943             previous is not None
1944             and previous.type == token.NAME
1945             and previous.value == "not"
1946         )
1947     ):
1948         return COMPARATOR_PRIORITY
1949
1950     if (
1951         leaf.value == "not"
1952         and leaf.parent
1953         and leaf.parent.type == syms.comp_op
1954         and not (
1955             previous is not None
1956             and previous.type == token.NAME
1957             and previous.value == "is"
1958         )
1959     ):
1960         return COMPARATOR_PRIORITY
1961
1962     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1963         return LOGIC_PRIORITY
1964
1965     return 0
1966
1967
1968 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1969 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1970
1971
1972 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1973     """Clean the prefix of the `leaf` and generate comments from it, if any.
1974
1975     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1976     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1977     move because it does away with modifying the grammar to include all the
1978     possible places in which comments can be placed.
1979
1980     The sad consequence for us though is that comments don't "belong" anywhere.
1981     This is why this function generates simple parentless Leaf objects for
1982     comments.  We simply don't know what the correct parent should be.
1983
1984     No matter though, we can live without this.  We really only need to
1985     differentiate between inline and standalone comments.  The latter don't
1986     share the line with any code.
1987
1988     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1989     are emitted with a fake STANDALONE_COMMENT token identifier.
1990     """
1991     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
1992         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
1993
1994
1995 @dataclass
1996 class ProtoComment:
1997     type: int  # token.COMMENT or STANDALONE_COMMENT
1998     value: str  # content of the comment
1999     newlines: int  # how many newlines before the comment
2000     consumed: int  # how many characters of the original leaf's prefix did we consume
2001
2002
2003 @lru_cache(maxsize=4096)
2004 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2005     result: List[ProtoComment] = []
2006     if not prefix or "#" not in prefix:
2007         return result
2008
2009     consumed = 0
2010     nlines = 0
2011     for index, line in enumerate(prefix.split("\n")):
2012         consumed += len(line) + 1  # adding the length of the split '\n'
2013         line = line.lstrip()
2014         if not line:
2015             nlines += 1
2016         if not line.startswith("#"):
2017             continue
2018
2019         if index == 0 and not is_endmarker:
2020             comment_type = token.COMMENT  # simple trailing comment
2021         else:
2022             comment_type = STANDALONE_COMMENT
2023         comment = make_comment(line)
2024         result.append(
2025             ProtoComment(
2026                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2027             )
2028         )
2029         nlines = 0
2030     return result
2031
2032
2033 def make_comment(content: str) -> str:
2034     """Return a consistently formatted comment from the given `content` string.
2035
2036     All comments (except for "##", "#!", "#:") should have a single space between
2037     the hash sign and the content.
2038
2039     If `content` didn't start with a hash sign, one is provided.
2040     """
2041     content = content.rstrip()
2042     if not content:
2043         return "#"
2044
2045     if content[0] == "#":
2046         content = content[1:]
2047     if content and content[0] not in " !:#":
2048         content = " " + content
2049     return "#" + content
2050
2051
2052 def split_line(
2053     line: Line, line_length: int, inner: bool = False, py36: bool = False
2054 ) -> Iterator[Line]:
2055     """Split a `line` into potentially many lines.
2056
2057     They should fit in the allotted `line_length` but might not be able to.
2058     `inner` signifies that there were a pair of brackets somewhere around the
2059     current `line`, possibly transitively. This means we can fallback to splitting
2060     by delimiters if the LHS/RHS don't yield any results.
2061
2062     If `py36` is True, splitting may generate syntax that is only compatible
2063     with Python 3.6 and later.
2064     """
2065     if line.is_comment:
2066         yield line
2067         return
2068
2069     line_str = str(line).strip("\n")
2070     if not line.should_explode and is_line_short_enough(
2071         line, line_length=line_length, line_str=line_str
2072     ):
2073         yield line
2074         return
2075
2076     split_funcs: List[SplitFunc]
2077     if line.is_def:
2078         split_funcs = [left_hand_split]
2079     else:
2080
2081         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2082             for omit in generate_trailers_to_omit(line, line_length):
2083                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2084                 if is_line_short_enough(lines[0], line_length=line_length):
2085                     yield from lines
2086                     return
2087
2088             # All splits failed, best effort split with no omits.
2089             # This mostly happens to multiline strings that are by definition
2090             # reported as not fitting a single line.
2091             yield from right_hand_split(line, py36)
2092
2093         if line.inside_brackets:
2094             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2095         else:
2096             split_funcs = [rhs]
2097     for split_func in split_funcs:
2098         # We are accumulating lines in `result` because we might want to abort
2099         # mission and return the original line in the end, or attempt a different
2100         # split altogether.
2101         result: List[Line] = []
2102         try:
2103             for l in split_func(line, py36):
2104                 if str(l).strip("\n") == line_str:
2105                     raise CannotSplit("Split function returned an unchanged result")
2106
2107                 result.extend(
2108                     split_line(l, line_length=line_length, inner=True, py36=py36)
2109                 )
2110         except CannotSplit as cs:
2111             continue
2112
2113         else:
2114             yield from result
2115             break
2116
2117     else:
2118         yield line
2119
2120
2121 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2122     """Split line into many lines, starting with the first matching bracket pair.
2123
2124     Note: this usually looks weird, only use this for function definitions.
2125     Prefer RHS otherwise.  This is why this function is not symmetrical with
2126     :func:`right_hand_split` which also handles optional parentheses.
2127     """
2128     head = Line(depth=line.depth)
2129     body = Line(depth=line.depth + 1, inside_brackets=True)
2130     tail = Line(depth=line.depth)
2131     tail_leaves: List[Leaf] = []
2132     body_leaves: List[Leaf] = []
2133     head_leaves: List[Leaf] = []
2134     current_leaves = head_leaves
2135     matching_bracket = None
2136     for leaf in line.leaves:
2137         if (
2138             current_leaves is body_leaves
2139             and leaf.type in CLOSING_BRACKETS
2140             and leaf.opening_bracket is matching_bracket
2141         ):
2142             current_leaves = tail_leaves if body_leaves else head_leaves
2143         current_leaves.append(leaf)
2144         if current_leaves is head_leaves:
2145             if leaf.type in OPENING_BRACKETS:
2146                 matching_bracket = leaf
2147                 current_leaves = body_leaves
2148     # Since body is a new indent level, remove spurious leading whitespace.
2149     if body_leaves:
2150         normalize_prefix(body_leaves[0], inside_brackets=True)
2151     # Build the new lines.
2152     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2153         for leaf in leaves:
2154             result.append(leaf, preformatted=True)
2155             for comment_after in line.comments_after(leaf):
2156                 result.append(comment_after, preformatted=True)
2157     bracket_split_succeeded_or_raise(head, body, tail)
2158     for result in (head, body, tail):
2159         if result:
2160             yield result
2161
2162
2163 def right_hand_split(
2164     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2165 ) -> Iterator[Line]:
2166     """Split line into many lines, starting with the last matching bracket pair.
2167
2168     If the split was by optional parentheses, attempt splitting without them, too.
2169     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2170     this split.
2171
2172     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2173     """
2174     head = Line(depth=line.depth)
2175     body = Line(depth=line.depth + 1, inside_brackets=True)
2176     tail = Line(depth=line.depth)
2177     tail_leaves: List[Leaf] = []
2178     body_leaves: List[Leaf] = []
2179     head_leaves: List[Leaf] = []
2180     current_leaves = tail_leaves
2181     opening_bracket = None
2182     closing_bracket = None
2183     for leaf in reversed(line.leaves):
2184         if current_leaves is body_leaves:
2185             if leaf is opening_bracket:
2186                 current_leaves = head_leaves if body_leaves else tail_leaves
2187         current_leaves.append(leaf)
2188         if current_leaves is tail_leaves:
2189             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2190                 opening_bracket = leaf.opening_bracket
2191                 closing_bracket = leaf
2192                 current_leaves = body_leaves
2193     tail_leaves.reverse()
2194     body_leaves.reverse()
2195     head_leaves.reverse()
2196     # Since body is a new indent level, remove spurious leading whitespace.
2197     if body_leaves:
2198         normalize_prefix(body_leaves[0], inside_brackets=True)
2199     if not head_leaves:
2200         # No `head` means the split failed. Either `tail` has all content or
2201         # the matching `opening_bracket` wasn't available on `line` anymore.
2202         raise CannotSplit("No brackets found")
2203
2204     # Build the new lines.
2205     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2206         for leaf in leaves:
2207             result.append(leaf, preformatted=True)
2208             for comment_after in line.comments_after(leaf):
2209                 result.append(comment_after, preformatted=True)
2210     assert opening_bracket and closing_bracket
2211     body.should_explode = should_explode(body, opening_bracket)
2212     bracket_split_succeeded_or_raise(head, body, tail)
2213     if (
2214         # the body shouldn't be exploded
2215         not body.should_explode
2216         # the opening bracket is an optional paren
2217         and opening_bracket.type == token.LPAR
2218         and not opening_bracket.value
2219         # the closing bracket is an optional paren
2220         and closing_bracket.type == token.RPAR
2221         and not closing_bracket.value
2222         # it's not an import (optional parens are the only thing we can split on
2223         # in this case; attempting a split without them is a waste of time)
2224         and not line.is_import
2225         # there are no standalone comments in the body
2226         and not body.contains_standalone_comments(0)
2227         # and we can actually remove the parens
2228         and can_omit_invisible_parens(body, line_length)
2229     ):
2230         omit = {id(closing_bracket), *omit}
2231         try:
2232             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2233             return
2234
2235         except CannotSplit:
2236             if not (
2237                 can_be_split(body)
2238                 or is_line_short_enough(body, line_length=line_length)
2239             ):
2240                 raise CannotSplit(
2241                     "Splitting failed, body is still too long and can't be split."
2242                 )
2243
2244             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2245                 raise CannotSplit(
2246                     "The current optional pair of parentheses is bound to fail to "
2247                     "satisfy the splitting algorithm because the head or the tail "
2248                     "contains multiline strings which by definition never fit one "
2249                     "line."
2250                 )
2251
2252     ensure_visible(opening_bracket)
2253     ensure_visible(closing_bracket)
2254     for result in (head, body, tail):
2255         if result:
2256             yield result
2257
2258
2259 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2260     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2261
2262     Do nothing otherwise.
2263
2264     A left- or right-hand split is based on a pair of brackets. Content before
2265     (and including) the opening bracket is left on one line, content inside the
2266     brackets is put on a separate line, and finally content starting with and
2267     following the closing bracket is put on a separate line.
2268
2269     Those are called `head`, `body`, and `tail`, respectively. If the split
2270     produced the same line (all content in `head`) or ended up with an empty `body`
2271     and the `tail` is just the closing bracket, then it's considered failed.
2272     """
2273     tail_len = len(str(tail).strip())
2274     if not body:
2275         if tail_len == 0:
2276             raise CannotSplit("Splitting brackets produced the same line")
2277
2278         elif tail_len < 3:
2279             raise CannotSplit(
2280                 f"Splitting brackets on an empty body to save "
2281                 f"{tail_len} characters is not worth it"
2282             )
2283
2284
2285 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2286     """Normalize prefix of the first leaf in every line returned by `split_func`.
2287
2288     This is a decorator over relevant split functions.
2289     """
2290
2291     @wraps(split_func)
2292     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2293         for l in split_func(line, py36):
2294             normalize_prefix(l.leaves[0], inside_brackets=True)
2295             yield l
2296
2297     return split_wrapper
2298
2299
2300 @dont_increase_indentation
2301 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2302     """Split according to delimiters of the highest priority.
2303
2304     If `py36` is True, the split will add trailing commas also in function
2305     signatures that contain `*` and `**`.
2306     """
2307     try:
2308         last_leaf = line.leaves[-1]
2309     except IndexError:
2310         raise CannotSplit("Line empty")
2311
2312     bt = line.bracket_tracker
2313     try:
2314         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2315     except ValueError:
2316         raise CannotSplit("No delimiters found")
2317
2318     if delimiter_priority == DOT_PRIORITY:
2319         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2320             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2321
2322     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2323     lowest_depth = sys.maxsize
2324     trailing_comma_safe = True
2325
2326     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2327         """Append `leaf` to current line or to new line if appending impossible."""
2328         nonlocal current_line
2329         try:
2330             current_line.append_safe(leaf, preformatted=True)
2331         except ValueError as ve:
2332             yield current_line
2333
2334             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2335             current_line.append(leaf)
2336
2337     for index, leaf in enumerate(line.leaves):
2338         yield from append_to_line(leaf)
2339
2340         for comment_after in line.comments_after(leaf, index):
2341             yield from append_to_line(comment_after)
2342
2343         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2344         if leaf.bracket_depth == lowest_depth and is_vararg(
2345             leaf, within=VARARGS_PARENTS
2346         ):
2347             trailing_comma_safe = trailing_comma_safe and py36
2348         leaf_priority = bt.delimiters.get(id(leaf))
2349         if leaf_priority == delimiter_priority:
2350             yield current_line
2351
2352             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2353     if current_line:
2354         if (
2355             trailing_comma_safe
2356             and delimiter_priority == COMMA_PRIORITY
2357             and current_line.leaves[-1].type != token.COMMA
2358             and current_line.leaves[-1].type != STANDALONE_COMMENT
2359         ):
2360             current_line.append(Leaf(token.COMMA, ","))
2361         yield current_line
2362
2363
2364 @dont_increase_indentation
2365 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2366     """Split standalone comments from the rest of the line."""
2367     if not line.contains_standalone_comments(0):
2368         raise CannotSplit("Line does not have any standalone comments")
2369
2370     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2371
2372     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2373         """Append `leaf` to current line or to new line if appending impossible."""
2374         nonlocal current_line
2375         try:
2376             current_line.append_safe(leaf, preformatted=True)
2377         except ValueError as ve:
2378             yield current_line
2379
2380             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2381             current_line.append(leaf)
2382
2383     for index, leaf in enumerate(line.leaves):
2384         yield from append_to_line(leaf)
2385
2386         for comment_after in line.comments_after(leaf, index):
2387             yield from append_to_line(comment_after)
2388
2389     if current_line:
2390         yield current_line
2391
2392
2393 def is_import(leaf: Leaf) -> bool:
2394     """Return True if the given leaf starts an import statement."""
2395     p = leaf.parent
2396     t = leaf.type
2397     v = leaf.value
2398     return bool(
2399         t == token.NAME
2400         and (
2401             (v == "import" and p and p.type == syms.import_name)
2402             or (v == "from" and p and p.type == syms.import_from)
2403         )
2404     )
2405
2406
2407 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2408     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2409     else.
2410
2411     Note: don't use backslashes for formatting or you'll lose your voting rights.
2412     """
2413     if not inside_brackets:
2414         spl = leaf.prefix.split("#")
2415         if "\\" not in spl[0]:
2416             nl_count = spl[-1].count("\n")
2417             if len(spl) > 1:
2418                 nl_count -= 1
2419             leaf.prefix = "\n" * nl_count
2420             return
2421
2422     leaf.prefix = ""
2423
2424
2425 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2426     """Make all string prefixes lowercase.
2427
2428     If remove_u_prefix is given, also removes any u prefix from the string.
2429
2430     Note: Mutates its argument.
2431     """
2432     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2433     assert match is not None, f"failed to match string {leaf.value!r}"
2434     orig_prefix = match.group(1)
2435     new_prefix = orig_prefix.lower()
2436     if remove_u_prefix:
2437         new_prefix = new_prefix.replace("u", "")
2438     leaf.value = f"{new_prefix}{match.group(2)}"
2439
2440
2441 def normalize_string_quotes(leaf: Leaf) -> None:
2442     """Prefer double quotes but only if it doesn't cause more escaping.
2443
2444     Adds or removes backslashes as appropriate. Doesn't parse and fix
2445     strings nested in f-strings (yet).
2446
2447     Note: Mutates its argument.
2448     """
2449     value = leaf.value.lstrip("furbFURB")
2450     if value[:3] == '"""':
2451         return
2452
2453     elif value[:3] == "'''":
2454         orig_quote = "'''"
2455         new_quote = '"""'
2456     elif value[0] == '"':
2457         orig_quote = '"'
2458         new_quote = "'"
2459     else:
2460         orig_quote = "'"
2461         new_quote = '"'
2462     first_quote_pos = leaf.value.find(orig_quote)
2463     if first_quote_pos == -1:
2464         return  # There's an internal error
2465
2466     prefix = leaf.value[:first_quote_pos]
2467     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2468     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2469     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2470     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2471     if "r" in prefix.casefold():
2472         if unescaped_new_quote.search(body):
2473             # There's at least one unescaped new_quote in this raw string
2474             # so converting is impossible
2475             return
2476
2477         # Do not introduce or remove backslashes in raw strings
2478         new_body = body
2479     else:
2480         # remove unnecessary escapes
2481         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2482         if body != new_body:
2483             # Consider the string without unnecessary escapes as the original
2484             body = new_body
2485             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2486         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2487         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2488     if "f" in prefix.casefold():
2489         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2490         for m in matches:
2491             if "\\" in str(m):
2492                 # Do not introduce backslashes in interpolated expressions
2493                 return
2494     if new_quote == '"""' and new_body[-1:] == '"':
2495         # edge case:
2496         new_body = new_body[:-1] + '\\"'
2497     orig_escape_count = body.count("\\")
2498     new_escape_count = new_body.count("\\")
2499     if new_escape_count > orig_escape_count:
2500         return  # Do not introduce more escaping
2501
2502     if new_escape_count == orig_escape_count and orig_quote == '"':
2503         return  # Prefer double quotes
2504
2505     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2506
2507
2508 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2509     """Make existing optional parentheses invisible or create new ones.
2510
2511     `parens_after` is a set of string leaf values immeditely after which parens
2512     should be put.
2513
2514     Standardizes on visible parentheses for single-element tuples, and keeps
2515     existing visible parentheses for other tuples and generator expressions.
2516     """
2517     for pc in list_comments(node.prefix, is_endmarker=False):
2518         if pc.value in FMT_OFF:
2519             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2520             return
2521
2522     check_lpar = False
2523     for index, child in enumerate(list(node.children)):
2524         if check_lpar:
2525             if child.type == syms.atom:
2526                 maybe_make_parens_invisible_in_atom(child)
2527             elif is_one_tuple(child):
2528                 # wrap child in visible parentheses
2529                 lpar = Leaf(token.LPAR, "(")
2530                 rpar = Leaf(token.RPAR, ")")
2531                 child.remove()
2532                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2533             elif node.type == syms.import_from:
2534                 # "import from" nodes store parentheses directly as part of
2535                 # the statement
2536                 if child.type == token.LPAR:
2537                     # make parentheses invisible
2538                     child.value = ""  # type: ignore
2539                     node.children[-1].value = ""  # type: ignore
2540                 elif child.type != token.STAR:
2541                     # insert invisible parentheses
2542                     node.insert_child(index, Leaf(token.LPAR, ""))
2543                     node.append_child(Leaf(token.RPAR, ""))
2544                 break
2545
2546             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2547                 # wrap child in invisible parentheses
2548                 lpar = Leaf(token.LPAR, "")
2549                 rpar = Leaf(token.RPAR, "")
2550                 index = child.remove() or 0
2551                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2552
2553         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2554
2555
2556 def normalize_fmt_off(node: Node) -> None:
2557     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2558     try_again = True
2559     while try_again:
2560         try_again = convert_one_fmt_off_pair(node)
2561
2562
2563 def convert_one_fmt_off_pair(node: Node) -> bool:
2564     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2565
2566     Returns True if a pair was converted.
2567     """
2568     for leaf in node.leaves():
2569         previous_consumed = 0
2570         for comment in list_comments(leaf.prefix, is_endmarker=False):
2571             if comment.value in FMT_OFF:
2572                 # We only want standalone comments. If there's no previous leaf or
2573                 # the previous leaf is indentation, it's a standalone comment in
2574                 # disguise.
2575                 if comment.type != STANDALONE_COMMENT:
2576                     prev = preceding_leaf(leaf)
2577                     if prev and prev.type not in WHITESPACE:
2578                         continue
2579
2580                 ignored_nodes = list(generate_ignored_nodes(leaf))
2581                 if not ignored_nodes:
2582                     continue
2583
2584                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2585                 parent = first.parent
2586                 prefix = first.prefix
2587                 first.prefix = prefix[comment.consumed :]
2588                 hidden_value = (
2589                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2590                 )
2591                 if hidden_value.endswith("\n"):
2592                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2593                     # leaf (possibly followed by a DEDENT).
2594                     hidden_value = hidden_value[:-1]
2595                 first_idx = None
2596                 for ignored in ignored_nodes:
2597                     index = ignored.remove()
2598                     if first_idx is None:
2599                         first_idx = index
2600                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2601                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2602                 parent.insert_child(
2603                     first_idx,
2604                     Leaf(
2605                         STANDALONE_COMMENT,
2606                         hidden_value,
2607                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2608                     ),
2609                 )
2610                 return True
2611
2612             previous_consumed = comment.consumed
2613
2614     return False
2615
2616
2617 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2618     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2619
2620     Stops at the end of the block.
2621     """
2622     container: Optional[LN] = container_of(leaf)
2623     while container is not None and container.type != token.ENDMARKER:
2624         for comment in list_comments(container.prefix, is_endmarker=False):
2625             if comment.value in FMT_ON:
2626                 return
2627
2628         yield container
2629
2630         container = container.next_sibling
2631
2632
2633 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2634     """If it's safe, make the parens in the atom `node` invisible, recursively."""
2635     if (
2636         node.type != syms.atom
2637         or is_empty_tuple(node)
2638         or is_one_tuple(node)
2639         or is_yield(node)
2640         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2641     ):
2642         return False
2643
2644     first = node.children[0]
2645     last = node.children[-1]
2646     if first.type == token.LPAR and last.type == token.RPAR:
2647         # make parentheses invisible
2648         first.value = ""  # type: ignore
2649         last.value = ""  # type: ignore
2650         if len(node.children) > 1:
2651             maybe_make_parens_invisible_in_atom(node.children[1])
2652         return True
2653
2654     return False
2655
2656
2657 def is_empty_tuple(node: LN) -> bool:
2658     """Return True if `node` holds an empty tuple."""
2659     return (
2660         node.type == syms.atom
2661         and len(node.children) == 2
2662         and node.children[0].type == token.LPAR
2663         and node.children[1].type == token.RPAR
2664     )
2665
2666
2667 def is_one_tuple(node: LN) -> bool:
2668     """Return True if `node` holds a tuple with one element, with or without parens."""
2669     if node.type == syms.atom:
2670         if len(node.children) != 3:
2671             return False
2672
2673         lpar, gexp, rpar = node.children
2674         if not (
2675             lpar.type == token.LPAR
2676             and gexp.type == syms.testlist_gexp
2677             and rpar.type == token.RPAR
2678         ):
2679             return False
2680
2681         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2682
2683     return (
2684         node.type in IMPLICIT_TUPLE
2685         and len(node.children) == 2
2686         and node.children[1].type == token.COMMA
2687     )
2688
2689
2690 def is_yield(node: LN) -> bool:
2691     """Return True if `node` holds a `yield` or `yield from` expression."""
2692     if node.type == syms.yield_expr:
2693         return True
2694
2695     if node.type == token.NAME and node.value == "yield":  # type: ignore
2696         return True
2697
2698     if node.type != syms.atom:
2699         return False
2700
2701     if len(node.children) != 3:
2702         return False
2703
2704     lpar, expr, rpar = node.children
2705     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2706         return is_yield(expr)
2707
2708     return False
2709
2710
2711 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2712     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2713
2714     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2715     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2716     extended iterable unpacking (PEP 3132) and additional unpacking
2717     generalizations (PEP 448).
2718     """
2719     if leaf.type not in STARS or not leaf.parent:
2720         return False
2721
2722     p = leaf.parent
2723     if p.type == syms.star_expr:
2724         # Star expressions are also used as assignment targets in extended
2725         # iterable unpacking (PEP 3132).  See what its parent is instead.
2726         if not p.parent:
2727             return False
2728
2729         p = p.parent
2730
2731     return p.type in within
2732
2733
2734 def is_multiline_string(leaf: Leaf) -> bool:
2735     """Return True if `leaf` is a multiline string that actually spans many lines."""
2736     value = leaf.value.lstrip("furbFURB")
2737     return value[:3] in {'"""', "'''"} and "\n" in value
2738
2739
2740 def is_stub_suite(node: Node) -> bool:
2741     """Return True if `node` is a suite with a stub body."""
2742     if (
2743         len(node.children) != 4
2744         or node.children[0].type != token.NEWLINE
2745         or node.children[1].type != token.INDENT
2746         or node.children[3].type != token.DEDENT
2747     ):
2748         return False
2749
2750     return is_stub_body(node.children[2])
2751
2752
2753 def is_stub_body(node: LN) -> bool:
2754     """Return True if `node` is a simple statement containing an ellipsis."""
2755     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2756         return False
2757
2758     if len(node.children) != 2:
2759         return False
2760
2761     child = node.children[0]
2762     return (
2763         child.type == syms.atom
2764         and len(child.children) == 3
2765         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2766     )
2767
2768
2769 def max_delimiter_priority_in_atom(node: LN) -> int:
2770     """Return maximum delimiter priority inside `node`.
2771
2772     This is specific to atoms with contents contained in a pair of parentheses.
2773     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2774     """
2775     if node.type != syms.atom:
2776         return 0
2777
2778     first = node.children[0]
2779     last = node.children[-1]
2780     if not (first.type == token.LPAR and last.type == token.RPAR):
2781         return 0
2782
2783     bt = BracketTracker()
2784     for c in node.children[1:-1]:
2785         if isinstance(c, Leaf):
2786             bt.mark(c)
2787         else:
2788             for leaf in c.leaves():
2789                 bt.mark(leaf)
2790     try:
2791         return bt.max_delimiter_priority()
2792
2793     except ValueError:
2794         return 0
2795
2796
2797 def ensure_visible(leaf: Leaf) -> None:
2798     """Make sure parentheses are visible.
2799
2800     They could be invisible as part of some statements (see
2801     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2802     """
2803     if leaf.type == token.LPAR:
2804         leaf.value = "("
2805     elif leaf.type == token.RPAR:
2806         leaf.value = ")"
2807
2808
2809 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2810     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2811     if not (
2812         opening_bracket.parent
2813         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2814         and opening_bracket.value in "[{("
2815     ):
2816         return False
2817
2818     try:
2819         last_leaf = line.leaves[-1]
2820         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2821         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2822     except (IndexError, ValueError):
2823         return False
2824
2825     return max_priority == COMMA_PRIORITY
2826
2827
2828 def is_python36(node: Node) -> bool:
2829     """Return True if the current file is using Python 3.6+ features.
2830
2831     Currently looking for:
2832     - f-strings; and
2833     - trailing commas after * or ** in function signatures and calls.
2834     """
2835     for n in node.pre_order():
2836         if n.type == token.STRING:
2837             value_head = n.value[:2]  # type: ignore
2838             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2839                 return True
2840
2841         elif (
2842             n.type in {syms.typedargslist, syms.arglist}
2843             and n.children
2844             and n.children[-1].type == token.COMMA
2845         ):
2846             for ch in n.children:
2847                 if ch.type in STARS:
2848                     return True
2849
2850                 if ch.type == syms.argument:
2851                     for argch in ch.children:
2852                         if argch.type in STARS:
2853                             return True
2854
2855     return False
2856
2857
2858 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2859     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2860
2861     Brackets can be omitted if the entire trailer up to and including
2862     a preceding closing bracket fits in one line.
2863
2864     Yielded sets are cumulative (contain results of previous yields, too).  First
2865     set is empty.
2866     """
2867
2868     omit: Set[LeafID] = set()
2869     yield omit
2870
2871     length = 4 * line.depth
2872     opening_bracket = None
2873     closing_bracket = None
2874     optional_brackets: Set[LeafID] = set()
2875     inner_brackets: Set[LeafID] = set()
2876     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2877         length += leaf_length
2878         if length > line_length:
2879             break
2880
2881         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2882         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2883             break
2884
2885         optional_brackets.discard(id(leaf))
2886         if opening_bracket:
2887             if leaf is opening_bracket:
2888                 opening_bracket = None
2889             elif leaf.type in CLOSING_BRACKETS:
2890                 inner_brackets.add(id(leaf))
2891         elif leaf.type in CLOSING_BRACKETS:
2892             if not leaf.value:
2893                 optional_brackets.add(id(opening_bracket))
2894                 continue
2895
2896             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2897                 # Empty brackets would fail a split so treat them as "inner"
2898                 # brackets (e.g. only add them to the `omit` set if another
2899                 # pair of brackets was good enough.
2900                 inner_brackets.add(id(leaf))
2901                 continue
2902
2903             opening_bracket = leaf.opening_bracket
2904             if closing_bracket:
2905                 omit.add(id(closing_bracket))
2906                 omit.update(inner_brackets)
2907                 inner_brackets.clear()
2908                 yield omit
2909             closing_bracket = leaf
2910
2911
2912 def get_future_imports(node: Node) -> Set[str]:
2913     """Return a set of __future__ imports in the file."""
2914     imports: Set[str] = set()
2915
2916     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
2917         for child in children:
2918             if isinstance(child, Leaf):
2919                 if child.type == token.NAME:
2920                     yield child.value
2921             elif child.type == syms.import_as_name:
2922                 orig_name = child.children[0]
2923                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
2924                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
2925                 yield orig_name.value
2926             elif child.type == syms.import_as_names:
2927                 yield from get_imports_from_children(child.children)
2928             else:
2929                 assert False, "Invalid syntax parsing imports"
2930
2931     for child in node.children:
2932         if child.type != syms.simple_stmt:
2933             break
2934         first_child = child.children[0]
2935         if isinstance(first_child, Leaf):
2936             # Continue looking if we see a docstring; otherwise stop.
2937             if (
2938                 len(child.children) == 2
2939                 and first_child.type == token.STRING
2940                 and child.children[1].type == token.NEWLINE
2941             ):
2942                 continue
2943             else:
2944                 break
2945         elif first_child.type == syms.import_from:
2946             module_name = first_child.children[1]
2947             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2948                 break
2949             imports |= set(get_imports_from_children(first_child.children[3:]))
2950         else:
2951             break
2952     return imports
2953
2954
2955 def gen_python_files_in_dir(
2956     path: Path,
2957     root: Path,
2958     include: Pattern[str],
2959     exclude: Pattern[str],
2960     report: "Report",
2961 ) -> Iterator[Path]:
2962     """Generate all files under `path` whose paths are not excluded by the
2963     `exclude` regex, but are included by the `include` regex.
2964
2965     Symbolic links pointing outside of the `root` directory are ignored.
2966
2967     `report` is where output about exclusions goes.
2968     """
2969     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2970     for child in path.iterdir():
2971         try:
2972             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2973         except ValueError:
2974             if child.is_symlink():
2975                 report.path_ignored(
2976                     child, f"is a symbolic link that points outside {root}"
2977                 )
2978                 continue
2979
2980             raise
2981
2982         if child.is_dir():
2983             normalized_path += "/"
2984         exclude_match = exclude.search(normalized_path)
2985         if exclude_match and exclude_match.group(0):
2986             report.path_ignored(child, f"matches the --exclude regular expression")
2987             continue
2988
2989         if child.is_dir():
2990             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2991
2992         elif child.is_file():
2993             include_match = include.search(normalized_path)
2994             if include_match:
2995                 yield child
2996
2997
2998 @lru_cache()
2999 def find_project_root(srcs: Iterable[str]) -> Path:
3000     """Return a directory containing .git, .hg, or pyproject.toml.
3001
3002     That directory can be one of the directories passed in `srcs` or their
3003     common parent.
3004
3005     If no directory in the tree contains a marker that would specify it's the
3006     project root, the root of the file system is returned.
3007     """
3008     if not srcs:
3009         return Path("/").resolve()
3010
3011     common_base = min(Path(src).resolve() for src in srcs)
3012     if common_base.is_dir():
3013         # Append a fake file so `parents` below returns `common_base_dir`, too.
3014         common_base /= "fake-file"
3015     for directory in common_base.parents:
3016         if (directory / ".git").is_dir():
3017             return directory
3018
3019         if (directory / ".hg").is_dir():
3020             return directory
3021
3022         if (directory / "pyproject.toml").is_file():
3023             return directory
3024
3025     return directory
3026
3027
3028 @dataclass
3029 class Report:
3030     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3031
3032     check: bool = False
3033     quiet: bool = False
3034     verbose: bool = False
3035     change_count: int = 0
3036     same_count: int = 0
3037     failure_count: int = 0
3038
3039     def done(self, src: Path, changed: Changed) -> None:
3040         """Increment the counter for successful reformatting. Write out a message."""
3041         if changed is Changed.YES:
3042             reformatted = "would reformat" if self.check else "reformatted"
3043             if self.verbose or not self.quiet:
3044                 out(f"{reformatted} {src}")
3045             self.change_count += 1
3046         else:
3047             if self.verbose:
3048                 if changed is Changed.NO:
3049                     msg = f"{src} already well formatted, good job."
3050                 else:
3051                     msg = f"{src} wasn't modified on disk since last run."
3052                 out(msg, bold=False)
3053             self.same_count += 1
3054
3055     def failed(self, src: Path, message: str) -> None:
3056         """Increment the counter for failed reformatting. Write out a message."""
3057         err(f"error: cannot format {src}: {message}")
3058         self.failure_count += 1
3059
3060     def path_ignored(self, path: Path, message: str) -> None:
3061         if self.verbose:
3062             out(f"{path} ignored: {message}", bold=False)
3063
3064     @property
3065     def return_code(self) -> int:
3066         """Return the exit code that the app should use.
3067
3068         This considers the current state of changed files and failures:
3069         - if there were any failures, return 123;
3070         - if any files were changed and --check is being used, return 1;
3071         - otherwise return 0.
3072         """
3073         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3074         # 126 we have special returncodes reserved by the shell.
3075         if self.failure_count:
3076             return 123
3077
3078         elif self.change_count and self.check:
3079             return 1
3080
3081         return 0
3082
3083     def __str__(self) -> str:
3084         """Render a color report of the current state.
3085
3086         Use `click.unstyle` to remove colors.
3087         """
3088         if self.check:
3089             reformatted = "would be reformatted"
3090             unchanged = "would be left unchanged"
3091             failed = "would fail to reformat"
3092         else:
3093             reformatted = "reformatted"
3094             unchanged = "left unchanged"
3095             failed = "failed to reformat"
3096         report = []
3097         if self.change_count:
3098             s = "s" if self.change_count > 1 else ""
3099             report.append(
3100                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3101             )
3102         if self.same_count:
3103             s = "s" if self.same_count > 1 else ""
3104             report.append(f"{self.same_count} file{s} {unchanged}")
3105         if self.failure_count:
3106             s = "s" if self.failure_count > 1 else ""
3107             report.append(
3108                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3109             )
3110         return ", ".join(report) + "."
3111
3112
3113 def assert_equivalent(src: str, dst: str) -> None:
3114     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3115
3116     import ast
3117     import traceback
3118
3119     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3120         """Simple visitor generating strings to compare ASTs by content."""
3121         yield f"{'  ' * depth}{node.__class__.__name__}("
3122
3123         for field in sorted(node._fields):
3124             try:
3125                 value = getattr(node, field)
3126             except AttributeError:
3127                 continue
3128
3129             yield f"{'  ' * (depth+1)}{field}="
3130
3131             if isinstance(value, list):
3132                 for item in value:
3133                     if isinstance(item, ast.AST):
3134                         yield from _v(item, depth + 2)
3135
3136             elif isinstance(value, ast.AST):
3137                 yield from _v(value, depth + 2)
3138
3139             else:
3140                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3141
3142         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3143
3144     try:
3145         src_ast = ast.parse(src)
3146     except Exception as exc:
3147         major, minor = sys.version_info[:2]
3148         raise AssertionError(
3149             f"cannot use --safe with this file; failed to parse source file "
3150             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3151             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3152         )
3153
3154     try:
3155         dst_ast = ast.parse(dst)
3156     except Exception as exc:
3157         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3158         raise AssertionError(
3159             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3160             f"Please report a bug on https://github.com/ambv/black/issues.  "
3161             f"This invalid output might be helpful: {log}"
3162         ) from None
3163
3164     src_ast_str = "\n".join(_v(src_ast))
3165     dst_ast_str = "\n".join(_v(dst_ast))
3166     if src_ast_str != dst_ast_str:
3167         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3168         raise AssertionError(
3169             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3170             f"the source.  "
3171             f"Please report a bug on https://github.com/ambv/black/issues.  "
3172             f"This diff might be helpful: {log}"
3173         ) from None
3174
3175
3176 def assert_stable(
3177     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3178 ) -> None:
3179     """Raise AssertionError if `dst` reformats differently the second time."""
3180     newdst = format_str(dst, line_length=line_length, mode=mode)
3181     if dst != newdst:
3182         log = dump_to_file(
3183             diff(src, dst, "source", "first pass"),
3184             diff(dst, newdst, "first pass", "second pass"),
3185         )
3186         raise AssertionError(
3187             f"INTERNAL ERROR: Black produced different code on the second pass "
3188             f"of the formatter.  "
3189             f"Please report a bug on https://github.com/ambv/black/issues.  "
3190             f"This diff might be helpful: {log}"
3191         ) from None
3192
3193
3194 def dump_to_file(*output: str) -> str:
3195     """Dump `output` to a temporary file. Return path to the file."""
3196     import tempfile
3197
3198     with tempfile.NamedTemporaryFile(
3199         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3200     ) as f:
3201         for lines in output:
3202             f.write(lines)
3203             if lines and lines[-1] != "\n":
3204                 f.write("\n")
3205     return f.name
3206
3207
3208 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3209     """Return a unified diff string between strings `a` and `b`."""
3210     import difflib
3211
3212     a_lines = [line + "\n" for line in a.split("\n")]
3213     b_lines = [line + "\n" for line in b.split("\n")]
3214     return "".join(
3215         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3216     )
3217
3218
3219 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3220     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3221     err("Aborted!")
3222     for task in tasks:
3223         task.cancel()
3224
3225
3226 def shutdown(loop: BaseEventLoop) -> None:
3227     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3228     try:
3229         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3230         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3231         if not to_cancel:
3232             return
3233
3234         for task in to_cancel:
3235             task.cancel()
3236         loop.run_until_complete(
3237             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3238         )
3239     finally:
3240         # `concurrent.futures.Future` objects cannot be cancelled once they
3241         # are already running. There might be some when the `shutdown()` happened.
3242         # Silence their logger's spew about the event loop being closed.
3243         cf_logger = logging.getLogger("concurrent.futures")
3244         cf_logger.setLevel(logging.CRITICAL)
3245         loop.close()
3246
3247
3248 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3249     """Replace `regex` with `replacement` twice on `original`.
3250
3251     This is used by string normalization to perform replaces on
3252     overlapping matches.
3253     """
3254     return regex.sub(replacement, regex.sub(replacement, original))
3255
3256
3257 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3258     """Compile a regular expression string in `regex`.
3259
3260     If it contains newlines, use verbose mode.
3261     """
3262     if "\n" in regex:
3263         regex = "(?x)" + regex
3264     return re.compile(regex)
3265
3266
3267 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3268     """Like `reversed(enumerate(sequence))` if that were possible."""
3269     index = len(sequence) - 1
3270     for element in reversed(sequence):
3271         yield (index, element)
3272         index -= 1
3273
3274
3275 def enumerate_with_length(
3276     line: Line, reversed: bool = False
3277 ) -> Iterator[Tuple[Index, Leaf, int]]:
3278     """Return an enumeration of leaves with their length.
3279
3280     Stops prematurely on multiline strings and standalone comments.
3281     """
3282     op = cast(
3283         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3284         enumerate_reversed if reversed else enumerate,
3285     )
3286     for index, leaf in op(line.leaves):
3287         length = len(leaf.prefix) + len(leaf.value)
3288         if "\n" in leaf.value:
3289             return  # Multiline strings, we can't continue.
3290
3291         comment: Optional[Leaf]
3292         for comment in line.comments_after(leaf, index):
3293             length += len(comment.value)
3294
3295         yield index, leaf, length
3296
3297
3298 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3299     """Return True if `line` is no longer than `line_length`.
3300
3301     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3302     """
3303     if not line_str:
3304         line_str = str(line).strip("\n")
3305     return (
3306         len(line_str) <= line_length
3307         and "\n" not in line_str  # multiline strings
3308         and not line.contains_standalone_comments()
3309     )
3310
3311
3312 def can_be_split(line: Line) -> bool:
3313     """Return False if the line cannot be split *for sure*.
3314
3315     This is not an exhaustive search but a cheap heuristic that we can use to
3316     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3317     in unnecessary parentheses).
3318     """
3319     leaves = line.leaves
3320     if len(leaves) < 2:
3321         return False
3322
3323     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3324         call_count = 0
3325         dot_count = 0
3326         next = leaves[-1]
3327         for leaf in leaves[-2::-1]:
3328             if leaf.type in OPENING_BRACKETS:
3329                 if next.type not in CLOSING_BRACKETS:
3330                     return False
3331
3332                 call_count += 1
3333             elif leaf.type == token.DOT:
3334                 dot_count += 1
3335             elif leaf.type == token.NAME:
3336                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3337                     return False
3338
3339             elif leaf.type not in CLOSING_BRACKETS:
3340                 return False
3341
3342             if dot_count > 1 and call_count > 1:
3343                 return False
3344
3345     return True
3346
3347
3348 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3349     """Does `line` have a shape safe to reformat without optional parens around it?
3350
3351     Returns True for only a subset of potentially nice looking formattings but
3352     the point is to not return false positives that end up producing lines that
3353     are too long.
3354     """
3355     bt = line.bracket_tracker
3356     if not bt.delimiters:
3357         # Without delimiters the optional parentheses are useless.
3358         return True
3359
3360     max_priority = bt.max_delimiter_priority()
3361     if bt.delimiter_count_with_priority(max_priority) > 1:
3362         # With more than one delimiter of a kind the optional parentheses read better.
3363         return False
3364
3365     if max_priority == DOT_PRIORITY:
3366         # A single stranded method call doesn't require optional parentheses.
3367         return True
3368
3369     assert len(line.leaves) >= 2, "Stranded delimiter"
3370
3371     first = line.leaves[0]
3372     second = line.leaves[1]
3373     penultimate = line.leaves[-2]
3374     last = line.leaves[-1]
3375
3376     # With a single delimiter, omit if the expression starts or ends with
3377     # a bracket.
3378     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3379         remainder = False
3380         length = 4 * line.depth
3381         for _index, leaf, leaf_length in enumerate_with_length(line):
3382             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3383                 remainder = True
3384             if remainder:
3385                 length += leaf_length
3386                 if length > line_length:
3387                     break
3388
3389                 if leaf.type in OPENING_BRACKETS:
3390                     # There are brackets we can further split on.
3391                     remainder = False
3392
3393         else:
3394             # checked the entire string and line length wasn't exceeded
3395             if len(line.leaves) == _index + 1:
3396                 return True
3397
3398         # Note: we are not returning False here because a line might have *both*
3399         # a leading opening bracket and a trailing closing bracket.  If the
3400         # opening bracket doesn't match our rule, maybe the closing will.
3401
3402     if (
3403         last.type == token.RPAR
3404         or last.type == token.RBRACE
3405         or (
3406             # don't use indexing for omitting optional parentheses;
3407             # it looks weird
3408             last.type == token.RSQB
3409             and last.parent
3410             and last.parent.type != syms.trailer
3411         )
3412     ):
3413         if penultimate.type in OPENING_BRACKETS:
3414             # Empty brackets don't help.
3415             return False
3416
3417         if is_multiline_string(first):
3418             # Additional wrapping of a multiline string in this situation is
3419             # unnecessary.
3420             return True
3421
3422         length = 4 * line.depth
3423         seen_other_brackets = False
3424         for _index, leaf, leaf_length in enumerate_with_length(line):
3425             length += leaf_length
3426             if leaf is last.opening_bracket:
3427                 if seen_other_brackets or length <= line_length:
3428                     return True
3429
3430             elif leaf.type in OPENING_BRACKETS:
3431                 # There are brackets we can further split on.
3432                 seen_other_brackets = True
3433
3434     return False
3435
3436
3437 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3438     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3439
3440
3441 def read_cache(line_length: int, mode: FileMode) -> Cache:
3442     """Read the cache if it exists and is well formed.
3443
3444     If it is not well formed, the call to write_cache later should resolve the issue.
3445     """
3446     cache_file = get_cache_file(line_length, mode)
3447     if not cache_file.exists():
3448         return {}
3449
3450     with cache_file.open("rb") as fobj:
3451         try:
3452             cache: Cache = pickle.load(fobj)
3453         except pickle.UnpicklingError:
3454             return {}
3455
3456     return cache
3457
3458
3459 def get_cache_info(path: Path) -> CacheInfo:
3460     """Return the information used to check if a file is already formatted or not."""
3461     stat = path.stat()
3462     return stat.st_mtime, stat.st_size
3463
3464
3465 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3466     """Split an iterable of paths in `sources` into two sets.
3467
3468     The first contains paths of files that modified on disk or are not in the
3469     cache. The other contains paths to non-modified files.
3470     """
3471     todo, done = set(), set()
3472     for src in sources:
3473         src = src.resolve()
3474         if cache.get(src) != get_cache_info(src):
3475             todo.add(src)
3476         else:
3477             done.add(src)
3478     return todo, done
3479
3480
3481 def write_cache(
3482     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3483 ) -> None:
3484     """Update the cache file."""
3485     cache_file = get_cache_file(line_length, mode)
3486     try:
3487         if not CACHE_DIR.exists():
3488             CACHE_DIR.mkdir(parents=True)
3489         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3490         with cache_file.open("wb") as fobj:
3491             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3492     except OSError:
3493         pass
3494
3495
3496 def patch_click() -> None:
3497     """Make Click not crash.
3498
3499     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3500     default which restricts paths that it can access during the lifetime of the
3501     application.  Click refuses to work in this scenario by raising a RuntimeError.
3502
3503     In case of Black the likelihood that non-ASCII characters are going to be used in
3504     file paths is minimal since it's Python source code.  Moreover, this crash was
3505     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3506     """
3507     try:
3508         from click import core
3509         from click import _unicodefun  # type: ignore
3510     except ModuleNotFoundError:
3511         return
3512
3513     for module in (core, _unicodefun):
3514         if hasattr(module, "_verify_python3_env"):
3515             module._verify_python3_env = lambda: None
3516
3517
3518 if __name__ == "__main__":
3519     patch_click()
3520     main()