]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use new github token for appveyor release
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "19.3b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PY27 = 2
117     PY33 = 3
118     PY34 = 4
119     PY35 = 5
120     PY36 = 6
121     PY37 = 7
122     PY38 = 8
123
124     def is_python2(self) -> bool:
125         return self is TargetVersion.PY27
126
127
128 PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
129
130
131 class Feature(Enum):
132     # All string literals are unicode
133     UNICODE_LITERALS = 1
134     F_STRINGS = 2
135     NUMERIC_UNDERSCORES = 3
136     TRAILING_COMMA = 4
137
138
139 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
140     TargetVersion.PY27: set(),
141     TargetVersion.PY33: {Feature.UNICODE_LITERALS},
142     TargetVersion.PY34: {Feature.UNICODE_LITERALS},
143     TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
144     TargetVersion.PY36: {
145         Feature.UNICODE_LITERALS,
146         Feature.F_STRINGS,
147         Feature.NUMERIC_UNDERSCORES,
148         Feature.TRAILING_COMMA,
149     },
150     TargetVersion.PY37: {
151         Feature.UNICODE_LITERALS,
152         Feature.F_STRINGS,
153         Feature.NUMERIC_UNDERSCORES,
154         Feature.TRAILING_COMMA,
155     },
156     TargetVersion.PY38: {
157         Feature.UNICODE_LITERALS,
158         Feature.F_STRINGS,
159         Feature.NUMERIC_UNDERSCORES,
160         Feature.TRAILING_COMMA,
161     },
162 }
163
164
165 @dataclass
166 class FileMode:
167     target_versions: Set[TargetVersion] = Factory(set)
168     line_length: int = DEFAULT_LINE_LENGTH
169     string_normalization: bool = True
170     is_pyi: bool = False
171
172     def get_cache_key(self) -> str:
173         if self.target_versions:
174             version_str = ",".join(
175                 str(version.value)
176                 for version in sorted(self.target_versions, key=lambda v: v.value)
177             )
178         else:
179             version_str = "-"
180         parts = [
181             version_str,
182             str(self.line_length),
183             str(int(self.string_normalization)),
184             str(int(self.is_pyi)),
185         ]
186         return ".".join(parts)
187
188
189 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
190     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
191
192
193 def read_pyproject_toml(
194     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
195 ) -> Optional[str]:
196     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
197
198     Returns the path to a successfully found and read configuration file, None
199     otherwise.
200     """
201     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
202     if not value:
203         root = find_project_root(ctx.params.get("src", ()))
204         path = root / "pyproject.toml"
205         if path.is_file():
206             value = str(path)
207         else:
208             return None
209
210     try:
211         pyproject_toml = toml.load(value)
212         config = pyproject_toml.get("tool", {}).get("black", {})
213     except (toml.TomlDecodeError, OSError) as e:
214         raise click.FileError(
215             filename=value, hint=f"Error reading configuration file: {e}"
216         )
217
218     if not config:
219         return None
220
221     if ctx.default_map is None:
222         ctx.default_map = {}
223     ctx.default_map.update(  # type: ignore  # bad types in .pyi
224         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
225     )
226     return value
227
228
229 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
230 @click.option(
231     "-l",
232     "--line-length",
233     type=int,
234     default=DEFAULT_LINE_LENGTH,
235     help="How many characters per line to allow.",
236     show_default=True,
237 )
238 @click.option(
239     "-t",
240     "--target-version",
241     type=click.Choice([v.name.lower() for v in TargetVersion]),
242     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
243     multiple=True,
244     help=(
245         "Python versions that should be supported by Black's output. [default: "
246         "per-file auto-detection]"
247     ),
248 )
249 @click.option(
250     "--py36",
251     is_flag=True,
252     help=(
253         "Allow using Python 3.6-only syntax on all input files.  This will put "
254         "trailing commas in function signatures and calls also after *args and "
255         "**kwargs. Deprecated; use --target-version instead. "
256         "[default: per-file auto-detection]"
257     ),
258 )
259 @click.option(
260     "--pyi",
261     is_flag=True,
262     help=(
263         "Format all input files like typing stubs regardless of file extension "
264         "(useful when piping source on standard input)."
265     ),
266 )
267 @click.option(
268     "-S",
269     "--skip-string-normalization",
270     is_flag=True,
271     help="Don't normalize string quotes or prefixes.",
272 )
273 @click.option(
274     "--check",
275     is_flag=True,
276     help=(
277         "Don't write the files back, just return the status.  Return code 0 "
278         "means nothing would change.  Return code 1 means some files would be "
279         "reformatted.  Return code 123 means there was an internal error."
280     ),
281 )
282 @click.option(
283     "--diff",
284     is_flag=True,
285     help="Don't write the files back, just output a diff for each file on stdout.",
286 )
287 @click.option(
288     "--fast/--safe",
289     is_flag=True,
290     help="If --fast given, skip temporary sanity checks. [default: --safe]",
291 )
292 @click.option(
293     "--include",
294     type=str,
295     default=DEFAULT_INCLUDES,
296     help=(
297         "A regular expression that matches files and directories that should be "
298         "included on recursive searches.  An empty value means all files are "
299         "included regardless of the name.  Use forward slashes for directories on "
300         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
301         "later."
302     ),
303     show_default=True,
304 )
305 @click.option(
306     "--exclude",
307     type=str,
308     default=DEFAULT_EXCLUDES,
309     help=(
310         "A regular expression that matches files and directories that should be "
311         "excluded on recursive searches.  An empty value means no paths are excluded. "
312         "Use forward slashes for directories on all platforms (Windows, too).  "
313         "Exclusions are calculated first, inclusions later."
314     ),
315     show_default=True,
316 )
317 @click.option(
318     "-q",
319     "--quiet",
320     is_flag=True,
321     help=(
322         "Don't emit non-error messages to stderr. Errors are still emitted, "
323         "silence those with 2>/dev/null."
324     ),
325 )
326 @click.option(
327     "-v",
328     "--verbose",
329     is_flag=True,
330     help=(
331         "Also emit messages to stderr about files that were not changed or were "
332         "ignored due to --exclude=."
333     ),
334 )
335 @click.version_option(version=__version__)
336 @click.argument(
337     "src",
338     nargs=-1,
339     type=click.Path(
340         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
341     ),
342     is_eager=True,
343 )
344 @click.option(
345     "--config",
346     type=click.Path(
347         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
348     ),
349     is_eager=True,
350     callback=read_pyproject_toml,
351     help="Read configuration from PATH.",
352 )
353 @click.pass_context
354 def main(
355     ctx: click.Context,
356     line_length: int,
357     target_version: List[TargetVersion],
358     check: bool,
359     diff: bool,
360     fast: bool,
361     pyi: bool,
362     py36: bool,
363     skip_string_normalization: bool,
364     quiet: bool,
365     verbose: bool,
366     include: str,
367     exclude: str,
368     src: Tuple[str],
369     config: Optional[str],
370 ) -> None:
371     """The uncompromising code formatter."""
372     write_back = WriteBack.from_configuration(check=check, diff=diff)
373     if target_version:
374         if py36:
375             err(f"Cannot use both --target-version and --py36")
376             ctx.exit(2)
377         else:
378             versions = set(target_version)
379     elif py36:
380         err(
381             "--py36 is deprecated and will be removed in a future version. "
382             "Use --target-version py36 instead."
383         )
384         versions = PY36_VERSIONS
385     else:
386         # We'll autodetect later.
387         versions = set()
388     mode = FileMode(
389         target_versions=versions,
390         line_length=line_length,
391         is_pyi=pyi,
392         string_normalization=not skip_string_normalization,
393     )
394     if config and verbose:
395         out(f"Using configuration from {config}.", bold=False, fg="blue")
396     try:
397         include_regex = re_compile_maybe_verbose(include)
398     except re.error:
399         err(f"Invalid regular expression for include given: {include!r}")
400         ctx.exit(2)
401     try:
402         exclude_regex = re_compile_maybe_verbose(exclude)
403     except re.error:
404         err(f"Invalid regular expression for exclude given: {exclude!r}")
405         ctx.exit(2)
406     report = Report(check=check, quiet=quiet, verbose=verbose)
407     root = find_project_root(src)
408     sources: Set[Path] = set()
409     for s in src:
410         p = Path(s)
411         if p.is_dir():
412             sources.update(
413                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
414             )
415         elif p.is_file() or s == "-":
416             # if a file was explicitly given, we don't care about its extension
417             sources.add(p)
418         else:
419             err(f"invalid path: {s}")
420     if len(sources) == 0:
421         if verbose or not quiet:
422             out("No paths given. Nothing to do 😴")
423         ctx.exit(0)
424
425     if len(sources) == 1:
426         reformat_one(
427             src=sources.pop(),
428             fast=fast,
429             write_back=write_back,
430             mode=mode,
431             report=report,
432         )
433     else:
434         loop = asyncio.get_event_loop()
435         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
436         try:
437             loop.run_until_complete(
438                 schedule_formatting(
439                     sources=sources,
440                     fast=fast,
441                     write_back=write_back,
442                     mode=mode,
443                     report=report,
444                     loop=loop,
445                     executor=executor,
446                 )
447             )
448         finally:
449             shutdown(loop)
450     if verbose or not quiet:
451         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
452         out(f"All done! {bang}")
453         click.secho(str(report), err=True)
454     ctx.exit(report.return_code)
455
456
457 def reformat_one(
458     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
459 ) -> None:
460     """Reformat a single file under `src` without spawning child processes.
461
462     If `quiet` is True, non-error messages are not output. `line_length`,
463     `write_back`, `fast` and `pyi` options are passed to
464     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
465     """
466     try:
467         changed = Changed.NO
468         if not src.is_file() and str(src) == "-":
469             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
470                 changed = Changed.YES
471         else:
472             cache: Cache = {}
473             if write_back != WriteBack.DIFF:
474                 cache = read_cache(mode)
475                 res_src = src.resolve()
476                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
477                     changed = Changed.CACHED
478             if changed is not Changed.CACHED and format_file_in_place(
479                 src, fast=fast, write_back=write_back, mode=mode
480             ):
481                 changed = Changed.YES
482             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
483                 write_back is WriteBack.CHECK and changed is Changed.NO
484             ):
485                 write_cache(cache, [src], mode)
486         report.done(src, changed)
487     except Exception as exc:
488         report.failed(src, str(exc))
489
490
491 async def schedule_formatting(
492     sources: Set[Path],
493     fast: bool,
494     write_back: WriteBack,
495     mode: FileMode,
496     report: "Report",
497     loop: BaseEventLoop,
498     executor: Executor,
499 ) -> None:
500     """Run formatting of `sources` in parallel using the provided `executor`.
501
502     (Use ProcessPoolExecutors for actual parallelism.)
503
504     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
505     :func:`format_file_in_place`.
506     """
507     cache: Cache = {}
508     if write_back != WriteBack.DIFF:
509         cache = read_cache(mode)
510         sources, cached = filter_cached(cache, sources)
511         for src in sorted(cached):
512             report.done(src, Changed.CACHED)
513     if not sources:
514         return
515
516     cancelled = []
517     sources_to_cache = []
518     lock = None
519     if write_back == WriteBack.DIFF:
520         # For diff output, we need locks to ensure we don't interleave output
521         # from different processes.
522         manager = Manager()
523         lock = manager.Lock()
524     tasks = {
525         loop.run_in_executor(
526             executor, format_file_in_place, src, fast, mode, write_back, lock
527         ): src
528         for src in sorted(sources)
529     }
530     pending: Iterable[asyncio.Task] = tasks.keys()
531     try:
532         loop.add_signal_handler(signal.SIGINT, cancel, pending)
533         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
534     except NotImplementedError:
535         # There are no good alternatives for these on Windows.
536         pass
537     while pending:
538         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
539         for task in done:
540             src = tasks.pop(task)
541             if task.cancelled():
542                 cancelled.append(task)
543             elif task.exception():
544                 report.failed(src, str(task.exception()))
545             else:
546                 changed = Changed.YES if task.result() else Changed.NO
547                 # If the file was written back or was successfully checked as
548                 # well-formatted, store this information in the cache.
549                 if write_back is WriteBack.YES or (
550                     write_back is WriteBack.CHECK and changed is Changed.NO
551                 ):
552                     sources_to_cache.append(src)
553                 report.done(src, changed)
554     if cancelled:
555         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
556     if sources_to_cache:
557         write_cache(cache, sources_to_cache, mode)
558
559
560 def format_file_in_place(
561     src: Path,
562     fast: bool,
563     mode: FileMode,
564     write_back: WriteBack = WriteBack.NO,
565     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
566 ) -> bool:
567     """Format file under `src` path. Return True if changed.
568
569     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
570     code to the file.
571     `line_length` and `fast` options are passed to :func:`format_file_contents`.
572     """
573     if src.suffix == ".pyi":
574         mode = evolve(mode, is_pyi=True)
575
576     then = datetime.utcfromtimestamp(src.stat().st_mtime)
577     with open(src, "rb") as buf:
578         src_contents, encoding, newline = decode_bytes(buf.read())
579     try:
580         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
581     except NothingChanged:
582         return False
583
584     if write_back == write_back.YES:
585         with open(src, "w", encoding=encoding, newline=newline) as f:
586             f.write(dst_contents)
587     elif write_back == write_back.DIFF:
588         now = datetime.utcnow()
589         src_name = f"{src}\t{then} +0000"
590         dst_name = f"{src}\t{now} +0000"
591         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
592         if lock:
593             lock.acquire()
594         try:
595             f = io.TextIOWrapper(
596                 sys.stdout.buffer,
597                 encoding=encoding,
598                 newline=newline,
599                 write_through=True,
600             )
601             f.write(diff_contents)
602             f.detach()
603         finally:
604             if lock:
605                 lock.release()
606     return True
607
608
609 def format_stdin_to_stdout(
610     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
611 ) -> bool:
612     """Format file on stdin. Return True if changed.
613
614     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
615     write a diff to stdout. The `mode` argument is passed to
616     :func:`format_file_contents`.
617     """
618     then = datetime.utcnow()
619     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
620     dst = src
621     try:
622         dst = format_file_contents(src, fast=fast, mode=mode)
623         return True
624
625     except NothingChanged:
626         return False
627
628     finally:
629         f = io.TextIOWrapper(
630             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
631         )
632         if write_back == WriteBack.YES:
633             f.write(dst)
634         elif write_back == WriteBack.DIFF:
635             now = datetime.utcnow()
636             src_name = f"STDIN\t{then} +0000"
637             dst_name = f"STDOUT\t{now} +0000"
638             f.write(diff(src, dst, src_name, dst_name))
639         f.detach()
640
641
642 def format_file_contents(
643     src_contents: str, *, fast: bool, mode: FileMode
644 ) -> FileContent:
645     """Reformat contents a file and return new contents.
646
647     If `fast` is False, additionally confirm that the reformatted code is
648     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
649     `line_length` is passed to :func:`format_str`.
650     """
651     if src_contents.strip() == "":
652         raise NothingChanged
653
654     dst_contents = format_str(src_contents, mode=mode)
655     if src_contents == dst_contents:
656         raise NothingChanged
657
658     if not fast:
659         assert_equivalent(src_contents, dst_contents)
660         assert_stable(src_contents, dst_contents, mode=mode)
661     return dst_contents
662
663
664 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
665     """Reformat a string and return new contents.
666
667     `line_length` determines how many characters per line are allowed.
668     """
669     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
670     dst_contents = ""
671     future_imports = get_future_imports(src_node)
672     if mode.target_versions:
673         versions = mode.target_versions
674     else:
675         versions = detect_target_versions(src_node)
676     normalize_fmt_off(src_node)
677     lines = LineGenerator(
678         remove_u_prefix="unicode_literals" in future_imports
679         or supports_feature(versions, Feature.UNICODE_LITERALS),
680         is_pyi=mode.is_pyi,
681         normalize_strings=mode.string_normalization,
682     )
683     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
684     empty_line = Line()
685     after = 0
686     for current_line in lines.visit(src_node):
687         for _ in range(after):
688             dst_contents += str(empty_line)
689         before, after = elt.maybe_empty_lines(current_line)
690         for _ in range(before):
691             dst_contents += str(empty_line)
692         for line in split_line(
693             current_line,
694             line_length=mode.line_length,
695             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
696         ):
697             dst_contents += str(line)
698     return dst_contents
699
700
701 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
702     """Return a tuple of (decoded_contents, encoding, newline).
703
704     `newline` is either CRLF or LF but `decoded_contents` is decoded with
705     universal newlines (i.e. only contains LF).
706     """
707     srcbuf = io.BytesIO(src)
708     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
709     if not lines:
710         return "", encoding, "\n"
711
712     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
713     srcbuf.seek(0)
714     with io.TextIOWrapper(srcbuf, encoding) as tiow:
715         return tiow.read(), encoding, newline
716
717
718 GRAMMARS = [
719     pygram.python_grammar_no_print_statement_no_exec_statement,
720     pygram.python_grammar_no_print_statement,
721     pygram.python_grammar,
722 ]
723
724
725 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
726     if not target_versions:
727         return GRAMMARS
728     elif all(not version.is_python2() for version in target_versions):
729         # Python 3-compatible code, so don't try Python 2 grammar
730         return [
731             pygram.python_grammar_no_print_statement_no_exec_statement,
732             pygram.python_grammar_no_print_statement,
733         ]
734     else:
735         return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
736
737
738 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
739     """Given a string with source, return the lib2to3 Node."""
740     if src_txt[-1:] != "\n":
741         src_txt += "\n"
742
743     for grammar in get_grammars(set(target_versions)):
744         drv = driver.Driver(grammar, pytree.convert)
745         try:
746             result = drv.parse_string(src_txt, True)
747             break
748
749         except ParseError as pe:
750             lineno, column = pe.context[1]
751             lines = src_txt.splitlines()
752             try:
753                 faulty_line = lines[lineno - 1]
754             except IndexError:
755                 faulty_line = "<line number missing in source>"
756             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
757     else:
758         raise exc from None
759
760     if isinstance(result, Leaf):
761         result = Node(syms.file_input, [result])
762     return result
763
764
765 def lib2to3_unparse(node: Node) -> str:
766     """Given a lib2to3 node, return its string representation."""
767     code = str(node)
768     return code
769
770
771 T = TypeVar("T")
772
773
774 class Visitor(Generic[T]):
775     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
776
777     def visit(self, node: LN) -> Iterator[T]:
778         """Main method to visit `node` and its children.
779
780         It tries to find a `visit_*()` method for the given `node.type`, like
781         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
782         If no dedicated `visit_*()` method is found, chooses `visit_default()`
783         instead.
784
785         Then yields objects of type `T` from the selected visitor.
786         """
787         if node.type < 256:
788             name = token.tok_name[node.type]
789         else:
790             name = type_repr(node.type)
791         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
792
793     def visit_default(self, node: LN) -> Iterator[T]:
794         """Default `visit_*()` implementation. Recurses to children of `node`."""
795         if isinstance(node, Node):
796             for child in node.children:
797                 yield from self.visit(child)
798
799
800 @dataclass
801 class DebugVisitor(Visitor[T]):
802     tree_depth: int = 0
803
804     def visit_default(self, node: LN) -> Iterator[T]:
805         indent = " " * (2 * self.tree_depth)
806         if isinstance(node, Node):
807             _type = type_repr(node.type)
808             out(f"{indent}{_type}", fg="yellow")
809             self.tree_depth += 1
810             for child in node.children:
811                 yield from self.visit(child)
812
813             self.tree_depth -= 1
814             out(f"{indent}/{_type}", fg="yellow", bold=False)
815         else:
816             _type = token.tok_name.get(node.type, str(node.type))
817             out(f"{indent}{_type}", fg="blue", nl=False)
818             if node.prefix:
819                 # We don't have to handle prefixes for `Node` objects since
820                 # that delegates to the first child anyway.
821                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
822             out(f" {node.value!r}", fg="blue", bold=False)
823
824     @classmethod
825     def show(cls, code: Union[str, Leaf, Node]) -> None:
826         """Pretty-print the lib2to3 AST of a given string of `code`.
827
828         Convenience method for debugging.
829         """
830         v: DebugVisitor[None] = DebugVisitor()
831         if isinstance(code, str):
832             code = lib2to3_parse(code)
833         list(v.visit(code))
834
835
836 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
837 STATEMENT = {
838     syms.if_stmt,
839     syms.while_stmt,
840     syms.for_stmt,
841     syms.try_stmt,
842     syms.except_clause,
843     syms.with_stmt,
844     syms.funcdef,
845     syms.classdef,
846 }
847 STANDALONE_COMMENT = 153
848 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
849 LOGIC_OPERATORS = {"and", "or"}
850 COMPARATORS = {
851     token.LESS,
852     token.GREATER,
853     token.EQEQUAL,
854     token.NOTEQUAL,
855     token.LESSEQUAL,
856     token.GREATEREQUAL,
857 }
858 MATH_OPERATORS = {
859     token.VBAR,
860     token.CIRCUMFLEX,
861     token.AMPER,
862     token.LEFTSHIFT,
863     token.RIGHTSHIFT,
864     token.PLUS,
865     token.MINUS,
866     token.STAR,
867     token.SLASH,
868     token.DOUBLESLASH,
869     token.PERCENT,
870     token.AT,
871     token.TILDE,
872     token.DOUBLESTAR,
873 }
874 STARS = {token.STAR, token.DOUBLESTAR}
875 VARARGS_PARENTS = {
876     syms.arglist,
877     syms.argument,  # double star in arglist
878     syms.trailer,  # single argument to call
879     syms.typedargslist,
880     syms.varargslist,  # lambdas
881 }
882 UNPACKING_PARENTS = {
883     syms.atom,  # single element of a list or set literal
884     syms.dictsetmaker,
885     syms.listmaker,
886     syms.testlist_gexp,
887     syms.testlist_star_expr,
888 }
889 TEST_DESCENDANTS = {
890     syms.test,
891     syms.lambdef,
892     syms.or_test,
893     syms.and_test,
894     syms.not_test,
895     syms.comparison,
896     syms.star_expr,
897     syms.expr,
898     syms.xor_expr,
899     syms.and_expr,
900     syms.shift_expr,
901     syms.arith_expr,
902     syms.trailer,
903     syms.term,
904     syms.power,
905 }
906 ASSIGNMENTS = {
907     "=",
908     "+=",
909     "-=",
910     "*=",
911     "@=",
912     "/=",
913     "%=",
914     "&=",
915     "|=",
916     "^=",
917     "<<=",
918     ">>=",
919     "**=",
920     "//=",
921 }
922 COMPREHENSION_PRIORITY = 20
923 COMMA_PRIORITY = 18
924 TERNARY_PRIORITY = 16
925 LOGIC_PRIORITY = 14
926 STRING_PRIORITY = 12
927 COMPARATOR_PRIORITY = 10
928 MATH_PRIORITIES = {
929     token.VBAR: 9,
930     token.CIRCUMFLEX: 8,
931     token.AMPER: 7,
932     token.LEFTSHIFT: 6,
933     token.RIGHTSHIFT: 6,
934     token.PLUS: 5,
935     token.MINUS: 5,
936     token.STAR: 4,
937     token.SLASH: 4,
938     token.DOUBLESLASH: 4,
939     token.PERCENT: 4,
940     token.AT: 4,
941     token.TILDE: 3,
942     token.DOUBLESTAR: 2,
943 }
944 DOT_PRIORITY = 1
945
946
947 @dataclass
948 class BracketTracker:
949     """Keeps track of brackets on a line."""
950
951     depth: int = 0
952     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
953     delimiters: Dict[LeafID, Priority] = Factory(dict)
954     previous: Optional[Leaf] = None
955     _for_loop_depths: List[int] = Factory(list)
956     _lambda_argument_depths: List[int] = Factory(list)
957
958     def mark(self, leaf: Leaf) -> None:
959         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
960
961         All leaves receive an int `bracket_depth` field that stores how deep
962         within brackets a given leaf is. 0 means there are no enclosing brackets
963         that started on this line.
964
965         If a leaf is itself a closing bracket, it receives an `opening_bracket`
966         field that it forms a pair with. This is a one-directional link to
967         avoid reference cycles.
968
969         If a leaf is a delimiter (a token on which Black can split the line if
970         needed) and it's on depth 0, its `id()` is stored in the tracker's
971         `delimiters` field.
972         """
973         if leaf.type == token.COMMENT:
974             return
975
976         self.maybe_decrement_after_for_loop_variable(leaf)
977         self.maybe_decrement_after_lambda_arguments(leaf)
978         if leaf.type in CLOSING_BRACKETS:
979             self.depth -= 1
980             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
981             leaf.opening_bracket = opening_bracket
982         leaf.bracket_depth = self.depth
983         if self.depth == 0:
984             delim = is_split_before_delimiter(leaf, self.previous)
985             if delim and self.previous is not None:
986                 self.delimiters[id(self.previous)] = delim
987             else:
988                 delim = is_split_after_delimiter(leaf, self.previous)
989                 if delim:
990                     self.delimiters[id(leaf)] = delim
991         if leaf.type in OPENING_BRACKETS:
992             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
993             self.depth += 1
994         self.previous = leaf
995         self.maybe_increment_lambda_arguments(leaf)
996         self.maybe_increment_for_loop_variable(leaf)
997
998     def any_open_brackets(self) -> bool:
999         """Return True if there is an yet unmatched open bracket on the line."""
1000         return bool(self.bracket_match)
1001
1002     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1003         """Return the highest priority of a delimiter found on the line.
1004
1005         Values are consistent with what `is_split_*_delimiter()` return.
1006         Raises ValueError on no delimiters.
1007         """
1008         return max(v for k, v in self.delimiters.items() if k not in exclude)
1009
1010     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1011         """Return the number of delimiters with the given `priority`.
1012
1013         If no `priority` is passed, defaults to max priority on the line.
1014         """
1015         if not self.delimiters:
1016             return 0
1017
1018         priority = priority or self.max_delimiter_priority()
1019         return sum(1 for p in self.delimiters.values() if p == priority)
1020
1021     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1022         """In a for loop, or comprehension, the variables are often unpacks.
1023
1024         To avoid splitting on the comma in this situation, increase the depth of
1025         tokens between `for` and `in`.
1026         """
1027         if leaf.type == token.NAME and leaf.value == "for":
1028             self.depth += 1
1029             self._for_loop_depths.append(self.depth)
1030             return True
1031
1032         return False
1033
1034     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1035         """See `maybe_increment_for_loop_variable` above for explanation."""
1036         if (
1037             self._for_loop_depths
1038             and self._for_loop_depths[-1] == self.depth
1039             and leaf.type == token.NAME
1040             and leaf.value == "in"
1041         ):
1042             self.depth -= 1
1043             self._for_loop_depths.pop()
1044             return True
1045
1046         return False
1047
1048     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1049         """In a lambda expression, there might be more than one argument.
1050
1051         To avoid splitting on the comma in this situation, increase the depth of
1052         tokens between `lambda` and `:`.
1053         """
1054         if leaf.type == token.NAME and leaf.value == "lambda":
1055             self.depth += 1
1056             self._lambda_argument_depths.append(self.depth)
1057             return True
1058
1059         return False
1060
1061     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1062         """See `maybe_increment_lambda_arguments` above for explanation."""
1063         if (
1064             self._lambda_argument_depths
1065             and self._lambda_argument_depths[-1] == self.depth
1066             and leaf.type == token.COLON
1067         ):
1068             self.depth -= 1
1069             self._lambda_argument_depths.pop()
1070             return True
1071
1072         return False
1073
1074     def get_open_lsqb(self) -> Optional[Leaf]:
1075         """Return the most recent opening square bracket (if any)."""
1076         return self.bracket_match.get((self.depth - 1, token.RSQB))
1077
1078
1079 @dataclass
1080 class Line:
1081     """Holds leaves and comments. Can be printed with `str(line)`."""
1082
1083     depth: int = 0
1084     leaves: List[Leaf] = Factory(list)
1085     comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
1086     bracket_tracker: BracketTracker = Factory(BracketTracker)
1087     inside_brackets: bool = False
1088     should_explode: bool = False
1089
1090     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1091         """Add a new `leaf` to the end of the line.
1092
1093         Unless `preformatted` is True, the `leaf` will receive a new consistent
1094         whitespace prefix and metadata applied by :class:`BracketTracker`.
1095         Trailing commas are maybe removed, unpacked for loop variables are
1096         demoted from being delimiters.
1097
1098         Inline comments are put aside.
1099         """
1100         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1101         if not has_value:
1102             return
1103
1104         if token.COLON == leaf.type and self.is_class_paren_empty:
1105             del self.leaves[-2:]
1106         if self.leaves and not preformatted:
1107             # Note: at this point leaf.prefix should be empty except for
1108             # imports, for which we only preserve newlines.
1109             leaf.prefix += whitespace(
1110                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1111             )
1112         if self.inside_brackets or not preformatted:
1113             self.bracket_tracker.mark(leaf)
1114             self.maybe_remove_trailing_comma(leaf)
1115         if not self.append_comment(leaf):
1116             self.leaves.append(leaf)
1117
1118     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1119         """Like :func:`append()` but disallow invalid standalone comment structure.
1120
1121         Raises ValueError when any `leaf` is appended after a standalone comment
1122         or when a standalone comment is not the first leaf on the line.
1123         """
1124         if self.bracket_tracker.depth == 0:
1125             if self.is_comment:
1126                 raise ValueError("cannot append to standalone comments")
1127
1128             if self.leaves and leaf.type == STANDALONE_COMMENT:
1129                 raise ValueError(
1130                     "cannot append standalone comments to a populated line"
1131                 )
1132
1133         self.append(leaf, preformatted=preformatted)
1134
1135     @property
1136     def is_comment(self) -> bool:
1137         """Is this line a standalone comment?"""
1138         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1139
1140     @property
1141     def is_decorator(self) -> bool:
1142         """Is this line a decorator?"""
1143         return bool(self) and self.leaves[0].type == token.AT
1144
1145     @property
1146     def is_import(self) -> bool:
1147         """Is this an import line?"""
1148         return bool(self) and is_import(self.leaves[0])
1149
1150     @property
1151     def is_class(self) -> bool:
1152         """Is this line a class definition?"""
1153         return (
1154             bool(self)
1155             and self.leaves[0].type == token.NAME
1156             and self.leaves[0].value == "class"
1157         )
1158
1159     @property
1160     def is_stub_class(self) -> bool:
1161         """Is this line a class definition with a body consisting only of "..."?"""
1162         return self.is_class and self.leaves[-3:] == [
1163             Leaf(token.DOT, ".") for _ in range(3)
1164         ]
1165
1166     @property
1167     def is_def(self) -> bool:
1168         """Is this a function definition? (Also returns True for async defs.)"""
1169         try:
1170             first_leaf = self.leaves[0]
1171         except IndexError:
1172             return False
1173
1174         try:
1175             second_leaf: Optional[Leaf] = self.leaves[1]
1176         except IndexError:
1177             second_leaf = None
1178         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1179             first_leaf.type == token.ASYNC
1180             and second_leaf is not None
1181             and second_leaf.type == token.NAME
1182             and second_leaf.value == "def"
1183         )
1184
1185     @property
1186     def is_class_paren_empty(self) -> bool:
1187         """Is this a class with no base classes but using parentheses?
1188
1189         Those are unnecessary and should be removed.
1190         """
1191         return (
1192             bool(self)
1193             and len(self.leaves) == 4
1194             and self.is_class
1195             and self.leaves[2].type == token.LPAR
1196             and self.leaves[2].value == "("
1197             and self.leaves[3].type == token.RPAR
1198             and self.leaves[3].value == ")"
1199         )
1200
1201     @property
1202     def is_triple_quoted_string(self) -> bool:
1203         """Is the line a triple quoted string?"""
1204         return (
1205             bool(self)
1206             and self.leaves[0].type == token.STRING
1207             and self.leaves[0].value.startswith(('"""', "'''"))
1208         )
1209
1210     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1211         """If so, needs to be split before emitting."""
1212         for leaf in self.leaves:
1213             if leaf.type == STANDALONE_COMMENT:
1214                 if leaf.bracket_depth <= depth_limit:
1215                     return True
1216         return False
1217
1218     def contains_inner_type_comments(self) -> bool:
1219         ignored_ids = set()
1220         try:
1221             last_leaf = self.leaves[-1]
1222             ignored_ids.add(id(last_leaf))
1223             if last_leaf.type == token.COMMA:
1224                 # When trailing commas are inserted by Black for consistency, comments
1225                 # after the previous last element are not moved (they don't have to,
1226                 # rendering will still be correct).  So we ignore trailing commas.
1227                 last_leaf = self.leaves[-2]
1228                 ignored_ids.add(id(last_leaf))
1229         except IndexError:
1230             return False
1231
1232         for leaf_id, comments in self.comments.items():
1233             if leaf_id in ignored_ids:
1234                 continue
1235
1236             for comment in comments:
1237                 if is_type_comment(comment):
1238                     return True
1239
1240         return False
1241
1242     def contains_multiline_strings(self) -> bool:
1243         for leaf in self.leaves:
1244             if is_multiline_string(leaf):
1245                 return True
1246
1247         return False
1248
1249     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1250         """Remove trailing comma if there is one and it's safe."""
1251         if not (
1252             self.leaves
1253             and self.leaves[-1].type == token.COMMA
1254             and closing.type in CLOSING_BRACKETS
1255         ):
1256             return False
1257
1258         if closing.type == token.RBRACE:
1259             self.remove_trailing_comma()
1260             return True
1261
1262         if closing.type == token.RSQB:
1263             comma = self.leaves[-1]
1264             if comma.parent and comma.parent.type == syms.listmaker:
1265                 self.remove_trailing_comma()
1266                 return True
1267
1268         # For parens let's check if it's safe to remove the comma.
1269         # Imports are always safe.
1270         if self.is_import:
1271             self.remove_trailing_comma()
1272             return True
1273
1274         # Otherwise, if the trailing one is the only one, we might mistakenly
1275         # change a tuple into a different type by removing the comma.
1276         depth = closing.bracket_depth + 1
1277         commas = 0
1278         opening = closing.opening_bracket
1279         for _opening_index, leaf in enumerate(self.leaves):
1280             if leaf is opening:
1281                 break
1282
1283         else:
1284             return False
1285
1286         for leaf in self.leaves[_opening_index + 1 :]:
1287             if leaf is closing:
1288                 break
1289
1290             bracket_depth = leaf.bracket_depth
1291             if bracket_depth == depth and leaf.type == token.COMMA:
1292                 commas += 1
1293                 if leaf.parent and leaf.parent.type == syms.arglist:
1294                     commas += 1
1295                     break
1296
1297         if commas > 1:
1298             self.remove_trailing_comma()
1299             return True
1300
1301         return False
1302
1303     def append_comment(self, comment: Leaf) -> bool:
1304         """Add an inline or standalone comment to the line."""
1305         if (
1306             comment.type == STANDALONE_COMMENT
1307             and self.bracket_tracker.any_open_brackets()
1308         ):
1309             comment.prefix = ""
1310             return False
1311
1312         if comment.type != token.COMMENT:
1313             return False
1314
1315         if not self.leaves:
1316             comment.type = STANDALONE_COMMENT
1317             comment.prefix = ""
1318             return False
1319
1320         self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
1321         return True
1322
1323     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1324         """Generate comments that should appear directly after `leaf`."""
1325         return self.comments.get(id(leaf), [])
1326
1327     def remove_trailing_comma(self) -> None:
1328         """Remove the trailing comma and moves the comments attached to it."""
1329         trailing_comma = self.leaves.pop()
1330         trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
1331         self.comments.setdefault(id(self.leaves[-1]), []).extend(
1332             trailing_comma_comments
1333         )
1334
1335     def is_complex_subscript(self, leaf: Leaf) -> bool:
1336         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1337         open_lsqb = self.bracket_tracker.get_open_lsqb()
1338         if open_lsqb is None:
1339             return False
1340
1341         subscript_start = open_lsqb.next_sibling
1342
1343         if isinstance(subscript_start, Node):
1344             if subscript_start.type == syms.listmaker:
1345                 return False
1346
1347             if subscript_start.type == syms.subscriptlist:
1348                 subscript_start = child_towards(subscript_start, leaf)
1349         return subscript_start is not None and any(
1350             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1351         )
1352
1353     def __str__(self) -> str:
1354         """Render the line."""
1355         if not self:
1356             return "\n"
1357
1358         indent = "    " * self.depth
1359         leaves = iter(self.leaves)
1360         first = next(leaves)
1361         res = f"{first.prefix}{indent}{first.value}"
1362         for leaf in leaves:
1363             res += str(leaf)
1364         for comment in itertools.chain.from_iterable(self.comments.values()):
1365             res += str(comment)
1366         return res + "\n"
1367
1368     def __bool__(self) -> bool:
1369         """Return True if the line has leaves or comments."""
1370         return bool(self.leaves or self.comments)
1371
1372
1373 @dataclass
1374 class EmptyLineTracker:
1375     """Provides a stateful method that returns the number of potential extra
1376     empty lines needed before and after the currently processed line.
1377
1378     Note: this tracker works on lines that haven't been split yet.  It assumes
1379     the prefix of the first leaf consists of optional newlines.  Those newlines
1380     are consumed by `maybe_empty_lines()` and included in the computation.
1381     """
1382
1383     is_pyi: bool = False
1384     previous_line: Optional[Line] = None
1385     previous_after: int = 0
1386     previous_defs: List[int] = Factory(list)
1387
1388     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1389         """Return the number of extra empty lines before and after the `current_line`.
1390
1391         This is for separating `def`, `async def` and `class` with extra empty
1392         lines (two on module-level).
1393         """
1394         before, after = self._maybe_empty_lines(current_line)
1395         before -= self.previous_after
1396         self.previous_after = after
1397         self.previous_line = current_line
1398         return before, after
1399
1400     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1401         max_allowed = 1
1402         if current_line.depth == 0:
1403             max_allowed = 1 if self.is_pyi else 2
1404         if current_line.leaves:
1405             # Consume the first leaf's extra newlines.
1406             first_leaf = current_line.leaves[0]
1407             before = first_leaf.prefix.count("\n")
1408             before = min(before, max_allowed)
1409             first_leaf.prefix = ""
1410         else:
1411             before = 0
1412         depth = current_line.depth
1413         while self.previous_defs and self.previous_defs[-1] >= depth:
1414             self.previous_defs.pop()
1415             if self.is_pyi:
1416                 before = 0 if depth else 1
1417             else:
1418                 before = 1 if depth else 2
1419         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1420             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1421
1422         if (
1423             self.previous_line
1424             and self.previous_line.is_import
1425             and not current_line.is_import
1426             and depth == self.previous_line.depth
1427         ):
1428             return (before or 1), 0
1429
1430         if (
1431             self.previous_line
1432             and self.previous_line.is_class
1433             and current_line.is_triple_quoted_string
1434         ):
1435             return before, 1
1436
1437         return before, 0
1438
1439     def _maybe_empty_lines_for_class_or_def(
1440         self, current_line: Line, before: int
1441     ) -> Tuple[int, int]:
1442         if not current_line.is_decorator:
1443             self.previous_defs.append(current_line.depth)
1444         if self.previous_line is None:
1445             # Don't insert empty lines before the first line in the file.
1446             return 0, 0
1447
1448         if self.previous_line.is_decorator:
1449             return 0, 0
1450
1451         if self.previous_line.depth < current_line.depth and (
1452             self.previous_line.is_class or self.previous_line.is_def
1453         ):
1454             return 0, 0
1455
1456         if (
1457             self.previous_line.is_comment
1458             and self.previous_line.depth == current_line.depth
1459             and before == 0
1460         ):
1461             return 0, 0
1462
1463         if self.is_pyi:
1464             if self.previous_line.depth > current_line.depth:
1465                 newlines = 1
1466             elif current_line.is_class or self.previous_line.is_class:
1467                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1468                     # No blank line between classes with an empty body
1469                     newlines = 0
1470                 else:
1471                     newlines = 1
1472             elif current_line.is_def and not self.previous_line.is_def:
1473                 # Blank line between a block of functions and a block of non-functions
1474                 newlines = 1
1475             else:
1476                 newlines = 0
1477         else:
1478             newlines = 2
1479         if current_line.depth and newlines:
1480             newlines -= 1
1481         return newlines, 0
1482
1483
1484 @dataclass
1485 class LineGenerator(Visitor[Line]):
1486     """Generates reformatted Line objects.  Empty lines are not emitted.
1487
1488     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1489     in ways that will no longer stringify to valid Python code on the tree.
1490     """
1491
1492     is_pyi: bool = False
1493     normalize_strings: bool = True
1494     current_line: Line = Factory(Line)
1495     remove_u_prefix: bool = False
1496
1497     def line(self, indent: int = 0) -> Iterator[Line]:
1498         """Generate a line.
1499
1500         If the line is empty, only emit if it makes sense.
1501         If the line is too long, split it first and then generate.
1502
1503         If any lines were generated, set up a new current_line.
1504         """
1505         if not self.current_line:
1506             self.current_line.depth += indent
1507             return  # Line is empty, don't emit. Creating a new one unnecessary.
1508
1509         complete_line = self.current_line
1510         self.current_line = Line(depth=complete_line.depth + indent)
1511         yield complete_line
1512
1513     def visit_default(self, node: LN) -> Iterator[Line]:
1514         """Default `visit_*()` implementation. Recurses to children of `node`."""
1515         if isinstance(node, Leaf):
1516             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1517             for comment in generate_comments(node):
1518                 if any_open_brackets:
1519                     # any comment within brackets is subject to splitting
1520                     self.current_line.append(comment)
1521                 elif comment.type == token.COMMENT:
1522                     # regular trailing comment
1523                     self.current_line.append(comment)
1524                     yield from self.line()
1525
1526                 else:
1527                     # regular standalone comment
1528                     yield from self.line()
1529
1530                     self.current_line.append(comment)
1531                     yield from self.line()
1532
1533             normalize_prefix(node, inside_brackets=any_open_brackets)
1534             if self.normalize_strings and node.type == token.STRING:
1535                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1536                 normalize_string_quotes(node)
1537             if node.type == token.NUMBER:
1538                 normalize_numeric_literal(node)
1539             if node.type not in WHITESPACE:
1540                 self.current_line.append(node)
1541         yield from super().visit_default(node)
1542
1543     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1544         """Increase indentation level, maybe yield a line."""
1545         # In blib2to3 INDENT never holds comments.
1546         yield from self.line(+1)
1547         yield from self.visit_default(node)
1548
1549     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1550         """Decrease indentation level, maybe yield a line."""
1551         # The current line might still wait for trailing comments.  At DEDENT time
1552         # there won't be any (they would be prefixes on the preceding NEWLINE).
1553         # Emit the line then.
1554         yield from self.line()
1555
1556         # While DEDENT has no value, its prefix may contain standalone comments
1557         # that belong to the current indentation level.  Get 'em.
1558         yield from self.visit_default(node)
1559
1560         # Finally, emit the dedent.
1561         yield from self.line(-1)
1562
1563     def visit_stmt(
1564         self, node: Node, keywords: Set[str], parens: Set[str]
1565     ) -> Iterator[Line]:
1566         """Visit a statement.
1567
1568         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1569         `def`, `with`, `class`, `assert` and assignments.
1570
1571         The relevant Python language `keywords` for a given statement will be
1572         NAME leaves within it. This methods puts those on a separate line.
1573
1574         `parens` holds a set of string leaf values immediately after which
1575         invisible parens should be put.
1576         """
1577         normalize_invisible_parens(node, parens_after=parens)
1578         for child in node.children:
1579             if child.type == token.NAME and child.value in keywords:  # type: ignore
1580                 yield from self.line()
1581
1582             yield from self.visit(child)
1583
1584     def visit_suite(self, node: Node) -> Iterator[Line]:
1585         """Visit a suite."""
1586         if self.is_pyi and is_stub_suite(node):
1587             yield from self.visit(node.children[2])
1588         else:
1589             yield from self.visit_default(node)
1590
1591     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1592         """Visit a statement without nested statements."""
1593         is_suite_like = node.parent and node.parent.type in STATEMENT
1594         if is_suite_like:
1595             if self.is_pyi and is_stub_body(node):
1596                 yield from self.visit_default(node)
1597             else:
1598                 yield from self.line(+1)
1599                 yield from self.visit_default(node)
1600                 yield from self.line(-1)
1601
1602         else:
1603             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1604                 yield from self.line()
1605             yield from self.visit_default(node)
1606
1607     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1608         """Visit `async def`, `async for`, `async with`."""
1609         yield from self.line()
1610
1611         children = iter(node.children)
1612         for child in children:
1613             yield from self.visit(child)
1614
1615             if child.type == token.ASYNC:
1616                 break
1617
1618         internal_stmt = next(children)
1619         for child in internal_stmt.children:
1620             yield from self.visit(child)
1621
1622     def visit_decorators(self, node: Node) -> Iterator[Line]:
1623         """Visit decorators."""
1624         for child in node.children:
1625             yield from self.line()
1626             yield from self.visit(child)
1627
1628     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1629         """Remove a semicolon and put the other statement on a separate line."""
1630         yield from self.line()
1631
1632     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1633         """End of file. Process outstanding comments and end with a newline."""
1634         yield from self.visit_default(leaf)
1635         yield from self.line()
1636
1637     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1638         if not self.current_line.bracket_tracker.any_open_brackets():
1639             yield from self.line()
1640         yield from self.visit_default(leaf)
1641
1642     def __attrs_post_init__(self) -> None:
1643         """You are in a twisty little maze of passages."""
1644         v = self.visit_stmt
1645         Ø: Set[str] = set()
1646         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1647         self.visit_if_stmt = partial(
1648             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1649         )
1650         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1651         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1652         self.visit_try_stmt = partial(
1653             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1654         )
1655         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1656         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1657         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1658         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1659         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1660         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1661         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1662         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1663         self.visit_async_funcdef = self.visit_async_stmt
1664         self.visit_decorated = self.visit_decorators
1665
1666
1667 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1668 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1669 OPENING_BRACKETS = set(BRACKET.keys())
1670 CLOSING_BRACKETS = set(BRACKET.values())
1671 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1672 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1673
1674
1675 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1676     """Return whitespace prefix if needed for the given `leaf`.
1677
1678     `complex_subscript` signals whether the given leaf is part of a subscription
1679     which has non-trivial arguments, like arithmetic expressions or function calls.
1680     """
1681     NO = ""
1682     SPACE = " "
1683     DOUBLESPACE = "  "
1684     t = leaf.type
1685     p = leaf.parent
1686     v = leaf.value
1687     if t in ALWAYS_NO_SPACE:
1688         return NO
1689
1690     if t == token.COMMENT:
1691         return DOUBLESPACE
1692
1693     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1694     if t == token.COLON and p.type not in {
1695         syms.subscript,
1696         syms.subscriptlist,
1697         syms.sliceop,
1698     }:
1699         return NO
1700
1701     prev = leaf.prev_sibling
1702     if not prev:
1703         prevp = preceding_leaf(p)
1704         if not prevp or prevp.type in OPENING_BRACKETS:
1705             return NO
1706
1707         if t == token.COLON:
1708             if prevp.type == token.COLON:
1709                 return NO
1710
1711             elif prevp.type != token.COMMA and not complex_subscript:
1712                 return NO
1713
1714             return SPACE
1715
1716         if prevp.type == token.EQUAL:
1717             if prevp.parent:
1718                 if prevp.parent.type in {
1719                     syms.arglist,
1720                     syms.argument,
1721                     syms.parameters,
1722                     syms.varargslist,
1723                 }:
1724                     return NO
1725
1726                 elif prevp.parent.type == syms.typedargslist:
1727                     # A bit hacky: if the equal sign has whitespace, it means we
1728                     # previously found it's a typed argument.  So, we're using
1729                     # that, too.
1730                     return prevp.prefix
1731
1732         elif prevp.type in STARS:
1733             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1734                 return NO
1735
1736         elif prevp.type == token.COLON:
1737             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1738                 return SPACE if complex_subscript else NO
1739
1740         elif (
1741             prevp.parent
1742             and prevp.parent.type == syms.factor
1743             and prevp.type in MATH_OPERATORS
1744         ):
1745             return NO
1746
1747         elif (
1748             prevp.type == token.RIGHTSHIFT
1749             and prevp.parent
1750             and prevp.parent.type == syms.shift_expr
1751             and prevp.prev_sibling
1752             and prevp.prev_sibling.type == token.NAME
1753             and prevp.prev_sibling.value == "print"  # type: ignore
1754         ):
1755             # Python 2 print chevron
1756             return NO
1757
1758     elif prev.type in OPENING_BRACKETS:
1759         return NO
1760
1761     if p.type in {syms.parameters, syms.arglist}:
1762         # untyped function signatures or calls
1763         if not prev or prev.type != token.COMMA:
1764             return NO
1765
1766     elif p.type == syms.varargslist:
1767         # lambdas
1768         if prev and prev.type != token.COMMA:
1769             return NO
1770
1771     elif p.type == syms.typedargslist:
1772         # typed function signatures
1773         if not prev:
1774             return NO
1775
1776         if t == token.EQUAL:
1777             if prev.type != syms.tname:
1778                 return NO
1779
1780         elif prev.type == token.EQUAL:
1781             # A bit hacky: if the equal sign has whitespace, it means we
1782             # previously found it's a typed argument.  So, we're using that, too.
1783             return prev.prefix
1784
1785         elif prev.type != token.COMMA:
1786             return NO
1787
1788     elif p.type == syms.tname:
1789         # type names
1790         if not prev:
1791             prevp = preceding_leaf(p)
1792             if not prevp or prevp.type != token.COMMA:
1793                 return NO
1794
1795     elif p.type == syms.trailer:
1796         # attributes and calls
1797         if t == token.LPAR or t == token.RPAR:
1798             return NO
1799
1800         if not prev:
1801             if t == token.DOT:
1802                 prevp = preceding_leaf(p)
1803                 if not prevp or prevp.type != token.NUMBER:
1804                     return NO
1805
1806             elif t == token.LSQB:
1807                 return NO
1808
1809         elif prev.type != token.COMMA:
1810             return NO
1811
1812     elif p.type == syms.argument:
1813         # single argument
1814         if t == token.EQUAL:
1815             return NO
1816
1817         if not prev:
1818             prevp = preceding_leaf(p)
1819             if not prevp or prevp.type == token.LPAR:
1820                 return NO
1821
1822         elif prev.type in {token.EQUAL} | STARS:
1823             return NO
1824
1825     elif p.type == syms.decorator:
1826         # decorators
1827         return NO
1828
1829     elif p.type == syms.dotted_name:
1830         if prev:
1831             return NO
1832
1833         prevp = preceding_leaf(p)
1834         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1835             return NO
1836
1837     elif p.type == syms.classdef:
1838         if t == token.LPAR:
1839             return NO
1840
1841         if prev and prev.type == token.LPAR:
1842             return NO
1843
1844     elif p.type in {syms.subscript, syms.sliceop}:
1845         # indexing
1846         if not prev:
1847             assert p.parent is not None, "subscripts are always parented"
1848             if p.parent.type == syms.subscriptlist:
1849                 return SPACE
1850
1851             return NO
1852
1853         elif not complex_subscript:
1854             return NO
1855
1856     elif p.type == syms.atom:
1857         if prev and t == token.DOT:
1858             # dots, but not the first one.
1859             return NO
1860
1861     elif p.type == syms.dictsetmaker:
1862         # dict unpacking
1863         if prev and prev.type == token.DOUBLESTAR:
1864             return NO
1865
1866     elif p.type in {syms.factor, syms.star_expr}:
1867         # unary ops
1868         if not prev:
1869             prevp = preceding_leaf(p)
1870             if not prevp or prevp.type in OPENING_BRACKETS:
1871                 return NO
1872
1873             prevp_parent = prevp.parent
1874             assert prevp_parent is not None
1875             if prevp.type == token.COLON and prevp_parent.type in {
1876                 syms.subscript,
1877                 syms.sliceop,
1878             }:
1879                 return NO
1880
1881             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1882                 return NO
1883
1884         elif t in {token.NAME, token.NUMBER, token.STRING}:
1885             return NO
1886
1887     elif p.type == syms.import_from:
1888         if t == token.DOT:
1889             if prev and prev.type == token.DOT:
1890                 return NO
1891
1892         elif t == token.NAME:
1893             if v == "import":
1894                 return SPACE
1895
1896             if prev and prev.type == token.DOT:
1897                 return NO
1898
1899     elif p.type == syms.sliceop:
1900         return NO
1901
1902     return SPACE
1903
1904
1905 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1906     """Return the first leaf that precedes `node`, if any."""
1907     while node:
1908         res = node.prev_sibling
1909         if res:
1910             if isinstance(res, Leaf):
1911                 return res
1912
1913             try:
1914                 return list(res.leaves())[-1]
1915
1916             except IndexError:
1917                 return None
1918
1919         node = node.parent
1920     return None
1921
1922
1923 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1924     """Return the child of `ancestor` that contains `descendant`."""
1925     node: Optional[LN] = descendant
1926     while node and node.parent != ancestor:
1927         node = node.parent
1928     return node
1929
1930
1931 def container_of(leaf: Leaf) -> LN:
1932     """Return `leaf` or one of its ancestors that is the topmost container of it.
1933
1934     By "container" we mean a node where `leaf` is the very first child.
1935     """
1936     same_prefix = leaf.prefix
1937     container: LN = leaf
1938     while container:
1939         parent = container.parent
1940         if parent is None:
1941             break
1942
1943         if parent.children[0].prefix != same_prefix:
1944             break
1945
1946         if parent.type == syms.file_input:
1947             break
1948
1949         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1950             break
1951
1952         container = parent
1953     return container
1954
1955
1956 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1957     """Return the priority of the `leaf` delimiter, given a line break after it.
1958
1959     The delimiter priorities returned here are from those delimiters that would
1960     cause a line break after themselves.
1961
1962     Higher numbers are higher priority.
1963     """
1964     if leaf.type == token.COMMA:
1965         return COMMA_PRIORITY
1966
1967     return 0
1968
1969
1970 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1971     """Return the priority of the `leaf` delimiter, given a line break before it.
1972
1973     The delimiter priorities returned here are from those delimiters that would
1974     cause a line break before themselves.
1975
1976     Higher numbers are higher priority.
1977     """
1978     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1979         # * and ** might also be MATH_OPERATORS but in this case they are not.
1980         # Don't treat them as a delimiter.
1981         return 0
1982
1983     if (
1984         leaf.type == token.DOT
1985         and leaf.parent
1986         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1987         and (previous is None or previous.type in CLOSING_BRACKETS)
1988     ):
1989         return DOT_PRIORITY
1990
1991     if (
1992         leaf.type in MATH_OPERATORS
1993         and leaf.parent
1994         and leaf.parent.type not in {syms.factor, syms.star_expr}
1995     ):
1996         return MATH_PRIORITIES[leaf.type]
1997
1998     if leaf.type in COMPARATORS:
1999         return COMPARATOR_PRIORITY
2000
2001     if (
2002         leaf.type == token.STRING
2003         and previous is not None
2004         and previous.type == token.STRING
2005     ):
2006         return STRING_PRIORITY
2007
2008     if leaf.type not in {token.NAME, token.ASYNC}:
2009         return 0
2010
2011     if (
2012         leaf.value == "for"
2013         and leaf.parent
2014         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2015         or leaf.type == token.ASYNC
2016     ):
2017         if (
2018             not isinstance(leaf.prev_sibling, Leaf)
2019             or leaf.prev_sibling.value != "async"
2020         ):
2021             return COMPREHENSION_PRIORITY
2022
2023     if (
2024         leaf.value == "if"
2025         and leaf.parent
2026         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2027     ):
2028         return COMPREHENSION_PRIORITY
2029
2030     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2031         return TERNARY_PRIORITY
2032
2033     if leaf.value == "is":
2034         return COMPARATOR_PRIORITY
2035
2036     if (
2037         leaf.value == "in"
2038         and leaf.parent
2039         and leaf.parent.type in {syms.comp_op, syms.comparison}
2040         and not (
2041             previous is not None
2042             and previous.type == token.NAME
2043             and previous.value == "not"
2044         )
2045     ):
2046         return COMPARATOR_PRIORITY
2047
2048     if (
2049         leaf.value == "not"
2050         and leaf.parent
2051         and leaf.parent.type == syms.comp_op
2052         and not (
2053             previous is not None
2054             and previous.type == token.NAME
2055             and previous.value == "is"
2056         )
2057     ):
2058         return COMPARATOR_PRIORITY
2059
2060     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2061         return LOGIC_PRIORITY
2062
2063     return 0
2064
2065
2066 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2067 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2068
2069
2070 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2071     """Clean the prefix of the `leaf` and generate comments from it, if any.
2072
2073     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2074     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2075     move because it does away with modifying the grammar to include all the
2076     possible places in which comments can be placed.
2077
2078     The sad consequence for us though is that comments don't "belong" anywhere.
2079     This is why this function generates simple parentless Leaf objects for
2080     comments.  We simply don't know what the correct parent should be.
2081
2082     No matter though, we can live without this.  We really only need to
2083     differentiate between inline and standalone comments.  The latter don't
2084     share the line with any code.
2085
2086     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2087     are emitted with a fake STANDALONE_COMMENT token identifier.
2088     """
2089     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2090         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2091
2092
2093 @dataclass
2094 class ProtoComment:
2095     """Describes a piece of syntax that is a comment.
2096
2097     It's not a :class:`blib2to3.pytree.Leaf` so that:
2098
2099     * it can be cached (`Leaf` objects should not be reused more than once as
2100       they store their lineno, column, prefix, and parent information);
2101     * `newlines` and `consumed` fields are kept separate from the `value`. This
2102       simplifies handling of special marker comments like ``# fmt: off/on``.
2103     """
2104
2105     type: int  # token.COMMENT or STANDALONE_COMMENT
2106     value: str  # content of the comment
2107     newlines: int  # how many newlines before the comment
2108     consumed: int  # how many characters of the original leaf's prefix did we consume
2109
2110
2111 @lru_cache(maxsize=4096)
2112 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2113     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2114     result: List[ProtoComment] = []
2115     if not prefix or "#" not in prefix:
2116         return result
2117
2118     consumed = 0
2119     nlines = 0
2120     for index, line in enumerate(prefix.split("\n")):
2121         consumed += len(line) + 1  # adding the length of the split '\n'
2122         line = line.lstrip()
2123         if not line:
2124             nlines += 1
2125         if not line.startswith("#"):
2126             continue
2127
2128         if index == 0 and not is_endmarker:
2129             comment_type = token.COMMENT  # simple trailing comment
2130         else:
2131             comment_type = STANDALONE_COMMENT
2132         comment = make_comment(line)
2133         result.append(
2134             ProtoComment(
2135                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2136             )
2137         )
2138         nlines = 0
2139     return result
2140
2141
2142 def make_comment(content: str) -> str:
2143     """Return a consistently formatted comment from the given `content` string.
2144
2145     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2146     space between the hash sign and the content.
2147
2148     If `content` didn't start with a hash sign, one is provided.
2149     """
2150     content = content.rstrip()
2151     if not content:
2152         return "#"
2153
2154     if content[0] == "#":
2155         content = content[1:]
2156     if content and content[0] not in " !:#'%":
2157         content = " " + content
2158     return "#" + content
2159
2160
2161 def split_line(
2162     line: Line,
2163     line_length: int,
2164     inner: bool = False,
2165     supports_trailing_commas: bool = False,
2166 ) -> Iterator[Line]:
2167     """Split a `line` into potentially many lines.
2168
2169     They should fit in the allotted `line_length` but might not be able to.
2170     `inner` signifies that there were a pair of brackets somewhere around the
2171     current `line`, possibly transitively. This means we can fallback to splitting
2172     by delimiters if the LHS/RHS don't yield any results.
2173
2174     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2175     """
2176     if line.is_comment:
2177         yield line
2178         return
2179
2180     line_str = str(line).strip("\n")
2181
2182     if (
2183         not line.contains_inner_type_comments()
2184         and not line.should_explode
2185         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2186     ):
2187         yield line
2188         return
2189
2190     split_funcs: List[SplitFunc]
2191     if line.is_def:
2192         split_funcs = [left_hand_split]
2193     else:
2194
2195         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2196             for omit in generate_trailers_to_omit(line, line_length):
2197                 lines = list(
2198                     right_hand_split(
2199                         line, line_length, supports_trailing_commas, omit=omit
2200                     )
2201                 )
2202                 if is_line_short_enough(lines[0], line_length=line_length):
2203                     yield from lines
2204                     return
2205
2206             # All splits failed, best effort split with no omits.
2207             # This mostly happens to multiline strings that are by definition
2208             # reported as not fitting a single line.
2209             yield from right_hand_split(line, line_length, supports_trailing_commas)
2210
2211         if line.inside_brackets:
2212             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2213         else:
2214             split_funcs = [rhs]
2215     for split_func in split_funcs:
2216         # We are accumulating lines in `result` because we might want to abort
2217         # mission and return the original line in the end, or attempt a different
2218         # split altogether.
2219         result: List[Line] = []
2220         try:
2221             for l in split_func(line, supports_trailing_commas):
2222                 if str(l).strip("\n") == line_str:
2223                     raise CannotSplit("Split function returned an unchanged result")
2224
2225                 result.extend(
2226                     split_line(
2227                         l,
2228                         line_length=line_length,
2229                         inner=True,
2230                         supports_trailing_commas=supports_trailing_commas,
2231                     )
2232                 )
2233         except CannotSplit:
2234             continue
2235
2236         else:
2237             yield from result
2238             break
2239
2240     else:
2241         yield line
2242
2243
2244 def left_hand_split(
2245     line: Line, supports_trailing_commas: bool = False
2246 ) -> Iterator[Line]:
2247     """Split line into many lines, starting with the first matching bracket pair.
2248
2249     Note: this usually looks weird, only use this for function definitions.
2250     Prefer RHS otherwise.  This is why this function is not symmetrical with
2251     :func:`right_hand_split` which also handles optional parentheses.
2252     """
2253     tail_leaves: List[Leaf] = []
2254     body_leaves: List[Leaf] = []
2255     head_leaves: List[Leaf] = []
2256     current_leaves = head_leaves
2257     matching_bracket = None
2258     for leaf in line.leaves:
2259         if (
2260             current_leaves is body_leaves
2261             and leaf.type in CLOSING_BRACKETS
2262             and leaf.opening_bracket is matching_bracket
2263         ):
2264             current_leaves = tail_leaves if body_leaves else head_leaves
2265         current_leaves.append(leaf)
2266         if current_leaves is head_leaves:
2267             if leaf.type in OPENING_BRACKETS:
2268                 matching_bracket = leaf
2269                 current_leaves = body_leaves
2270     if not matching_bracket:
2271         raise CannotSplit("No brackets found")
2272
2273     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2274     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2275     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2276     bracket_split_succeeded_or_raise(head, body, tail)
2277     for result in (head, body, tail):
2278         if result:
2279             yield result
2280
2281
2282 def right_hand_split(
2283     line: Line,
2284     line_length: int,
2285     supports_trailing_commas: bool = False,
2286     omit: Collection[LeafID] = (),
2287 ) -> Iterator[Line]:
2288     """Split line into many lines, starting with the last matching bracket pair.
2289
2290     If the split was by optional parentheses, attempt splitting without them, too.
2291     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2292     this split.
2293
2294     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2295     """
2296     tail_leaves: List[Leaf] = []
2297     body_leaves: List[Leaf] = []
2298     head_leaves: List[Leaf] = []
2299     current_leaves = tail_leaves
2300     opening_bracket = None
2301     closing_bracket = None
2302     for leaf in reversed(line.leaves):
2303         if current_leaves is body_leaves:
2304             if leaf is opening_bracket:
2305                 current_leaves = head_leaves if body_leaves else tail_leaves
2306         current_leaves.append(leaf)
2307         if current_leaves is tail_leaves:
2308             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2309                 opening_bracket = leaf.opening_bracket
2310                 closing_bracket = leaf
2311                 current_leaves = body_leaves
2312     if not (opening_bracket and closing_bracket and head_leaves):
2313         # If there is no opening or closing_bracket that means the split failed and
2314         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2315         # the matching `opening_bracket` wasn't available on `line` anymore.
2316         raise CannotSplit("No brackets found")
2317
2318     tail_leaves.reverse()
2319     body_leaves.reverse()
2320     head_leaves.reverse()
2321     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2322     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2323     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2324     bracket_split_succeeded_or_raise(head, body, tail)
2325     if (
2326         # the body shouldn't be exploded
2327         not body.should_explode
2328         # the opening bracket is an optional paren
2329         and opening_bracket.type == token.LPAR
2330         and not opening_bracket.value
2331         # the closing bracket is an optional paren
2332         and closing_bracket.type == token.RPAR
2333         and not closing_bracket.value
2334         # it's not an import (optional parens are the only thing we can split on
2335         # in this case; attempting a split without them is a waste of time)
2336         and not line.is_import
2337         # there are no standalone comments in the body
2338         and not body.contains_standalone_comments(0)
2339         # and we can actually remove the parens
2340         and can_omit_invisible_parens(body, line_length)
2341     ):
2342         omit = {id(closing_bracket), *omit}
2343         try:
2344             yield from right_hand_split(
2345                 line,
2346                 line_length,
2347                 supports_trailing_commas=supports_trailing_commas,
2348                 omit=omit,
2349             )
2350             return
2351
2352         except CannotSplit:
2353             if not (
2354                 can_be_split(body)
2355                 or is_line_short_enough(body, line_length=line_length)
2356             ):
2357                 raise CannotSplit(
2358                     "Splitting failed, body is still too long and can't be split."
2359                 )
2360
2361             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2362                 raise CannotSplit(
2363                     "The current optional pair of parentheses is bound to fail to "
2364                     "satisfy the splitting algorithm because the head or the tail "
2365                     "contains multiline strings which by definition never fit one "
2366                     "line."
2367                 )
2368
2369     ensure_visible(opening_bracket)
2370     ensure_visible(closing_bracket)
2371     for result in (head, body, tail):
2372         if result:
2373             yield result
2374
2375
2376 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2377     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2378
2379     Do nothing otherwise.
2380
2381     A left- or right-hand split is based on a pair of brackets. Content before
2382     (and including) the opening bracket is left on one line, content inside the
2383     brackets is put on a separate line, and finally content starting with and
2384     following the closing bracket is put on a separate line.
2385
2386     Those are called `head`, `body`, and `tail`, respectively. If the split
2387     produced the same line (all content in `head`) or ended up with an empty `body`
2388     and the `tail` is just the closing bracket, then it's considered failed.
2389     """
2390     tail_len = len(str(tail).strip())
2391     if not body:
2392         if tail_len == 0:
2393             raise CannotSplit("Splitting brackets produced the same line")
2394
2395         elif tail_len < 3:
2396             raise CannotSplit(
2397                 f"Splitting brackets on an empty body to save "
2398                 f"{tail_len} characters is not worth it"
2399             )
2400
2401
2402 def bracket_split_build_line(
2403     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2404 ) -> Line:
2405     """Return a new line with given `leaves` and respective comments from `original`.
2406
2407     If `is_body` is True, the result line is one-indented inside brackets and as such
2408     has its first leaf's prefix normalized and a trailing comma added when expected.
2409     """
2410     result = Line(depth=original.depth)
2411     if is_body:
2412         result.inside_brackets = True
2413         result.depth += 1
2414         if leaves:
2415             # Since body is a new indent level, remove spurious leading whitespace.
2416             normalize_prefix(leaves[0], inside_brackets=True)
2417             # Ensure a trailing comma when expected.
2418             if original.is_import:
2419                 if leaves[-1].type != token.COMMA:
2420                     leaves.append(Leaf(token.COMMA, ","))
2421     # Populate the line
2422     for leaf in leaves:
2423         result.append(leaf, preformatted=True)
2424         for comment_after in original.comments_after(leaf):
2425             result.append(comment_after, preformatted=True)
2426     if is_body:
2427         result.should_explode = should_explode(result, opening_bracket)
2428     return result
2429
2430
2431 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2432     """Normalize prefix of the first leaf in every line returned by `split_func`.
2433
2434     This is a decorator over relevant split functions.
2435     """
2436
2437     @wraps(split_func)
2438     def split_wrapper(
2439         line: Line, supports_trailing_commas: bool = False
2440     ) -> Iterator[Line]:
2441         for l in split_func(line, supports_trailing_commas):
2442             normalize_prefix(l.leaves[0], inside_brackets=True)
2443             yield l
2444
2445     return split_wrapper
2446
2447
2448 @dont_increase_indentation
2449 def delimiter_split(
2450     line: Line, supports_trailing_commas: bool = False
2451 ) -> Iterator[Line]:
2452     """Split according to delimiters of the highest priority.
2453
2454     If `supports_trailing_commas` is True, the split will add trailing commas
2455     also in function signatures that contain `*` and `**`.
2456     """
2457     try:
2458         last_leaf = line.leaves[-1]
2459     except IndexError:
2460         raise CannotSplit("Line empty")
2461
2462     bt = line.bracket_tracker
2463     try:
2464         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2465     except ValueError:
2466         raise CannotSplit("No delimiters found")
2467
2468     if delimiter_priority == DOT_PRIORITY:
2469         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2470             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2471
2472     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2473     lowest_depth = sys.maxsize
2474     trailing_comma_safe = True
2475
2476     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2477         """Append `leaf` to current line or to new line if appending impossible."""
2478         nonlocal current_line
2479         try:
2480             current_line.append_safe(leaf, preformatted=True)
2481         except ValueError:
2482             yield current_line
2483
2484             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2485             current_line.append(leaf)
2486
2487     for leaf in line.leaves:
2488         yield from append_to_line(leaf)
2489
2490         for comment_after in line.comments_after(leaf):
2491             yield from append_to_line(comment_after)
2492
2493         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2494         if leaf.bracket_depth == lowest_depth and is_vararg(
2495             leaf, within=VARARGS_PARENTS
2496         ):
2497             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2498         leaf_priority = bt.delimiters.get(id(leaf))
2499         if leaf_priority == delimiter_priority:
2500             yield current_line
2501
2502             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2503     if current_line:
2504         if (
2505             trailing_comma_safe
2506             and delimiter_priority == COMMA_PRIORITY
2507             and current_line.leaves[-1].type != token.COMMA
2508             and current_line.leaves[-1].type != STANDALONE_COMMENT
2509         ):
2510             current_line.append(Leaf(token.COMMA, ","))
2511         yield current_line
2512
2513
2514 @dont_increase_indentation
2515 def standalone_comment_split(
2516     line: Line, supports_trailing_commas: bool = False
2517 ) -> Iterator[Line]:
2518     """Split standalone comments from the rest of the line."""
2519     if not line.contains_standalone_comments(0):
2520         raise CannotSplit("Line does not have any standalone comments")
2521
2522     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2523
2524     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2525         """Append `leaf` to current line or to new line if appending impossible."""
2526         nonlocal current_line
2527         try:
2528             current_line.append_safe(leaf, preformatted=True)
2529         except ValueError:
2530             yield current_line
2531
2532             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2533             current_line.append(leaf)
2534
2535     for leaf in line.leaves:
2536         yield from append_to_line(leaf)
2537
2538         for comment_after in line.comments_after(leaf):
2539             yield from append_to_line(comment_after)
2540
2541     if current_line:
2542         yield current_line
2543
2544
2545 def is_import(leaf: Leaf) -> bool:
2546     """Return True if the given leaf starts an import statement."""
2547     p = leaf.parent
2548     t = leaf.type
2549     v = leaf.value
2550     return bool(
2551         t == token.NAME
2552         and (
2553             (v == "import" and p and p.type == syms.import_name)
2554             or (v == "from" and p and p.type == syms.import_from)
2555         )
2556     )
2557
2558
2559 def is_type_comment(leaf: Leaf) -> bool:
2560     """Return True if the given leaf is a special comment.
2561     Only returns true for type comments for now."""
2562     t = leaf.type
2563     v = leaf.value
2564     return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
2565
2566
2567 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2568     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2569     else.
2570
2571     Note: don't use backslashes for formatting or you'll lose your voting rights.
2572     """
2573     if not inside_brackets:
2574         spl = leaf.prefix.split("#")
2575         if "\\" not in spl[0]:
2576             nl_count = spl[-1].count("\n")
2577             if len(spl) > 1:
2578                 nl_count -= 1
2579             leaf.prefix = "\n" * nl_count
2580             return
2581
2582     leaf.prefix = ""
2583
2584
2585 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2586     """Make all string prefixes lowercase.
2587
2588     If remove_u_prefix is given, also removes any u prefix from the string.
2589
2590     Note: Mutates its argument.
2591     """
2592     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2593     assert match is not None, f"failed to match string {leaf.value!r}"
2594     orig_prefix = match.group(1)
2595     new_prefix = orig_prefix.lower()
2596     if remove_u_prefix:
2597         new_prefix = new_prefix.replace("u", "")
2598     leaf.value = f"{new_prefix}{match.group(2)}"
2599
2600
2601 def normalize_string_quotes(leaf: Leaf) -> None:
2602     """Prefer double quotes but only if it doesn't cause more escaping.
2603
2604     Adds or removes backslashes as appropriate. Doesn't parse and fix
2605     strings nested in f-strings (yet).
2606
2607     Note: Mutates its argument.
2608     """
2609     value = leaf.value.lstrip("furbFURB")
2610     if value[:3] == '"""':
2611         return
2612
2613     elif value[:3] == "'''":
2614         orig_quote = "'''"
2615         new_quote = '"""'
2616     elif value[0] == '"':
2617         orig_quote = '"'
2618         new_quote = "'"
2619     else:
2620         orig_quote = "'"
2621         new_quote = '"'
2622     first_quote_pos = leaf.value.find(orig_quote)
2623     if first_quote_pos == -1:
2624         return  # There's an internal error
2625
2626     prefix = leaf.value[:first_quote_pos]
2627     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2628     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2629     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2630     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2631     if "r" in prefix.casefold():
2632         if unescaped_new_quote.search(body):
2633             # There's at least one unescaped new_quote in this raw string
2634             # so converting is impossible
2635             return
2636
2637         # Do not introduce or remove backslashes in raw strings
2638         new_body = body
2639     else:
2640         # remove unnecessary escapes
2641         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2642         if body != new_body:
2643             # Consider the string without unnecessary escapes as the original
2644             body = new_body
2645             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2646         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2647         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2648     if "f" in prefix.casefold():
2649         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2650         for m in matches:
2651             if "\\" in str(m):
2652                 # Do not introduce backslashes in interpolated expressions
2653                 return
2654     if new_quote == '"""' and new_body[-1:] == '"':
2655         # edge case:
2656         new_body = new_body[:-1] + '\\"'
2657     orig_escape_count = body.count("\\")
2658     new_escape_count = new_body.count("\\")
2659     if new_escape_count > orig_escape_count:
2660         return  # Do not introduce more escaping
2661
2662     if new_escape_count == orig_escape_count and orig_quote == '"':
2663         return  # Prefer double quotes
2664
2665     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2666
2667
2668 def normalize_numeric_literal(leaf: Leaf) -> None:
2669     """Normalizes numeric (float, int, and complex) literals.
2670
2671     All letters used in the representation are normalized to lowercase (except
2672     in Python 2 long literals).
2673     """
2674     text = leaf.value.lower()
2675     if text.startswith(("0o", "0b")):
2676         # Leave octal and binary literals alone.
2677         pass
2678     elif text.startswith("0x"):
2679         # Change hex literals to upper case.
2680         before, after = text[:2], text[2:]
2681         text = f"{before}{after.upper()}"
2682     elif "e" in text:
2683         before, after = text.split("e")
2684         sign = ""
2685         if after.startswith("-"):
2686             after = after[1:]
2687             sign = "-"
2688         elif after.startswith("+"):
2689             after = after[1:]
2690         before = format_float_or_int_string(before)
2691         text = f"{before}e{sign}{after}"
2692     elif text.endswith(("j", "l")):
2693         number = text[:-1]
2694         suffix = text[-1]
2695         # Capitalize in "2L" because "l" looks too similar to "1".
2696         if suffix == "l":
2697             suffix = "L"
2698         text = f"{format_float_or_int_string(number)}{suffix}"
2699     else:
2700         text = format_float_or_int_string(text)
2701     leaf.value = text
2702
2703
2704 def format_float_or_int_string(text: str) -> str:
2705     """Formats a float string like "1.0"."""
2706     if "." not in text:
2707         return text
2708
2709     before, after = text.split(".")
2710     return f"{before or 0}.{after or 0}"
2711
2712
2713 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2714     """Make existing optional parentheses invisible or create new ones.
2715
2716     `parens_after` is a set of string leaf values immeditely after which parens
2717     should be put.
2718
2719     Standardizes on visible parentheses for single-element tuples, and keeps
2720     existing visible parentheses for other tuples and generator expressions.
2721     """
2722     for pc in list_comments(node.prefix, is_endmarker=False):
2723         if pc.value in FMT_OFF:
2724             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2725             return
2726
2727     check_lpar = False
2728     for index, child in enumerate(list(node.children)):
2729         if check_lpar:
2730             if child.type == syms.atom:
2731                 if maybe_make_parens_invisible_in_atom(child):
2732                     lpar = Leaf(token.LPAR, "")
2733                     rpar = Leaf(token.RPAR, "")
2734                     index = child.remove() or 0
2735                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2736             elif is_one_tuple(child):
2737                 # wrap child in visible parentheses
2738                 lpar = Leaf(token.LPAR, "(")
2739                 rpar = Leaf(token.RPAR, ")")
2740                 child.remove()
2741                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2742             elif node.type == syms.import_from:
2743                 # "import from" nodes store parentheses directly as part of
2744                 # the statement
2745                 if child.type == token.LPAR:
2746                     # make parentheses invisible
2747                     child.value = ""  # type: ignore
2748                     node.children[-1].value = ""  # type: ignore
2749                 elif child.type != token.STAR:
2750                     # insert invisible parentheses
2751                     node.insert_child(index, Leaf(token.LPAR, ""))
2752                     node.append_child(Leaf(token.RPAR, ""))
2753                 break
2754
2755             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2756                 # wrap child in invisible parentheses
2757                 lpar = Leaf(token.LPAR, "")
2758                 rpar = Leaf(token.RPAR, "")
2759                 index = child.remove() or 0
2760                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2761
2762         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2763
2764
2765 def normalize_fmt_off(node: Node) -> None:
2766     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2767     try_again = True
2768     while try_again:
2769         try_again = convert_one_fmt_off_pair(node)
2770
2771
2772 def convert_one_fmt_off_pair(node: Node) -> bool:
2773     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2774
2775     Returns True if a pair was converted.
2776     """
2777     for leaf in node.leaves():
2778         previous_consumed = 0
2779         for comment in list_comments(leaf.prefix, is_endmarker=False):
2780             if comment.value in FMT_OFF:
2781                 # We only want standalone comments. If there's no previous leaf or
2782                 # the previous leaf is indentation, it's a standalone comment in
2783                 # disguise.
2784                 if comment.type != STANDALONE_COMMENT:
2785                     prev = preceding_leaf(leaf)
2786                     if prev and prev.type not in WHITESPACE:
2787                         continue
2788
2789                 ignored_nodes = list(generate_ignored_nodes(leaf))
2790                 if not ignored_nodes:
2791                     continue
2792
2793                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2794                 parent = first.parent
2795                 prefix = first.prefix
2796                 first.prefix = prefix[comment.consumed :]
2797                 hidden_value = (
2798                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2799                 )
2800                 if hidden_value.endswith("\n"):
2801                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2802                     # leaf (possibly followed by a DEDENT).
2803                     hidden_value = hidden_value[:-1]
2804                 first_idx = None
2805                 for ignored in ignored_nodes:
2806                     index = ignored.remove()
2807                     if first_idx is None:
2808                         first_idx = index
2809                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2810                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2811                 parent.insert_child(
2812                     first_idx,
2813                     Leaf(
2814                         STANDALONE_COMMENT,
2815                         hidden_value,
2816                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2817                     ),
2818                 )
2819                 return True
2820
2821             previous_consumed = comment.consumed
2822
2823     return False
2824
2825
2826 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2827     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2828
2829     Stops at the end of the block.
2830     """
2831     container: Optional[LN] = container_of(leaf)
2832     while container is not None and container.type != token.ENDMARKER:
2833         for comment in list_comments(container.prefix, is_endmarker=False):
2834             if comment.value in FMT_ON:
2835                 return
2836
2837         yield container
2838
2839         container = container.next_sibling
2840
2841
2842 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2843     """If it's safe, make the parens in the atom `node` invisible, recursively.
2844
2845     Returns whether the node should itself be wrapped in invisible parentheses.
2846
2847     """
2848     if (
2849         node.type != syms.atom
2850         or is_empty_tuple(node)
2851         or is_one_tuple(node)
2852         or is_yield(node)
2853         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2854     ):
2855         return False
2856
2857     first = node.children[0]
2858     last = node.children[-1]
2859     if first.type == token.LPAR and last.type == token.RPAR:
2860         # make parentheses invisible
2861         first.value = ""  # type: ignore
2862         last.value = ""  # type: ignore
2863         if len(node.children) > 1:
2864             maybe_make_parens_invisible_in_atom(node.children[1])
2865         return False
2866
2867     return True
2868
2869
2870 def is_empty_tuple(node: LN) -> bool:
2871     """Return True if `node` holds an empty tuple."""
2872     return (
2873         node.type == syms.atom
2874         and len(node.children) == 2
2875         and node.children[0].type == token.LPAR
2876         and node.children[1].type == token.RPAR
2877     )
2878
2879
2880 def is_one_tuple(node: LN) -> bool:
2881     """Return True if `node` holds a tuple with one element, with or without parens."""
2882     if node.type == syms.atom:
2883         if len(node.children) != 3:
2884             return False
2885
2886         lpar, gexp, rpar = node.children
2887         if not (
2888             lpar.type == token.LPAR
2889             and gexp.type == syms.testlist_gexp
2890             and rpar.type == token.RPAR
2891         ):
2892             return False
2893
2894         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2895
2896     return (
2897         node.type in IMPLICIT_TUPLE
2898         and len(node.children) == 2
2899         and node.children[1].type == token.COMMA
2900     )
2901
2902
2903 def is_yield(node: LN) -> bool:
2904     """Return True if `node` holds a `yield` or `yield from` expression."""
2905     if node.type == syms.yield_expr:
2906         return True
2907
2908     if node.type == token.NAME and node.value == "yield":  # type: ignore
2909         return True
2910
2911     if node.type != syms.atom:
2912         return False
2913
2914     if len(node.children) != 3:
2915         return False
2916
2917     lpar, expr, rpar = node.children
2918     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2919         return is_yield(expr)
2920
2921     return False
2922
2923
2924 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2925     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2926
2927     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2928     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2929     extended iterable unpacking (PEP 3132) and additional unpacking
2930     generalizations (PEP 448).
2931     """
2932     if leaf.type not in STARS or not leaf.parent:
2933         return False
2934
2935     p = leaf.parent
2936     if p.type == syms.star_expr:
2937         # Star expressions are also used as assignment targets in extended
2938         # iterable unpacking (PEP 3132).  See what its parent is instead.
2939         if not p.parent:
2940             return False
2941
2942         p = p.parent
2943
2944     return p.type in within
2945
2946
2947 def is_multiline_string(leaf: Leaf) -> bool:
2948     """Return True if `leaf` is a multiline string that actually spans many lines."""
2949     value = leaf.value.lstrip("furbFURB")
2950     return value[:3] in {'"""', "'''"} and "\n" in value
2951
2952
2953 def is_stub_suite(node: Node) -> bool:
2954     """Return True if `node` is a suite with a stub body."""
2955     if (
2956         len(node.children) != 4
2957         or node.children[0].type != token.NEWLINE
2958         or node.children[1].type != token.INDENT
2959         or node.children[3].type != token.DEDENT
2960     ):
2961         return False
2962
2963     return is_stub_body(node.children[2])
2964
2965
2966 def is_stub_body(node: LN) -> bool:
2967     """Return True if `node` is a simple statement containing an ellipsis."""
2968     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2969         return False
2970
2971     if len(node.children) != 2:
2972         return False
2973
2974     child = node.children[0]
2975     return (
2976         child.type == syms.atom
2977         and len(child.children) == 3
2978         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2979     )
2980
2981
2982 def max_delimiter_priority_in_atom(node: LN) -> int:
2983     """Return maximum delimiter priority inside `node`.
2984
2985     This is specific to atoms with contents contained in a pair of parentheses.
2986     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2987     """
2988     if node.type != syms.atom:
2989         return 0
2990
2991     first = node.children[0]
2992     last = node.children[-1]
2993     if not (first.type == token.LPAR and last.type == token.RPAR):
2994         return 0
2995
2996     bt = BracketTracker()
2997     for c in node.children[1:-1]:
2998         if isinstance(c, Leaf):
2999             bt.mark(c)
3000         else:
3001             for leaf in c.leaves():
3002                 bt.mark(leaf)
3003     try:
3004         return bt.max_delimiter_priority()
3005
3006     except ValueError:
3007         return 0
3008
3009
3010 def ensure_visible(leaf: Leaf) -> None:
3011     """Make sure parentheses are visible.
3012
3013     They could be invisible as part of some statements (see
3014     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3015     """
3016     if leaf.type == token.LPAR:
3017         leaf.value = "("
3018     elif leaf.type == token.RPAR:
3019         leaf.value = ")"
3020
3021
3022 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3023     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3024
3025     if not (
3026         opening_bracket.parent
3027         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3028         and opening_bracket.value in "[{("
3029     ):
3030         return False
3031
3032     try:
3033         last_leaf = line.leaves[-1]
3034         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3035         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3036     except (IndexError, ValueError):
3037         return False
3038
3039     return max_priority == COMMA_PRIORITY
3040
3041
3042 def get_features_used(node: Node) -> Set[Feature]:
3043     """Return a set of (relatively) new Python features used in this file.
3044
3045     Currently looking for:
3046     - f-strings;
3047     - underscores in numeric literals; and
3048     - trailing commas after * or ** in function signatures and calls.
3049     """
3050     features: Set[Feature] = set()
3051     for n in node.pre_order():
3052         if n.type == token.STRING:
3053             value_head = n.value[:2]  # type: ignore
3054             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3055                 features.add(Feature.F_STRINGS)
3056
3057         elif n.type == token.NUMBER:
3058             if "_" in n.value:  # type: ignore
3059                 features.add(Feature.NUMERIC_UNDERSCORES)
3060
3061         elif (
3062             n.type in {syms.typedargslist, syms.arglist}
3063             and n.children
3064             and n.children[-1].type == token.COMMA
3065         ):
3066             for ch in n.children:
3067                 if ch.type in STARS:
3068                     features.add(Feature.TRAILING_COMMA)
3069
3070                 if ch.type == syms.argument:
3071                     for argch in ch.children:
3072                         if argch.type in STARS:
3073                             features.add(Feature.TRAILING_COMMA)
3074
3075     return features
3076
3077
3078 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3079     """Detect the version to target based on the nodes used."""
3080     features = get_features_used(node)
3081     return {
3082         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3083     }
3084
3085
3086 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3087     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3088
3089     Brackets can be omitted if the entire trailer up to and including
3090     a preceding closing bracket fits in one line.
3091
3092     Yielded sets are cumulative (contain results of previous yields, too).  First
3093     set is empty.
3094     """
3095
3096     omit: Set[LeafID] = set()
3097     yield omit
3098
3099     length = 4 * line.depth
3100     opening_bracket = None
3101     closing_bracket = None
3102     inner_brackets: Set[LeafID] = set()
3103     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3104         length += leaf_length
3105         if length > line_length:
3106             break
3107
3108         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3109         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3110             break
3111
3112         if opening_bracket:
3113             if leaf is opening_bracket:
3114                 opening_bracket = None
3115             elif leaf.type in CLOSING_BRACKETS:
3116                 inner_brackets.add(id(leaf))
3117         elif leaf.type in CLOSING_BRACKETS:
3118             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3119                 # Empty brackets would fail a split so treat them as "inner"
3120                 # brackets (e.g. only add them to the `omit` set if another
3121                 # pair of brackets was good enough.
3122                 inner_brackets.add(id(leaf))
3123                 continue
3124
3125             if closing_bracket:
3126                 omit.add(id(closing_bracket))
3127                 omit.update(inner_brackets)
3128                 inner_brackets.clear()
3129                 yield omit
3130
3131             if leaf.value:
3132                 opening_bracket = leaf.opening_bracket
3133                 closing_bracket = leaf
3134
3135
3136 def get_future_imports(node: Node) -> Set[str]:
3137     """Return a set of __future__ imports in the file."""
3138     imports: Set[str] = set()
3139
3140     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3141         for child in children:
3142             if isinstance(child, Leaf):
3143                 if child.type == token.NAME:
3144                     yield child.value
3145             elif child.type == syms.import_as_name:
3146                 orig_name = child.children[0]
3147                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3148                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3149                 yield orig_name.value
3150             elif child.type == syms.import_as_names:
3151                 yield from get_imports_from_children(child.children)
3152             else:
3153                 assert False, "Invalid syntax parsing imports"
3154
3155     for child in node.children:
3156         if child.type != syms.simple_stmt:
3157             break
3158         first_child = child.children[0]
3159         if isinstance(first_child, Leaf):
3160             # Continue looking if we see a docstring; otherwise stop.
3161             if (
3162                 len(child.children) == 2
3163                 and first_child.type == token.STRING
3164                 and child.children[1].type == token.NEWLINE
3165             ):
3166                 continue
3167             else:
3168                 break
3169         elif first_child.type == syms.import_from:
3170             module_name = first_child.children[1]
3171             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3172                 break
3173             imports |= set(get_imports_from_children(first_child.children[3:]))
3174         else:
3175             break
3176     return imports
3177
3178
3179 def gen_python_files_in_dir(
3180     path: Path,
3181     root: Path,
3182     include: Pattern[str],
3183     exclude: Pattern[str],
3184     report: "Report",
3185 ) -> Iterator[Path]:
3186     """Generate all files under `path` whose paths are not excluded by the
3187     `exclude` regex, but are included by the `include` regex.
3188
3189     Symbolic links pointing outside of the `root` directory are ignored.
3190
3191     `report` is where output about exclusions goes.
3192     """
3193     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3194     for child in path.iterdir():
3195         try:
3196             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3197         except ValueError:
3198             if child.is_symlink():
3199                 report.path_ignored(
3200                     child, f"is a symbolic link that points outside {root}"
3201                 )
3202                 continue
3203
3204             raise
3205
3206         if child.is_dir():
3207             normalized_path += "/"
3208         exclude_match = exclude.search(normalized_path)
3209         if exclude_match and exclude_match.group(0):
3210             report.path_ignored(child, f"matches the --exclude regular expression")
3211             continue
3212
3213         if child.is_dir():
3214             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3215
3216         elif child.is_file():
3217             include_match = include.search(normalized_path)
3218             if include_match:
3219                 yield child
3220
3221
3222 @lru_cache()
3223 def find_project_root(srcs: Iterable[str]) -> Path:
3224     """Return a directory containing .git, .hg, or pyproject.toml.
3225
3226     That directory can be one of the directories passed in `srcs` or their
3227     common parent.
3228
3229     If no directory in the tree contains a marker that would specify it's the
3230     project root, the root of the file system is returned.
3231     """
3232     if not srcs:
3233         return Path("/").resolve()
3234
3235     common_base = min(Path(src).resolve() for src in srcs)
3236     if common_base.is_dir():
3237         # Append a fake file so `parents` below returns `common_base_dir`, too.
3238         common_base /= "fake-file"
3239     for directory in common_base.parents:
3240         if (directory / ".git").is_dir():
3241             return directory
3242
3243         if (directory / ".hg").is_dir():
3244             return directory
3245
3246         if (directory / "pyproject.toml").is_file():
3247             return directory
3248
3249     return directory
3250
3251
3252 @dataclass
3253 class Report:
3254     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3255
3256     check: bool = False
3257     quiet: bool = False
3258     verbose: bool = False
3259     change_count: int = 0
3260     same_count: int = 0
3261     failure_count: int = 0
3262
3263     def done(self, src: Path, changed: Changed) -> None:
3264         """Increment the counter for successful reformatting. Write out a message."""
3265         if changed is Changed.YES:
3266             reformatted = "would reformat" if self.check else "reformatted"
3267             if self.verbose or not self.quiet:
3268                 out(f"{reformatted} {src}")
3269             self.change_count += 1
3270         else:
3271             if self.verbose:
3272                 if changed is Changed.NO:
3273                     msg = f"{src} already well formatted, good job."
3274                 else:
3275                     msg = f"{src} wasn't modified on disk since last run."
3276                 out(msg, bold=False)
3277             self.same_count += 1
3278
3279     def failed(self, src: Path, message: str) -> None:
3280         """Increment the counter for failed reformatting. Write out a message."""
3281         err(f"error: cannot format {src}: {message}")
3282         self.failure_count += 1
3283
3284     def path_ignored(self, path: Path, message: str) -> None:
3285         if self.verbose:
3286             out(f"{path} ignored: {message}", bold=False)
3287
3288     @property
3289     def return_code(self) -> int:
3290         """Return the exit code that the app should use.
3291
3292         This considers the current state of changed files and failures:
3293         - if there were any failures, return 123;
3294         - if any files were changed and --check is being used, return 1;
3295         - otherwise return 0.
3296         """
3297         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3298         # 126 we have special return codes reserved by the shell.
3299         if self.failure_count:
3300             return 123
3301
3302         elif self.change_count and self.check:
3303             return 1
3304
3305         return 0
3306
3307     def __str__(self) -> str:
3308         """Render a color report of the current state.
3309
3310         Use `click.unstyle` to remove colors.
3311         """
3312         if self.check:
3313             reformatted = "would be reformatted"
3314             unchanged = "would be left unchanged"
3315             failed = "would fail to reformat"
3316         else:
3317             reformatted = "reformatted"
3318             unchanged = "left unchanged"
3319             failed = "failed to reformat"
3320         report = []
3321         if self.change_count:
3322             s = "s" if self.change_count > 1 else ""
3323             report.append(
3324                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3325             )
3326         if self.same_count:
3327             s = "s" if self.same_count > 1 else ""
3328             report.append(f"{self.same_count} file{s} {unchanged}")
3329         if self.failure_count:
3330             s = "s" if self.failure_count > 1 else ""
3331             report.append(
3332                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3333             )
3334         return ", ".join(report) + "."
3335
3336
3337 def assert_equivalent(src: str, dst: str) -> None:
3338     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3339
3340     import ast
3341     import traceback
3342
3343     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3344         """Simple visitor generating strings to compare ASTs by content."""
3345         yield f"{'  ' * depth}{node.__class__.__name__}("
3346
3347         for field in sorted(node._fields):
3348             try:
3349                 value = getattr(node, field)
3350             except AttributeError:
3351                 continue
3352
3353             yield f"{'  ' * (depth+1)}{field}="
3354
3355             if isinstance(value, list):
3356                 for item in value:
3357                     # Ignore nested tuples within del statements, because we may insert
3358                     # parentheses and they change the AST.
3359                     if (
3360                         field == "targets"
3361                         and isinstance(node, ast.Delete)
3362                         and isinstance(item, ast.Tuple)
3363                     ):
3364                         for item in item.elts:
3365                             yield from _v(item, depth + 2)
3366                     elif isinstance(item, ast.AST):
3367                         yield from _v(item, depth + 2)
3368
3369             elif isinstance(value, ast.AST):
3370                 yield from _v(value, depth + 2)
3371
3372             else:
3373                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3374
3375         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3376
3377     try:
3378         src_ast = ast.parse(src)
3379     except Exception as exc:
3380         major, minor = sys.version_info[:2]
3381         raise AssertionError(
3382             f"cannot use --safe with this file; failed to parse source file "
3383             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3384             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3385         )
3386
3387     try:
3388         dst_ast = ast.parse(dst)
3389     except Exception as exc:
3390         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3391         raise AssertionError(
3392             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3393             f"Please report a bug on https://github.com/ambv/black/issues.  "
3394             f"This invalid output might be helpful: {log}"
3395         ) from None
3396
3397     src_ast_str = "\n".join(_v(src_ast))
3398     dst_ast_str = "\n".join(_v(dst_ast))
3399     if src_ast_str != dst_ast_str:
3400         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3401         raise AssertionError(
3402             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3403             f"the source.  "
3404             f"Please report a bug on https://github.com/ambv/black/issues.  "
3405             f"This diff might be helpful: {log}"
3406         ) from None
3407
3408
3409 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3410     """Raise AssertionError if `dst` reformats differently the second time."""
3411     newdst = format_str(dst, mode=mode)
3412     if dst != newdst:
3413         log = dump_to_file(
3414             diff(src, dst, "source", "first pass"),
3415             diff(dst, newdst, "first pass", "second pass"),
3416         )
3417         raise AssertionError(
3418             f"INTERNAL ERROR: Black produced different code on the second pass "
3419             f"of the formatter.  "
3420             f"Please report a bug on https://github.com/ambv/black/issues.  "
3421             f"This diff might be helpful: {log}"
3422         ) from None
3423
3424
3425 def dump_to_file(*output: str) -> str:
3426     """Dump `output` to a temporary file. Return path to the file."""
3427     import tempfile
3428
3429     with tempfile.NamedTemporaryFile(
3430         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3431     ) as f:
3432         for lines in output:
3433             f.write(lines)
3434             if lines and lines[-1] != "\n":
3435                 f.write("\n")
3436     return f.name
3437
3438
3439 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3440     """Return a unified diff string between strings `a` and `b`."""
3441     import difflib
3442
3443     a_lines = [line + "\n" for line in a.split("\n")]
3444     b_lines = [line + "\n" for line in b.split("\n")]
3445     return "".join(
3446         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3447     )
3448
3449
3450 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3451     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3452     err("Aborted!")
3453     for task in tasks:
3454         task.cancel()
3455
3456
3457 def shutdown(loop: BaseEventLoop) -> None:
3458     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3459     try:
3460         if sys.version_info[:2] >= (3, 7):
3461             all_tasks = asyncio.all_tasks
3462         else:
3463             all_tasks = asyncio.Task.all_tasks
3464         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3465         to_cancel = [task for task in all_tasks(loop) if not task.done()]
3466         if not to_cancel:
3467             return
3468
3469         for task in to_cancel:
3470             task.cancel()
3471         loop.run_until_complete(
3472             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3473         )
3474     finally:
3475         # `concurrent.futures.Future` objects cannot be cancelled once they
3476         # are already running. There might be some when the `shutdown()` happened.
3477         # Silence their logger's spew about the event loop being closed.
3478         cf_logger = logging.getLogger("concurrent.futures")
3479         cf_logger.setLevel(logging.CRITICAL)
3480         loop.close()
3481
3482
3483 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3484     """Replace `regex` with `replacement` twice on `original`.
3485
3486     This is used by string normalization to perform replaces on
3487     overlapping matches.
3488     """
3489     return regex.sub(replacement, regex.sub(replacement, original))
3490
3491
3492 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3493     """Compile a regular expression string in `regex`.
3494
3495     If it contains newlines, use verbose mode.
3496     """
3497     if "\n" in regex:
3498         regex = "(?x)" + regex
3499     return re.compile(regex)
3500
3501
3502 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3503     """Like `reversed(enumerate(sequence))` if that were possible."""
3504     index = len(sequence) - 1
3505     for element in reversed(sequence):
3506         yield (index, element)
3507         index -= 1
3508
3509
3510 def enumerate_with_length(
3511     line: Line, reversed: bool = False
3512 ) -> Iterator[Tuple[Index, Leaf, int]]:
3513     """Return an enumeration of leaves with their length.
3514
3515     Stops prematurely on multiline strings and standalone comments.
3516     """
3517     op = cast(
3518         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3519         enumerate_reversed if reversed else enumerate,
3520     )
3521     for index, leaf in op(line.leaves):
3522         length = len(leaf.prefix) + len(leaf.value)
3523         if "\n" in leaf.value:
3524             return  # Multiline strings, we can't continue.
3525
3526         comment: Optional[Leaf]
3527         for comment in line.comments_after(leaf):
3528             length += len(comment.value)
3529
3530         yield index, leaf, length
3531
3532
3533 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3534     """Return True if `line` is no longer than `line_length`.
3535
3536     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3537     """
3538     if not line_str:
3539         line_str = str(line).strip("\n")
3540     return (
3541         len(line_str) <= line_length
3542         and "\n" not in line_str  # multiline strings
3543         and not line.contains_standalone_comments()
3544     )
3545
3546
3547 def can_be_split(line: Line) -> bool:
3548     """Return False if the line cannot be split *for sure*.
3549
3550     This is not an exhaustive search but a cheap heuristic that we can use to
3551     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3552     in unnecessary parentheses).
3553     """
3554     leaves = line.leaves
3555     if len(leaves) < 2:
3556         return False
3557
3558     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3559         call_count = 0
3560         dot_count = 0
3561         next = leaves[-1]
3562         for leaf in leaves[-2::-1]:
3563             if leaf.type in OPENING_BRACKETS:
3564                 if next.type not in CLOSING_BRACKETS:
3565                     return False
3566
3567                 call_count += 1
3568             elif leaf.type == token.DOT:
3569                 dot_count += 1
3570             elif leaf.type == token.NAME:
3571                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3572                     return False
3573
3574             elif leaf.type not in CLOSING_BRACKETS:
3575                 return False
3576
3577             if dot_count > 1 and call_count > 1:
3578                 return False
3579
3580     return True
3581
3582
3583 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3584     """Does `line` have a shape safe to reformat without optional parens around it?
3585
3586     Returns True for only a subset of potentially nice looking formattings but
3587     the point is to not return false positives that end up producing lines that
3588     are too long.
3589     """
3590     bt = line.bracket_tracker
3591     if not bt.delimiters:
3592         # Without delimiters the optional parentheses are useless.
3593         return True
3594
3595     max_priority = bt.max_delimiter_priority()
3596     if bt.delimiter_count_with_priority(max_priority) > 1:
3597         # With more than one delimiter of a kind the optional parentheses read better.
3598         return False
3599
3600     if max_priority == DOT_PRIORITY:
3601         # A single stranded method call doesn't require optional parentheses.
3602         return True
3603
3604     assert len(line.leaves) >= 2, "Stranded delimiter"
3605
3606     first = line.leaves[0]
3607     second = line.leaves[1]
3608     penultimate = line.leaves[-2]
3609     last = line.leaves[-1]
3610
3611     # With a single delimiter, omit if the expression starts or ends with
3612     # a bracket.
3613     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3614         remainder = False
3615         length = 4 * line.depth
3616         for _index, leaf, leaf_length in enumerate_with_length(line):
3617             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3618                 remainder = True
3619             if remainder:
3620                 length += leaf_length
3621                 if length > line_length:
3622                     break
3623
3624                 if leaf.type in OPENING_BRACKETS:
3625                     # There are brackets we can further split on.
3626                     remainder = False
3627
3628         else:
3629             # checked the entire string and line length wasn't exceeded
3630             if len(line.leaves) == _index + 1:
3631                 return True
3632
3633         # Note: we are not returning False here because a line might have *both*
3634         # a leading opening bracket and a trailing closing bracket.  If the
3635         # opening bracket doesn't match our rule, maybe the closing will.
3636
3637     if (
3638         last.type == token.RPAR
3639         or last.type == token.RBRACE
3640         or (
3641             # don't use indexing for omitting optional parentheses;
3642             # it looks weird
3643             last.type == token.RSQB
3644             and last.parent
3645             and last.parent.type != syms.trailer
3646         )
3647     ):
3648         if penultimate.type in OPENING_BRACKETS:
3649             # Empty brackets don't help.
3650             return False
3651
3652         if is_multiline_string(first):
3653             # Additional wrapping of a multiline string in this situation is
3654             # unnecessary.
3655             return True
3656
3657         length = 4 * line.depth
3658         seen_other_brackets = False
3659         for _index, leaf, leaf_length in enumerate_with_length(line):
3660             length += leaf_length
3661             if leaf is last.opening_bracket:
3662                 if seen_other_brackets or length <= line_length:
3663                     return True
3664
3665             elif leaf.type in OPENING_BRACKETS:
3666                 # There are brackets we can further split on.
3667                 seen_other_brackets = True
3668
3669     return False
3670
3671
3672 def get_cache_file(mode: FileMode) -> Path:
3673     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3674
3675
3676 def read_cache(mode: FileMode) -> Cache:
3677     """Read the cache if it exists and is well formed.
3678
3679     If it is not well formed, the call to write_cache later should resolve the issue.
3680     """
3681     cache_file = get_cache_file(mode)
3682     if not cache_file.exists():
3683         return {}
3684
3685     with cache_file.open("rb") as fobj:
3686         try:
3687             cache: Cache = pickle.load(fobj)
3688         except pickle.UnpicklingError:
3689             return {}
3690
3691     return cache
3692
3693
3694 def get_cache_info(path: Path) -> CacheInfo:
3695     """Return the information used to check if a file is already formatted or not."""
3696     stat = path.stat()
3697     return stat.st_mtime, stat.st_size
3698
3699
3700 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3701     """Split an iterable of paths in `sources` into two sets.
3702
3703     The first contains paths of files that modified on disk or are not in the
3704     cache. The other contains paths to non-modified files.
3705     """
3706     todo, done = set(), set()
3707     for src in sources:
3708         src = src.resolve()
3709         if cache.get(src) != get_cache_info(src):
3710             todo.add(src)
3711         else:
3712             done.add(src)
3713     return todo, done
3714
3715
3716 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3717     """Update the cache file."""
3718     cache_file = get_cache_file(mode)
3719     try:
3720         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3721         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3722         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3723             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3724         os.replace(f.name, cache_file)
3725     except OSError:
3726         pass
3727
3728
3729 def patch_click() -> None:
3730     """Make Click not crash.
3731
3732     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3733     default which restricts paths that it can access during the lifetime of the
3734     application.  Click refuses to work in this scenario by raising a RuntimeError.
3735
3736     In case of Black the likelihood that non-ASCII characters are going to be used in
3737     file paths is minimal since it's Python source code.  Moreover, this crash was
3738     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3739     """
3740     try:
3741         from click import core
3742         from click import _unicodefun  # type: ignore
3743     except ModuleNotFoundError:
3744         return
3745
3746     for module in (core, _unicodefun):
3747         if hasattr(module, "_verify_python3_env"):
3748             module._verify_python3_env = lambda: None
3749
3750
3751 def patched_main() -> None:
3752     freeze_support()
3753     patch_click()
3754     main()
3755
3756
3757 if __name__ == "__main__":
3758     patched_main()