]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Suggest BufWritePre instead of BufWritePost for vi (#376)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import lru_cache, partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Sequence,
30     Set,
31     Tuple,
32     TypeVar,
33     Union,
34     cast,
35 )
36
37 from appdirs import user_cache_dir
38 from attr import dataclass, Factory
39 import click
40 import toml
41
42 # lib2to3 fork
43 from blib2to3.pytree import Node, Leaf, type_repr
44 from blib2to3 import pygram, pytree
45 from blib2to3.pgen2 import driver, token
46 from blib2to3.pgen2.parse import ParseError
47
48
49 __version__ = "18.6b4"
50 DEFAULT_LINE_LENGTH = 88
51 DEFAULT_EXCLUDES = (
52     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
53 )
54 DEFAULT_INCLUDES = r"\.pyi?$"
55 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
56
57
58 # types
59 FileContent = str
60 Encoding = str
61 NewLine = str
62 Depth = int
63 NodeType = int
64 LeafID = int
65 Priority = int
66 Index = int
67 LN = Union[Leaf, Node]
68 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
69 Timestamp = float
70 FileSize = int
71 CacheInfo = Tuple[Timestamp, FileSize]
72 Cache = Dict[Path, CacheInfo]
73 out = partial(click.secho, bold=True, err=True)
74 err = partial(click.secho, fg="red", err=True)
75
76 pygram.initialize(CACHE_DIR)
77 syms = pygram.python_symbols
78
79
80 class NothingChanged(UserWarning):
81     """Raised by :func:`format_file` when reformatted code is the same as source."""
82
83
84 class CannotSplit(Exception):
85     """A readable split that fits the allotted line length is impossible.
86
87     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
88     :func:`delimiter_split`.
89     """
90
91
92 class WriteBack(Enum):
93     NO = 0
94     YES = 1
95     DIFF = 2
96
97     @classmethod
98     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
99         if check and not diff:
100             return cls.NO
101
102         return cls.DIFF if diff else cls.YES
103
104
105 class Changed(Enum):
106     NO = 0
107     CACHED = 1
108     YES = 2
109
110
111 class FileMode(Flag):
112     AUTO_DETECT = 0
113     PYTHON36 = 1
114     PYI = 2
115     NO_STRING_NORMALIZATION = 4
116
117     @classmethod
118     def from_configuration(
119         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
120     ) -> "FileMode":
121         mode = cls.AUTO_DETECT
122         if py36:
123             mode |= cls.PYTHON36
124         if pyi:
125             mode |= cls.PYI
126         if skip_string_normalization:
127             mode |= cls.NO_STRING_NORMALIZATION
128         return mode
129
130
131 def read_pyproject_toml(
132     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
133 ) -> Optional[str]:
134     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
135
136     Returns the path to a successfully found and read configuration file, None
137     otherwise.
138     """
139     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
140     if not value:
141         root = find_project_root(ctx.params.get("src", ()))
142         path = root / "pyproject.toml"
143         if path.is_file():
144             value = str(path)
145         else:
146             return None
147
148     try:
149         pyproject_toml = toml.load(value)
150         config = pyproject_toml.get("tool", {}).get("black", {})
151     except (toml.TomlDecodeError, OSError) as e:
152         raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
153
154     if not config:
155         return None
156
157     if ctx.default_map is None:
158         ctx.default_map = {}
159     ctx.default_map.update(  # type: ignore  # bad types in .pyi
160         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
161     )
162     return value
163
164
165 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
166 @click.option(
167     "-l",
168     "--line-length",
169     type=int,
170     default=DEFAULT_LINE_LENGTH,
171     help="How many character per line to allow.",
172     show_default=True,
173 )
174 @click.option(
175     "--py36",
176     is_flag=True,
177     help=(
178         "Allow using Python 3.6-only syntax on all input files.  This will put "
179         "trailing commas in function signatures and calls also after *args and "
180         "**kwargs.  [default: per-file auto-detection]"
181     ),
182 )
183 @click.option(
184     "--pyi",
185     is_flag=True,
186     help=(
187         "Format all input files like typing stubs regardless of file extension "
188         "(useful when piping source on standard input)."
189     ),
190 )
191 @click.option(
192     "-S",
193     "--skip-string-normalization",
194     is_flag=True,
195     help="Don't normalize string quotes or prefixes.",
196 )
197 @click.option(
198     "--check",
199     is_flag=True,
200     help=(
201         "Don't write the files back, just return the status.  Return code 0 "
202         "means nothing would change.  Return code 1 means some files would be "
203         "reformatted.  Return code 123 means there was an internal error."
204     ),
205 )
206 @click.option(
207     "--diff",
208     is_flag=True,
209     help="Don't write the files back, just output a diff for each file on stdout.",
210 )
211 @click.option(
212     "--fast/--safe",
213     is_flag=True,
214     help="If --fast given, skip temporary sanity checks. [default: --safe]",
215 )
216 @click.option(
217     "--include",
218     type=str,
219     default=DEFAULT_INCLUDES,
220     help=(
221         "A regular expression that matches files and directories that should be "
222         "included on recursive searches.  An empty value means all files are "
223         "included regardless of the name.  Use forward slashes for directories on "
224         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
225         "later."
226     ),
227     show_default=True,
228 )
229 @click.option(
230     "--exclude",
231     type=str,
232     default=DEFAULT_EXCLUDES,
233     help=(
234         "A regular expression that matches files and directories that should be "
235         "excluded on recursive searches.  An empty value means no paths are excluded. "
236         "Use forward slashes for directories on all platforms (Windows, too).  "
237         "Exclusions are calculated first, inclusions later."
238     ),
239     show_default=True,
240 )
241 @click.option(
242     "-q",
243     "--quiet",
244     is_flag=True,
245     help=(
246         "Don't emit non-error messages to stderr. Errors are still emitted, "
247         "silence those with 2>/dev/null."
248     ),
249 )
250 @click.option(
251     "-v",
252     "--verbose",
253     is_flag=True,
254     help=(
255         "Also emit messages to stderr about files that were not changed or were "
256         "ignored due to --exclude=."
257     ),
258 )
259 @click.version_option(version=__version__)
260 @click.argument(
261     "src",
262     nargs=-1,
263     type=click.Path(
264         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
265     ),
266     is_eager=True,
267 )
268 @click.option(
269     "--config",
270     type=click.Path(
271         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
272     ),
273     is_eager=True,
274     callback=read_pyproject_toml,
275     help="Read configuration from PATH.",
276 )
277 @click.pass_context
278 def main(
279     ctx: click.Context,
280     line_length: int,
281     check: bool,
282     diff: bool,
283     fast: bool,
284     pyi: bool,
285     py36: bool,
286     skip_string_normalization: bool,
287     quiet: bool,
288     verbose: bool,
289     include: str,
290     exclude: str,
291     src: Tuple[str],
292     config: Optional[str],
293 ) -> None:
294     """The uncompromising code formatter."""
295     write_back = WriteBack.from_configuration(check=check, diff=diff)
296     mode = FileMode.from_configuration(
297         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
298     )
299     if config and verbose:
300         out(f"Using configuration from {config}.", bold=False, fg="blue")
301     try:
302         include_regex = re_compile_maybe_verbose(include)
303     except re.error:
304         err(f"Invalid regular expression for include given: {include!r}")
305         ctx.exit(2)
306     try:
307         exclude_regex = re_compile_maybe_verbose(exclude)
308     except re.error:
309         err(f"Invalid regular expression for exclude given: {exclude!r}")
310         ctx.exit(2)
311     report = Report(check=check, quiet=quiet, verbose=verbose)
312     root = find_project_root(src)
313     sources: Set[Path] = set()
314     for s in src:
315         p = Path(s)
316         if p.is_dir():
317             sources.update(
318                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
319             )
320         elif p.is_file() or s == "-":
321             # if a file was explicitly given, we don't care about its extension
322             sources.add(p)
323         else:
324             err(f"invalid path: {s}")
325     if len(sources) == 0:
326         if verbose or not quiet:
327             out("No paths given. Nothing to do 😴")
328         ctx.exit(0)
329
330     if len(sources) == 1:
331         reformat_one(
332             src=sources.pop(),
333             line_length=line_length,
334             fast=fast,
335             write_back=write_back,
336             mode=mode,
337             report=report,
338         )
339     else:
340         loop = asyncio.get_event_loop()
341         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
342         try:
343             loop.run_until_complete(
344                 schedule_formatting(
345                     sources=sources,
346                     line_length=line_length,
347                     fast=fast,
348                     write_back=write_back,
349                     mode=mode,
350                     report=report,
351                     loop=loop,
352                     executor=executor,
353                 )
354             )
355         finally:
356             shutdown(loop)
357     if verbose or not quiet:
358         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
359         out(f"All done! {bang}")
360         click.secho(str(report), err=True)
361     ctx.exit(report.return_code)
362
363
364 def reformat_one(
365     src: Path,
366     line_length: int,
367     fast: bool,
368     write_back: WriteBack,
369     mode: FileMode,
370     report: "Report",
371 ) -> None:
372     """Reformat a single file under `src` without spawning child processes.
373
374     If `quiet` is True, non-error messages are not output. `line_length`,
375     `write_back`, `fast` and `pyi` options are passed to
376     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
377     """
378     try:
379         changed = Changed.NO
380         if not src.is_file() and str(src) == "-":
381             if format_stdin_to_stdout(
382                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
383             ):
384                 changed = Changed.YES
385         else:
386             cache: Cache = {}
387             if write_back != WriteBack.DIFF:
388                 cache = read_cache(line_length, mode)
389                 res_src = src.resolve()
390                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
391                     changed = Changed.CACHED
392             if changed is not Changed.CACHED and format_file_in_place(
393                 src,
394                 line_length=line_length,
395                 fast=fast,
396                 write_back=write_back,
397                 mode=mode,
398             ):
399                 changed = Changed.YES
400             if write_back == WriteBack.YES and changed is not Changed.NO:
401                 write_cache(cache, [src], line_length, mode)
402         report.done(src, changed)
403     except Exception as exc:
404         report.failed(src, str(exc))
405
406
407 async def schedule_formatting(
408     sources: Set[Path],
409     line_length: int,
410     fast: bool,
411     write_back: WriteBack,
412     mode: FileMode,
413     report: "Report",
414     loop: BaseEventLoop,
415     executor: Executor,
416 ) -> None:
417     """Run formatting of `sources` in parallel using the provided `executor`.
418
419     (Use ProcessPoolExecutors for actual parallelism.)
420
421     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
422     :func:`format_file_in_place`.
423     """
424     cache: Cache = {}
425     if write_back != WriteBack.DIFF:
426         cache = read_cache(line_length, mode)
427         sources, cached = filter_cached(cache, sources)
428         for src in sorted(cached):
429             report.done(src, Changed.CACHED)
430     cancelled = []
431     formatted = []
432     if sources:
433         lock = None
434         if write_back == WriteBack.DIFF:
435             # For diff output, we need locks to ensure we don't interleave output
436             # from different processes.
437             manager = Manager()
438             lock = manager.Lock()
439         tasks = {
440             loop.run_in_executor(
441                 executor,
442                 format_file_in_place,
443                 src,
444                 line_length,
445                 fast,
446                 write_back,
447                 mode,
448                 lock,
449             ): src
450             for src in sorted(sources)
451         }
452         pending: Iterable[asyncio.Task] = tasks.keys()
453         try:
454             loop.add_signal_handler(signal.SIGINT, cancel, pending)
455             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
456         except NotImplementedError:
457             # There are no good alternatives for these on Windows
458             pass
459         while pending:
460             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
461             for task in done:
462                 src = tasks.pop(task)
463                 if task.cancelled():
464                     cancelled.append(task)
465                 elif task.exception():
466                     report.failed(src, str(task.exception()))
467                 else:
468                     formatted.append(src)
469                     report.done(src, Changed.YES if task.result() else Changed.NO)
470     if cancelled:
471         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
472     if write_back == WriteBack.YES and formatted:
473         write_cache(cache, formatted, line_length, mode)
474
475
476 def format_file_in_place(
477     src: Path,
478     line_length: int,
479     fast: bool,
480     write_back: WriteBack = WriteBack.NO,
481     mode: FileMode = FileMode.AUTO_DETECT,
482     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
483 ) -> bool:
484     """Format file under `src` path. Return True if changed.
485
486     If `write_back` is True, write reformatted code back to stdout.
487     `line_length` and `fast` options are passed to :func:`format_file_contents`.
488     """
489     if src.suffix == ".pyi":
490         mode |= FileMode.PYI
491
492     then = datetime.utcfromtimestamp(src.stat().st_mtime)
493     with open(src, "rb") as buf:
494         src_contents, encoding, newline = decode_bytes(buf.read())
495     try:
496         dst_contents = format_file_contents(
497             src_contents, line_length=line_length, fast=fast, mode=mode
498         )
499     except NothingChanged:
500         return False
501
502     if write_back == write_back.YES:
503         with open(src, "w", encoding=encoding, newline=newline) as f:
504             f.write(dst_contents)
505     elif write_back == write_back.DIFF:
506         now = datetime.utcnow()
507         src_name = f"{src}\t{then} +0000"
508         dst_name = f"{src}\t{now} +0000"
509         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
510         if lock:
511             lock.acquire()
512         try:
513             f = io.TextIOWrapper(
514                 sys.stdout.buffer,
515                 encoding=encoding,
516                 newline=newline,
517                 write_through=True,
518             )
519             f.write(diff_contents)
520             f.detach()
521         finally:
522             if lock:
523                 lock.release()
524     return True
525
526
527 def format_stdin_to_stdout(
528     line_length: int,
529     fast: bool,
530     write_back: WriteBack = WriteBack.NO,
531     mode: FileMode = FileMode.AUTO_DETECT,
532 ) -> bool:
533     """Format file on stdin. Return True if changed.
534
535     If `write_back` is True, write reformatted code back to stdout.
536     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
537     :func:`format_file_contents`.
538     """
539     then = datetime.utcnow()
540     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
541     dst = src
542     try:
543         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
544         return True
545
546     except NothingChanged:
547         return False
548
549     finally:
550         f = io.TextIOWrapper(
551             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
552         )
553         if write_back == WriteBack.YES:
554             f.write(dst)
555         elif write_back == WriteBack.DIFF:
556             now = datetime.utcnow()
557             src_name = f"STDIN\t{then} +0000"
558             dst_name = f"STDOUT\t{now} +0000"
559             f.write(diff(src, dst, src_name, dst_name))
560         f.detach()
561
562
563 def format_file_contents(
564     src_contents: str,
565     *,
566     line_length: int,
567     fast: bool,
568     mode: FileMode = FileMode.AUTO_DETECT,
569 ) -> FileContent:
570     """Reformat contents a file and return new contents.
571
572     If `fast` is False, additionally confirm that the reformatted code is
573     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
574     `line_length` is passed to :func:`format_str`.
575     """
576     if src_contents.strip() == "":
577         raise NothingChanged
578
579     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
580     if src_contents == dst_contents:
581         raise NothingChanged
582
583     if not fast:
584         assert_equivalent(src_contents, dst_contents)
585         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
586     return dst_contents
587
588
589 def format_str(
590     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
591 ) -> FileContent:
592     """Reformat a string and return new contents.
593
594     `line_length` determines how many characters per line are allowed.
595     """
596     src_node = lib2to3_parse(src_contents)
597     dst_contents = ""
598     future_imports = get_future_imports(src_node)
599     is_pyi = bool(mode & FileMode.PYI)
600     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
601     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
602     normalize_fmt_off(src_node)
603     lines = LineGenerator(
604         remove_u_prefix=py36 or "unicode_literals" in future_imports,
605         is_pyi=is_pyi,
606         normalize_strings=normalize_strings,
607     )
608     elt = EmptyLineTracker(is_pyi=is_pyi)
609     empty_line = Line()
610     after = 0
611     for current_line in lines.visit(src_node):
612         for _ in range(after):
613             dst_contents += str(empty_line)
614         before, after = elt.maybe_empty_lines(current_line)
615         for _ in range(before):
616             dst_contents += str(empty_line)
617         for line in split_line(current_line, line_length=line_length, py36=py36):
618             dst_contents += str(line)
619     return dst_contents
620
621
622 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
623     """Return a tuple of (decoded_contents, encoding, newline).
624
625     `newline` is either CRLF or LF but `decoded_contents` is decoded with
626     universal newlines (i.e. only contains LF).
627     """
628     srcbuf = io.BytesIO(src)
629     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
630     if not lines:
631         return "", encoding, "\n"
632
633     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
634     srcbuf.seek(0)
635     with io.TextIOWrapper(srcbuf, encoding) as tiow:
636         return tiow.read(), encoding, newline
637
638
639 GRAMMARS = [
640     pygram.python_grammar_no_print_statement_no_exec_statement,
641     pygram.python_grammar_no_print_statement,
642     pygram.python_grammar,
643 ]
644
645
646 def lib2to3_parse(src_txt: str) -> Node:
647     """Given a string with source, return the lib2to3 Node."""
648     grammar = pygram.python_grammar_no_print_statement
649     if src_txt[-1:] != "\n":
650         src_txt += "\n"
651     for grammar in GRAMMARS:
652         drv = driver.Driver(grammar, pytree.convert)
653         try:
654             result = drv.parse_string(src_txt, True)
655             break
656
657         except ParseError as pe:
658             lineno, column = pe.context[1]
659             lines = src_txt.splitlines()
660             try:
661                 faulty_line = lines[lineno - 1]
662             except IndexError:
663                 faulty_line = "<line number missing in source>"
664             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
665     else:
666         raise exc from None
667
668     if isinstance(result, Leaf):
669         result = Node(syms.file_input, [result])
670     return result
671
672
673 def lib2to3_unparse(node: Node) -> str:
674     """Given a lib2to3 node, return its string representation."""
675     code = str(node)
676     return code
677
678
679 T = TypeVar("T")
680
681
682 class Visitor(Generic[T]):
683     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
684
685     def visit(self, node: LN) -> Iterator[T]:
686         """Main method to visit `node` and its children.
687
688         It tries to find a `visit_*()` method for the given `node.type`, like
689         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
690         If no dedicated `visit_*()` method is found, chooses `visit_default()`
691         instead.
692
693         Then yields objects of type `T` from the selected visitor.
694         """
695         if node.type < 256:
696             name = token.tok_name[node.type]
697         else:
698             name = type_repr(node.type)
699         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
700
701     def visit_default(self, node: LN) -> Iterator[T]:
702         """Default `visit_*()` implementation. Recurses to children of `node`."""
703         if isinstance(node, Node):
704             for child in node.children:
705                 yield from self.visit(child)
706
707
708 @dataclass
709 class DebugVisitor(Visitor[T]):
710     tree_depth: int = 0
711
712     def visit_default(self, node: LN) -> Iterator[T]:
713         indent = " " * (2 * self.tree_depth)
714         if isinstance(node, Node):
715             _type = type_repr(node.type)
716             out(f"{indent}{_type}", fg="yellow")
717             self.tree_depth += 1
718             for child in node.children:
719                 yield from self.visit(child)
720
721             self.tree_depth -= 1
722             out(f"{indent}/{_type}", fg="yellow", bold=False)
723         else:
724             _type = token.tok_name.get(node.type, str(node.type))
725             out(f"{indent}{_type}", fg="blue", nl=False)
726             if node.prefix:
727                 # We don't have to handle prefixes for `Node` objects since
728                 # that delegates to the first child anyway.
729                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
730             out(f" {node.value!r}", fg="blue", bold=False)
731
732     @classmethod
733     def show(cls, code: Union[str, Leaf, Node]) -> None:
734         """Pretty-print the lib2to3 AST of a given string of `code`.
735
736         Convenience method for debugging.
737         """
738         v: DebugVisitor[None] = DebugVisitor()
739         if isinstance(code, str):
740             code = lib2to3_parse(code)
741         list(v.visit(code))
742
743
744 KEYWORDS = set(keyword.kwlist)
745 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
746 FLOW_CONTROL = {"return", "raise", "break", "continue"}
747 STATEMENT = {
748     syms.if_stmt,
749     syms.while_stmt,
750     syms.for_stmt,
751     syms.try_stmt,
752     syms.except_clause,
753     syms.with_stmt,
754     syms.funcdef,
755     syms.classdef,
756 }
757 STANDALONE_COMMENT = 153
758 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
759 LOGIC_OPERATORS = {"and", "or"}
760 COMPARATORS = {
761     token.LESS,
762     token.GREATER,
763     token.EQEQUAL,
764     token.NOTEQUAL,
765     token.LESSEQUAL,
766     token.GREATEREQUAL,
767 }
768 MATH_OPERATORS = {
769     token.VBAR,
770     token.CIRCUMFLEX,
771     token.AMPER,
772     token.LEFTSHIFT,
773     token.RIGHTSHIFT,
774     token.PLUS,
775     token.MINUS,
776     token.STAR,
777     token.SLASH,
778     token.DOUBLESLASH,
779     token.PERCENT,
780     token.AT,
781     token.TILDE,
782     token.DOUBLESTAR,
783 }
784 STARS = {token.STAR, token.DOUBLESTAR}
785 VARARGS_PARENTS = {
786     syms.arglist,
787     syms.argument,  # double star in arglist
788     syms.trailer,  # single argument to call
789     syms.typedargslist,
790     syms.varargslist,  # lambdas
791 }
792 UNPACKING_PARENTS = {
793     syms.atom,  # single element of a list or set literal
794     syms.dictsetmaker,
795     syms.listmaker,
796     syms.testlist_gexp,
797     syms.testlist_star_expr,
798 }
799 SURROUNDED_BY_BRACKETS = {
800     syms.typedargslist,
801     syms.arglist,
802     syms.subscriptlist,
803     syms.vfplist,
804     syms.import_as_names,
805     syms.yield_expr,
806     syms.testlist_gexp,
807     syms.testlist_star_expr,
808     syms.listmaker,
809     syms.dictsetmaker,
810 }
811 TEST_DESCENDANTS = {
812     syms.test,
813     syms.lambdef,
814     syms.or_test,
815     syms.and_test,
816     syms.not_test,
817     syms.comparison,
818     syms.star_expr,
819     syms.expr,
820     syms.xor_expr,
821     syms.and_expr,
822     syms.shift_expr,
823     syms.arith_expr,
824     syms.trailer,
825     syms.term,
826     syms.power,
827 }
828 ASSIGNMENTS = {
829     "=",
830     "+=",
831     "-=",
832     "*=",
833     "@=",
834     "/=",
835     "%=",
836     "&=",
837     "|=",
838     "^=",
839     "<<=",
840     ">>=",
841     "**=",
842     "//=",
843 }
844 COMPREHENSION_PRIORITY = 20
845 COMMA_PRIORITY = 18
846 TERNARY_PRIORITY = 16
847 LOGIC_PRIORITY = 14
848 STRING_PRIORITY = 12
849 COMPARATOR_PRIORITY = 10
850 MATH_PRIORITIES = {
851     token.VBAR: 9,
852     token.CIRCUMFLEX: 8,
853     token.AMPER: 7,
854     token.LEFTSHIFT: 6,
855     token.RIGHTSHIFT: 6,
856     token.PLUS: 5,
857     token.MINUS: 5,
858     token.STAR: 4,
859     token.SLASH: 4,
860     token.DOUBLESLASH: 4,
861     token.PERCENT: 4,
862     token.AT: 4,
863     token.TILDE: 3,
864     token.DOUBLESTAR: 2,
865 }
866 DOT_PRIORITY = 1
867
868
869 @dataclass
870 class BracketTracker:
871     """Keeps track of brackets on a line."""
872
873     depth: int = 0
874     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
875     delimiters: Dict[LeafID, Priority] = Factory(dict)
876     previous: Optional[Leaf] = None
877     _for_loop_variable: int = 0
878     _lambda_arguments: int = 0
879
880     def mark(self, leaf: Leaf) -> None:
881         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
882
883         All leaves receive an int `bracket_depth` field that stores how deep
884         within brackets a given leaf is. 0 means there are no enclosing brackets
885         that started on this line.
886
887         If a leaf is itself a closing bracket, it receives an `opening_bracket`
888         field that it forms a pair with. This is a one-directional link to
889         avoid reference cycles.
890
891         If a leaf is a delimiter (a token on which Black can split the line if
892         needed) and it's on depth 0, its `id()` is stored in the tracker's
893         `delimiters` field.
894         """
895         if leaf.type == token.COMMENT:
896             return
897
898         self.maybe_decrement_after_for_loop_variable(leaf)
899         self.maybe_decrement_after_lambda_arguments(leaf)
900         if leaf.type in CLOSING_BRACKETS:
901             self.depth -= 1
902             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
903             leaf.opening_bracket = opening_bracket
904         leaf.bracket_depth = self.depth
905         if self.depth == 0:
906             delim = is_split_before_delimiter(leaf, self.previous)
907             if delim and self.previous is not None:
908                 self.delimiters[id(self.previous)] = delim
909             else:
910                 delim = is_split_after_delimiter(leaf, self.previous)
911                 if delim:
912                     self.delimiters[id(leaf)] = delim
913         if leaf.type in OPENING_BRACKETS:
914             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
915             self.depth += 1
916         self.previous = leaf
917         self.maybe_increment_lambda_arguments(leaf)
918         self.maybe_increment_for_loop_variable(leaf)
919
920     def any_open_brackets(self) -> bool:
921         """Return True if there is an yet unmatched open bracket on the line."""
922         return bool(self.bracket_match)
923
924     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
925         """Return the highest priority of a delimiter found on the line.
926
927         Values are consistent with what `is_split_*_delimiter()` return.
928         Raises ValueError on no delimiters.
929         """
930         return max(v for k, v in self.delimiters.items() if k not in exclude)
931
932     def delimiter_count_with_priority(self, priority: int = 0) -> int:
933         """Return the number of delimiters with the given `priority`.
934
935         If no `priority` is passed, defaults to max priority on the line.
936         """
937         if not self.delimiters:
938             return 0
939
940         priority = priority or self.max_delimiter_priority()
941         return sum(1 for p in self.delimiters.values() if p == priority)
942
943     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
944         """In a for loop, or comprehension, the variables are often unpacks.
945
946         To avoid splitting on the comma in this situation, increase the depth of
947         tokens between `for` and `in`.
948         """
949         if leaf.type == token.NAME and leaf.value == "for":
950             self.depth += 1
951             self._for_loop_variable += 1
952             return True
953
954         return False
955
956     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
957         """See `maybe_increment_for_loop_variable` above for explanation."""
958         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
959             self.depth -= 1
960             self._for_loop_variable -= 1
961             return True
962
963         return False
964
965     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
966         """In a lambda expression, there might be more than one argument.
967
968         To avoid splitting on the comma in this situation, increase the depth of
969         tokens between `lambda` and `:`.
970         """
971         if leaf.type == token.NAME and leaf.value == "lambda":
972             self.depth += 1
973             self._lambda_arguments += 1
974             return True
975
976         return False
977
978     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
979         """See `maybe_increment_lambda_arguments` above for explanation."""
980         if self._lambda_arguments and leaf.type == token.COLON:
981             self.depth -= 1
982             self._lambda_arguments -= 1
983             return True
984
985         return False
986
987     def get_open_lsqb(self) -> Optional[Leaf]:
988         """Return the most recent opening square bracket (if any)."""
989         return self.bracket_match.get((self.depth - 1, token.RSQB))
990
991
992 @dataclass
993 class Line:
994     """Holds leaves and comments. Can be printed with `str(line)`."""
995
996     depth: int = 0
997     leaves: List[Leaf] = Factory(list)
998     comments: List[Tuple[Index, Leaf]] = Factory(list)
999     bracket_tracker: BracketTracker = Factory(BracketTracker)
1000     inside_brackets: bool = False
1001     should_explode: bool = False
1002
1003     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1004         """Add a new `leaf` to the end of the line.
1005
1006         Unless `preformatted` is True, the `leaf` will receive a new consistent
1007         whitespace prefix and metadata applied by :class:`BracketTracker`.
1008         Trailing commas are maybe removed, unpacked for loop variables are
1009         demoted from being delimiters.
1010
1011         Inline comments are put aside.
1012         """
1013         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1014         if not has_value:
1015             return
1016
1017         if token.COLON == leaf.type and self.is_class_paren_empty:
1018             del self.leaves[-2:]
1019         if self.leaves and not preformatted:
1020             # Note: at this point leaf.prefix should be empty except for
1021             # imports, for which we only preserve newlines.
1022             leaf.prefix += whitespace(
1023                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1024             )
1025         if self.inside_brackets or not preformatted:
1026             self.bracket_tracker.mark(leaf)
1027             self.maybe_remove_trailing_comma(leaf)
1028         if not self.append_comment(leaf):
1029             self.leaves.append(leaf)
1030
1031     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1032         """Like :func:`append()` but disallow invalid standalone comment structure.
1033
1034         Raises ValueError when any `leaf` is appended after a standalone comment
1035         or when a standalone comment is not the first leaf on the line.
1036         """
1037         if self.bracket_tracker.depth == 0:
1038             if self.is_comment:
1039                 raise ValueError("cannot append to standalone comments")
1040
1041             if self.leaves and leaf.type == STANDALONE_COMMENT:
1042                 raise ValueError(
1043                     "cannot append standalone comments to a populated line"
1044                 )
1045
1046         self.append(leaf, preformatted=preformatted)
1047
1048     @property
1049     def is_comment(self) -> bool:
1050         """Is this line a standalone comment?"""
1051         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1052
1053     @property
1054     def is_decorator(self) -> bool:
1055         """Is this line a decorator?"""
1056         return bool(self) and self.leaves[0].type == token.AT
1057
1058     @property
1059     def is_import(self) -> bool:
1060         """Is this an import line?"""
1061         return bool(self) and is_import(self.leaves[0])
1062
1063     @property
1064     def is_class(self) -> bool:
1065         """Is this line a class definition?"""
1066         return (
1067             bool(self)
1068             and self.leaves[0].type == token.NAME
1069             and self.leaves[0].value == "class"
1070         )
1071
1072     @property
1073     def is_stub_class(self) -> bool:
1074         """Is this line a class definition with a body consisting only of "..."?"""
1075         return self.is_class and self.leaves[-3:] == [
1076             Leaf(token.DOT, ".") for _ in range(3)
1077         ]
1078
1079     @property
1080     def is_def(self) -> bool:
1081         """Is this a function definition? (Also returns True for async defs.)"""
1082         try:
1083             first_leaf = self.leaves[0]
1084         except IndexError:
1085             return False
1086
1087         try:
1088             second_leaf: Optional[Leaf] = self.leaves[1]
1089         except IndexError:
1090             second_leaf = None
1091         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1092             first_leaf.type == token.ASYNC
1093             and second_leaf is not None
1094             and second_leaf.type == token.NAME
1095             and second_leaf.value == "def"
1096         )
1097
1098     @property
1099     def is_class_paren_empty(self) -> bool:
1100         """Is this a class with no base classes but using parentheses?
1101
1102         Those are unnecessary and should be removed.
1103         """
1104         return (
1105             bool(self)
1106             and len(self.leaves) == 4
1107             and self.is_class
1108             and self.leaves[2].type == token.LPAR
1109             and self.leaves[2].value == "("
1110             and self.leaves[3].type == token.RPAR
1111             and self.leaves[3].value == ")"
1112         )
1113
1114     @property
1115     def is_triple_quoted_string(self) -> bool:
1116         """Is the line a triple quoted string?"""
1117         return (
1118             bool(self)
1119             and self.leaves[0].type == token.STRING
1120             and self.leaves[0].value.startswith(('"""', "'''"))
1121         )
1122
1123     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1124         """If so, needs to be split before emitting."""
1125         for leaf in self.leaves:
1126             if leaf.type == STANDALONE_COMMENT:
1127                 if leaf.bracket_depth <= depth_limit:
1128                     return True
1129
1130         return False
1131
1132     def contains_multiline_strings(self) -> bool:
1133         for leaf in self.leaves:
1134             if is_multiline_string(leaf):
1135                 return True
1136
1137         return False
1138
1139     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1140         """Remove trailing comma if there is one and it's safe."""
1141         if not (
1142             self.leaves
1143             and self.leaves[-1].type == token.COMMA
1144             and closing.type in CLOSING_BRACKETS
1145         ):
1146             return False
1147
1148         if closing.type == token.RBRACE:
1149             self.remove_trailing_comma()
1150             return True
1151
1152         if closing.type == token.RSQB:
1153             comma = self.leaves[-1]
1154             if comma.parent and comma.parent.type == syms.listmaker:
1155                 self.remove_trailing_comma()
1156                 return True
1157
1158         # For parens let's check if it's safe to remove the comma.
1159         # Imports are always safe.
1160         if self.is_import:
1161             self.remove_trailing_comma()
1162             return True
1163
1164         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1165         # change a tuple into a different type by removing the comma.
1166         depth = closing.bracket_depth + 1
1167         commas = 0
1168         opening = closing.opening_bracket
1169         for _opening_index, leaf in enumerate(self.leaves):
1170             if leaf is opening:
1171                 break
1172
1173         else:
1174             return False
1175
1176         for leaf in self.leaves[_opening_index + 1 :]:
1177             if leaf is closing:
1178                 break
1179
1180             bracket_depth = leaf.bracket_depth
1181             if bracket_depth == depth and leaf.type == token.COMMA:
1182                 commas += 1
1183                 if leaf.parent and leaf.parent.type == syms.arglist:
1184                     commas += 1
1185                     break
1186
1187         if commas > 1:
1188             self.remove_trailing_comma()
1189             return True
1190
1191         return False
1192
1193     def append_comment(self, comment: Leaf) -> bool:
1194         """Add an inline or standalone comment to the line."""
1195         if (
1196             comment.type == STANDALONE_COMMENT
1197             and self.bracket_tracker.any_open_brackets()
1198         ):
1199             comment.prefix = ""
1200             return False
1201
1202         if comment.type != token.COMMENT:
1203             return False
1204
1205         after = len(self.leaves) - 1
1206         if after == -1:
1207             comment.type = STANDALONE_COMMENT
1208             comment.prefix = ""
1209             return False
1210
1211         else:
1212             self.comments.append((after, comment))
1213             return True
1214
1215     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1216         """Generate comments that should appear directly after `leaf`.
1217
1218         Provide a non-negative leaf `_index` to speed up the function.
1219         """
1220         if not self.comments:
1221             return
1222
1223         if _index == -1:
1224             for _index, _leaf in enumerate(self.leaves):
1225                 if leaf is _leaf:
1226                     break
1227
1228             else:
1229                 return
1230
1231         for index, comment_after in self.comments:
1232             if _index == index:
1233                 yield comment_after
1234
1235     def remove_trailing_comma(self) -> None:
1236         """Remove the trailing comma and moves the comments attached to it."""
1237         comma_index = len(self.leaves) - 1
1238         for i in range(len(self.comments)):
1239             comment_index, comment = self.comments[i]
1240             if comment_index == comma_index:
1241                 self.comments[i] = (comma_index - 1, comment)
1242         self.leaves.pop()
1243
1244     def is_complex_subscript(self, leaf: Leaf) -> bool:
1245         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1246         open_lsqb = self.bracket_tracker.get_open_lsqb()
1247         if open_lsqb is None:
1248             return False
1249
1250         subscript_start = open_lsqb.next_sibling
1251
1252         if isinstance(subscript_start, Node):
1253             if subscript_start.type == syms.listmaker:
1254                 return False
1255
1256             if subscript_start.type == syms.subscriptlist:
1257                 subscript_start = child_towards(subscript_start, leaf)
1258         return subscript_start is not None and any(
1259             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1260         )
1261
1262     def __str__(self) -> str:
1263         """Render the line."""
1264         if not self:
1265             return "\n"
1266
1267         indent = "    " * self.depth
1268         leaves = iter(self.leaves)
1269         first = next(leaves)
1270         res = f"{first.prefix}{indent}{first.value}"
1271         for leaf in leaves:
1272             res += str(leaf)
1273         for _, comment in self.comments:
1274             res += str(comment)
1275         return res + "\n"
1276
1277     def __bool__(self) -> bool:
1278         """Return True if the line has leaves or comments."""
1279         return bool(self.leaves or self.comments)
1280
1281
1282 @dataclass
1283 class EmptyLineTracker:
1284     """Provides a stateful method that returns the number of potential extra
1285     empty lines needed before and after the currently processed line.
1286
1287     Note: this tracker works on lines that haven't been split yet.  It assumes
1288     the prefix of the first leaf consists of optional newlines.  Those newlines
1289     are consumed by `maybe_empty_lines()` and included in the computation.
1290     """
1291
1292     is_pyi: bool = False
1293     previous_line: Optional[Line] = None
1294     previous_after: int = 0
1295     previous_defs: List[int] = Factory(list)
1296
1297     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1298         """Return the number of extra empty lines before and after the `current_line`.
1299
1300         This is for separating `def`, `async def` and `class` with extra empty
1301         lines (two on module-level).
1302         """
1303         before, after = self._maybe_empty_lines(current_line)
1304         before -= self.previous_after
1305         self.previous_after = after
1306         self.previous_line = current_line
1307         return before, after
1308
1309     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1310         max_allowed = 1
1311         if current_line.depth == 0:
1312             max_allowed = 1 if self.is_pyi else 2
1313         if current_line.leaves:
1314             # Consume the first leaf's extra newlines.
1315             first_leaf = current_line.leaves[0]
1316             before = first_leaf.prefix.count("\n")
1317             before = min(before, max_allowed)
1318             first_leaf.prefix = ""
1319         else:
1320             before = 0
1321         depth = current_line.depth
1322         while self.previous_defs and self.previous_defs[-1] >= depth:
1323             self.previous_defs.pop()
1324             if self.is_pyi:
1325                 before = 0 if depth else 1
1326             else:
1327                 before = 1 if depth else 2
1328         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1329             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1330
1331         if (
1332             self.previous_line
1333             and self.previous_line.is_import
1334             and not current_line.is_import
1335             and depth == self.previous_line.depth
1336         ):
1337             return (before or 1), 0
1338
1339         if (
1340             self.previous_line
1341             and self.previous_line.is_class
1342             and current_line.is_triple_quoted_string
1343         ):
1344             return before, 1
1345
1346         return before, 0
1347
1348     def _maybe_empty_lines_for_class_or_def(
1349         self, current_line: Line, before: int
1350     ) -> Tuple[int, int]:
1351         if not current_line.is_decorator:
1352             self.previous_defs.append(current_line.depth)
1353         if self.previous_line is None:
1354             # Don't insert empty lines before the first line in the file.
1355             return 0, 0
1356
1357         if self.previous_line.is_decorator:
1358             return 0, 0
1359
1360         if self.previous_line.depth < current_line.depth and (
1361             self.previous_line.is_class or self.previous_line.is_def
1362         ):
1363             return 0, 0
1364
1365         if (
1366             self.previous_line.is_comment
1367             and self.previous_line.depth == current_line.depth
1368             and before == 0
1369         ):
1370             return 0, 0
1371
1372         if self.is_pyi:
1373             if self.previous_line.depth > current_line.depth:
1374                 newlines = 1
1375             elif current_line.is_class or self.previous_line.is_class:
1376                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1377                     # No blank line between classes with an emty body
1378                     newlines = 0
1379                 else:
1380                     newlines = 1
1381             elif current_line.is_def and not self.previous_line.is_def:
1382                 # Blank line between a block of functions and a block of non-functions
1383                 newlines = 1
1384             else:
1385                 newlines = 0
1386         else:
1387             newlines = 2
1388         if current_line.depth and newlines:
1389             newlines -= 1
1390         return newlines, 0
1391
1392
1393 @dataclass
1394 class LineGenerator(Visitor[Line]):
1395     """Generates reformatted Line objects.  Empty lines are not emitted.
1396
1397     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1398     in ways that will no longer stringify to valid Python code on the tree.
1399     """
1400
1401     is_pyi: bool = False
1402     normalize_strings: bool = True
1403     current_line: Line = Factory(Line)
1404     remove_u_prefix: bool = False
1405
1406     def line(self, indent: int = 0) -> Iterator[Line]:
1407         """Generate a line.
1408
1409         If the line is empty, only emit if it makes sense.
1410         If the line is too long, split it first and then generate.
1411
1412         If any lines were generated, set up a new current_line.
1413         """
1414         if not self.current_line:
1415             self.current_line.depth += indent
1416             return  # Line is empty, don't emit. Creating a new one unnecessary.
1417
1418         complete_line = self.current_line
1419         self.current_line = Line(depth=complete_line.depth + indent)
1420         yield complete_line
1421
1422     def visit_default(self, node: LN) -> Iterator[Line]:
1423         """Default `visit_*()` implementation. Recurses to children of `node`."""
1424         if isinstance(node, Leaf):
1425             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1426             for comment in generate_comments(node):
1427                 if any_open_brackets:
1428                     # any comment within brackets is subject to splitting
1429                     self.current_line.append(comment)
1430                 elif comment.type == token.COMMENT:
1431                     # regular trailing comment
1432                     self.current_line.append(comment)
1433                     yield from self.line()
1434
1435                 else:
1436                     # regular standalone comment
1437                     yield from self.line()
1438
1439                     self.current_line.append(comment)
1440                     yield from self.line()
1441
1442             normalize_prefix(node, inside_brackets=any_open_brackets)
1443             if self.normalize_strings and node.type == token.STRING:
1444                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1445                 normalize_string_quotes(node)
1446             if node.type not in WHITESPACE:
1447                 self.current_line.append(node)
1448         yield from super().visit_default(node)
1449
1450     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1451         """Increase indentation level, maybe yield a line."""
1452         # In blib2to3 INDENT never holds comments.
1453         yield from self.line(+1)
1454         yield from self.visit_default(node)
1455
1456     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1457         """Decrease indentation level, maybe yield a line."""
1458         # The current line might still wait for trailing comments.  At DEDENT time
1459         # there won't be any (they would be prefixes on the preceding NEWLINE).
1460         # Emit the line then.
1461         yield from self.line()
1462
1463         # While DEDENT has no value, its prefix may contain standalone comments
1464         # that belong to the current indentation level.  Get 'em.
1465         yield from self.visit_default(node)
1466
1467         # Finally, emit the dedent.
1468         yield from self.line(-1)
1469
1470     def visit_stmt(
1471         self, node: Node, keywords: Set[str], parens: Set[str]
1472     ) -> Iterator[Line]:
1473         """Visit a statement.
1474
1475         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1476         `def`, `with`, `class`, `assert` and assignments.
1477
1478         The relevant Python language `keywords` for a given statement will be
1479         NAME leaves within it. This methods puts those on a separate line.
1480
1481         `parens` holds a set of string leaf values immediately after which
1482         invisible parens should be put.
1483         """
1484         normalize_invisible_parens(node, parens_after=parens)
1485         for child in node.children:
1486             if child.type == token.NAME and child.value in keywords:  # type: ignore
1487                 yield from self.line()
1488
1489             yield from self.visit(child)
1490
1491     def visit_suite(self, node: Node) -> Iterator[Line]:
1492         """Visit a suite."""
1493         if self.is_pyi and is_stub_suite(node):
1494             yield from self.visit(node.children[2])
1495         else:
1496             yield from self.visit_default(node)
1497
1498     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1499         """Visit a statement without nested statements."""
1500         is_suite_like = node.parent and node.parent.type in STATEMENT
1501         if is_suite_like:
1502             if self.is_pyi and is_stub_body(node):
1503                 yield from self.visit_default(node)
1504             else:
1505                 yield from self.line(+1)
1506                 yield from self.visit_default(node)
1507                 yield from self.line(-1)
1508
1509         else:
1510             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1511                 yield from self.line()
1512             yield from self.visit_default(node)
1513
1514     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1515         """Visit `async def`, `async for`, `async with`."""
1516         yield from self.line()
1517
1518         children = iter(node.children)
1519         for child in children:
1520             yield from self.visit(child)
1521
1522             if child.type == token.ASYNC:
1523                 break
1524
1525         internal_stmt = next(children)
1526         for child in internal_stmt.children:
1527             yield from self.visit(child)
1528
1529     def visit_decorators(self, node: Node) -> Iterator[Line]:
1530         """Visit decorators."""
1531         for child in node.children:
1532             yield from self.line()
1533             yield from self.visit(child)
1534
1535     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1536         """Remove a semicolon and put the other statement on a separate line."""
1537         yield from self.line()
1538
1539     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1540         """End of file. Process outstanding comments and end with a newline."""
1541         yield from self.visit_default(leaf)
1542         yield from self.line()
1543
1544     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1545         if not self.current_line.bracket_tracker.any_open_brackets():
1546             yield from self.line()
1547         yield from self.visit_default(leaf)
1548
1549     def __attrs_post_init__(self) -> None:
1550         """You are in a twisty little maze of passages."""
1551         v = self.visit_stmt
1552         Ø: Set[str] = set()
1553         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1554         self.visit_if_stmt = partial(
1555             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1556         )
1557         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1558         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1559         self.visit_try_stmt = partial(
1560             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1561         )
1562         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1563         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1564         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1565         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1566         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1567         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1568         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1569         self.visit_async_funcdef = self.visit_async_stmt
1570         self.visit_decorated = self.visit_decorators
1571
1572
1573 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1574 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1575 OPENING_BRACKETS = set(BRACKET.keys())
1576 CLOSING_BRACKETS = set(BRACKET.values())
1577 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1578 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1579
1580
1581 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1582     """Return whitespace prefix if needed for the given `leaf`.
1583
1584     `complex_subscript` signals whether the given leaf is part of a subscription
1585     which has non-trivial arguments, like arithmetic expressions or function calls.
1586     """
1587     NO = ""
1588     SPACE = " "
1589     DOUBLESPACE = "  "
1590     t = leaf.type
1591     p = leaf.parent
1592     v = leaf.value
1593     if t in ALWAYS_NO_SPACE:
1594         return NO
1595
1596     if t == token.COMMENT:
1597         return DOUBLESPACE
1598
1599     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1600     if t == token.COLON and p.type not in {
1601         syms.subscript,
1602         syms.subscriptlist,
1603         syms.sliceop,
1604     }:
1605         return NO
1606
1607     prev = leaf.prev_sibling
1608     if not prev:
1609         prevp = preceding_leaf(p)
1610         if not prevp or prevp.type in OPENING_BRACKETS:
1611             return NO
1612
1613         if t == token.COLON:
1614             if prevp.type == token.COLON:
1615                 return NO
1616
1617             elif prevp.type != token.COMMA and not complex_subscript:
1618                 return NO
1619
1620             return SPACE
1621
1622         if prevp.type == token.EQUAL:
1623             if prevp.parent:
1624                 if prevp.parent.type in {
1625                     syms.arglist,
1626                     syms.argument,
1627                     syms.parameters,
1628                     syms.varargslist,
1629                 }:
1630                     return NO
1631
1632                 elif prevp.parent.type == syms.typedargslist:
1633                     # A bit hacky: if the equal sign has whitespace, it means we
1634                     # previously found it's a typed argument.  So, we're using
1635                     # that, too.
1636                     return prevp.prefix
1637
1638         elif prevp.type in STARS:
1639             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1640                 return NO
1641
1642         elif prevp.type == token.COLON:
1643             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1644                 return SPACE if complex_subscript else NO
1645
1646         elif (
1647             prevp.parent
1648             and prevp.parent.type == syms.factor
1649             and prevp.type in MATH_OPERATORS
1650         ):
1651             return NO
1652
1653         elif (
1654             prevp.type == token.RIGHTSHIFT
1655             and prevp.parent
1656             and prevp.parent.type == syms.shift_expr
1657             and prevp.prev_sibling
1658             and prevp.prev_sibling.type == token.NAME
1659             and prevp.prev_sibling.value == "print"  # type: ignore
1660         ):
1661             # Python 2 print chevron
1662             return NO
1663
1664     elif prev.type in OPENING_BRACKETS:
1665         return NO
1666
1667     if p.type in {syms.parameters, syms.arglist}:
1668         # untyped function signatures or calls
1669         if not prev or prev.type != token.COMMA:
1670             return NO
1671
1672     elif p.type == syms.varargslist:
1673         # lambdas
1674         if prev and prev.type != token.COMMA:
1675             return NO
1676
1677     elif p.type == syms.typedargslist:
1678         # typed function signatures
1679         if not prev:
1680             return NO
1681
1682         if t == token.EQUAL:
1683             if prev.type != syms.tname:
1684                 return NO
1685
1686         elif prev.type == token.EQUAL:
1687             # A bit hacky: if the equal sign has whitespace, it means we
1688             # previously found it's a typed argument.  So, we're using that, too.
1689             return prev.prefix
1690
1691         elif prev.type != token.COMMA:
1692             return NO
1693
1694     elif p.type == syms.tname:
1695         # type names
1696         if not prev:
1697             prevp = preceding_leaf(p)
1698             if not prevp or prevp.type != token.COMMA:
1699                 return NO
1700
1701     elif p.type == syms.trailer:
1702         # attributes and calls
1703         if t == token.LPAR or t == token.RPAR:
1704             return NO
1705
1706         if not prev:
1707             if t == token.DOT:
1708                 prevp = preceding_leaf(p)
1709                 if not prevp or prevp.type != token.NUMBER:
1710                     return NO
1711
1712             elif t == token.LSQB:
1713                 return NO
1714
1715         elif prev.type != token.COMMA:
1716             return NO
1717
1718     elif p.type == syms.argument:
1719         # single argument
1720         if t == token.EQUAL:
1721             return NO
1722
1723         if not prev:
1724             prevp = preceding_leaf(p)
1725             if not prevp or prevp.type == token.LPAR:
1726                 return NO
1727
1728         elif prev.type in {token.EQUAL} | STARS:
1729             return NO
1730
1731     elif p.type == syms.decorator:
1732         # decorators
1733         return NO
1734
1735     elif p.type == syms.dotted_name:
1736         if prev:
1737             return NO
1738
1739         prevp = preceding_leaf(p)
1740         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1741             return NO
1742
1743     elif p.type == syms.classdef:
1744         if t == token.LPAR:
1745             return NO
1746
1747         if prev and prev.type == token.LPAR:
1748             return NO
1749
1750     elif p.type in {syms.subscript, syms.sliceop}:
1751         # indexing
1752         if not prev:
1753             assert p.parent is not None, "subscripts are always parented"
1754             if p.parent.type == syms.subscriptlist:
1755                 return SPACE
1756
1757             return NO
1758
1759         elif not complex_subscript:
1760             return NO
1761
1762     elif p.type == syms.atom:
1763         if prev and t == token.DOT:
1764             # dots, but not the first one.
1765             return NO
1766
1767     elif p.type == syms.dictsetmaker:
1768         # dict unpacking
1769         if prev and prev.type == token.DOUBLESTAR:
1770             return NO
1771
1772     elif p.type in {syms.factor, syms.star_expr}:
1773         # unary ops
1774         if not prev:
1775             prevp = preceding_leaf(p)
1776             if not prevp or prevp.type in OPENING_BRACKETS:
1777                 return NO
1778
1779             prevp_parent = prevp.parent
1780             assert prevp_parent is not None
1781             if prevp.type == token.COLON and prevp_parent.type in {
1782                 syms.subscript,
1783                 syms.sliceop,
1784             }:
1785                 return NO
1786
1787             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1788                 return NO
1789
1790         elif t in {token.NAME, token.NUMBER, token.STRING}:
1791             return NO
1792
1793     elif p.type == syms.import_from:
1794         if t == token.DOT:
1795             if prev and prev.type == token.DOT:
1796                 return NO
1797
1798         elif t == token.NAME:
1799             if v == "import":
1800                 return SPACE
1801
1802             if prev and prev.type == token.DOT:
1803                 return NO
1804
1805     elif p.type == syms.sliceop:
1806         return NO
1807
1808     return SPACE
1809
1810
1811 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1812     """Return the first leaf that precedes `node`, if any."""
1813     while node:
1814         res = node.prev_sibling
1815         if res:
1816             if isinstance(res, Leaf):
1817                 return res
1818
1819             try:
1820                 return list(res.leaves())[-1]
1821
1822             except IndexError:
1823                 return None
1824
1825         node = node.parent
1826     return None
1827
1828
1829 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1830     """Return the child of `ancestor` that contains `descendant`."""
1831     node: Optional[LN] = descendant
1832     while node and node.parent != ancestor:
1833         node = node.parent
1834     return node
1835
1836
1837 def container_of(leaf: Leaf) -> LN:
1838     """Return `leaf` or one of its ancestors that is the topmost container of it.
1839
1840     By "container" we mean a node where `leaf` is the very first child.
1841     """
1842     same_prefix = leaf.prefix
1843     container: LN = leaf
1844     while container:
1845         parent = container.parent
1846         if parent is None:
1847             break
1848
1849         if parent.children[0].prefix != same_prefix:
1850             break
1851
1852         if parent.type == syms.file_input:
1853             break
1854
1855         if parent.type in SURROUNDED_BY_BRACKETS:
1856             break
1857
1858         container = parent
1859     return container
1860
1861
1862 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1863     """Return the priority of the `leaf` delimiter, given a line break after it.
1864
1865     The delimiter priorities returned here are from those delimiters that would
1866     cause a line break after themselves.
1867
1868     Higher numbers are higher priority.
1869     """
1870     if leaf.type == token.COMMA:
1871         return COMMA_PRIORITY
1872
1873     return 0
1874
1875
1876 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1877     """Return the priority of the `leaf` delimiter, given a line before after it.
1878
1879     The delimiter priorities returned here are from those delimiters that would
1880     cause a line break before themselves.
1881
1882     Higher numbers are higher priority.
1883     """
1884     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1885         # * and ** might also be MATH_OPERATORS but in this case they are not.
1886         # Don't treat them as a delimiter.
1887         return 0
1888
1889     if (
1890         leaf.type == token.DOT
1891         and leaf.parent
1892         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1893         and (previous is None or previous.type in CLOSING_BRACKETS)
1894     ):
1895         return DOT_PRIORITY
1896
1897     if (
1898         leaf.type in MATH_OPERATORS
1899         and leaf.parent
1900         and leaf.parent.type not in {syms.factor, syms.star_expr}
1901     ):
1902         return MATH_PRIORITIES[leaf.type]
1903
1904     if leaf.type in COMPARATORS:
1905         return COMPARATOR_PRIORITY
1906
1907     if (
1908         leaf.type == token.STRING
1909         and previous is not None
1910         and previous.type == token.STRING
1911     ):
1912         return STRING_PRIORITY
1913
1914     if leaf.type != token.NAME:
1915         return 0
1916
1917     if (
1918         leaf.value == "for"
1919         and leaf.parent
1920         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1921     ):
1922         return COMPREHENSION_PRIORITY
1923
1924     if (
1925         leaf.value == "if"
1926         and leaf.parent
1927         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1928     ):
1929         return COMPREHENSION_PRIORITY
1930
1931     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1932         return TERNARY_PRIORITY
1933
1934     if leaf.value == "is":
1935         return COMPARATOR_PRIORITY
1936
1937     if (
1938         leaf.value == "in"
1939         and leaf.parent
1940         and leaf.parent.type in {syms.comp_op, syms.comparison}
1941         and not (
1942             previous is not None
1943             and previous.type == token.NAME
1944             and previous.value == "not"
1945         )
1946     ):
1947         return COMPARATOR_PRIORITY
1948
1949     if (
1950         leaf.value == "not"
1951         and leaf.parent
1952         and leaf.parent.type == syms.comp_op
1953         and not (
1954             previous is not None
1955             and previous.type == token.NAME
1956             and previous.value == "is"
1957         )
1958     ):
1959         return COMPARATOR_PRIORITY
1960
1961     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1962         return LOGIC_PRIORITY
1963
1964     return 0
1965
1966
1967 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
1968 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
1969
1970
1971 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1972     """Clean the prefix of the `leaf` and generate comments from it, if any.
1973
1974     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1975     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1976     move because it does away with modifying the grammar to include all the
1977     possible places in which comments can be placed.
1978
1979     The sad consequence for us though is that comments don't "belong" anywhere.
1980     This is why this function generates simple parentless Leaf objects for
1981     comments.  We simply don't know what the correct parent should be.
1982
1983     No matter though, we can live without this.  We really only need to
1984     differentiate between inline and standalone comments.  The latter don't
1985     share the line with any code.
1986
1987     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1988     are emitted with a fake STANDALONE_COMMENT token identifier.
1989     """
1990     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
1991         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
1992
1993
1994 @dataclass
1995 class ProtoComment:
1996     type: int  # token.COMMENT or STANDALONE_COMMENT
1997     value: str  # content of the comment
1998     newlines: int  # how many newlines before the comment
1999     consumed: int  # how many characters of the original leaf's prefix did we consume
2000
2001
2002 @lru_cache(maxsize=4096)
2003 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2004     result: List[ProtoComment] = []
2005     if not prefix or "#" not in prefix:
2006         return result
2007
2008     consumed = 0
2009     nlines = 0
2010     for index, line in enumerate(prefix.split("\n")):
2011         consumed += len(line) + 1  # adding the length of the split '\n'
2012         line = line.lstrip()
2013         if not line:
2014             nlines += 1
2015         if not line.startswith("#"):
2016             continue
2017
2018         if index == 0 and not is_endmarker:
2019             comment_type = token.COMMENT  # simple trailing comment
2020         else:
2021             comment_type = STANDALONE_COMMENT
2022         comment = make_comment(line)
2023         result.append(
2024             ProtoComment(
2025                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2026             )
2027         )
2028         nlines = 0
2029     return result
2030
2031
2032 def make_comment(content: str) -> str:
2033     """Return a consistently formatted comment from the given `content` string.
2034
2035     All comments (except for "##", "#!", "#:") should have a single space between
2036     the hash sign and the content.
2037
2038     If `content` didn't start with a hash sign, one is provided.
2039     """
2040     content = content.rstrip()
2041     if not content:
2042         return "#"
2043
2044     if content[0] == "#":
2045         content = content[1:]
2046     if content and content[0] not in " !:#":
2047         content = " " + content
2048     return "#" + content
2049
2050
2051 def split_line(
2052     line: Line, line_length: int, inner: bool = False, py36: bool = False
2053 ) -> Iterator[Line]:
2054     """Split a `line` into potentially many lines.
2055
2056     They should fit in the allotted `line_length` but might not be able to.
2057     `inner` signifies that there were a pair of brackets somewhere around the
2058     current `line`, possibly transitively. This means we can fallback to splitting
2059     by delimiters if the LHS/RHS don't yield any results.
2060
2061     If `py36` is True, splitting may generate syntax that is only compatible
2062     with Python 3.6 and later.
2063     """
2064     if line.is_comment:
2065         yield line
2066         return
2067
2068     line_str = str(line).strip("\n")
2069     if not line.should_explode and is_line_short_enough(
2070         line, line_length=line_length, line_str=line_str
2071     ):
2072         yield line
2073         return
2074
2075     split_funcs: List[SplitFunc]
2076     if line.is_def:
2077         split_funcs = [left_hand_split]
2078     else:
2079
2080         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2081             for omit in generate_trailers_to_omit(line, line_length):
2082                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2083                 if is_line_short_enough(lines[0], line_length=line_length):
2084                     yield from lines
2085                     return
2086
2087             # All splits failed, best effort split with no omits.
2088             # This mostly happens to multiline strings that are by definition
2089             # reported as not fitting a single line.
2090             yield from right_hand_split(line, py36)
2091
2092         if line.inside_brackets:
2093             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2094         else:
2095             split_funcs = [rhs]
2096     for split_func in split_funcs:
2097         # We are accumulating lines in `result` because we might want to abort
2098         # mission and return the original line in the end, or attempt a different
2099         # split altogether.
2100         result: List[Line] = []
2101         try:
2102             for l in split_func(line, py36):
2103                 if str(l).strip("\n") == line_str:
2104                     raise CannotSplit("Split function returned an unchanged result")
2105
2106                 result.extend(
2107                     split_line(l, line_length=line_length, inner=True, py36=py36)
2108                 )
2109         except CannotSplit as cs:
2110             continue
2111
2112         else:
2113             yield from result
2114             break
2115
2116     else:
2117         yield line
2118
2119
2120 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2121     """Split line into many lines, starting with the first matching bracket pair.
2122
2123     Note: this usually looks weird, only use this for function definitions.
2124     Prefer RHS otherwise.  This is why this function is not symmetrical with
2125     :func:`right_hand_split` which also handles optional parentheses.
2126     """
2127     head = Line(depth=line.depth)
2128     body = Line(depth=line.depth + 1, inside_brackets=True)
2129     tail = Line(depth=line.depth)
2130     tail_leaves: List[Leaf] = []
2131     body_leaves: List[Leaf] = []
2132     head_leaves: List[Leaf] = []
2133     current_leaves = head_leaves
2134     matching_bracket = None
2135     for leaf in line.leaves:
2136         if (
2137             current_leaves is body_leaves
2138             and leaf.type in CLOSING_BRACKETS
2139             and leaf.opening_bracket is matching_bracket
2140         ):
2141             current_leaves = tail_leaves if body_leaves else head_leaves
2142         current_leaves.append(leaf)
2143         if current_leaves is head_leaves:
2144             if leaf.type in OPENING_BRACKETS:
2145                 matching_bracket = leaf
2146                 current_leaves = body_leaves
2147     # Since body is a new indent level, remove spurious leading whitespace.
2148     if body_leaves:
2149         normalize_prefix(body_leaves[0], inside_brackets=True)
2150     # Build the new lines.
2151     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2152         for leaf in leaves:
2153             result.append(leaf, preformatted=True)
2154             for comment_after in line.comments_after(leaf):
2155                 result.append(comment_after, preformatted=True)
2156     bracket_split_succeeded_or_raise(head, body, tail)
2157     for result in (head, body, tail):
2158         if result:
2159             yield result
2160
2161
2162 def right_hand_split(
2163     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2164 ) -> Iterator[Line]:
2165     """Split line into many lines, starting with the last matching bracket pair.
2166
2167     If the split was by optional parentheses, attempt splitting without them, too.
2168     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2169     this split.
2170
2171     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2172     """
2173     head = Line(depth=line.depth)
2174     body = Line(depth=line.depth + 1, inside_brackets=True)
2175     tail = Line(depth=line.depth)
2176     tail_leaves: List[Leaf] = []
2177     body_leaves: List[Leaf] = []
2178     head_leaves: List[Leaf] = []
2179     current_leaves = tail_leaves
2180     opening_bracket = None
2181     closing_bracket = None
2182     for leaf in reversed(line.leaves):
2183         if current_leaves is body_leaves:
2184             if leaf is opening_bracket:
2185                 current_leaves = head_leaves if body_leaves else tail_leaves
2186         current_leaves.append(leaf)
2187         if current_leaves is tail_leaves:
2188             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2189                 opening_bracket = leaf.opening_bracket
2190                 closing_bracket = leaf
2191                 current_leaves = body_leaves
2192     tail_leaves.reverse()
2193     body_leaves.reverse()
2194     head_leaves.reverse()
2195     # Since body is a new indent level, remove spurious leading whitespace.
2196     if body_leaves:
2197         normalize_prefix(body_leaves[0], inside_brackets=True)
2198     if not head_leaves:
2199         # No `head` means the split failed. Either `tail` has all content or
2200         # the matching `opening_bracket` wasn't available on `line` anymore.
2201         raise CannotSplit("No brackets found")
2202
2203     # Build the new lines.
2204     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2205         for leaf in leaves:
2206             result.append(leaf, preformatted=True)
2207             for comment_after in line.comments_after(leaf):
2208                 result.append(comment_after, preformatted=True)
2209     assert opening_bracket and closing_bracket
2210     body.should_explode = should_explode(body, opening_bracket)
2211     bracket_split_succeeded_or_raise(head, body, tail)
2212     if (
2213         # the body shouldn't be exploded
2214         not body.should_explode
2215         # the opening bracket is an optional paren
2216         and opening_bracket.type == token.LPAR
2217         and not opening_bracket.value
2218         # the closing bracket is an optional paren
2219         and closing_bracket.type == token.RPAR
2220         and not closing_bracket.value
2221         # it's not an import (optional parens are the only thing we can split on
2222         # in this case; attempting a split without them is a waste of time)
2223         and not line.is_import
2224         # there are no standalone comments in the body
2225         and not body.contains_standalone_comments(0)
2226         # and we can actually remove the parens
2227         and can_omit_invisible_parens(body, line_length)
2228     ):
2229         omit = {id(closing_bracket), *omit}
2230         try:
2231             yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2232             return
2233
2234         except CannotSplit:
2235             if not (
2236                 can_be_split(body)
2237                 or is_line_short_enough(body, line_length=line_length)
2238             ):
2239                 raise CannotSplit(
2240                     "Splitting failed, body is still too long and can't be split."
2241                 )
2242
2243             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2244                 raise CannotSplit(
2245                     "The current optional pair of parentheses is bound to fail to "
2246                     "satisfy the splitting algorithm because the head or the tail "
2247                     "contains multiline strings which by definition never fit one "
2248                     "line."
2249                 )
2250
2251     ensure_visible(opening_bracket)
2252     ensure_visible(closing_bracket)
2253     for result in (head, body, tail):
2254         if result:
2255             yield result
2256
2257
2258 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2259     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2260
2261     Do nothing otherwise.
2262
2263     A left- or right-hand split is based on a pair of brackets. Content before
2264     (and including) the opening bracket is left on one line, content inside the
2265     brackets is put on a separate line, and finally content starting with and
2266     following the closing bracket is put on a separate line.
2267
2268     Those are called `head`, `body`, and `tail`, respectively. If the split
2269     produced the same line (all content in `head`) or ended up with an empty `body`
2270     and the `tail` is just the closing bracket, then it's considered failed.
2271     """
2272     tail_len = len(str(tail).strip())
2273     if not body:
2274         if tail_len == 0:
2275             raise CannotSplit("Splitting brackets produced the same line")
2276
2277         elif tail_len < 3:
2278             raise CannotSplit(
2279                 f"Splitting brackets on an empty body to save "
2280                 f"{tail_len} characters is not worth it"
2281             )
2282
2283
2284 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2285     """Normalize prefix of the first leaf in every line returned by `split_func`.
2286
2287     This is a decorator over relevant split functions.
2288     """
2289
2290     @wraps(split_func)
2291     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2292         for l in split_func(line, py36):
2293             normalize_prefix(l.leaves[0], inside_brackets=True)
2294             yield l
2295
2296     return split_wrapper
2297
2298
2299 @dont_increase_indentation
2300 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2301     """Split according to delimiters of the highest priority.
2302
2303     If `py36` is True, the split will add trailing commas also in function
2304     signatures that contain `*` and `**`.
2305     """
2306     try:
2307         last_leaf = line.leaves[-1]
2308     except IndexError:
2309         raise CannotSplit("Line empty")
2310
2311     bt = line.bracket_tracker
2312     try:
2313         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2314     except ValueError:
2315         raise CannotSplit("No delimiters found")
2316
2317     if delimiter_priority == DOT_PRIORITY:
2318         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2319             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2320
2321     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2322     lowest_depth = sys.maxsize
2323     trailing_comma_safe = True
2324
2325     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2326         """Append `leaf` to current line or to new line if appending impossible."""
2327         nonlocal current_line
2328         try:
2329             current_line.append_safe(leaf, preformatted=True)
2330         except ValueError as ve:
2331             yield current_line
2332
2333             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2334             current_line.append(leaf)
2335
2336     for index, leaf in enumerate(line.leaves):
2337         yield from append_to_line(leaf)
2338
2339         for comment_after in line.comments_after(leaf, index):
2340             yield from append_to_line(comment_after)
2341
2342         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2343         if leaf.bracket_depth == lowest_depth and is_vararg(
2344             leaf, within=VARARGS_PARENTS
2345         ):
2346             trailing_comma_safe = trailing_comma_safe and py36
2347         leaf_priority = bt.delimiters.get(id(leaf))
2348         if leaf_priority == delimiter_priority:
2349             yield current_line
2350
2351             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2352     if current_line:
2353         if (
2354             trailing_comma_safe
2355             and delimiter_priority == COMMA_PRIORITY
2356             and current_line.leaves[-1].type != token.COMMA
2357             and current_line.leaves[-1].type != STANDALONE_COMMENT
2358         ):
2359             current_line.append(Leaf(token.COMMA, ","))
2360         yield current_line
2361
2362
2363 @dont_increase_indentation
2364 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2365     """Split standalone comments from the rest of the line."""
2366     if not line.contains_standalone_comments(0):
2367         raise CannotSplit("Line does not have any standalone comments")
2368
2369     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2370
2371     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2372         """Append `leaf` to current line or to new line if appending impossible."""
2373         nonlocal current_line
2374         try:
2375             current_line.append_safe(leaf, preformatted=True)
2376         except ValueError as ve:
2377             yield current_line
2378
2379             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2380             current_line.append(leaf)
2381
2382     for index, leaf in enumerate(line.leaves):
2383         yield from append_to_line(leaf)
2384
2385         for comment_after in line.comments_after(leaf, index):
2386             yield from append_to_line(comment_after)
2387
2388     if current_line:
2389         yield current_line
2390
2391
2392 def is_import(leaf: Leaf) -> bool:
2393     """Return True if the given leaf starts an import statement."""
2394     p = leaf.parent
2395     t = leaf.type
2396     v = leaf.value
2397     return bool(
2398         t == token.NAME
2399         and (
2400             (v == "import" and p and p.type == syms.import_name)
2401             or (v == "from" and p and p.type == syms.import_from)
2402         )
2403     )
2404
2405
2406 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2407     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2408     else.
2409
2410     Note: don't use backslashes for formatting or you'll lose your voting rights.
2411     """
2412     if not inside_brackets:
2413         spl = leaf.prefix.split("#")
2414         if "\\" not in spl[0]:
2415             nl_count = spl[-1].count("\n")
2416             if len(spl) > 1:
2417                 nl_count -= 1
2418             leaf.prefix = "\n" * nl_count
2419             return
2420
2421     leaf.prefix = ""
2422
2423
2424 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2425     """Make all string prefixes lowercase.
2426
2427     If remove_u_prefix is given, also removes any u prefix from the string.
2428
2429     Note: Mutates its argument.
2430     """
2431     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2432     assert match is not None, f"failed to match string {leaf.value!r}"
2433     orig_prefix = match.group(1)
2434     new_prefix = orig_prefix.lower()
2435     if remove_u_prefix:
2436         new_prefix = new_prefix.replace("u", "")
2437     leaf.value = f"{new_prefix}{match.group(2)}"
2438
2439
2440 def normalize_string_quotes(leaf: Leaf) -> None:
2441     """Prefer double quotes but only if it doesn't cause more escaping.
2442
2443     Adds or removes backslashes as appropriate. Doesn't parse and fix
2444     strings nested in f-strings (yet).
2445
2446     Note: Mutates its argument.
2447     """
2448     value = leaf.value.lstrip("furbFURB")
2449     if value[:3] == '"""':
2450         return
2451
2452     elif value[:3] == "'''":
2453         orig_quote = "'''"
2454         new_quote = '"""'
2455     elif value[0] == '"':
2456         orig_quote = '"'
2457         new_quote = "'"
2458     else:
2459         orig_quote = "'"
2460         new_quote = '"'
2461     first_quote_pos = leaf.value.find(orig_quote)
2462     if first_quote_pos == -1:
2463         return  # There's an internal error
2464
2465     prefix = leaf.value[:first_quote_pos]
2466     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2467     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2468     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2469     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2470     if "r" in prefix.casefold():
2471         if unescaped_new_quote.search(body):
2472             # There's at least one unescaped new_quote in this raw string
2473             # so converting is impossible
2474             return
2475
2476         # Do not introduce or remove backslashes in raw strings
2477         new_body = body
2478     else:
2479         # remove unnecessary escapes
2480         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2481         if body != new_body:
2482             # Consider the string without unnecessary escapes as the original
2483             body = new_body
2484             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2485         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2486         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2487     if "f" in prefix.casefold():
2488         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2489         for m in matches:
2490             if "\\" in str(m):
2491                 # Do not introduce backslashes in interpolated expressions
2492                 return
2493     if new_quote == '"""' and new_body[-1:] == '"':
2494         # edge case:
2495         new_body = new_body[:-1] + '\\"'
2496     orig_escape_count = body.count("\\")
2497     new_escape_count = new_body.count("\\")
2498     if new_escape_count > orig_escape_count:
2499         return  # Do not introduce more escaping
2500
2501     if new_escape_count == orig_escape_count and orig_quote == '"':
2502         return  # Prefer double quotes
2503
2504     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2505
2506
2507 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2508     """Make existing optional parentheses invisible or create new ones.
2509
2510     `parens_after` is a set of string leaf values immeditely after which parens
2511     should be put.
2512
2513     Standardizes on visible parentheses for single-element tuples, and keeps
2514     existing visible parentheses for other tuples and generator expressions.
2515     """
2516     for pc in list_comments(node.prefix, is_endmarker=False):
2517         if pc.value in FMT_OFF:
2518             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2519             return
2520
2521     check_lpar = False
2522     for index, child in enumerate(list(node.children)):
2523         if check_lpar:
2524             if child.type == syms.atom:
2525                 maybe_make_parens_invisible_in_atom(child)
2526             elif is_one_tuple(child):
2527                 # wrap child in visible parentheses
2528                 lpar = Leaf(token.LPAR, "(")
2529                 rpar = Leaf(token.RPAR, ")")
2530                 child.remove()
2531                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2532             elif node.type == syms.import_from:
2533                 # "import from" nodes store parentheses directly as part of
2534                 # the statement
2535                 if child.type == token.LPAR:
2536                     # make parentheses invisible
2537                     child.value = ""  # type: ignore
2538                     node.children[-1].value = ""  # type: ignore
2539                 elif child.type != token.STAR:
2540                     # insert invisible parentheses
2541                     node.insert_child(index, Leaf(token.LPAR, ""))
2542                     node.append_child(Leaf(token.RPAR, ""))
2543                 break
2544
2545             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2546                 # wrap child in invisible parentheses
2547                 lpar = Leaf(token.LPAR, "")
2548                 rpar = Leaf(token.RPAR, "")
2549                 index = child.remove() or 0
2550                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2551
2552         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2553
2554
2555 def normalize_fmt_off(node: Node) -> None:
2556     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2557     try_again = True
2558     while try_again:
2559         try_again = convert_one_fmt_off_pair(node)
2560
2561
2562 def convert_one_fmt_off_pair(node: Node) -> bool:
2563     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2564
2565     Returns True if a pair was converted.
2566     """
2567     for leaf in node.leaves():
2568         previous_consumed = 0
2569         for comment in list_comments(leaf.prefix, is_endmarker=False):
2570             if comment.value in FMT_OFF:
2571                 # We only want standalone comments. If there's no previous leaf or
2572                 # the previous leaf is indentation, it's a standalone comment in
2573                 # disguise.
2574                 if comment.type != STANDALONE_COMMENT:
2575                     prev = preceding_leaf(leaf)
2576                     if prev and prev.type not in WHITESPACE:
2577                         continue
2578
2579                 ignored_nodes = list(generate_ignored_nodes(leaf))
2580                 if not ignored_nodes:
2581                     continue
2582
2583                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2584                 parent = first.parent
2585                 prefix = first.prefix
2586                 first.prefix = prefix[comment.consumed :]
2587                 hidden_value = (
2588                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2589                 )
2590                 if hidden_value.endswith("\n"):
2591                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2592                     # leaf (possibly followed by a DEDENT).
2593                     hidden_value = hidden_value[:-1]
2594                 first_idx = None
2595                 for ignored in ignored_nodes:
2596                     index = ignored.remove()
2597                     if first_idx is None:
2598                         first_idx = index
2599                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2600                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2601                 parent.insert_child(
2602                     first_idx,
2603                     Leaf(
2604                         STANDALONE_COMMENT,
2605                         hidden_value,
2606                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2607                     ),
2608                 )
2609                 return True
2610
2611             previous_consumed = comment.consumed
2612
2613     return False
2614
2615
2616 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2617     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2618
2619     Stops at the end of the block.
2620     """
2621     container: Optional[LN] = container_of(leaf)
2622     while container is not None and container.type != token.ENDMARKER:
2623         for comment in list_comments(container.prefix, is_endmarker=False):
2624             if comment.value in FMT_ON:
2625                 return
2626
2627         yield container
2628
2629         container = container.next_sibling
2630
2631
2632 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2633     """If it's safe, make the parens in the atom `node` invisible, recursively."""
2634     if (
2635         node.type != syms.atom
2636         or is_empty_tuple(node)
2637         or is_one_tuple(node)
2638         or is_yield(node)
2639         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2640     ):
2641         return False
2642
2643     first = node.children[0]
2644     last = node.children[-1]
2645     if first.type == token.LPAR and last.type == token.RPAR:
2646         # make parentheses invisible
2647         first.value = ""  # type: ignore
2648         last.value = ""  # type: ignore
2649         if len(node.children) > 1:
2650             maybe_make_parens_invisible_in_atom(node.children[1])
2651         return True
2652
2653     return False
2654
2655
2656 def is_empty_tuple(node: LN) -> bool:
2657     """Return True if `node` holds an empty tuple."""
2658     return (
2659         node.type == syms.atom
2660         and len(node.children) == 2
2661         and node.children[0].type == token.LPAR
2662         and node.children[1].type == token.RPAR
2663     )
2664
2665
2666 def is_one_tuple(node: LN) -> bool:
2667     """Return True if `node` holds a tuple with one element, with or without parens."""
2668     if node.type == syms.atom:
2669         if len(node.children) != 3:
2670             return False
2671
2672         lpar, gexp, rpar = node.children
2673         if not (
2674             lpar.type == token.LPAR
2675             and gexp.type == syms.testlist_gexp
2676             and rpar.type == token.RPAR
2677         ):
2678             return False
2679
2680         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2681
2682     return (
2683         node.type in IMPLICIT_TUPLE
2684         and len(node.children) == 2
2685         and node.children[1].type == token.COMMA
2686     )
2687
2688
2689 def is_yield(node: LN) -> bool:
2690     """Return True if `node` holds a `yield` or `yield from` expression."""
2691     if node.type == syms.yield_expr:
2692         return True
2693
2694     if node.type == token.NAME and node.value == "yield":  # type: ignore
2695         return True
2696
2697     if node.type != syms.atom:
2698         return False
2699
2700     if len(node.children) != 3:
2701         return False
2702
2703     lpar, expr, rpar = node.children
2704     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2705         return is_yield(expr)
2706
2707     return False
2708
2709
2710 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2711     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2712
2713     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2714     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2715     extended iterable unpacking (PEP 3132) and additional unpacking
2716     generalizations (PEP 448).
2717     """
2718     if leaf.type not in STARS or not leaf.parent:
2719         return False
2720
2721     p = leaf.parent
2722     if p.type == syms.star_expr:
2723         # Star expressions are also used as assignment targets in extended
2724         # iterable unpacking (PEP 3132).  See what its parent is instead.
2725         if not p.parent:
2726             return False
2727
2728         p = p.parent
2729
2730     return p.type in within
2731
2732
2733 def is_multiline_string(leaf: Leaf) -> bool:
2734     """Return True if `leaf` is a multiline string that actually spans many lines."""
2735     value = leaf.value.lstrip("furbFURB")
2736     return value[:3] in {'"""', "'''"} and "\n" in value
2737
2738
2739 def is_stub_suite(node: Node) -> bool:
2740     """Return True if `node` is a suite with a stub body."""
2741     if (
2742         len(node.children) != 4
2743         or node.children[0].type != token.NEWLINE
2744         or node.children[1].type != token.INDENT
2745         or node.children[3].type != token.DEDENT
2746     ):
2747         return False
2748
2749     return is_stub_body(node.children[2])
2750
2751
2752 def is_stub_body(node: LN) -> bool:
2753     """Return True if `node` is a simple statement containing an ellipsis."""
2754     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2755         return False
2756
2757     if len(node.children) != 2:
2758         return False
2759
2760     child = node.children[0]
2761     return (
2762         child.type == syms.atom
2763         and len(child.children) == 3
2764         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2765     )
2766
2767
2768 def max_delimiter_priority_in_atom(node: LN) -> int:
2769     """Return maximum delimiter priority inside `node`.
2770
2771     This is specific to atoms with contents contained in a pair of parentheses.
2772     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2773     """
2774     if node.type != syms.atom:
2775         return 0
2776
2777     first = node.children[0]
2778     last = node.children[-1]
2779     if not (first.type == token.LPAR and last.type == token.RPAR):
2780         return 0
2781
2782     bt = BracketTracker()
2783     for c in node.children[1:-1]:
2784         if isinstance(c, Leaf):
2785             bt.mark(c)
2786         else:
2787             for leaf in c.leaves():
2788                 bt.mark(leaf)
2789     try:
2790         return bt.max_delimiter_priority()
2791
2792     except ValueError:
2793         return 0
2794
2795
2796 def ensure_visible(leaf: Leaf) -> None:
2797     """Make sure parentheses are visible.
2798
2799     They could be invisible as part of some statements (see
2800     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2801     """
2802     if leaf.type == token.LPAR:
2803         leaf.value = "("
2804     elif leaf.type == token.RPAR:
2805         leaf.value = ")"
2806
2807
2808 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2809     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2810     if not (
2811         opening_bracket.parent
2812         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2813         and opening_bracket.value in "[{("
2814     ):
2815         return False
2816
2817     try:
2818         last_leaf = line.leaves[-1]
2819         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2820         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2821     except (IndexError, ValueError):
2822         return False
2823
2824     return max_priority == COMMA_PRIORITY
2825
2826
2827 def is_python36(node: Node) -> bool:
2828     """Return True if the current file is using Python 3.6+ features.
2829
2830     Currently looking for:
2831     - f-strings; and
2832     - trailing commas after * or ** in function signatures and calls.
2833     """
2834     for n in node.pre_order():
2835         if n.type == token.STRING:
2836             value_head = n.value[:2]  # type: ignore
2837             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2838                 return True
2839
2840         elif (
2841             n.type in {syms.typedargslist, syms.arglist}
2842             and n.children
2843             and n.children[-1].type == token.COMMA
2844         ):
2845             for ch in n.children:
2846                 if ch.type in STARS:
2847                     return True
2848
2849                 if ch.type == syms.argument:
2850                     for argch in ch.children:
2851                         if argch.type in STARS:
2852                             return True
2853
2854     return False
2855
2856
2857 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2858     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2859
2860     Brackets can be omitted if the entire trailer up to and including
2861     a preceding closing bracket fits in one line.
2862
2863     Yielded sets are cumulative (contain results of previous yields, too).  First
2864     set is empty.
2865     """
2866
2867     omit: Set[LeafID] = set()
2868     yield omit
2869
2870     length = 4 * line.depth
2871     opening_bracket = None
2872     closing_bracket = None
2873     optional_brackets: Set[LeafID] = set()
2874     inner_brackets: Set[LeafID] = set()
2875     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2876         length += leaf_length
2877         if length > line_length:
2878             break
2879
2880         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2881         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2882             break
2883
2884         optional_brackets.discard(id(leaf))
2885         if opening_bracket:
2886             if leaf is opening_bracket:
2887                 opening_bracket = None
2888             elif leaf.type in CLOSING_BRACKETS:
2889                 inner_brackets.add(id(leaf))
2890         elif leaf.type in CLOSING_BRACKETS:
2891             if not leaf.value:
2892                 optional_brackets.add(id(opening_bracket))
2893                 continue
2894
2895             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2896                 # Empty brackets would fail a split so treat them as "inner"
2897                 # brackets (e.g. only add them to the `omit` set if another
2898                 # pair of brackets was good enough.
2899                 inner_brackets.add(id(leaf))
2900                 continue
2901
2902             opening_bracket = leaf.opening_bracket
2903             if closing_bracket:
2904                 omit.add(id(closing_bracket))
2905                 omit.update(inner_brackets)
2906                 inner_brackets.clear()
2907                 yield omit
2908             closing_bracket = leaf
2909
2910
2911 def get_future_imports(node: Node) -> Set[str]:
2912     """Return a set of __future__ imports in the file."""
2913     imports = set()
2914     for child in node.children:
2915         if child.type != syms.simple_stmt:
2916             break
2917         first_child = child.children[0]
2918         if isinstance(first_child, Leaf):
2919             # Continue looking if we see a docstring; otherwise stop.
2920             if (
2921                 len(child.children) == 2
2922                 and first_child.type == token.STRING
2923                 and child.children[1].type == token.NEWLINE
2924             ):
2925                 continue
2926             else:
2927                 break
2928         elif first_child.type == syms.import_from:
2929             module_name = first_child.children[1]
2930             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2931                 break
2932             for import_from_child in first_child.children[3:]:
2933                 if isinstance(import_from_child, Leaf):
2934                     if import_from_child.type == token.NAME:
2935                         imports.add(import_from_child.value)
2936                 else:
2937                     assert import_from_child.type == syms.import_as_names
2938                     for leaf in import_from_child.children:
2939                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2940                             imports.add(leaf.value)
2941         else:
2942             break
2943     return imports
2944
2945
2946 def gen_python_files_in_dir(
2947     path: Path,
2948     root: Path,
2949     include: Pattern[str],
2950     exclude: Pattern[str],
2951     report: "Report",
2952 ) -> Iterator[Path]:
2953     """Generate all files under `path` whose paths are not excluded by the
2954     `exclude` regex, but are included by the `include` regex.
2955
2956     Symbolic links pointing outside of the `root` directory are ignored.
2957
2958     `report` is where output about exclusions goes.
2959     """
2960     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2961     for child in path.iterdir():
2962         try:
2963             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2964         except ValueError:
2965             if child.is_symlink():
2966                 report.path_ignored(
2967                     child, f"is a symbolic link that points outside {root}"
2968                 )
2969                 continue
2970
2971             raise
2972
2973         if child.is_dir():
2974             normalized_path += "/"
2975         exclude_match = exclude.search(normalized_path)
2976         if exclude_match and exclude_match.group(0):
2977             report.path_ignored(child, f"matches the --exclude regular expression")
2978             continue
2979
2980         if child.is_dir():
2981             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2982
2983         elif child.is_file():
2984             include_match = include.search(normalized_path)
2985             if include_match:
2986                 yield child
2987
2988
2989 @lru_cache()
2990 def find_project_root(srcs: Iterable[str]) -> Path:
2991     """Return a directory containing .git, .hg, or pyproject.toml.
2992
2993     That directory can be one of the directories passed in `srcs` or their
2994     common parent.
2995
2996     If no directory in the tree contains a marker that would specify it's the
2997     project root, the root of the file system is returned.
2998     """
2999     if not srcs:
3000         return Path("/").resolve()
3001
3002     common_base = min(Path(src).resolve() for src in srcs)
3003     if common_base.is_dir():
3004         # Append a fake file so `parents` below returns `common_base_dir`, too.
3005         common_base /= "fake-file"
3006     for directory in common_base.parents:
3007         if (directory / ".git").is_dir():
3008             return directory
3009
3010         if (directory / ".hg").is_dir():
3011             return directory
3012
3013         if (directory / "pyproject.toml").is_file():
3014             return directory
3015
3016     return directory
3017
3018
3019 @dataclass
3020 class Report:
3021     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3022
3023     check: bool = False
3024     quiet: bool = False
3025     verbose: bool = False
3026     change_count: int = 0
3027     same_count: int = 0
3028     failure_count: int = 0
3029
3030     def done(self, src: Path, changed: Changed) -> None:
3031         """Increment the counter for successful reformatting. Write out a message."""
3032         if changed is Changed.YES:
3033             reformatted = "would reformat" if self.check else "reformatted"
3034             if self.verbose or not self.quiet:
3035                 out(f"{reformatted} {src}")
3036             self.change_count += 1
3037         else:
3038             if self.verbose:
3039                 if changed is Changed.NO:
3040                     msg = f"{src} already well formatted, good job."
3041                 else:
3042                     msg = f"{src} wasn't modified on disk since last run."
3043                 out(msg, bold=False)
3044             self.same_count += 1
3045
3046     def failed(self, src: Path, message: str) -> None:
3047         """Increment the counter for failed reformatting. Write out a message."""
3048         err(f"error: cannot format {src}: {message}")
3049         self.failure_count += 1
3050
3051     def path_ignored(self, path: Path, message: str) -> None:
3052         if self.verbose:
3053             out(f"{path} ignored: {message}", bold=False)
3054
3055     @property
3056     def return_code(self) -> int:
3057         """Return the exit code that the app should use.
3058
3059         This considers the current state of changed files and failures:
3060         - if there were any failures, return 123;
3061         - if any files were changed and --check is being used, return 1;
3062         - otherwise return 0.
3063         """
3064         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3065         # 126 we have special returncodes reserved by the shell.
3066         if self.failure_count:
3067             return 123
3068
3069         elif self.change_count and self.check:
3070             return 1
3071
3072         return 0
3073
3074     def __str__(self) -> str:
3075         """Render a color report of the current state.
3076
3077         Use `click.unstyle` to remove colors.
3078         """
3079         if self.check:
3080             reformatted = "would be reformatted"
3081             unchanged = "would be left unchanged"
3082             failed = "would fail to reformat"
3083         else:
3084             reformatted = "reformatted"
3085             unchanged = "left unchanged"
3086             failed = "failed to reformat"
3087         report = []
3088         if self.change_count:
3089             s = "s" if self.change_count > 1 else ""
3090             report.append(
3091                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3092             )
3093         if self.same_count:
3094             s = "s" if self.same_count > 1 else ""
3095             report.append(f"{self.same_count} file{s} {unchanged}")
3096         if self.failure_count:
3097             s = "s" if self.failure_count > 1 else ""
3098             report.append(
3099                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3100             )
3101         return ", ".join(report) + "."
3102
3103
3104 def assert_equivalent(src: str, dst: str) -> None:
3105     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3106
3107     import ast
3108     import traceback
3109
3110     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3111         """Simple visitor generating strings to compare ASTs by content."""
3112         yield f"{'  ' * depth}{node.__class__.__name__}("
3113
3114         for field in sorted(node._fields):
3115             try:
3116                 value = getattr(node, field)
3117             except AttributeError:
3118                 continue
3119
3120             yield f"{'  ' * (depth+1)}{field}="
3121
3122             if isinstance(value, list):
3123                 for item in value:
3124                     if isinstance(item, ast.AST):
3125                         yield from _v(item, depth + 2)
3126
3127             elif isinstance(value, ast.AST):
3128                 yield from _v(value, depth + 2)
3129
3130             else:
3131                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3132
3133         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3134
3135     try:
3136         src_ast = ast.parse(src)
3137     except Exception as exc:
3138         major, minor = sys.version_info[:2]
3139         raise AssertionError(
3140             f"cannot use --safe with this file; failed to parse source file "
3141             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3142             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3143         )
3144
3145     try:
3146         dst_ast = ast.parse(dst)
3147     except Exception as exc:
3148         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3149         raise AssertionError(
3150             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3151             f"Please report a bug on https://github.com/ambv/black/issues.  "
3152             f"This invalid output might be helpful: {log}"
3153         ) from None
3154
3155     src_ast_str = "\n".join(_v(src_ast))
3156     dst_ast_str = "\n".join(_v(dst_ast))
3157     if src_ast_str != dst_ast_str:
3158         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3159         raise AssertionError(
3160             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3161             f"the source.  "
3162             f"Please report a bug on https://github.com/ambv/black/issues.  "
3163             f"This diff might be helpful: {log}"
3164         ) from None
3165
3166
3167 def assert_stable(
3168     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3169 ) -> None:
3170     """Raise AssertionError if `dst` reformats differently the second time."""
3171     newdst = format_str(dst, line_length=line_length, mode=mode)
3172     if dst != newdst:
3173         log = dump_to_file(
3174             diff(src, dst, "source", "first pass"),
3175             diff(dst, newdst, "first pass", "second pass"),
3176         )
3177         raise AssertionError(
3178             f"INTERNAL ERROR: Black produced different code on the second pass "
3179             f"of the formatter.  "
3180             f"Please report a bug on https://github.com/ambv/black/issues.  "
3181             f"This diff might be helpful: {log}"
3182         ) from None
3183
3184
3185 def dump_to_file(*output: str) -> str:
3186     """Dump `output` to a temporary file. Return path to the file."""
3187     import tempfile
3188
3189     with tempfile.NamedTemporaryFile(
3190         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3191     ) as f:
3192         for lines in output:
3193             f.write(lines)
3194             if lines and lines[-1] != "\n":
3195                 f.write("\n")
3196     return f.name
3197
3198
3199 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3200     """Return a unified diff string between strings `a` and `b`."""
3201     import difflib
3202
3203     a_lines = [line + "\n" for line in a.split("\n")]
3204     b_lines = [line + "\n" for line in b.split("\n")]
3205     return "".join(
3206         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3207     )
3208
3209
3210 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3211     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3212     err("Aborted!")
3213     for task in tasks:
3214         task.cancel()
3215
3216
3217 def shutdown(loop: BaseEventLoop) -> None:
3218     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3219     try:
3220         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3221         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3222         if not to_cancel:
3223             return
3224
3225         for task in to_cancel:
3226             task.cancel()
3227         loop.run_until_complete(
3228             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3229         )
3230     finally:
3231         # `concurrent.futures.Future` objects cannot be cancelled once they
3232         # are already running. There might be some when the `shutdown()` happened.
3233         # Silence their logger's spew about the event loop being closed.
3234         cf_logger = logging.getLogger("concurrent.futures")
3235         cf_logger.setLevel(logging.CRITICAL)
3236         loop.close()
3237
3238
3239 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3240     """Replace `regex` with `replacement` twice on `original`.
3241
3242     This is used by string normalization to perform replaces on
3243     overlapping matches.
3244     """
3245     return regex.sub(replacement, regex.sub(replacement, original))
3246
3247
3248 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3249     """Compile a regular expression string in `regex`.
3250
3251     If it contains newlines, use verbose mode.
3252     """
3253     if "\n" in regex:
3254         regex = "(?x)" + regex
3255     return re.compile(regex)
3256
3257
3258 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3259     """Like `reversed(enumerate(sequence))` if that were possible."""
3260     index = len(sequence) - 1
3261     for element in reversed(sequence):
3262         yield (index, element)
3263         index -= 1
3264
3265
3266 def enumerate_with_length(
3267     line: Line, reversed: bool = False
3268 ) -> Iterator[Tuple[Index, Leaf, int]]:
3269     """Return an enumeration of leaves with their length.
3270
3271     Stops prematurely on multiline strings and standalone comments.
3272     """
3273     op = cast(
3274         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3275         enumerate_reversed if reversed else enumerate,
3276     )
3277     for index, leaf in op(line.leaves):
3278         length = len(leaf.prefix) + len(leaf.value)
3279         if "\n" in leaf.value:
3280             return  # Multiline strings, we can't continue.
3281
3282         comment: Optional[Leaf]
3283         for comment in line.comments_after(leaf, index):
3284             length += len(comment.value)
3285
3286         yield index, leaf, length
3287
3288
3289 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3290     """Return True if `line` is no longer than `line_length`.
3291
3292     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3293     """
3294     if not line_str:
3295         line_str = str(line).strip("\n")
3296     return (
3297         len(line_str) <= line_length
3298         and "\n" not in line_str  # multiline strings
3299         and not line.contains_standalone_comments()
3300     )
3301
3302
3303 def can_be_split(line: Line) -> bool:
3304     """Return False if the line cannot be split *for sure*.
3305
3306     This is not an exhaustive search but a cheap heuristic that we can use to
3307     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3308     in unnecessary parentheses).
3309     """
3310     leaves = line.leaves
3311     if len(leaves) < 2:
3312         return False
3313
3314     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3315         call_count = 0
3316         dot_count = 0
3317         next = leaves[-1]
3318         for leaf in leaves[-2::-1]:
3319             if leaf.type in OPENING_BRACKETS:
3320                 if next.type not in CLOSING_BRACKETS:
3321                     return False
3322
3323                 call_count += 1
3324             elif leaf.type == token.DOT:
3325                 dot_count += 1
3326             elif leaf.type == token.NAME:
3327                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3328                     return False
3329
3330             elif leaf.type not in CLOSING_BRACKETS:
3331                 return False
3332
3333             if dot_count > 1 and call_count > 1:
3334                 return False
3335
3336     return True
3337
3338
3339 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3340     """Does `line` have a shape safe to reformat without optional parens around it?
3341
3342     Returns True for only a subset of potentially nice looking formattings but
3343     the point is to not return false positives that end up producing lines that
3344     are too long.
3345     """
3346     bt = line.bracket_tracker
3347     if not bt.delimiters:
3348         # Without delimiters the optional parentheses are useless.
3349         return True
3350
3351     max_priority = bt.max_delimiter_priority()
3352     if bt.delimiter_count_with_priority(max_priority) > 1:
3353         # With more than one delimiter of a kind the optional parentheses read better.
3354         return False
3355
3356     if max_priority == DOT_PRIORITY:
3357         # A single stranded method call doesn't require optional parentheses.
3358         return True
3359
3360     assert len(line.leaves) >= 2, "Stranded delimiter"
3361
3362     first = line.leaves[0]
3363     second = line.leaves[1]
3364     penultimate = line.leaves[-2]
3365     last = line.leaves[-1]
3366
3367     # With a single delimiter, omit if the expression starts or ends with
3368     # a bracket.
3369     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3370         remainder = False
3371         length = 4 * line.depth
3372         for _index, leaf, leaf_length in enumerate_with_length(line):
3373             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3374                 remainder = True
3375             if remainder:
3376                 length += leaf_length
3377                 if length > line_length:
3378                     break
3379
3380                 if leaf.type in OPENING_BRACKETS:
3381                     # There are brackets we can further split on.
3382                     remainder = False
3383
3384         else:
3385             # checked the entire string and line length wasn't exceeded
3386             if len(line.leaves) == _index + 1:
3387                 return True
3388
3389         # Note: we are not returning False here because a line might have *both*
3390         # a leading opening bracket and a trailing closing bracket.  If the
3391         # opening bracket doesn't match our rule, maybe the closing will.
3392
3393     if (
3394         last.type == token.RPAR
3395         or last.type == token.RBRACE
3396         or (
3397             # don't use indexing for omitting optional parentheses;
3398             # it looks weird
3399             last.type == token.RSQB
3400             and last.parent
3401             and last.parent.type != syms.trailer
3402         )
3403     ):
3404         if penultimate.type in OPENING_BRACKETS:
3405             # Empty brackets don't help.
3406             return False
3407
3408         if is_multiline_string(first):
3409             # Additional wrapping of a multiline string in this situation is
3410             # unnecessary.
3411             return True
3412
3413         length = 4 * line.depth
3414         seen_other_brackets = False
3415         for _index, leaf, leaf_length in enumerate_with_length(line):
3416             length += leaf_length
3417             if leaf is last.opening_bracket:
3418                 if seen_other_brackets or length <= line_length:
3419                     return True
3420
3421             elif leaf.type in OPENING_BRACKETS:
3422                 # There are brackets we can further split on.
3423                 seen_other_brackets = True
3424
3425     return False
3426
3427
3428 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3429     return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
3430
3431
3432 def read_cache(line_length: int, mode: FileMode) -> Cache:
3433     """Read the cache if it exists and is well formed.
3434
3435     If it is not well formed, the call to write_cache later should resolve the issue.
3436     """
3437     cache_file = get_cache_file(line_length, mode)
3438     if not cache_file.exists():
3439         return {}
3440
3441     with cache_file.open("rb") as fobj:
3442         try:
3443             cache: Cache = pickle.load(fobj)
3444         except pickle.UnpicklingError:
3445             return {}
3446
3447     return cache
3448
3449
3450 def get_cache_info(path: Path) -> CacheInfo:
3451     """Return the information used to check if a file is already formatted or not."""
3452     stat = path.stat()
3453     return stat.st_mtime, stat.st_size
3454
3455
3456 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3457     """Split an iterable of paths in `sources` into two sets.
3458
3459     The first contains paths of files that modified on disk or are not in the
3460     cache. The other contains paths to non-modified files.
3461     """
3462     todo, done = set(), set()
3463     for src in sources:
3464         src = src.resolve()
3465         if cache.get(src) != get_cache_info(src):
3466             todo.add(src)
3467         else:
3468             done.add(src)
3469     return todo, done
3470
3471
3472 def write_cache(
3473     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3474 ) -> None:
3475     """Update the cache file."""
3476     cache_file = get_cache_file(line_length, mode)
3477     try:
3478         if not CACHE_DIR.exists():
3479             CACHE_DIR.mkdir(parents=True)
3480         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3481         with cache_file.open("wb") as fobj:
3482             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3483     except OSError:
3484         pass
3485
3486
3487 def patch_click() -> None:
3488     """Make Click not crash.
3489
3490     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3491     default which restricts paths that it can access during the lifetime of the
3492     application.  Click refuses to work in this scenario by raising a RuntimeError.
3493
3494     In case of Black the likelihood that non-ASCII characters are going to be used in
3495     file paths is minimal since it's Python source code.  Moreover, this crash was
3496     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3497     """
3498     try:
3499         from click import core
3500         from click import _unicodefun  # type: ignore
3501     except ModuleNotFoundError:
3502         return
3503
3504     for module in (core, _unicodefun):
3505         if hasattr(module, "_verify_python3_env"):
3506             module._verify_python3_env = lambda: None
3507
3508
3509 if __name__ == "__main__":
3510     patch_click()
3511     main()