]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

b9eca0adbaf29bc170e77f595b5c905945819e82
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Sequence,
30     Set,
31     Tuple,
32     TypeVar,
33     Union,
34     cast,
35 )
36
37 from appdirs import user_cache_dir
38 from attr import dataclass, Factory
39 import click
40 import toml
41
42 # lib2to3 fork
43 from blib2to3.pytree import Node, Leaf, type_repr
44 from blib2to3 import pygram, pytree
45 from blib2to3.pgen2 import driver, token
46 from blib2to3.pgen2.parse import ParseError
47
48
49 __version__ = "18.6b2"
50 DEFAULT_LINE_LENGTH = 88
51 DEFAULT_EXCLUDES = (
52     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
53 )
54 DEFAULT_INCLUDES = r"\.pyi?$"
55 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
56
57
58 # types
59 FileContent = str
60 Encoding = str
61 NewLine = str
62 Depth = int
63 NodeType = int
64 LeafID = int
65 Priority = int
66 Index = int
67 LN = Union[Leaf, Node]
68 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
69 Timestamp = float
70 FileSize = int
71 CacheInfo = Tuple[Timestamp, FileSize]
72 Cache = Dict[Path, CacheInfo]
73 out = partial(click.secho, bold=True, err=True)
74 err = partial(click.secho, fg="red", err=True)
75
76 pygram.initialize(CACHE_DIR)
77 syms = pygram.python_symbols
78
79
80 class NothingChanged(UserWarning):
81     """Raised by :func:`format_file` when reformatted code is the same as source."""
82
83
84 class CannotSplit(Exception):
85     """A readable split that fits the allotted line length is impossible.
86
87     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
88     :func:`delimiter_split`.
89     """
90
91
92 class WriteBack(Enum):
93     NO = 0
94     YES = 1
95     DIFF = 2
96
97     @classmethod
98     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
99         if check and not diff:
100             return cls.NO
101
102         return cls.DIFF if diff else cls.YES
103
104
105 class Changed(Enum):
106     NO = 0
107     CACHED = 1
108     YES = 2
109
110
111 class FileMode(Flag):
112     AUTO_DETECT = 0
113     PYTHON36 = 1
114     PYI = 2
115     NO_STRING_NORMALIZATION = 4
116
117     @classmethod
118     def from_configuration(
119         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
120     ) -> "FileMode":
121         mode = cls.AUTO_DETECT
122         if py36:
123             mode |= cls.PYTHON36
124         if pyi:
125             mode |= cls.PYI
126         if skip_string_normalization:
127             mode |= cls.NO_STRING_NORMALIZATION
128         return mode
129
130
131 def read_pyproject_toml(
132     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
133 ) -> Optional[str]:
134     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
135
136     Returns the path to a successfully found and read configuration file, None
137     otherwise.
138     """
139     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
140     if not value:
141         root = find_project_root(ctx.params.get("src", ()))
142         path = root / "pyproject.toml"
143         if path.is_file():
144             value = str(path)
145         else:
146             return None
147
148     try:
149         pyproject_toml = toml.load(value)
150         config = pyproject_toml.get("tool", {}).get("black", {})
151     except (toml.TomlDecodeError, OSError) as e:
152         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
153
154     if not config:
155         return None
156
157     if ctx.default_map is None:
158         ctx.default_map = {}
159     ctx.default_map.update(  # type: ignore  # bad types in .pyi
160         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
161     )
162     return value
163
164
165 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
166 @click.option(
167     "-l",
168     "--line-length",
169     type=int,
170     default=DEFAULT_LINE_LENGTH,
171     help="How many character per line to allow.",
172     show_default=True,
173 )
174 @click.option(
175     "--py36",
176     is_flag=True,
177     help=(
178         "Allow using Python 3.6-only syntax on all input files.  This will put "
179         "trailing commas in function signatures and calls also after *args and "
180         "**kwargs.  [default: per-file auto-detection]"
181     ),
182 )
183 @click.option(
184     "--pyi",
185     is_flag=True,
186     help=(
187         "Format all input files like typing stubs regardless of file extension "
188         "(useful when piping source on standard input)."
189     ),
190 )
191 @click.option(
192     "-S",
193     "--skip-string-normalization",
194     is_flag=True,
195     help="Don't normalize string quotes or prefixes.",
196 )
197 @click.option(
198     "--check",
199     is_flag=True,
200     help=(
201         "Don't write the files back, just return the status.  Return code 0 "
202         "means nothing would change.  Return code 1 means some files would be "
203         "reformatted.  Return code 123 means there was an internal error."
204     ),
205 )
206 @click.option(
207     "--diff",
208     is_flag=True,
209     help="Don't write the files back, just output a diff for each file on stdout.",
210 )
211 @click.option(
212     "--fast/--safe",
213     is_flag=True,
214     help="If --fast given, skip temporary sanity checks. [default: --safe]",
215 )
216 @click.option(
217     "--include",
218     type=str,
219     default=DEFAULT_INCLUDES,
220     help=(
221         "A regular expression that matches files and directories that should be "
222         "included on recursive searches.  An empty value means all files are "
223         "included regardless of the name.  Use forward slashes for directories on "
224         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
225         "later."
226     ),
227     show_default=True,
228 )
229 @click.option(
230     "--exclude",
231     type=str,
232     default=DEFAULT_EXCLUDES,
233     help=(
234         "A regular expression that matches files and directories that should be "
235         "excluded on recursive searches.  An empty value means no paths are excluded. "
236         "Use forward slashes for directories on all platforms (Windows, too).  "
237         "Exclusions are calculated first, inclusions later."
238     ),
239     show_default=True,
240 )
241 @click.option(
242     "-q",
243     "--quiet",
244     is_flag=True,
245     help=(
246         "Don't emit non-error messages to stderr. Errors are still emitted, "
247         "silence those with 2>/dev/null."
248     ),
249 )
250 @click.option(
251     "-v",
252     "--verbose",
253     is_flag=True,
254     help=(
255         "Also emit messages to stderr about files that were not changed or were "
256         "ignored due to --exclude=."
257     ),
258 )
259 @click.version_option(version=__version__)
260 @click.argument(
261     "src",
262     nargs=-1,
263     type=click.Path(
264         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
265     ),
266     is_eager=True,
267 )
268 @click.option(
269     "--config",
270     type=click.Path(
271         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
272     ),
273     is_eager=True,
274     callback=read_pyproject_toml,
275     help="Read configuration from PATH.",
276 )
277 @click.pass_context
278 def main(
279     ctx: click.Context,
280     line_length: int,
281     check: bool,
282     diff: bool,
283     fast: bool,
284     pyi: bool,
285     py36: bool,
286     skip_string_normalization: bool,
287     quiet: bool,
288     verbose: bool,
289     include: str,
290     exclude: str,
291     src: Tuple[str],
292     config: Optional[str],
293 ) -> None:
294     """The uncompromising code formatter."""
295     write_back = WriteBack.from_configuration(check=check, diff=diff)
296     mode = FileMode.from_configuration(
297         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
298     )
299     if config and verbose:
300         out(f"Using configuration from {config}.", bold=False, fg="blue")
301     try:
302         include_regex = re_compile_maybe_verbose(include)
303     except re.error:
304         err(f"Invalid regular expression for include given: {include!r}")
305         ctx.exit(2)
306     try:
307         exclude_regex = re_compile_maybe_verbose(exclude)
308     except re.error:
309         err(f"Invalid regular expression for exclude given: {exclude!r}")
310         ctx.exit(2)
311     report = Report(check=check, quiet=quiet, verbose=verbose)
312     root = find_project_root(src)
313     sources: Set[Path] = set()
314     for s in src:
315         p = Path(s)
316         if p.is_dir():
317             sources.update(
318                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
319             )
320         elif p.is_file() or s == "-":
321             # if a file was explicitly given, we don't care about its extension
322             sources.add(p)
323         else:
324             err(f"invalid path: {s}")
325     if len(sources) == 0:
326         if verbose or not quiet:
327             out("No paths given. Nothing to do 😴")
328         ctx.exit(0)
329
330     if len(sources) == 1:
331         reformat_one(
332             src=sources.pop(),
333             line_length=line_length,
334             fast=fast,
335             write_back=write_back,
336             mode=mode,
337             report=report,
338         )
339     else:
340         loop = asyncio.get_event_loop()
341         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
342         try:
343             loop.run_until_complete(
344                 schedule_formatting(
345                     sources=sources,
346                     line_length=line_length,
347                     fast=fast,
348                     write_back=write_back,
349                     mode=mode,
350                     report=report,
351                     loop=loop,
352                     executor=executor,
353                 )
354             )
355         finally:
356             shutdown(loop)
357     if verbose or not quiet:
358         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
359         out(f"All done! {bang}")
360         click.secho(str(report), err=True)
361     ctx.exit(report.return_code)
362
363
364 def reformat_one(
365     src: Path,
366     line_length: int,
367     fast: bool,
368     write_back: WriteBack,
369     mode: FileMode,
370     report: "Report",
371 ) -> None:
372     """Reformat a single file under `src` without spawning child processes.
373
374     If `quiet` is True, non-error messages are not output. `line_length`,
375     `write_back`, `fast` and `pyi` options are passed to
376     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
377     """
378     try:
379         changed = Changed.NO
380         if not src.is_file() and str(src) == "-":
381             if format_stdin_to_stdout(
382                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
383             ):
384                 changed = Changed.YES
385         else:
386             cache: Cache = {}
387             if write_back != WriteBack.DIFF:
388                 cache = read_cache(line_length, mode)
389                 res_src = src.resolve()
390                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
391                     changed = Changed.CACHED
392             if changed is not Changed.CACHED and format_file_in_place(
393                 src,
394                 line_length=line_length,
395                 fast=fast,
396                 write_back=write_back,
397                 mode=mode,
398             ):
399                 changed = Changed.YES
400             if write_back == WriteBack.YES and changed is not Changed.NO:
401                 write_cache(cache, [src], line_length, mode)
402         report.done(src, changed)
403     except Exception as exc:
404         report.failed(src, str(exc))
405
406
407 async def schedule_formatting(
408     sources: Set[Path],
409     line_length: int,
410     fast: bool,
411     write_back: WriteBack,
412     mode: FileMode,
413     report: "Report",
414     loop: BaseEventLoop,
415     executor: Executor,
416 ) -> None:
417     """Run formatting of `sources` in parallel using the provided `executor`.
418
419     (Use ProcessPoolExecutors for actual parallelism.)
420
421     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
422     :func:`format_file_in_place`.
423     """
424     cache: Cache = {}
425     if write_back != WriteBack.DIFF:
426         cache = read_cache(line_length, mode)
427         sources, cached = filter_cached(cache, sources)
428         for src in sorted(cached):
429             report.done(src, Changed.CACHED)
430     cancelled = []
431     formatted = []
432     if sources:
433         lock = None
434         if write_back == WriteBack.DIFF:
435             # For diff output, we need locks to ensure we don't interleave output
436             # from different processes.
437             manager = Manager()
438             lock = manager.Lock()
439         tasks = {
440             loop.run_in_executor(
441                 executor,
442                 format_file_in_place,
443                 src,
444                 line_length,
445                 fast,
446                 write_back,
447                 mode,
448                 lock,
449             ): src
450             for src in sorted(sources)
451         }
452         pending: Iterable[asyncio.Task] = tasks.keys()
453         try:
454             loop.add_signal_handler(signal.SIGINT, cancel, pending)
455             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
456         except NotImplementedError:
457             # There are no good alternatives for these on Windows
458             pass
459         while pending:
460             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
461             for task in done:
462                 src = tasks.pop(task)
463                 if task.cancelled():
464                     cancelled.append(task)
465                 elif task.exception():
466                     report.failed(src, str(task.exception()))
467                 else:
468                     formatted.append(src)
469                     report.done(src, Changed.YES if task.result() else Changed.NO)
470     if cancelled:
471         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
472     if write_back == WriteBack.YES and formatted:
473         write_cache(cache, formatted, line_length, mode)
474
475
476 def format_file_in_place(
477     src: Path,
478     line_length: int,
479     fast: bool,
480     write_back: WriteBack = WriteBack.NO,
481     mode: FileMode = FileMode.AUTO_DETECT,
482     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
483 ) -> bool:
484     """Format file under `src` path. Return True if changed.
485
486     If `write_back` is True, write reformatted code back to stdout.
487     `line_length` and `fast` options are passed to :func:`format_file_contents`.
488     """
489     if src.suffix == ".pyi":
490         mode |= FileMode.PYI
491
492     then = datetime.utcfromtimestamp(src.stat().st_mtime)
493     with open(src, "rb") as buf:
494         src_contents, encoding, newline = decode_bytes(buf.read())
495     try:
496         dst_contents = format_file_contents(
497             src_contents, line_length=line_length, fast=fast, mode=mode
498         )
499     except NothingChanged:
500         return False
501
502     if write_back == write_back.YES:
503         with open(src, "w", encoding=encoding, newline=newline) as f:
504             f.write(dst_contents)
505     elif write_back == write_back.DIFF:
506         now = datetime.utcnow()
507         src_name = f"{src}\t{then} +0000"
508         dst_name = f"{src}\t{now} +0000"
509         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
510         if lock:
511             lock.acquire()
512         try:
513             f = io.TextIOWrapper(
514                 sys.stdout.buffer,
515                 encoding=encoding,
516                 newline=newline,
517                 write_through=True,
518             )
519             f.write(diff_contents)
520             f.detach()
521         finally:
522             if lock:
523                 lock.release()
524     return True
525
526
527 def format_stdin_to_stdout(
528     line_length: int,
529     fast: bool,
530     write_back: WriteBack = WriteBack.NO,
531     mode: FileMode = FileMode.AUTO_DETECT,
532 ) -> bool:
533     """Format file on stdin. Return True if changed.
534
535     If `write_back` is True, write reformatted code back to stdout.
536     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
537     :func:`format_file_contents`.
538     """
539     then = datetime.utcnow()
540     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
541     dst = src
542     try:
543         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
544         return True
545
546     except NothingChanged:
547         return False
548
549     finally:
550         f = io.TextIOWrapper(
551             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
552         )
553         if write_back == WriteBack.YES:
554             f.write(dst)
555         elif write_back == WriteBack.DIFF:
556             now = datetime.utcnow()
557             src_name = f"STDIN\t{then} +0000"
558             dst_name = f"STDOUT\t{now} +0000"
559             f.write(diff(src, dst, src_name, dst_name))
560         f.detach()
561
562
563 def format_file_contents(
564     src_contents: str,
565     *,
566     line_length: int,
567     fast: bool,
568     mode: FileMode = FileMode.AUTO_DETECT,
569 ) -> FileContent:
570     """Reformat contents a file and return new contents.
571
572     If `fast` is False, additionally confirm that the reformatted code is
573     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
574     `line_length` is passed to :func:`format_str`.
575     """
576     if src_contents.strip() == "":
577         raise NothingChanged
578
579     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
580     if src_contents == dst_contents:
581         raise NothingChanged
582
583     if not fast:
584         assert_equivalent(src_contents, dst_contents)
585         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
586     return dst_contents
587
588
589 def format_str(
590     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
591 ) -> FileContent:
592     """Reformat a string and return new contents.
593
594     `line_length` determines how many characters per line are allowed.
595     """
596     src_node = lib2to3_parse(src_contents)
597     dst_contents = ""
598     future_imports = get_future_imports(src_node)
599     is_pyi = bool(mode & FileMode.PYI)
600     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
601     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
602     normalize_fmt_off(src_node)
603     lines = LineGenerator(
604         remove_u_prefix=py36 or "unicode_literals" in future_imports,
605         is_pyi=is_pyi,
606         normalize_strings=normalize_strings,
607     )
608     elt = EmptyLineTracker(is_pyi=is_pyi)
609     empty_line = Line()
610     after = 0
611     for current_line in lines.visit(src_node):
612         for _ in range(after):
613             dst_contents += str(empty_line)
614         before, after = elt.maybe_empty_lines(current_line)
615         for _ in range(before):
616             dst_contents += str(empty_line)
617         for line in split_line(current_line, line_length=line_length, py36=py36):
618             dst_contents += str(line)
619     return dst_contents
620
621
622 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
623     """Return a tuple of (decoded_contents, encoding, newline).
624
625     `newline` is either CRLF or LF but `decoded_contents` is decoded with
626     universal newlines (i.e. only contains LF).
627     """
628     srcbuf = io.BytesIO(src)
629     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
630     if not lines:
631         return "", encoding, "\n"
632
633     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
634     srcbuf.seek(0)
635     with io.TextIOWrapper(srcbuf, encoding) as tiow:
636         return tiow.read(), encoding, newline
637
638
639 GRAMMARS = [
640     pygram.python_grammar_no_print_statement_no_exec_statement,
641     pygram.python_grammar_no_print_statement,
642     pygram.python_grammar,
643 ]
644
645
646 def lib2to3_parse(src_txt: str) -> Node:
647     """Given a string with source, return the lib2to3 Node."""
648     grammar = pygram.python_grammar_no_print_statement
649     if src_txt[-1:] != "\n":
650         src_txt += "\n"
651     for grammar in GRAMMARS:
652         drv = driver.Driver(grammar, pytree.convert)
653         try:
654             result = drv.parse_string(src_txt, True)
655             break
656
657         except ParseError as pe:
658             lineno, column = pe.context[1]
659             lines = src_txt.splitlines()
660             try:
661                 faulty_line = lines[lineno - 1]
662             except IndexError:
663                 faulty_line = "<line number missing in source>"
664             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
665     else:
666         raise exc from None
667
668     if isinstance(result, Leaf):
669         result = Node(syms.file_input, [result])
670     return result
671
672
673 def lib2to3_unparse(node: Node) -> str:
674     """Given a lib2to3 node, return its string representation."""
675     code = str(node)
676     return code
677
678
679 T = TypeVar("T")
680
681
682 class Visitor(Generic[T]):
683     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
684
685     def visit(self, node: LN) -> Iterator[T]:
686         """Main method to visit `node` and its children.
687
688         It tries to find a `visit_*()` method for the given `node.type`, like
689         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
690         If no dedicated `visit_*()` method is found, chooses `visit_default()`
691         instead.
692
693         Then yields objects of type `T` from the selected visitor.
694         """
695         if node.type < 256:
696             name = token.tok_name[node.type]
697         else:
698             name = type_repr(node.type)
699         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
700
701     def visit_default(self, node: LN) -> Iterator[T]:
702         """Default `visit_*()` implementation. Recurses to children of `node`."""
703         if isinstance(node, Node):
704             for child in node.children:
705                 yield from self.visit(child)
706
707
708 @dataclass
709 class DebugVisitor(Visitor[T]):
710     tree_depth: int = 0
711
712     def visit_default(self, node: LN) -> Iterator[T]:
713         indent = " " * (2 * self.tree_depth)
714         if isinstance(node, Node):
715             _type = type_repr(node.type)
716             out(f"{indent}{_type}", fg="yellow")
717             self.tree_depth += 1
718             for child in node.children:
719                 yield from self.visit(child)
720
721             self.tree_depth -= 1
722             out(f"{indent}/{_type}", fg="yellow", bold=False)
723         else:
724             _type = token.tok_name.get(node.type, str(node.type))
725             out(f"{indent}{_type}", fg="blue", nl=False)
726             if node.prefix:
727                 # We don't have to handle prefixes for `Node` objects since
728                 # that delegates to the first child anyway.
729                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
730             out(f" {node.value!r}", fg="blue", bold=False)
731
732     @classmethod
733     def show(cls, code: Union[str, Leaf, Node]) -> None:
734         """Pretty-print the lib2to3 AST of a given string of `code`.
735
736         Convenience method for debugging.
737         """
738         v: DebugVisitor[None] = DebugVisitor()
739         if isinstance(code, str):
740             code = lib2to3_parse(code)
741         list(v.visit(code))
742
743
744 KEYWORDS = set(keyword.kwlist)
745 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
746 FLOW_CONTROL = {"return", "raise", "break", "continue"}
747 STATEMENT = {
748     syms.if_stmt,
749     syms.while_stmt,
750     syms.for_stmt,
751     syms.try_stmt,
752     syms.except_clause,
753     syms.with_stmt,
754     syms.funcdef,
755     syms.classdef,
756 }
757 STANDALONE_COMMENT = 153
758 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
759 LOGIC_OPERATORS = {"and", "or"}
760 COMPARATORS = {
761     token.LESS,
762     token.GREATER,
763     token.EQEQUAL,
764     token.NOTEQUAL,
765     token.LESSEQUAL,
766     token.GREATEREQUAL,
767 }
768 MATH_OPERATORS = {
769     token.VBAR,
770     token.CIRCUMFLEX,
771     token.AMPER,
772     token.LEFTSHIFT,
773     token.RIGHTSHIFT,
774     token.PLUS,
775     token.MINUS,
776     token.STAR,
777     token.SLASH,
778     token.DOUBLESLASH,
779     token.PERCENT,
780     token.AT,
781     token.TILDE,
782     token.DOUBLESTAR,
783 }
784 STARS = {token.STAR, token.DOUBLESTAR}
785 VARARGS_PARENTS = {
786     syms.arglist,
787     syms.argument,  # double star in arglist
788     syms.trailer,  # single argument to call
789     syms.typedargslist,
790     syms.varargslist,  # lambdas
791 }
792 UNPACKING_PARENTS = {
793     syms.atom,  # single element of a list or set literal
794     syms.dictsetmaker,
795     syms.listmaker,
796     syms.testlist_gexp,
797     syms.testlist_star_expr,
798 }
799 SURROUNDED_BY_BRACKETS = {
800     syms.typedargslist,
801     syms.arglist,
802     syms.subscriptlist,
803     syms.vfplist,
804     syms.import_as_names,
805     syms.yield_expr,
806     syms.testlist_gexp,
807     syms.testlist_star_expr,
808     syms.listmaker,
809     syms.dictsetmaker,
810 }
811 TEST_DESCENDANTS = {
812     syms.test,
813     syms.lambdef,
814     syms.or_test,
815     syms.and_test,
816     syms.not_test,
817     syms.comparison,
818     syms.star_expr,
819     syms.expr,
820     syms.xor_expr,
821     syms.and_expr,
822     syms.shift_expr,
823     syms.arith_expr,
824     syms.trailer,
825     syms.term,
826     syms.power,
827 }
828 ASSIGNMENTS = {
829     "=",
830     "+=",
831     "-=",
832     "*=",
833     "@=",
834     "/=",
835     "%=",
836     "&=",
837     "|=",
838     "^=",
839     "<<=",
840     ">>=",
841     "**=",
842     "//=",
843 }
844 COMPREHENSION_PRIORITY = 20
845 COMMA_PRIORITY = 18
846 TERNARY_PRIORITY = 16
847 LOGIC_PRIORITY = 14
848 STRING_PRIORITY = 12
849 COMPARATOR_PRIORITY = 10
850 MATH_PRIORITIES = {
851     token.VBAR: 9,
852     token.CIRCUMFLEX: 8,
853     token.AMPER: 7,
854     token.LEFTSHIFT: 6,
855     token.RIGHTSHIFT: 6,
856     token.PLUS: 5,
857     token.MINUS: 5,
858     token.STAR: 4,
859     token.SLASH: 4,
860     token.DOUBLESLASH: 4,
861     token.PERCENT: 4,
862     token.AT: 4,
863     token.TILDE: 3,
864     token.DOUBLESTAR: 2,
865 }
866 DOT_PRIORITY = 1
867
868
869 @dataclass
870 class BracketTracker:
871     """Keeps track of brackets on a line."""
872
873     depth: int = 0
874     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
875     delimiters: Dict[LeafID, Priority] = Factory(dict)
876     previous: Optional[Leaf] = None
877     _for_loop_variable: int = 0
878     _lambda_arguments: int = 0
879
880     def mark(self, leaf: Leaf) -> None:
881         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
882
883         All leaves receive an int `bracket_depth` field that stores how deep
884         within brackets a given leaf is. 0 means there are no enclosing brackets
885         that started on this line.
886
887         If a leaf is itself a closing bracket, it receives an `opening_bracket`
888         field that it forms a pair with. This is a one-directional link to
889         avoid reference cycles.
890
891         If a leaf is a delimiter (a token on which Black can split the line if
892         needed) and it's on depth 0, its `id()` is stored in the tracker's
893         `delimiters` field.
894         """
895         if leaf.type == token.COMMENT:
896             return
897
898         self.maybe_decrement_after_for_loop_variable(leaf)
899         self.maybe_decrement_after_lambda_arguments(leaf)
900         if leaf.type in CLOSING_BRACKETS:
901             self.depth -= 1
902             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
903             leaf.opening_bracket = opening_bracket
904         leaf.bracket_depth = self.depth
905         if self.depth == 0:
906             delim = is_split_before_delimiter(leaf, self.previous)
907             if delim and self.previous is not None:
908                 self.delimiters[id(self.previous)] = delim
909             else:
910                 delim = is_split_after_delimiter(leaf, self.previous)
911                 if delim:
912                     self.delimiters[id(leaf)] = delim
913         if leaf.type in OPENING_BRACKETS:
914             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
915             self.depth += 1
916         self.previous = leaf
917         self.maybe_increment_lambda_arguments(leaf)
918         self.maybe_increment_for_loop_variable(leaf)
919
920     def any_open_brackets(self) -> bool:
921         """Return True if there is an yet unmatched open bracket on the line."""
922         return bool(self.bracket_match)
923
924     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
925         """Return the highest priority of a delimiter found on the line.
926
927         Values are consistent with what `is_split_*_delimiter()` return.
928         Raises ValueError on no delimiters.
929         """
930         return max(v for k, v in self.delimiters.items() if k not in exclude)
931
932     def delimiter_count_with_priority(self, priority: int = 0) -> int:
933         """Return the number of delimiters with the given `priority`.
934
935         If no `priority` is passed, defaults to max priority on the line.
936         """
937         if not self.delimiters:
938             return 0
939
940         priority = priority or self.max_delimiter_priority()
941         return sum(1 for p in self.delimiters.values() if p == priority)
942
943     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
944         """In a for loop, or comprehension, the variables are often unpacks.
945
946         To avoid splitting on the comma in this situation, increase the depth of
947         tokens between `for` and `in`.
948         """
949         if leaf.type == token.NAME and leaf.value == "for":
950             self.depth += 1
951             self._for_loop_variable += 1
952             return True
953
954         return False
955
956     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
957         """See `maybe_increment_for_loop_variable` above for explanation."""
958         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
959             self.depth -= 1
960             self._for_loop_variable -= 1
961             return True
962
963         return False
964
965     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
966         """In a lambda expression, there might be more than one argument.
967
968         To avoid splitting on the comma in this situation, increase the depth of
969         tokens between `lambda` and `:`.
970         """
971         if leaf.type == token.NAME and leaf.value == "lambda":
972             self.depth += 1
973             self._lambda_arguments += 1
974             return True
975
976         return False
977
978     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
979         """See `maybe_increment_lambda_arguments` above for explanation."""
980         if self._lambda_arguments and leaf.type == token.COLON:
981             self.depth -= 1
982             self._lambda_arguments -= 1
983             return True
984
985         return False
986
987     def get_open_lsqb(self) -> Optional[Leaf]:
988         """Return the most recent opening square bracket (if any)."""
989         return self.bracket_match.get((self.depth - 1, token.RSQB))
990
991
992 @dataclass
993 class Line:
994     """Holds leaves and comments. Can be printed with `str(line)`."""
995
996     depth: int = 0
997     leaves: List[Leaf] = Factory(list)
998     comments: List[Tuple[Index, Leaf]] = Factory(list)
999     bracket_tracker: BracketTracker = Factory(BracketTracker)
1000     inside_brackets: bool = False
1001     should_explode: bool = False
1002
1003     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1004         """Add a new `leaf` to the end of the line.
1005
1006         Unless `preformatted` is True, the `leaf` will receive a new consistent
1007         whitespace prefix and metadata applied by :class:`BracketTracker`.
1008         Trailing commas are maybe removed, unpacked for loop variables are
1009         demoted from being delimiters.
1010
1011         Inline comments are put aside.
1012         """
1013         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1014         if not has_value:
1015             return
1016
1017         if token.COLON == leaf.type and self.is_class_paren_empty:
1018             del self.leaves[-2:]
1019         if self.leaves and not preformatted:
1020             # Note: at this point leaf.prefix should be empty except for
1021             # imports, for which we only preserve newlines.
1022             leaf.prefix += whitespace(
1023                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1024             )
1025         if self.inside_brackets or not preformatted:
1026             self.bracket_tracker.mark(leaf)
1027             self.maybe_remove_trailing_comma(leaf)
1028         if not self.append_comment(leaf):
1029             self.leaves.append(leaf)
1030
1031     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1032         """Like :func:`append()` but disallow invalid standalone comment structure.
1033
1034         Raises ValueError when any `leaf` is appended after a standalone comment
1035         or when a standalone comment is not the first leaf on the line.
1036         """
1037         if self.bracket_tracker.depth == 0:
1038             if self.is_comment:
1039                 raise ValueError("cannot append to standalone comments")
1040
1041             if self.leaves and leaf.type == STANDALONE_COMMENT:
1042                 raise ValueError(
1043                     "cannot append standalone comments to a populated line"
1044                 )
1045
1046         self.append(leaf, preformatted=preformatted)
1047
1048     @property
1049     def is_comment(self) -> bool:
1050         """Is this line a standalone comment?"""
1051         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1052
1053     @property
1054     def is_decorator(self) -> bool:
1055         """Is this line a decorator?"""
1056         return bool(self) and self.leaves[0].type == token.AT
1057
1058     @property
1059     def is_import(self) -> bool:
1060         """Is this an import line?"""
1061         return bool(self) and is_import(self.leaves[0])
1062
1063     @property
1064     def is_class(self) -> bool:
1065         """Is this line a class definition?"""
1066         return (
1067             bool(self)
1068             and self.leaves[0].type == token.NAME
1069             and self.leaves[0].value == "class"
1070         )
1071
1072     @property
1073     def is_stub_class(self) -> bool:
1074         """Is this line a class definition with a body consisting only of "..."?"""
1075         return self.is_class and self.leaves[-3:] == [
1076             Leaf(token.DOT, ".") for _ in range(3)
1077         ]
1078
1079     @property
1080     def is_def(self) -> bool:
1081         """Is this a function definition? (Also returns True for async defs.)"""
1082         try:
1083             first_leaf = self.leaves[0]
1084         except IndexError:
1085             return False
1086
1087         try:
1088             second_leaf: Optional[Leaf] = self.leaves[1]
1089         except IndexError:
1090             second_leaf = None
1091         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1092             first_leaf.type == token.ASYNC
1093             and second_leaf is not None
1094             and second_leaf.type == token.NAME
1095             and second_leaf.value == "def"
1096         )
1097
1098     @property
1099     def is_class_paren_empty(self) -> bool:
1100         """Is this a class with no base classes but using parentheses?
1101
1102         Those are unnecessary and should be removed.
1103         """
1104         return (
1105             bool(self)
1106             and len(self.leaves) == 4
1107             and self.is_class
1108             and self.leaves[2].type == token.LPAR
1109             and self.leaves[2].value == "("
1110             and self.leaves[3].type == token.RPAR
1111             and self.leaves[3].value == ")"
1112         )
1113
1114     @property
1115     def is_triple_quoted_string(self) -> bool:
1116         """Is the line a triple quoted string?"""
1117         return (
1118             bool(self)
1119             and self.leaves[0].type == token.STRING
1120             and self.leaves[0].value.startswith(('"""', "'''"))
1121         )
1122
1123     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1124         """If so, needs to be split before emitting."""
1125         for leaf in self.leaves:
1126             if leaf.type == STANDALONE_COMMENT:
1127                 if leaf.bracket_depth <= depth_limit:
1128                     return True
1129
1130         return False
1131
1132     def contains_multiline_strings(self) -> bool:
1133         for leaf in self.leaves:
1134             if is_multiline_string(leaf):
1135                 return True
1136
1137         return False
1138
1139     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1140         """Remove trailing comma if there is one and it's safe."""
1141         if not (
1142             self.leaves
1143             and self.leaves[-1].type == token.COMMA
1144             and closing.type in CLOSING_BRACKETS
1145         ):
1146             return False
1147
1148         if closing.type == token.RBRACE:
1149             self.remove_trailing_comma()
1150             return True
1151
1152         if closing.type == token.RSQB:
1153             comma = self.leaves[-1]
1154             if comma.parent and comma.parent.type == syms.listmaker:
1155                 self.remove_trailing_comma()
1156                 return True
1157
1158         # For parens let's check if it's safe to remove the comma.
1159         # Imports are always safe.
1160         if self.is_import:
1161             self.remove_trailing_comma()
1162             return True
1163
1164         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1165         # change a tuple into a different type by removing the comma.
1166         depth = closing.bracket_depth + 1
1167         commas = 0
1168         opening = closing.opening_bracket
1169         for _opening_index, leaf in enumerate(self.leaves):
1170             if leaf is opening:
1171                 break
1172
1173         else:
1174             return False
1175
1176         for leaf in self.leaves[_opening_index + 1 :]:
1177             if leaf is closing:
1178                 break
1179
1180             bracket_depth = leaf.bracket_depth
1181             if bracket_depth == depth and leaf.type == token.COMMA:
1182                 commas += 1
1183                 if leaf.parent and leaf.parent.type == syms.arglist:
1184                     commas += 1
1185                     break
1186
1187         if commas > 1:
1188             self.remove_trailing_comma()
1189             return True
1190
1191         return False
1192
1193     def append_comment(self, comment: Leaf) -> bool:
1194         """Add an inline or standalone comment to the line."""
1195         if (
1196             comment.type == STANDALONE_COMMENT
1197             and self.bracket_tracker.any_open_brackets()
1198         ):
1199             comment.prefix = ""
1200             return False
1201
1202         if comment.type != token.COMMENT:
1203             return False
1204
1205         after = len(self.leaves) - 1
1206         if after == -1:
1207             comment.type = STANDALONE_COMMENT
1208             comment.prefix = ""
1209             return False
1210
1211         else:
1212             self.comments.append((after, comment))
1213             return True
1214
1215     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1216         """Generate comments that should appear directly after `leaf`.
1217
1218         Provide a non-negative leaf `_index` to speed up the function.
1219         """
1220         if not self.comments:
1221             return
1222
1223         if _index == -1:
1224             for _index, _leaf in enumerate(self.leaves):
1225                 if leaf is _leaf:
1226                     break
1227
1228             else:
1229                 return
1230
1231         for index, comment_after in self.comments:
1232             if _index == index:
1233                 yield comment_after
1234
1235     def remove_trailing_comma(self) -> None:
1236         """Remove the trailing comma and moves the comments attached to it."""
1237         comma_index = len(self.leaves) - 1
1238         for i in range(len(self.comments)):
1239             comment_index, comment = self.comments[i]
1240             if comment_index == comma_index:
1241                 self.comments[i] = (comma_index - 1, comment)
1242         self.leaves.pop()
1243
1244     def is_complex_subscript(self, leaf: Leaf) -> bool:
1245         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1246         open_lsqb = self.bracket_tracker.get_open_lsqb()
1247         if open_lsqb is None:
1248             return False
1249
1250         subscript_start = open_lsqb.next_sibling
1251
1252         if isinstance(subscript_start, Node):
1253             if subscript_start.type == syms.listmaker:
1254                 return False
1255
1256             if subscript_start.type == syms.subscriptlist:
1257                 subscript_start = child_towards(subscript_start, leaf)
1258         return subscript_start is not None and any(
1259             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1260         )
1261
1262     def __str__(self) -> str:
1263         """Render the line."""
1264         if not self:
1265             return "\n"
1266
1267         indent = "    " * self.depth
1268         leaves = iter(self.leaves)
1269         first = next(leaves)
1270         res = f"{first.prefix}{indent}{first.value}"
1271         for leaf in leaves:
1272             res += str(leaf)
1273         for _, comment in self.comments:
1274             res += str(comment)
1275         return res + "\n"
1276
1277     def __bool__(self) -> bool:
1278         """Return True if the line has leaves or comments."""
1279         return bool(self.leaves or self.comments)
1280
1281
1282 @dataclass
1283 class EmptyLineTracker:
1284     """Provides a stateful method that returns the number of potential extra
1285     empty lines needed before and after the currently processed line.
1286
1287     Note: this tracker works on lines that haven't been split yet.  It assumes
1288     the prefix of the first leaf consists of optional newlines.  Those newlines
1289     are consumed by `maybe_empty_lines()` and included in the computation.
1290     """
1291
1292     is_pyi: bool = False
1293     previous_line: Optional[Line] = None
1294     previous_after: int = 0
1295     previous_defs: List[int] = Factory(list)
1296
1297     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1298         """Return the number of extra empty lines before and after the `current_line`.
1299
1300         This is for separating `def`, `async def` and `class` with extra empty
1301         lines (two on module-level).
1302         """
1303         before, after = self._maybe_empty_lines(current_line)
1304         before -= self.previous_after
1305         self.previous_after = after
1306         self.previous_line = current_line
1307         return before, after
1308
1309     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1310         max_allowed = 1
1311         if current_line.depth == 0:
1312             max_allowed = 1 if self.is_pyi else 2
1313         if current_line.leaves:
1314             # Consume the first leaf's extra newlines.
1315             first_leaf = current_line.leaves[0]
1316             before = first_leaf.prefix.count("\n")
1317             before = min(before, max_allowed)
1318             first_leaf.prefix = ""
1319         else:
1320             before = 0
1321         depth = current_line.depth
1322         while self.previous_defs and self.previous_defs[-1] >= depth:
1323             self.previous_defs.pop()
1324             if self.is_pyi:
1325                 before = 0 if depth else 1
1326             else:
1327                 before = 1 if depth else 2
1328         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1329             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1330
1331         if (
1332             self.previous_line
1333             and self.previous_line.is_import
1334             and not current_line.is_import
1335             and depth == self.previous_line.depth
1336         ):
1337             return (before or 1), 0
1338
1339         if (
1340             self.previous_line
1341             and self.previous_line.is_class
1342             and current_line.is_triple_quoted_string
1343         ):
1344             return before, 1
1345
1346         return before, 0
1347
1348     def _maybe_empty_lines_for_class_or_def(
1349         self, current_line: Line, before: int
1350     ) -> Tuple[int, int]:
1351         if not current_line.is_decorator:
1352             self.previous_defs.append(current_line.depth)
1353         if self.previous_line is None:
1354             # Don't insert empty lines before the first line in the file.
1355             return 0, 0
1356
1357         if self.previous_line.is_decorator:
1358             return 0, 0
1359
1360         if self.previous_line.depth < current_line.depth and (
1361             self.previous_line.is_class or self.previous_line.is_def
1362         ):
1363             return 0, 0
1364
1365         if (
1366             self.previous_line.is_comment
1367             and self.previous_line.depth == current_line.depth
1368             and before == 0
1369         ):
1370             return 0, 0
1371
1372         if self.is_pyi:
1373             if self.previous_line.depth > current_line.depth:
1374                 newlines = 1
1375             elif current_line.is_class or self.previous_line.is_class:
1376                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1377                     # No blank line between classes with an emty body
1378                     newlines = 0
1379                 else:
1380                     newlines = 1
1381             elif current_line.is_def and not self.previous_line.is_def:
1382                 # Blank line between a block of functions and a block of non-functions
1383                 newlines = 1
1384             else:
1385                 newlines = 0
1386         else:
1387             newlines = 2
1388         if current_line.depth and newlines:
1389             newlines -= 1
1390         return newlines, 0
1391
1392
1393 @dataclass
1394 class LineGenerator(Visitor[Line]):
1395     """Generates reformatted Line objects.  Empty lines are not emitted.
1396
1397     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1398     in ways that will no longer stringify to valid Python code on the tree.
1399     """
1400
1401     is_pyi: bool = False
1402     normalize_strings: bool = True
1403     current_line: Line = Factory(Line)
1404     remove_u_prefix: bool = False
1405
1406     def line(self, indent: int = 0) -> Iterator[Line]:
1407         """Generate a line.
1408
1409         If the line is empty, only emit if it makes sense.
1410         If the line is too long, split it first and then generate.
1411
1412         If any lines were generated, set up a new current_line.
1413         """
1414         if not self.current_line:
1415             self.current_line.depth += indent
1416             return  # Line is empty, don't emit. Creating a new one unnecessary.
1417
1418         complete_line = self.current_line
1419         self.current_line = Line(depth=complete_line.depth + indent)
1420         yield complete_line
1421
1422     def visit_default(self, node: LN) -> Iterator[Line]:
1423         """Default `visit_*()` implementation. Recurses to children of `node`."""
1424         if isinstance(node, Leaf):
1425             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1426             for comment in generate_comments(node):
1427                 if any_open_brackets:
1428                     # any comment within brackets is subject to splitting
1429                     self.current_line.append(comment)
1430                 elif comment.type == token.COMMENT:
1431                     # regular trailing comment
1432                     self.current_line.append(comment)
1433                     yield from self.line()
1434
1435                 else:
1436                     # regular standalone comment
1437                     yield from self.line()
1438
1439                     self.current_line.append(comment)
1440                     yield from self.line()
1441
1442             normalize_prefix(node, inside_brackets=any_open_brackets)
1443             if self.normalize_strings and node.type == token.STRING:
1444                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1445                 normalize_string_quotes(node)
1446             if node.type not in WHITESPACE:
1447                 self.current_line.append(node)
1448         yield from super().visit_default(node)
1449
1450     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1451         """Increase indentation level, maybe yield a line."""
1452         # In blib2to3 INDENT never holds comments.
1453         yield from self.line(+1)
1454         yield from self.visit_default(node)
1455
1456     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1457         """Decrease indentation level, maybe yield a line."""
1458         # The current line might still wait for trailing comments.  At DEDENT time
1459         # there won't be any (they would be prefixes on the preceding NEWLINE).
1460         # Emit the line then.
1461         yield from self.line()
1462
1463         # While DEDENT has no value, its prefix may contain standalone comments
1464         # that belong to the current indentation level.  Get 'em.
1465         yield from self.visit_default(node)
1466
1467         # Finally, emit the dedent.
1468         yield from self.line(-1)
1469
1470     def visit_stmt(
1471         self, node: Node, keywords: Set[str], parens: Set[str]
1472     ) -> Iterator[Line]:
1473         """Visit a statement.
1474
1475         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1476         `def`, `with`, `class`, `assert` and assignments.
1477
1478         The relevant Python language `keywords` for a given statement will be
1479         NAME leaves within it. This methods puts those on a separate line.
1480
1481         `parens` holds a set of string leaf values immediately after which
1482         invisible parens should be put.
1483         """
1484         normalize_invisible_parens(node, parens_after=parens)
1485         for child in node.children:
1486             if child.type == token.NAME and child.value in keywords:  # type: ignore
1487                 yield from self.line()
1488
1489             yield from self.visit(child)
1490
1491     def visit_suite(self, node: Node) -> Iterator[Line]:
1492         """Visit a suite."""
1493         if self.is_pyi and is_stub_suite(node):
1494             yield from self.visit(node.children[2])
1495         else:
1496             yield from self.visit_default(node)
1497
1498     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1499         """Visit a statement without nested statements."""
1500         is_suite_like = node.parent and node.parent.type in STATEMENT
1501         if is_suite_like:
1502             if self.is_pyi and is_stub_body(node):
1503                 yield from self.visit_default(node)
1504             else:
1505                 yield from self.line(+1)
1506                 yield from self.visit_default(node)
1507                 yield from self.line(-1)
1508
1509         else:
1510             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1511                 yield from self.line()
1512             yield from self.visit_default(node)
1513
1514     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1515         """Visit `async def`, `async for`, `async with`."""
1516         yield from self.line()
1517
1518         children = iter(node.children)
1519         for child in children:
1520             yield from self.visit(child)
1521
1522             if child.type == token.ASYNC:
1523                 break
1524
1525         internal_stmt = next(children)
1526         for child in internal_stmt.children:
1527             yield from self.visit(child)
1528
1529     def visit_decorators(self, node: Node) -> Iterator[Line]:
1530         """Visit decorators."""
1531         for child in node.children:
1532             yield from self.line()
1533             yield from self.visit(child)
1534
1535     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1536         """Remove a semicolon and put the other statement on a separate line."""
1537         yield from self.line()
1538
1539     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1540         """End of file. Process outstanding comments and end with a newline."""
1541         yield from self.visit_default(leaf)
1542         yield from self.line()
1543
1544     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1545         if not self.current_line.bracket_tracker.any_open_brackets():
1546             yield from self.line()
1547         yield from self.visit_default(leaf)
1548
1549     def __attrs_post_init__(self) -> None:
1550         """You are in a twisty little maze of passages."""
1551         v = self.visit_stmt
1552         Ø: Set[str] = set()
1553         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1554         self.visit_if_stmt = partial(
1555             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1556         )
1557         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1558         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1559         self.visit_try_stmt = partial(
1560             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1561         )
1562         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1563         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1564         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1565         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1566         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1567         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1568         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1569         self.visit_async_funcdef = self.visit_async_stmt
1570         self.visit_decorated = self.visit_decorators
1571
1572
1573 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1574 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1575 OPENING_BRACKETS = set(BRACKET.keys())
1576 CLOSING_BRACKETS = set(BRACKET.values())
1577 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1578 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1579
1580
1581 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1582     """Return whitespace prefix if needed for the given `leaf`.
1583
1584     `complex_subscript` signals whether the given leaf is part of a subscription
1585     which has non-trivial arguments, like arithmetic expressions or function calls.
1586     """
1587     NO = ""
1588     SPACE = " "
1589     DOUBLESPACE = "  "
1590     t = leaf.type
1591     p = leaf.parent
1592     v = leaf.value
1593     if t in ALWAYS_NO_SPACE:
1594         return NO
1595
1596     if t == token.COMMENT:
1597         return DOUBLESPACE
1598
1599     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1600     if t == token.COLON and p.type not in {
1601         syms.subscript,
1602         syms.subscriptlist,
1603         syms.sliceop,
1604     }:
1605         return NO
1606
1607     prev = leaf.prev_sibling
1608     if not prev:
1609         prevp = preceding_leaf(p)
1610         if not prevp or prevp.type in OPENING_BRACKETS:
1611             return NO
1612
1613         if t == token.COLON:
1614             if prevp.type == token.COLON:
1615                 return NO
1616
1617             elif prevp.type != token.COMMA and not complex_subscript:
1618                 return NO
1619
1620             return SPACE
1621
1622         if prevp.type == token.EQUAL:
1623             if prevp.parent:
1624                 if prevp.parent.type in {
1625                     syms.arglist,
1626                     syms.argument,
1627                     syms.parameters,
1628                     syms.varargslist,
1629                 }:
1630                     return NO
1631
1632                 elif prevp.parent.type == syms.typedargslist:
1633                     # A bit hacky: if the equal sign has whitespace, it means we
1634                     # previously found it's a typed argument.  So, we're using
1635                     # that, too.
1636                     return prevp.prefix
1637
1638         elif prevp.type in STARS:
1639             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1640                 return NO
1641
1642         elif prevp.type == token.COLON:
1643             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1644                 return SPACE if complex_subscript else NO
1645
1646         elif (
1647             prevp.parent
1648             and prevp.parent.type == syms.factor
1649             and prevp.type in MATH_OPERATORS
1650         ):
1651             return NO
1652
1653         elif (
1654             prevp.type == token.RIGHTSHIFT
1655             and prevp.parent
1656             and prevp.parent.type == syms.shift_expr
1657             and prevp.prev_sibling
1658             and prevp.prev_sibling.type == token.NAME
1659             and prevp.prev_sibling.value == "print"  # type: ignore
1660         ):
1661             # Python 2 print chevron
1662             return NO
1663
1664     elif prev.type in OPENING_BRACKETS:
1665         return NO
1666
1667     if p.type in {syms.parameters, syms.arglist}:
1668         # untyped function signatures or calls
1669         if not prev or prev.type != token.COMMA:
1670             return NO
1671
1672     elif p.type == syms.varargslist:
1673         # lambdas
1674         if prev and prev.type != token.COMMA:
1675             return NO
1676
1677     elif p.type == syms.typedargslist:
1678         # typed function signatures
1679         if not prev:
1680             return NO
1681
1682         if t == token.EQUAL:
1683             if prev.type != syms.tname:
1684                 return NO
1685
1686         elif prev.type == token.EQUAL:
1687             # A bit hacky: if the equal sign has whitespace, it means we
1688             # previously found it's a typed argument.  So, we're using that, too.
1689             return prev.prefix
1690
1691         elif prev.type != token.COMMA:
1692             return NO
1693
1694     elif p.type == syms.tname:
1695         # type names
1696         if not prev:
1697             prevp = preceding_leaf(p)
1698             if not prevp or prevp.type != token.COMMA:
1699                 return NO
1700
1701     elif p.type == syms.trailer:
1702         # attributes and calls
1703         if t == token.LPAR or t == token.RPAR:
1704             return NO
1705
1706         if not prev:
1707             if t == token.DOT:
1708                 prevp = preceding_leaf(p)
1709                 if not prevp or prevp.type != token.NUMBER:
1710                     return NO
1711
1712             elif t == token.LSQB:
1713                 return NO
1714
1715         elif prev.type != token.COMMA:
1716             return NO
1717
1718     elif p.type == syms.argument:
1719         # single argument
1720         if t == token.EQUAL:
1721             return NO
1722
1723         if not prev:
1724             prevp = preceding_leaf(p)
1725             if not prevp or prevp.type == token.LPAR:
1726                 return NO
1727
1728         elif prev.type in {token.EQUAL} | STARS:
1729             return NO
1730
1731     elif p.type == syms.decorator:
1732         # decorators
1733         return NO
1734
1735     elif p.type == syms.dotted_name:
1736         if prev:
1737             return NO
1738
1739         prevp = preceding_leaf(p)
1740         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1741             return NO
1742
1743     elif p.type == syms.classdef:
1744         if t == token.LPAR:
1745             return NO
1746
1747         if prev and prev.type == token.LPAR:
1748             return NO
1749
1750     elif p.type in {syms.subscript, syms.sliceop}:
1751         # indexing
1752         if not prev:
1753             assert p.parent is not None, "subscripts are always parented"
1754             if p.parent.type == syms.subscriptlist:
1755                 return SPACE
1756
1757             return NO
1758
1759         elif not complex_subscript:
1760             return NO
1761
1762     elif p.type == syms.atom:
1763         if prev and t == token.DOT:
1764             # dots, but not the first one.
1765             return NO
1766
1767     elif p.type == syms.dictsetmaker:
1768         # dict unpacking
1769         if prev and prev.type == token.DOUBLESTAR:
1770             return NO
1771
1772     elif p.type in {syms.factor, syms.star_expr}:
1773         # unary ops
1774         if not prev:
1775             prevp = preceding_leaf(p)
1776             if not prevp or prevp.type in OPENING_BRACKETS:
1777                 return NO
1778
1779             prevp_parent = prevp.parent
1780             assert prevp_parent is not None
1781             if prevp.type == token.COLON and prevp_parent.type in {
1782                 syms.subscript,
1783                 syms.sliceop,
1784             }:
1785                 return NO
1786
1787             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1788                 return NO
1789
1790         elif t in {token.NAME, token.NUMBER, token.STRING}:
1791             return NO
1792
1793     elif p.type == syms.import_from:
1794         if t == token.DOT:
1795             if prev and prev.type == token.DOT:
1796                 return NO
1797
1798         elif t == token.NAME:
1799             if v == "import":
1800                 return SPACE
1801
1802             if prev and prev.type == token.DOT:
1803                 return NO
1804
1805     elif p.type == syms.sliceop:
1806         return NO
1807
1808     return SPACE
1809
1810
1811 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1812     """Return the first leaf that precedes `node`, if any."""
1813     while node:
1814         res = node.prev_sibling
1815         if res:
1816             if isinstance(res, Leaf):
1817                 return res
1818
1819             try:
1820                 return list(res.leaves())[-1]
1821
1822             except IndexError:
1823                 return None
1824
1825         node = node.parent
1826     return None
1827
1828
1829 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1830     """Return the child of `ancestor` that contains `descendant`."""
1831     node: Optional[LN] = descendant
1832     while node and node.parent != ancestor:
1833         node = node.parent
1834     return node
1835
1836
1837 def container_of(leaf: Leaf) -> LN:
1838     """Return `leaf` or one of its ancestors that is the topmost container of it.
1839
1840     By "container" we mean a node where `leaf` is the very first child.
1841     """
1842     same_prefix = leaf.prefix
1843     container: LN = leaf
1844     while container:
1845         parent = container.parent
1846         if parent is None:
1847             break
1848
1849         if parent.children[0].prefix != same_prefix:
1850             break
1851
1852         if parent.type == syms.file_input:
1853             break
1854
1855         if parent.type in SURROUNDED_BY_BRACKETS:
1856             break
1857
1858         container = parent
1859     return container
1860
1861
1862 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1863     """Return the priority of the `leaf` delimiter, given a line break after it.
1864
1865     The delimiter priorities returned here are from those delimiters that would
1866     cause a line break after themselves.
1867
1868     Higher numbers are higher priority.
1869     """
1870     if leaf.type == token.COMMA:
1871         return COMMA_PRIORITY
1872
1873     return 0
1874
1875
1876 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1877     """Return the priority of the `leaf` delimiter, given a line before after it.
1878
1879     The delimiter priorities returned here are from those delimiters that would
1880     cause a line break before themselves.
1881
1882     Higher numbers are higher priority.
1883     """
1884     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1885         # * and ** might also be MATH_OPERATORS but in this case they are not.
1886         # Don't treat them as a delimiter.
1887         return 0
1888
1889     if (
1890         leaf.type == token.DOT
1891         and leaf.parent
1892         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1893         and (previous is None or previous.type in CLOSING_BRACKETS)
1894     ):
1895         return DOT_PRIORITY
1896
1897     if (
1898         leaf.type in MATH_OPERATORS
1899         and leaf.parent
1900         and leaf.parent.type not in {syms.factor, syms.star_expr}
1901     ):
1902         return MATH_PRIORITIES[leaf.type]
1903
1904     if leaf.type in COMPARATORS:
1905         return COMPARATOR_PRIORITY
1906
1907     if (
1908         leaf.type == token.STRING
1909         and previous is not None
1910         and previous.type == token.STRING
1911     ):
1912         return STRING_PRIORITY
1913
1914     if leaf.type != token.NAME:
1915         return 0
1916
1917     if (
1918         leaf.value == "for"
1919         and leaf.parent
1920         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1921     ):
1922         return COMPREHENSION_PRIORITY
1923
1924     if (
1925         leaf.value == "if"
1926         and leaf.parent
1927         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1928     ):
1929         return COMPREHENSION_PRIORITY
1930
1931     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1932         return TERNARY_PRIORITY
1933
1934     if leaf.value == "is":
1935         return COMPARATOR_PRIORITY
1936
1937     if (
1938         leaf.value == "in"
1939         and leaf.parent
1940         and leaf.parent.type in {syms.comp_op, syms.comparison}
1941         and not (
1942             previous is not None
1943             and previous.type == token.NAME
1944             and previous.value == "not"
1945         )
1946     ):
1947         return COMPARATOR_PRIORITY
1948
1949     if (
1950         leaf.value == "not"
1951         and leaf.parent
1952         and leaf.parent.type == syms.comp_op
1953         and not (
1954             previous is not None
1955             and previous.type == token.NAME
1956             and previous.value == "is"
1957         )
1958     ):
1959         return COMPARATOR_PRIORITY
1960
1961     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1962         return LOGIC_PRIORITY
1963
1964     return 0
1965
1966
1967 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1968 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1969
1970
1971 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1972     """Clean the prefix of the `leaf` and generate comments from it, if any.
1973
1974     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1975     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1976     move because it does away with modifying the grammar to include all the
1977     possible places in which comments can be placed.
1978
1979     The sad consequence for us though is that comments don't "belong" anywhere.
1980     This is why this function generates simple parentless Leaf objects for
1981     comments.  We simply don't know what the correct parent should be.
1982
1983     No matter though, we can live without this.  We really only need to
1984     differentiate between inline and standalone comments.  The latter don't
1985     share the line with any code.
1986
1987     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1988     are emitted with a fake STANDALONE_COMMENT token identifier.
1989     """
1990     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
1991         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
1992
1993
1994 @dataclass
1995 class ProtoComment:
1996     type: int  # token.COMMENT or STANDALONE_COMMENT
1997     value: str  # content of the comment
1998     newlines: int  # how many newlines before the comment
1999     consumed: int  # how many characters of the original leaf's prefix did we consume
2000
2001
2002 @lru_cache(maxsize=4096)
2003 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2004     result: List[ProtoComment] = []
2005     if not prefix or "#" not in prefix:
2006         return result
2007
2008     consumed = 0
2009     nlines = 0
2010     for index, line in enumerate(prefix.split("\n")):
2011         consumed += len(line) + 1  # adding the length of the split '\n'
2012         line = line.lstrip()
2013         if not line:
2014             nlines += 1
2015         if not line.startswith("#"):
2016             continue
2017
2018         if index == 0 and not is_endmarker:
2019             comment_type = token.COMMENT  # simple trailing comment
2020         else:
2021             comment_type = STANDALONE_COMMENT
2022         comment = make_comment(line)
2023         result.append(
2024             ProtoComment(
2025                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2026             )
2027         )
2028         nlines = 0
2029     return result
2030
2031
2032 def make_comment(content: str) -> str:
2033     """Return a consistently formatted comment from the given `content` string.
2034
2035     All comments (except for "##", "#!", "#:") should have a single space between
2036     the hash sign and the content.
2037
2038     If `content` didn't start with a hash sign, one is provided.
2039     """
2040     content = content.rstrip()
2041     if not content:
2042         return "#"
2043
2044     if content[0] == "#":
2045         content = content[1:]
2046     if content and content[0] not in " !:#":
2047         content = " " + content
2048     return "#" + content
2049
2050
2051 def split_line(
2052     line: Line, line_length: int, inner: bool = False, py36: bool = False
2053 ) -> Iterator[Line]:
2054     """Split a `line` into potentially many lines.
2055
2056     They should fit in the allotted `line_length` but might not be able to.
2057     `inner` signifies that there were a pair of brackets somewhere around the
2058     current `line`, possibly transitively. This means we can fallback to splitting
2059     by delimiters if the LHS/RHS don't yield any results.
2060
2061     If `py36` is True, splitting may generate syntax that is only compatible
2062     with Python 3.6 and later.
2063     """
2064     if line.is_comment:
2065         yield line
2066         return
2067
2068     line_str = str(line).strip("\n")
2069     if not line.should_explode and is_line_short_enough(
2070         line, line_length=line_length, line_str=line_str
2071     ):
2072         yield line
2073         return
2074
2075     split_funcs: List[SplitFunc]
2076     if line.is_def:
2077         split_funcs = [left_hand_split]
2078     else:
2079
2080         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2081             for omit in generate_trailers_to_omit(line, line_length):
2082                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2083                 if is_line_short_enough(lines[0], line_length=line_length):
2084                     yield from lines
2085                     return
2086
2087             # All splits failed, best effort split with no omits.
2088             # This mostly happens to multiline strings that are by definition
2089             # reported as not fitting a single line.
2090             yield from right_hand_split(line, py36)
2091
2092         if line.inside_brackets:
2093             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2094         else:
2095             split_funcs = [rhs]
2096     for split_func in split_funcs:
2097         # We are accumulating lines in `result` because we might want to abort
2098         # mission and return the original line in the end, or attempt a different
2099         # split altogether.
2100         result: List[Line] = []
2101         try:
2102             for l in split_func(line, py36):
2103                 if str(l).strip("\n") == line_str:
2104                     raise CannotSplit("Split function returned an unchanged result")
2105
2106                 result.extend(
2107                     split_line(l, line_length=line_length, inner=True, py36=py36)
2108                 )
2109         except CannotSplit as cs:
2110             continue
2111
2112         else:
2113             yield from result
2114             break
2115
2116     else:
2117         yield line
2118
2119
2120 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2121     """Split line into many lines, starting with the first matching bracket pair.
2122
2123     Note: this usually looks weird, only use this for function definitions.
2124     Prefer RHS otherwise.  This is why this function is not symmetrical with
2125     :func:`right_hand_split` which also handles optional parentheses.
2126     """
2127     head = Line(depth=line.depth)
2128     body = Line(depth=line.depth + 1, inside_brackets=True)
2129     tail = Line(depth=line.depth)
2130     tail_leaves: List[Leaf] = []
2131     body_leaves: List[Leaf] = []
2132     head_leaves: List[Leaf] = []
2133     current_leaves = head_leaves
2134     matching_bracket = None
2135     for leaf in line.leaves:
2136         if (
2137             current_leaves is body_leaves
2138             and leaf.type in CLOSING_BRACKETS
2139             and leaf.opening_bracket is matching_bracket
2140         ):
2141             current_leaves = tail_leaves if body_leaves else head_leaves
2142         current_leaves.append(leaf)
2143         if current_leaves is head_leaves:
2144             if leaf.type in OPENING_BRACKETS:
2145                 matching_bracket = leaf
2146                 current_leaves = body_leaves
2147     # Since body is a new indent level, remove spurious leading whitespace.
2148     if body_leaves:
2149         normalize_prefix(body_leaves[0], inside_brackets=True)
2150     # Build the new lines.
2151     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2152         for leaf in leaves:
2153             result.append(leaf, preformatted=True)
2154             for comment_after in line.comments_after(leaf):
2155                 result.append(comment_after, preformatted=True)
2156     bracket_split_succeeded_or_raise(head, body, tail)
2157     for result in (head, body, tail):
2158         if result:
2159             yield result
2160
2161
2162 def right_hand_split(
2163     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2164 ) -> Iterator[Line]:
2165     """Split line into many lines, starting with the last matching bracket pair.
2166
2167     If the split was by optional parentheses, attempt splitting without them, too.
2168     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2169     this split.
2170
2171     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2172     """
2173     head = Line(depth=line.depth)
2174     body = Line(depth=line.depth + 1, inside_brackets=True)
2175     tail = Line(depth=line.depth)
2176     tail_leaves: List[Leaf] = []
2177     body_leaves: List[Leaf] = []
2178     head_leaves: List[Leaf] = []
2179     current_leaves = tail_leaves
2180     opening_bracket = None
2181     closing_bracket = None
2182     for leaf in reversed(line.leaves):
2183         if current_leaves is body_leaves:
2184             if leaf is opening_bracket:
2185                 current_leaves = head_leaves if body_leaves else tail_leaves
2186         current_leaves.append(leaf)
2187         if current_leaves is tail_leaves:
2188             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2189                 opening_bracket = leaf.opening_bracket
2190                 closing_bracket = leaf
2191                 current_leaves = body_leaves
2192     tail_leaves.reverse()
2193     body_leaves.reverse()
2194     head_leaves.reverse()
2195     # Since body is a new indent level, remove spurious leading whitespace.
2196     if body_leaves:
2197         normalize_prefix(body_leaves[0], inside_brackets=True)
2198     if not head_leaves:
2199         # No `head` means the split failed. Either `tail` has all content or
2200         # the matching `opening_bracket` wasn't available on `line` anymore.
2201         raise CannotSplit("No brackets found")
2202
2203     # Build the new lines.
2204     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2205         for leaf in leaves:
2206             result.append(leaf, preformatted=True)
2207             for comment_after in line.comments_after(leaf):
2208                 result.append(comment_after, preformatted=True)
2209     assert opening_bracket and closing_bracket
2210     body.should_explode = should_explode(body, opening_bracket)
2211     bracket_split_succeeded_or_raise(head, body, tail)
2212     if (
2213         # the body shouldn't be exploded
2214         not body.should_explode
2215         # the opening bracket is an optional paren
2216         and opening_bracket.type == token.LPAR
2217         and not opening_bracket.value
2218         # the closing bracket is an optional paren
2219         and closing_bracket.type == token.RPAR
2220         and not closing_bracket.value
2221         # it's not an import (optional parens are the only thing we can split on
2222         # in this case; attempting a split without them is a waste of time)
2223         and not line.is_import
2224         # there are no standalone comments in the body
2225         and not body.contains_standalone_comments(0)
2226         # and we can actually remove the parens
2227         and can_omit_invisible_parens(body, line_length)
2228     ):
2229         omit = {id(closing_bracket), *omit}
2230         try:
2231             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2232             return
2233
2234         except CannotSplit:
2235             if not (
2236                 can_be_split(body)
2237                 or is_line_short_enough(body, line_length=line_length)
2238             ):
2239                 raise CannotSplit(
2240                     "Splitting failed, body is still too long and can't be split."
2241                 )
2242
2243             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2244                 raise CannotSplit(
2245                     "The current optional pair of parentheses is bound to fail to "
2246                     "satisfy the splitting algorithm because the head or the tail "
2247                     "contains multiline strings which by definition never fit one "
2248                     "line."
2249                 )
2250
2251     ensure_visible(opening_bracket)
2252     ensure_visible(closing_bracket)
2253     for result in (head, body, tail):
2254         if result:
2255             yield result
2256
2257
2258 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2259     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2260
2261     Do nothing otherwise.
2262
2263     A left- or right-hand split is based on a pair of brackets. Content before
2264     (and including) the opening bracket is left on one line, content inside the
2265     brackets is put on a separate line, and finally content starting with and
2266     following the closing bracket is put on a separate line.
2267
2268     Those are called `head`, `body`, and `tail`, respectively. If the split
2269     produced the same line (all content in `head`) or ended up with an empty `body`
2270     and the `tail` is just the closing bracket, then it's considered failed.
2271     """
2272     tail_len = len(str(tail).strip())
2273     if not body:
2274         if tail_len == 0:
2275             raise CannotSplit("Splitting brackets produced the same line")
2276
2277         elif tail_len < 3:
2278             raise CannotSplit(
2279                 f"Splitting brackets on an empty body to save "
2280                 f"{tail_len} characters is not worth it"
2281             )
2282
2283
2284 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2285     """Normalize prefix of the first leaf in every line returned by `split_func`.
2286
2287     This is a decorator over relevant split functions.
2288     """
2289
2290     @wraps(split_func)
2291     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2292         for l in split_func(line, py36):
2293             normalize_prefix(l.leaves[0], inside_brackets=True)
2294             yield l
2295
2296     return split_wrapper
2297
2298
2299 @dont_increase_indentation
2300 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2301     """Split according to delimiters of the highest priority.
2302
2303     If `py36` is True, the split will add trailing commas also in function
2304     signatures that contain `*` and `**`.
2305     """
2306     try:
2307         last_leaf = line.leaves[-1]
2308     except IndexError:
2309         raise CannotSplit("Line empty")
2310
2311     bt = line.bracket_tracker
2312     try:
2313         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2314     except ValueError:
2315         raise CannotSplit("No delimiters found")
2316
2317     if delimiter_priority == DOT_PRIORITY:
2318         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2319             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2320
2321     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2322     lowest_depth = sys.maxsize
2323     trailing_comma_safe = True
2324
2325     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2326         """Append `leaf` to current line or to new line if appending impossible."""
2327         nonlocal current_line
2328         try:
2329             current_line.append_safe(leaf, preformatted=True)
2330         except ValueError as ve:
2331             yield current_line
2332
2333             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2334             current_line.append(leaf)
2335
2336     for index, leaf in enumerate(line.leaves):
2337         yield from append_to_line(leaf)
2338
2339         for comment_after in line.comments_after(leaf, index):
2340             yield from append_to_line(comment_after)
2341
2342         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2343         if leaf.bracket_depth == lowest_depth and is_vararg(
2344             leaf, within=VARARGS_PARENTS
2345         ):
2346             trailing_comma_safe = trailing_comma_safe and py36
2347         leaf_priority = bt.delimiters.get(id(leaf))
2348         if leaf_priority == delimiter_priority:
2349             yield current_line
2350
2351             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2352     if current_line:
2353         if (
2354             trailing_comma_safe
2355             and delimiter_priority == COMMA_PRIORITY
2356             and current_line.leaves[-1].type != token.COMMA
2357             and current_line.leaves[-1].type != STANDALONE_COMMENT
2358         ):
2359             current_line.append(Leaf(token.COMMA, ","))
2360         yield current_line
2361
2362
2363 @dont_increase_indentation
2364 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2365     """Split standalone comments from the rest of the line."""
2366     if not line.contains_standalone_comments(0):
2367         raise CannotSplit("Line does not have any standalone comments")
2368
2369     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2370
2371     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2372         """Append `leaf` to current line or to new line if appending impossible."""
2373         nonlocal current_line
2374         try:
2375             current_line.append_safe(leaf, preformatted=True)
2376         except ValueError as ve:
2377             yield current_line
2378
2379             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2380             current_line.append(leaf)
2381
2382     for index, leaf in enumerate(line.leaves):
2383         yield from append_to_line(leaf)
2384
2385         for comment_after in line.comments_after(leaf, index):
2386             yield from append_to_line(comment_after)
2387
2388     if current_line:
2389         yield current_line
2390
2391
2392 def is_import(leaf: Leaf) -> bool:
2393     """Return True if the given leaf starts an import statement."""
2394     p = leaf.parent
2395     t = leaf.type
2396     v = leaf.value
2397     return bool(
2398         t == token.NAME
2399         and (
2400             (v == "import" and p and p.type == syms.import_name)
2401             or (v == "from" and p and p.type == syms.import_from)
2402         )
2403     )
2404
2405
2406 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2407     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2408     else.
2409
2410     Note: don't use backslashes for formatting or you'll lose your voting rights.
2411     """
2412     if not inside_brackets:
2413         spl = leaf.prefix.split("#")
2414         if "\\" not in spl[0]:
2415             nl_count = spl[-1].count("\n")
2416             if len(spl) > 1:
2417                 nl_count -= 1
2418             leaf.prefix = "\n" * nl_count
2419             return
2420
2421     leaf.prefix = ""
2422
2423
2424 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2425     """Make all string prefixes lowercase.
2426
2427     If remove_u_prefix is given, also removes any u prefix from the string.
2428
2429     Note: Mutates its argument.
2430     """
2431     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2432     assert match is not None, f"failed to match string {leaf.value!r}"
2433     orig_prefix = match.group(1)
2434     new_prefix = orig_prefix.lower()
2435     if remove_u_prefix:
2436         new_prefix = new_prefix.replace("u", "")
2437     leaf.value = f"{new_prefix}{match.group(2)}"
2438
2439
2440 def normalize_string_quotes(leaf: Leaf) -> None:
2441     """Prefer double quotes but only if it doesn't cause more escaping.
2442
2443     Adds or removes backslashes as appropriate. Doesn't parse and fix
2444     strings nested in f-strings (yet).
2445
2446     Note: Mutates its argument.
2447     """
2448     value = leaf.value.lstrip("furbFURB")
2449     if value[:3] == '"""':
2450         return
2451
2452     elif value[:3] == "'''":
2453         orig_quote = "'''"
2454         new_quote = '"""'
2455     elif value[0] == '"':
2456         orig_quote = '"'
2457         new_quote = "'"
2458     else:
2459         orig_quote = "'"
2460         new_quote = '"'
2461     first_quote_pos = leaf.value.find(orig_quote)
2462     if first_quote_pos == -1:
2463         return  # There's an internal error
2464
2465     prefix = leaf.value[:first_quote_pos]
2466     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2467     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2468     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2469     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2470     if "r" in prefix.casefold():
2471         if unescaped_new_quote.search(body):
2472             # There's at least one unescaped new_quote in this raw string
2473             # so converting is impossible
2474             return
2475
2476         # Do not introduce or remove backslashes in raw strings
2477         new_body = body
2478     else:
2479         # remove unnecessary escapes
2480         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2481         if body != new_body:
2482             # Consider the string without unnecessary escapes as the original
2483             body = new_body
2484             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2485         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2486         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2487     if "f" in prefix.casefold():
2488         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2489         for m in matches:
2490             if "\\" in str(m):
2491                 # Do not introduce backslashes in interpolated expressions
2492                 return
2493     if new_quote == '"""' and new_body[-1:] == '"':
2494         # edge case:
2495         new_body = new_body[:-1] + '\\"'
2496     orig_escape_count = body.count("\\")
2497     new_escape_count = new_body.count("\\")
2498     if new_escape_count > orig_escape_count:
2499         return  # Do not introduce more escaping
2500
2501     if new_escape_count == orig_escape_count and orig_quote == '"':
2502         return  # Prefer double quotes
2503
2504     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2505
2506
2507 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2508     """Make existing optional parentheses invisible or create new ones.
2509
2510     `parens_after` is a set of string leaf values immeditely after which parens
2511     should be put.
2512
2513     Standardizes on visible parentheses for single-element tuples, and keeps
2514     existing visible parentheses for other tuples and generator expressions.
2515     """
2516     for pc in list_comments(node.prefix, is_endmarker=False):
2517         if pc.value in FMT_OFF:
2518             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2519             return
2520
2521     check_lpar = False
2522     for index, child in enumerate(list(node.children)):
2523         if check_lpar:
2524             if child.type == syms.atom:
2525                 maybe_make_parens_invisible_in_atom(child)
2526             elif is_one_tuple(child):
2527                 # wrap child in visible parentheses
2528                 lpar = Leaf(token.LPAR, "(")
2529                 rpar = Leaf(token.RPAR, ")")
2530                 child.remove()
2531                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2532             elif node.type == syms.import_from:
2533                 # "import from" nodes store parentheses directly as part of
2534                 # the statement
2535                 if child.type == token.LPAR:
2536                     # make parentheses invisible
2537                     child.value = ""  # type: ignore
2538                     node.children[-1].value = ""  # type: ignore
2539                 elif child.type != token.STAR:
2540                     # insert invisible parentheses
2541                     node.insert_child(index, Leaf(token.LPAR, ""))
2542                     node.append_child(Leaf(token.RPAR, ""))
2543                 break
2544
2545             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2546                 # wrap child in invisible parentheses
2547                 lpar = Leaf(token.LPAR, "")
2548                 rpar = Leaf(token.RPAR, "")
2549                 index = child.remove() or 0
2550                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2551
2552         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2553
2554
2555 def normalize_fmt_off(node: Node) -> None:
2556     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2557     try_again = True
2558     while try_again:
2559         try_again = convert_one_fmt_off_pair(node)
2560
2561
2562 def convert_one_fmt_off_pair(node: Node) -> bool:
2563     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2564
2565     Returns True if a pair was converted.
2566     """
2567     for leaf in node.leaves():
2568         previous_consumed = 0
2569         for comment in list_comments(leaf.prefix, is_endmarker=False):
2570             if comment.value in FMT_OFF:
2571                 # We only want standalone comments. If there's no previous leaf or
2572                 # the previous leaf is indentation, it's a standalone comment in
2573                 # disguise.
2574                 if comment.type != STANDALONE_COMMENT:
2575                     prev = preceding_leaf(leaf)
2576                     if prev and prev.type not in WHITESPACE:
2577                         continue
2578
2579                 ignored_nodes = list(generate_ignored_nodes(leaf))
2580                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2581                 parent = first.parent
2582                 prefix = first.prefix
2583                 first.prefix = prefix[comment.consumed :]
2584                 hidden_value = (
2585                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2586                 )
2587                 if hidden_value.endswith("\n"):
2588                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2589                     # leaf (possibly followed by a DEDENT).
2590                     hidden_value = hidden_value[:-1]
2591                 first_idx = None
2592                 for ignored in ignored_nodes:
2593                     index = ignored.remove()
2594                     if first_idx is None:
2595                         first_idx = index
2596                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2597                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2598                 parent.insert_child(
2599                     first_idx,
2600                     Leaf(
2601                         STANDALONE_COMMENT,
2602                         hidden_value,
2603                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2604                     ),
2605                 )
2606                 return True
2607
2608             previous_consumed += comment.consumed
2609
2610     return False
2611
2612
2613 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2614     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2615
2616     Stops at the end of the block.
2617     """
2618     container: Optional[LN] = container_of(leaf)
2619     while container is not None and container.type != token.ENDMARKER:
2620         for comment in list_comments(container.prefix, is_endmarker=False):
2621             if comment.value in FMT_ON:
2622                 return
2623
2624         yield container
2625
2626         container = container.next_sibling
2627
2628
2629 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2630     """If it's safe, make the parens in the atom `node` invisible, recursively."""
2631     if (
2632         node.type != syms.atom
2633         or is_empty_tuple(node)
2634         or is_one_tuple(node)
2635         or is_yield(node)
2636         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2637     ):
2638         return False
2639
2640     first = node.children[0]
2641     last = node.children[-1]
2642     if first.type == token.LPAR and last.type == token.RPAR:
2643         # make parentheses invisible
2644         first.value = ""  # type: ignore
2645         last.value = ""  # type: ignore
2646         if len(node.children) > 1:
2647             maybe_make_parens_invisible_in_atom(node.children[1])
2648         return True
2649
2650     return False
2651
2652
2653 def is_empty_tuple(node: LN) -> bool:
2654     """Return True if `node` holds an empty tuple."""
2655     return (
2656         node.type == syms.atom
2657         and len(node.children) == 2
2658         and node.children[0].type == token.LPAR
2659         and node.children[1].type == token.RPAR
2660     )
2661
2662
2663 def is_one_tuple(node: LN) -> bool:
2664     """Return True if `node` holds a tuple with one element, with or without parens."""
2665     if node.type == syms.atom:
2666         if len(node.children) != 3:
2667             return False
2668
2669         lpar, gexp, rpar = node.children
2670         if not (
2671             lpar.type == token.LPAR
2672             and gexp.type == syms.testlist_gexp
2673             and rpar.type == token.RPAR
2674         ):
2675             return False
2676
2677         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2678
2679     return (
2680         node.type in IMPLICIT_TUPLE
2681         and len(node.children) == 2
2682         and node.children[1].type == token.COMMA
2683     )
2684
2685
2686 def is_yield(node: LN) -> bool:
2687     """Return True if `node` holds a `yield` or `yield from` expression."""
2688     if node.type == syms.yield_expr:
2689         return True
2690
2691     if node.type == token.NAME and node.value == "yield":  # type: ignore
2692         return True
2693
2694     if node.type != syms.atom:
2695         return False
2696
2697     if len(node.children) != 3:
2698         return False
2699
2700     lpar, expr, rpar = node.children
2701     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2702         return is_yield(expr)
2703
2704     return False
2705
2706
2707 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2708     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2709
2710     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2711     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2712     extended iterable unpacking (PEP 3132) and additional unpacking
2713     generalizations (PEP 448).
2714     """
2715     if leaf.type not in STARS or not leaf.parent:
2716         return False
2717
2718     p = leaf.parent
2719     if p.type == syms.star_expr:
2720         # Star expressions are also used as assignment targets in extended
2721         # iterable unpacking (PEP 3132).  See what its parent is instead.
2722         if not p.parent:
2723             return False
2724
2725         p = p.parent
2726
2727     return p.type in within
2728
2729
2730 def is_multiline_string(leaf: Leaf) -> bool:
2731     """Return True if `leaf` is a multiline string that actually spans many lines."""
2732     value = leaf.value.lstrip("furbFURB")
2733     return value[:3] in {'"""', "'''"} and "\n" in value
2734
2735
2736 def is_stub_suite(node: Node) -> bool:
2737     """Return True if `node` is a suite with a stub body."""
2738     if (
2739         len(node.children) != 4
2740         or node.children[0].type != token.NEWLINE
2741         or node.children[1].type != token.INDENT
2742         or node.children[3].type != token.DEDENT
2743     ):
2744         return False
2745
2746     return is_stub_body(node.children[2])
2747
2748
2749 def is_stub_body(node: LN) -> bool:
2750     """Return True if `node` is a simple statement containing an ellipsis."""
2751     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2752         return False
2753
2754     if len(node.children) != 2:
2755         return False
2756
2757     child = node.children[0]
2758     return (
2759         child.type == syms.atom
2760         and len(child.children) == 3
2761         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2762     )
2763
2764
2765 def max_delimiter_priority_in_atom(node: LN) -> int:
2766     """Return maximum delimiter priority inside `node`.
2767
2768     This is specific to atoms with contents contained in a pair of parentheses.
2769     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2770     """
2771     if node.type != syms.atom:
2772         return 0
2773
2774     first = node.children[0]
2775     last = node.children[-1]
2776     if not (first.type == token.LPAR and last.type == token.RPAR):
2777         return 0
2778
2779     bt = BracketTracker()
2780     for c in node.children[1:-1]:
2781         if isinstance(c, Leaf):
2782             bt.mark(c)
2783         else:
2784             for leaf in c.leaves():
2785                 bt.mark(leaf)
2786     try:
2787         return bt.max_delimiter_priority()
2788
2789     except ValueError:
2790         return 0
2791
2792
2793 def ensure_visible(leaf: Leaf) -> None:
2794     """Make sure parentheses are visible.
2795
2796     They could be invisible as part of some statements (see
2797     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2798     """
2799     if leaf.type == token.LPAR:
2800         leaf.value = "("
2801     elif leaf.type == token.RPAR:
2802         leaf.value = ")"
2803
2804
2805 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2806     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2807     if not (
2808         opening_bracket.parent
2809         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2810         and opening_bracket.value in "[{("
2811     ):
2812         return False
2813
2814     try:
2815         last_leaf = line.leaves[-1]
2816         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2817         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2818     except (IndexError, ValueError):
2819         return False
2820
2821     return max_priority == COMMA_PRIORITY
2822
2823
2824 def is_python36(node: Node) -> bool:
2825     """Return True if the current file is using Python 3.6+ features.
2826
2827     Currently looking for:
2828     - f-strings; and
2829     - trailing commas after * or ** in function signatures and calls.
2830     """
2831     for n in node.pre_order():
2832         if n.type == token.STRING:
2833             value_head = n.value[:2]  # type: ignore
2834             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2835                 return True
2836
2837         elif (
2838             n.type in {syms.typedargslist, syms.arglist}
2839             and n.children
2840             and n.children[-1].type == token.COMMA
2841         ):
2842             for ch in n.children:
2843                 if ch.type in STARS:
2844                     return True
2845
2846                 if ch.type == syms.argument:
2847                     for argch in ch.children:
2848                         if argch.type in STARS:
2849                             return True
2850
2851     return False
2852
2853
2854 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2855     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2856
2857     Brackets can be omitted if the entire trailer up to and including
2858     a preceding closing bracket fits in one line.
2859
2860     Yielded sets are cumulative (contain results of previous yields, too).  First
2861     set is empty.
2862     """
2863
2864     omit: Set[LeafID] = set()
2865     yield omit
2866
2867     length = 4 * line.depth
2868     opening_bracket = None
2869     closing_bracket = None
2870     optional_brackets: Set[LeafID] = set()
2871     inner_brackets: Set[LeafID] = set()
2872     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2873         length += leaf_length
2874         if length > line_length:
2875             break
2876
2877         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2878         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2879             break
2880
2881         optional_brackets.discard(id(leaf))
2882         if opening_bracket:
2883             if leaf is opening_bracket:
2884                 opening_bracket = None
2885             elif leaf.type in CLOSING_BRACKETS:
2886                 inner_brackets.add(id(leaf))
2887         elif leaf.type in CLOSING_BRACKETS:
2888             if not leaf.value:
2889                 optional_brackets.add(id(opening_bracket))
2890                 continue
2891
2892             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2893                 # Empty brackets would fail a split so treat them as "inner"
2894                 # brackets (e.g. only add them to the `omit` set if another
2895                 # pair of brackets was good enough.
2896                 inner_brackets.add(id(leaf))
2897                 continue
2898
2899             opening_bracket = leaf.opening_bracket
2900             if closing_bracket:
2901                 omit.add(id(closing_bracket))
2902                 omit.update(inner_brackets)
2903                 inner_brackets.clear()
2904                 yield omit
2905             closing_bracket = leaf
2906
2907
2908 def get_future_imports(node: Node) -> Set[str]:
2909     """Return a set of __future__ imports in the file."""
2910     imports = set()
2911     for child in node.children:
2912         if child.type != syms.simple_stmt:
2913             break
2914         first_child = child.children[0]
2915         if isinstance(first_child, Leaf):
2916             # Continue looking if we see a docstring; otherwise stop.
2917             if (
2918                 len(child.children) == 2
2919                 and first_child.type == token.STRING
2920                 and child.children[1].type == token.NEWLINE
2921             ):
2922                 continue
2923             else:
2924                 break
2925         elif first_child.type == syms.import_from:
2926             module_name = first_child.children[1]
2927             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2928                 break
2929             for import_from_child in first_child.children[3:]:
2930                 if isinstance(import_from_child, Leaf):
2931                     if import_from_child.type == token.NAME:
2932                         imports.add(import_from_child.value)
2933                 else:
2934                     assert import_from_child.type == syms.import_as_names
2935                     for leaf in import_from_child.children:
2936                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2937                             imports.add(leaf.value)
2938         else:
2939             break
2940     return imports
2941
2942
2943 def gen_python_files_in_dir(
2944     path: Path,
2945     root: Path,
2946     include: Pattern[str],
2947     exclude: Pattern[str],
2948     report: "Report",
2949 ) -> Iterator[Path]:
2950     """Generate all files under `path` whose paths are not excluded by the
2951     `exclude` regex, but are included by the `include` regex.
2952
2953     Symbolic links pointing outside of the root directory are ignored.
2954
2955     `report` is where output about exclusions goes.
2956     """
2957     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2958     for child in path.iterdir():
2959         try:
2960             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2961         except ValueError:
2962             if child.is_symlink():
2963                 report.path_ignored(
2964                     child,
2965                     "is a symbolic link that points outside of the root directory",
2966                 )
2967                 continue
2968
2969             raise
2970
2971         if child.is_dir():
2972             normalized_path += "/"
2973         exclude_match = exclude.search(normalized_path)
2974         if exclude_match and exclude_match.group(0):
2975             report.path_ignored(child, f"matches the --exclude regular expression")
2976             continue
2977
2978         if child.is_dir():
2979             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2980
2981         elif child.is_file():
2982             include_match = include.search(normalized_path)
2983             if include_match:
2984                 yield child
2985
2986
2987 @lru_cache()
2988 def find_project_root(srcs: Iterable[str]) -> Path:
2989     """Return a directory containing .git, .hg, or pyproject.toml.
2990
2991     That directory can be one of the directories passed in `srcs` or their
2992     common parent.
2993
2994     If no directory in the tree contains a marker that would specify it's the
2995     project root, the root of the file system is returned.
2996     """
2997     if not srcs:
2998         return Path("/").resolve()
2999
3000     common_base = min(Path(src).resolve() for src in srcs)
3001     if common_base.is_dir():
3002         # Append a fake file so `parents` below returns `common_base_dir`, too.
3003         common_base /= "fake-file"
3004     for directory in common_base.parents:
3005         if (directory / ".git").is_dir():
3006             return directory
3007
3008         if (directory / ".hg").is_dir():
3009             return directory
3010
3011         if (directory / "pyproject.toml").is_file():
3012             return directory
3013
3014     return directory
3015
3016
3017 @dataclass
3018 class Report:
3019     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3020
3021     check: bool = False
3022     quiet: bool = False
3023     verbose: bool = False
3024     change_count: int = 0
3025     same_count: int = 0
3026     failure_count: int = 0
3027
3028     def done(self, src: Path, changed: Changed) -> None:
3029         """Increment the counter for successful reformatting. Write out a message."""
3030         if changed is Changed.YES:
3031             reformatted = "would reformat" if self.check else "reformatted"
3032             if self.verbose or not self.quiet:
3033                 out(f"{reformatted} {src}")
3034             self.change_count += 1
3035         else:
3036             if self.verbose:
3037                 if changed is Changed.NO:
3038                     msg = f"{src} already well formatted, good job."
3039                 else:
3040                     msg = f"{src} wasn't modified on disk since last run."
3041                 out(msg, bold=False)
3042             self.same_count += 1
3043
3044     def failed(self, src: Path, message: str) -> None:
3045         """Increment the counter for failed reformatting. Write out a message."""
3046         err(f"error: cannot format {src}: {message}")
3047         self.failure_count += 1
3048
3049     def path_ignored(self, path: Path, message: str) -> None:
3050         if self.verbose:
3051             out(f"{path} ignored: {message}", bold=False)
3052
3053     @property
3054     def return_code(self) -> int:
3055         """Return the exit code that the app should use.
3056
3057         This considers the current state of changed files and failures:
3058         - if there were any failures, return 123;
3059         - if any files were changed and --check is being used, return 1;
3060         - otherwise return 0.
3061         """
3062         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3063         # 126 we have special returncodes reserved by the shell.
3064         if self.failure_count:
3065             return 123
3066
3067         elif self.change_count and self.check:
3068             return 1
3069
3070         return 0
3071
3072     def __str__(self) -> str:
3073         """Render a color report of the current state.
3074
3075         Use `click.unstyle` to remove colors.
3076         """
3077         if self.check:
3078             reformatted = "would be reformatted"
3079             unchanged = "would be left unchanged"
3080             failed = "would fail to reformat"
3081         else:
3082             reformatted = "reformatted"
3083             unchanged = "left unchanged"
3084             failed = "failed to reformat"
3085         report = []
3086         if self.change_count:
3087             s = "s" if self.change_count > 1 else ""
3088             report.append(
3089                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3090             )
3091         if self.same_count:
3092             s = "s" if self.same_count > 1 else ""
3093             report.append(f"{self.same_count} file{s} {unchanged}")
3094         if self.failure_count:
3095             s = "s" if self.failure_count > 1 else ""
3096             report.append(
3097                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3098             )
3099         return ", ".join(report) + "."
3100
3101
3102 def assert_equivalent(src: str, dst: str) -> None:
3103     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3104
3105     import ast
3106     import traceback
3107
3108     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3109         """Simple visitor generating strings to compare ASTs by content."""
3110         yield f"{'  ' * depth}{node.__class__.__name__}("
3111
3112         for field in sorted(node._fields):
3113             try:
3114                 value = getattr(node, field)
3115             except AttributeError:
3116                 continue
3117
3118             yield f"{'  ' * (depth+1)}{field}="
3119
3120             if isinstance(value, list):
3121                 for item in value:
3122                     if isinstance(item, ast.AST):
3123                         yield from _v(item, depth + 2)
3124
3125             elif isinstance(value, ast.AST):
3126                 yield from _v(value, depth + 2)
3127
3128             else:
3129                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3130
3131         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3132
3133     try:
3134         src_ast = ast.parse(src)
3135     except Exception as exc:
3136         major, minor = sys.version_info[:2]
3137         raise AssertionError(
3138             f"cannot use --safe with this file; failed to parse source file "
3139             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3140             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3141         )
3142
3143     try:
3144         dst_ast = ast.parse(dst)
3145     except Exception as exc:
3146         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3147         raise AssertionError(
3148             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3149             f"Please report a bug on https://github.com/ambv/black/issues.  "
3150             f"This invalid output might be helpful: {log}"
3151         ) from None
3152
3153     src_ast_str = "\n".join(_v(src_ast))
3154     dst_ast_str = "\n".join(_v(dst_ast))
3155     if src_ast_str != dst_ast_str:
3156         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3157         raise AssertionError(
3158             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3159             f"the source.  "
3160             f"Please report a bug on https://github.com/ambv/black/issues.  "
3161             f"This diff might be helpful: {log}"
3162         ) from None
3163
3164
3165 def assert_stable(
3166     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3167 ) -> None:
3168     """Raise AssertionError if `dst` reformats differently the second time."""
3169     newdst = format_str(dst, line_length=line_length, mode=mode)
3170     if dst != newdst:
3171         log = dump_to_file(
3172             diff(src, dst, "source", "first pass"),
3173             diff(dst, newdst, "first pass", "second pass"),
3174         )
3175         raise AssertionError(
3176             f"INTERNAL ERROR: Black produced different code on the second pass "
3177             f"of the formatter.  "
3178             f"Please report a bug on https://github.com/ambv/black/issues.  "
3179             f"This diff might be helpful: {log}"
3180         ) from None
3181
3182
3183 def dump_to_file(*output: str) -> str:
3184     """Dump `output` to a temporary file. Return path to the file."""
3185     import tempfile
3186
3187     with tempfile.NamedTemporaryFile(
3188         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3189     ) as f:
3190         for lines in output:
3191             f.write(lines)
3192             if lines and lines[-1] != "\n":
3193                 f.write("\n")
3194     return f.name
3195
3196
3197 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3198     """Return a unified diff string between strings `a` and `b`."""
3199     import difflib
3200
3201     a_lines = [line + "\n" for line in a.split("\n")]
3202     b_lines = [line + "\n" for line in b.split("\n")]
3203     return "".join(
3204         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3205     )
3206
3207
3208 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3209     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3210     err("Aborted!")
3211     for task in tasks:
3212         task.cancel()
3213
3214
3215 def shutdown(loop: BaseEventLoop) -> None:
3216     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3217     try:
3218         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3219         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3220         if not to_cancel:
3221             return
3222
3223         for task in to_cancel:
3224             task.cancel()
3225         loop.run_until_complete(
3226             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3227         )
3228     finally:
3229         # `concurrent.futures.Future` objects cannot be cancelled once they
3230         # are already running. There might be some when the `shutdown()` happened.
3231         # Silence their logger's spew about the event loop being closed.
3232         cf_logger = logging.getLogger("concurrent.futures")
3233         cf_logger.setLevel(logging.CRITICAL)
3234         loop.close()
3235
3236
3237 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3238     """Replace `regex` with `replacement` twice on `original`.
3239
3240     This is used by string normalization to perform replaces on
3241     overlapping matches.
3242     """
3243     return regex.sub(replacement, regex.sub(replacement, original))
3244
3245
3246 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3247     """Compile a regular expression string in `regex`.
3248
3249     If it contains newlines, use verbose mode.
3250     """
3251     if "\n" in regex:
3252         regex = "(?x)" + regex
3253     return re.compile(regex)
3254
3255
3256 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3257     """Like `reversed(enumerate(sequence))` if that were possible."""
3258     index = len(sequence) - 1
3259     for element in reversed(sequence):
3260         yield (index, element)
3261         index -= 1
3262
3263
3264 def enumerate_with_length(
3265     line: Line, reversed: bool = False
3266 ) -> Iterator[Tuple[Index, Leaf, int]]:
3267     """Return an enumeration of leaves with their length.
3268
3269     Stops prematurely on multiline strings and standalone comments.
3270     """
3271     op = cast(
3272         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3273         enumerate_reversed if reversed else enumerate,
3274     )
3275     for index, leaf in op(line.leaves):
3276         length = len(leaf.prefix) + len(leaf.value)
3277         if "\n" in leaf.value:
3278             return  # Multiline strings, we can't continue.
3279
3280         comment: Optional[Leaf]
3281         for comment in line.comments_after(leaf, index):
3282             length += len(comment.value)
3283
3284         yield index, leaf, length
3285
3286
3287 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3288     """Return True if `line` is no longer than `line_length`.
3289
3290     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3291     """
3292     if not line_str:
3293         line_str = str(line).strip("\n")
3294     return (
3295         len(line_str) <= line_length
3296         and "\n" not in line_str  # multiline strings
3297         and not line.contains_standalone_comments()
3298     )
3299
3300
3301 def can_be_split(line: Line) -> bool:
3302     """Return False if the line cannot be split *for sure*.
3303
3304     This is not an exhaustive search but a cheap heuristic that we can use to
3305     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3306     in unnecessary parentheses).
3307     """
3308     leaves = line.leaves
3309     if len(leaves) < 2:
3310         return False
3311
3312     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3313         call_count = 0
3314         dot_count = 0
3315         next = leaves[-1]
3316         for leaf in leaves[-2::-1]:
3317             if leaf.type in OPENING_BRACKETS:
3318                 if next.type not in CLOSING_BRACKETS:
3319                     return False
3320
3321                 call_count += 1
3322             elif leaf.type == token.DOT:
3323                 dot_count += 1
3324             elif leaf.type == token.NAME:
3325                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3326                     return False
3327
3328             elif leaf.type not in CLOSING_BRACKETS:
3329                 return False
3330
3331             if dot_count > 1 and call_count > 1:
3332                 return False
3333
3334     return True
3335
3336
3337 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3338     """Does `line` have a shape safe to reformat without optional parens around it?
3339
3340     Returns True for only a subset of potentially nice looking formattings but
3341     the point is to not return false positives that end up producing lines that
3342     are too long.
3343     """
3344     bt = line.bracket_tracker
3345     if not bt.delimiters:
3346         # Without delimiters the optional parentheses are useless.
3347         return True
3348
3349     max_priority = bt.max_delimiter_priority()
3350     if bt.delimiter_count_with_priority(max_priority) > 1:
3351         # With more than one delimiter of a kind the optional parentheses read better.
3352         return False
3353
3354     if max_priority == DOT_PRIORITY:
3355         # A single stranded method call doesn't require optional parentheses.
3356         return True
3357
3358     assert len(line.leaves) >= 2, "Stranded delimiter"
3359
3360     first = line.leaves[0]
3361     second = line.leaves[1]
3362     penultimate = line.leaves[-2]
3363     last = line.leaves[-1]
3364
3365     # With a single delimiter, omit if the expression starts or ends with
3366     # a bracket.
3367     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3368         remainder = False
3369         length = 4 * line.depth
3370         for _index, leaf, leaf_length in enumerate_with_length(line):
3371             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3372                 remainder = True
3373             if remainder:
3374                 length += leaf_length
3375                 if length > line_length:
3376                     break
3377
3378                 if leaf.type in OPENING_BRACKETS:
3379                     # There are brackets we can further split on.
3380                     remainder = False
3381
3382         else:
3383             # checked the entire string and line length wasn't exceeded
3384             if len(line.leaves) == _index + 1:
3385                 return True
3386
3387         # Note: we are not returning False here because a line might have *both*
3388         # a leading opening bracket and a trailing closing bracket.  If the
3389         # opening bracket doesn't match our rule, maybe the closing will.
3390
3391     if (
3392         last.type == token.RPAR
3393         or last.type == token.RBRACE
3394         or (
3395             # don't use indexing for omitting optional parentheses;
3396             # it looks weird
3397             last.type == token.RSQB
3398             and last.parent
3399             and last.parent.type != syms.trailer
3400         )
3401     ):
3402         if penultimate.type in OPENING_BRACKETS:
3403             # Empty brackets don't help.
3404             return False
3405
3406         if is_multiline_string(first):
3407             # Additional wrapping of a multiline string in this situation is
3408             # unnecessary.
3409             return True
3410
3411         length = 4 * line.depth
3412         seen_other_brackets = False
3413         for _index, leaf, leaf_length in enumerate_with_length(line):
3414             length += leaf_length
3415             if leaf is last.opening_bracket:
3416                 if seen_other_brackets or length <= line_length:
3417                     return True
3418
3419             elif leaf.type in OPENING_BRACKETS:
3420                 # There are brackets we can further split on.
3421                 seen_other_brackets = True
3422
3423     return False
3424
3425
3426 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3427     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3428
3429
3430 def read_cache(line_length: int, mode: FileMode) -> Cache:
3431     """Read the cache if it exists and is well formed.
3432
3433     If it is not well formed, the call to write_cache later should resolve the issue.
3434     """
3435     cache_file = get_cache_file(line_length, mode)
3436     if not cache_file.exists():
3437         return {}
3438
3439     with cache_file.open("rb") as fobj:
3440         try:
3441             cache: Cache = pickle.load(fobj)
3442         except pickle.UnpicklingError:
3443             return {}
3444
3445     return cache
3446
3447
3448 def get_cache_info(path: Path) -> CacheInfo:
3449     """Return the information used to check if a file is already formatted or not."""
3450     stat = path.stat()
3451     return stat.st_mtime, stat.st_size
3452
3453
3454 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3455     """Split an iterable of paths in `sources` into two sets.
3456
3457     The first contains paths of files that modified on disk or are not in the
3458     cache. The other contains paths to non-modified files.
3459     """
3460     todo, done = set(), set()
3461     for src in sources:
3462         src = src.resolve()
3463         if cache.get(src) != get_cache_info(src):
3464             todo.add(src)
3465         else:
3466             done.add(src)
3467     return todo, done
3468
3469
3470 def write_cache(
3471     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3472 ) -> None:
3473     """Update the cache file."""
3474     cache_file = get_cache_file(line_length, mode)
3475     try:
3476         if not CACHE_DIR.exists():
3477             CACHE_DIR.mkdir(parents=True)
3478         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3479         with cache_file.open("wb") as fobj:
3480             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3481     except OSError:
3482         pass
3483
3484
3485 def patch_click() -> None:
3486     """Make Click not crash.
3487
3488     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3489     default which restricts paths that it can access during the lifetime of the
3490     application.  Click refuses to work in this scenario by raising a RuntimeError.
3491
3492     In case of Black the likelihood that non-ASCII characters are going to be used in
3493     file paths is minimal since it's Python source code.  Moreover, this crash was
3494     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3495     """
3496     try:
3497         from click import core
3498         from click import _unicodefun  # type: ignore
3499     except ModuleNotFoundError:
3500         return
3501
3502     for module in (core, _unicodefun):
3503         if hasattr(module, "_verify_python3_env"):
3504             module._verify_python3_env = lambda: None
3505
3506
3507 if __name__ == "__main__":
3508     patch_click()
3509     main()