]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

8f50d3e217dbf1746036186691ce4cb3f7dbb1fe
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "18.9b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PYPY35 = 1
117     CPY27 = 2
118     CPY33 = 3
119     CPY34 = 4
120     CPY35 = 5
121     CPY36 = 6
122     CPY37 = 7
123     CPY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.CPY27
127
128
129 PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA = 4
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.CPY27: set(),
142     TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
143     TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
144     TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
145     TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
146     TargetVersion.CPY36: {
147         Feature.UNICODE_LITERALS,
148         Feature.F_STRINGS,
149         Feature.NUMERIC_UNDERSCORES,
150         Feature.TRAILING_COMMA,
151     },
152     TargetVersion.CPY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA,
157     },
158     TargetVersion.CPY38: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA,
163     },
164 }
165
166
167 @dataclass
168 class FileMode:
169     target_versions: Set[TargetVersion] = Factory(set)
170     line_length: int = DEFAULT_LINE_LENGTH
171     string_normalization: bool = True
172     is_pyi: bool = False
173
174     def get_cache_key(self) -> str:
175         if self.target_versions:
176             version_str = ",".join(
177                 str(version.value)
178                 for version in sorted(self.target_versions, key=lambda v: v.value)
179             )
180         else:
181             version_str = "-"
182         parts = [
183             version_str,
184             str(self.line_length),
185             str(int(self.string_normalization)),
186             str(int(self.is_pyi)),
187         ]
188         return ".".join(parts)
189
190
191 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
192     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
193
194
195 def read_pyproject_toml(
196     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
197 ) -> Optional[str]:
198     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
199
200     Returns the path to a successfully found and read configuration file, None
201     otherwise.
202     """
203     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
204     if not value:
205         root = find_project_root(ctx.params.get("src", ()))
206         path = root / "pyproject.toml"
207         if path.is_file():
208             value = str(path)
209         else:
210             return None
211
212     try:
213         pyproject_toml = toml.load(value)
214         config = pyproject_toml.get("tool", {}).get("black", {})
215     except (toml.TomlDecodeError, OSError) as e:
216         raise click.FileError(
217             filename=value, hint=f"Error reading configuration file: {e}"
218         )
219
220     if not config:
221         return None
222
223     if ctx.default_map is None:
224         ctx.default_map = {}
225     ctx.default_map.update(  # type: ignore  # bad types in .pyi
226         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
227     )
228     return value
229
230
231 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
232 @click.option(
233     "-l",
234     "--line-length",
235     type=int,
236     default=DEFAULT_LINE_LENGTH,
237     help="How many characters per line to allow.",
238     show_default=True,
239 )
240 @click.option(
241     "-t",
242     "--target-version",
243     type=click.Choice([v.name.lower() for v in TargetVersion]),
244     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
245     multiple=True,
246     help=(
247         "Python versions that should be supported by Black's output. [default: "
248         "per-file auto-detection]"
249     ),
250 )
251 @click.option(
252     "--py36",
253     is_flag=True,
254     help=(
255         "Allow using Python 3.6-only syntax on all input files.  This will put "
256         "trailing commas in function signatures and calls also after *args and "
257         "**kwargs.  [default: per-file auto-detection]"
258     ),
259 )
260 @click.option(
261     "--pyi",
262     is_flag=True,
263     help=(
264         "Format all input files like typing stubs regardless of file extension "
265         "(useful when piping source on standard input)."
266     ),
267 )
268 @click.option(
269     "-S",
270     "--skip-string-normalization",
271     is_flag=True,
272     help="Don't normalize string quotes or prefixes.",
273 )
274 @click.option(
275     "--check",
276     is_flag=True,
277     help=(
278         "Don't write the files back, just return the status.  Return code 0 "
279         "means nothing would change.  Return code 1 means some files would be "
280         "reformatted.  Return code 123 means there was an internal error."
281     ),
282 )
283 @click.option(
284     "--diff",
285     is_flag=True,
286     help="Don't write the files back, just output a diff for each file on stdout.",
287 )
288 @click.option(
289     "--fast/--safe",
290     is_flag=True,
291     help="If --fast given, skip temporary sanity checks. [default: --safe]",
292 )
293 @click.option(
294     "--include",
295     type=str,
296     default=DEFAULT_INCLUDES,
297     help=(
298         "A regular expression that matches files and directories that should be "
299         "included on recursive searches.  An empty value means all files are "
300         "included regardless of the name.  Use forward slashes for directories on "
301         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
302         "later."
303     ),
304     show_default=True,
305 )
306 @click.option(
307     "--exclude",
308     type=str,
309     default=DEFAULT_EXCLUDES,
310     help=(
311         "A regular expression that matches files and directories that should be "
312         "excluded on recursive searches.  An empty value means no paths are excluded. "
313         "Use forward slashes for directories on all platforms (Windows, too).  "
314         "Exclusions are calculated first, inclusions later."
315     ),
316     show_default=True,
317 )
318 @click.option(
319     "-q",
320     "--quiet",
321     is_flag=True,
322     help=(
323         "Don't emit non-error messages to stderr. Errors are still emitted, "
324         "silence those with 2>/dev/null."
325     ),
326 )
327 @click.option(
328     "-v",
329     "--verbose",
330     is_flag=True,
331     help=(
332         "Also emit messages to stderr about files that were not changed or were "
333         "ignored due to --exclude=."
334     ),
335 )
336 @click.version_option(version=__version__)
337 @click.argument(
338     "src",
339     nargs=-1,
340     type=click.Path(
341         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
342     ),
343     is_eager=True,
344 )
345 @click.option(
346     "--config",
347     type=click.Path(
348         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
349     ),
350     is_eager=True,
351     callback=read_pyproject_toml,
352     help="Read configuration from PATH.",
353 )
354 @click.pass_context
355 def main(
356     ctx: click.Context,
357     line_length: int,
358     target_version: List[TargetVersion],
359     check: bool,
360     diff: bool,
361     fast: bool,
362     pyi: bool,
363     py36: bool,
364     skip_string_normalization: bool,
365     quiet: bool,
366     verbose: bool,
367     include: str,
368     exclude: str,
369     src: Tuple[str],
370     config: Optional[str],
371 ) -> None:
372     """The uncompromising code formatter."""
373     write_back = WriteBack.from_configuration(check=check, diff=diff)
374     if target_version:
375         if py36:
376             err(f"Cannot use both --target-version and --py36")
377             ctx.exit(2)
378         else:
379             versions = set(target_version)
380     elif py36:
381         versions = PY36_VERSIONS
382     else:
383         # We'll autodetect later.
384         versions = set()
385     mode = FileMode(
386         target_versions=versions,
387         line_length=line_length,
388         is_pyi=pyi,
389         string_normalization=not skip_string_normalization,
390     )
391     if config and verbose:
392         out(f"Using configuration from {config}.", bold=False, fg="blue")
393     try:
394         include_regex = re_compile_maybe_verbose(include)
395     except re.error:
396         err(f"Invalid regular expression for include given: {include!r}")
397         ctx.exit(2)
398     try:
399         exclude_regex = re_compile_maybe_verbose(exclude)
400     except re.error:
401         err(f"Invalid regular expression for exclude given: {exclude!r}")
402         ctx.exit(2)
403     report = Report(check=check, quiet=quiet, verbose=verbose)
404     root = find_project_root(src)
405     sources: Set[Path] = set()
406     for s in src:
407         p = Path(s)
408         if p.is_dir():
409             sources.update(
410                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
411             )
412         elif p.is_file() or s == "-":
413             # if a file was explicitly given, we don't care about its extension
414             sources.add(p)
415         else:
416             err(f"invalid path: {s}")
417     if len(sources) == 0:
418         if verbose or not quiet:
419             out("No paths given. Nothing to do 😴")
420         ctx.exit(0)
421
422     if len(sources) == 1:
423         reformat_one(
424             src=sources.pop(),
425             fast=fast,
426             write_back=write_back,
427             mode=mode,
428             report=report,
429         )
430     else:
431         loop = asyncio.get_event_loop()
432         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
433         try:
434             loop.run_until_complete(
435                 schedule_formatting(
436                     sources=sources,
437                     fast=fast,
438                     write_back=write_back,
439                     mode=mode,
440                     report=report,
441                     loop=loop,
442                     executor=executor,
443                 )
444             )
445         finally:
446             shutdown(loop)
447     if verbose or not quiet:
448         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
449         out(f"All done! {bang}")
450         click.secho(str(report), err=True)
451     ctx.exit(report.return_code)
452
453
454 def reformat_one(
455     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
456 ) -> None:
457     """Reformat a single file under `src` without spawning child processes.
458
459     If `quiet` is True, non-error messages are not output. `line_length`,
460     `write_back`, `fast` and `pyi` options are passed to
461     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
462     """
463     try:
464         changed = Changed.NO
465         if not src.is_file() and str(src) == "-":
466             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
467                 changed = Changed.YES
468         else:
469             cache: Cache = {}
470             if write_back != WriteBack.DIFF:
471                 cache = read_cache(mode)
472                 res_src = src.resolve()
473                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
474                     changed = Changed.CACHED
475             if changed is not Changed.CACHED and format_file_in_place(
476                 src, fast=fast, write_back=write_back, mode=mode
477             ):
478                 changed = Changed.YES
479             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
480                 write_back is WriteBack.CHECK and changed is Changed.NO
481             ):
482                 write_cache(cache, [src], mode)
483         report.done(src, changed)
484     except Exception as exc:
485         report.failed(src, str(exc))
486
487
488 async def schedule_formatting(
489     sources: Set[Path],
490     fast: bool,
491     write_back: WriteBack,
492     mode: FileMode,
493     report: "Report",
494     loop: BaseEventLoop,
495     executor: Executor,
496 ) -> None:
497     """Run formatting of `sources` in parallel using the provided `executor`.
498
499     (Use ProcessPoolExecutors for actual parallelism.)
500
501     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
502     :func:`format_file_in_place`.
503     """
504     cache: Cache = {}
505     if write_back != WriteBack.DIFF:
506         cache = read_cache(mode)
507         sources, cached = filter_cached(cache, sources)
508         for src in sorted(cached):
509             report.done(src, Changed.CACHED)
510     if not sources:
511         return
512
513     cancelled = []
514     sources_to_cache = []
515     lock = None
516     if write_back == WriteBack.DIFF:
517         # For diff output, we need locks to ensure we don't interleave output
518         # from different processes.
519         manager = Manager()
520         lock = manager.Lock()
521     tasks = {
522         loop.run_in_executor(
523             executor, format_file_in_place, src, fast, mode, write_back, lock
524         ): src
525         for src in sorted(sources)
526     }
527     pending: Iterable[asyncio.Task] = tasks.keys()
528     try:
529         loop.add_signal_handler(signal.SIGINT, cancel, pending)
530         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
531     except NotImplementedError:
532         # There are no good alternatives for these on Windows.
533         pass
534     while pending:
535         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
536         for task in done:
537             src = tasks.pop(task)
538             if task.cancelled():
539                 cancelled.append(task)
540             elif task.exception():
541                 report.failed(src, str(task.exception()))
542             else:
543                 changed = Changed.YES if task.result() else Changed.NO
544                 # If the file was written back or was successfully checked as
545                 # well-formatted, store this information in the cache.
546                 if write_back is WriteBack.YES or (
547                     write_back is WriteBack.CHECK and changed is Changed.NO
548                 ):
549                     sources_to_cache.append(src)
550                 report.done(src, changed)
551     if cancelled:
552         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
553     if sources_to_cache:
554         write_cache(cache, sources_to_cache, mode)
555
556
557 def format_file_in_place(
558     src: Path,
559     fast: bool,
560     mode: FileMode,
561     write_back: WriteBack = WriteBack.NO,
562     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
563 ) -> bool:
564     """Format file under `src` path. Return True if changed.
565
566     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
567     code to the file.
568     `line_length` and `fast` options are passed to :func:`format_file_contents`.
569     """
570     if src.suffix == ".pyi":
571         mode = evolve(mode, is_pyi=True)
572
573     then = datetime.utcfromtimestamp(src.stat().st_mtime)
574     with open(src, "rb") as buf:
575         src_contents, encoding, newline = decode_bytes(buf.read())
576     try:
577         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
578     except NothingChanged:
579         return False
580
581     if write_back == write_back.YES:
582         with open(src, "w", encoding=encoding, newline=newline) as f:
583             f.write(dst_contents)
584     elif write_back == write_back.DIFF:
585         now = datetime.utcnow()
586         src_name = f"{src}\t{then} +0000"
587         dst_name = f"{src}\t{now} +0000"
588         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
589         if lock:
590             lock.acquire()
591         try:
592             f = io.TextIOWrapper(
593                 sys.stdout.buffer,
594                 encoding=encoding,
595                 newline=newline,
596                 write_through=True,
597             )
598             f.write(diff_contents)
599             f.detach()
600         finally:
601             if lock:
602                 lock.release()
603     return True
604
605
606 def format_stdin_to_stdout(
607     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
608 ) -> bool:
609     """Format file on stdin. Return True if changed.
610
611     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
612     write a diff to stdout. The `mode` argument is passed to
613     :func:`format_file_contents`.
614     """
615     then = datetime.utcnow()
616     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
617     dst = src
618     try:
619         dst = format_file_contents(src, fast=fast, mode=mode)
620         return True
621
622     except NothingChanged:
623         return False
624
625     finally:
626         f = io.TextIOWrapper(
627             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
628         )
629         if write_back == WriteBack.YES:
630             f.write(dst)
631         elif write_back == WriteBack.DIFF:
632             now = datetime.utcnow()
633             src_name = f"STDIN\t{then} +0000"
634             dst_name = f"STDOUT\t{now} +0000"
635             f.write(diff(src, dst, src_name, dst_name))
636         f.detach()
637
638
639 def format_file_contents(
640     src_contents: str, *, fast: bool, mode: FileMode
641 ) -> FileContent:
642     """Reformat contents a file and return new contents.
643
644     If `fast` is False, additionally confirm that the reformatted code is
645     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
646     `line_length` is passed to :func:`format_str`.
647     """
648     if src_contents.strip() == "":
649         raise NothingChanged
650
651     dst_contents = format_str(src_contents, mode=mode)
652     if src_contents == dst_contents:
653         raise NothingChanged
654
655     if not fast:
656         assert_equivalent(src_contents, dst_contents)
657         assert_stable(src_contents, dst_contents, mode=mode)
658     return dst_contents
659
660
661 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
662     """Reformat a string and return new contents.
663
664     `line_length` determines how many characters per line are allowed.
665     """
666     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
667     dst_contents = ""
668     future_imports = get_future_imports(src_node)
669     if mode.target_versions:
670         versions = mode.target_versions
671     else:
672         versions = detect_target_versions(src_node)
673     normalize_fmt_off(src_node)
674     lines = LineGenerator(
675         remove_u_prefix="unicode_literals" in future_imports
676         or supports_feature(versions, Feature.UNICODE_LITERALS),
677         is_pyi=mode.is_pyi,
678         normalize_strings=mode.string_normalization,
679     )
680     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
681     empty_line = Line()
682     after = 0
683     for current_line in lines.visit(src_node):
684         for _ in range(after):
685             dst_contents += str(empty_line)
686         before, after = elt.maybe_empty_lines(current_line)
687         for _ in range(before):
688             dst_contents += str(empty_line)
689         for line in split_line(
690             current_line,
691             line_length=mode.line_length,
692             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
693         ):
694             dst_contents += str(line)
695     return dst_contents
696
697
698 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
699     """Return a tuple of (decoded_contents, encoding, newline).
700
701     `newline` is either CRLF or LF but `decoded_contents` is decoded with
702     universal newlines (i.e. only contains LF).
703     """
704     srcbuf = io.BytesIO(src)
705     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
706     if not lines:
707         return "", encoding, "\n"
708
709     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
710     srcbuf.seek(0)
711     with io.TextIOWrapper(srcbuf, encoding) as tiow:
712         return tiow.read(), encoding, newline
713
714
715 GRAMMARS = [
716     pygram.python_grammar_no_print_statement_no_exec_statement,
717     pygram.python_grammar_no_print_statement,
718     pygram.python_grammar,
719 ]
720
721
722 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
723     if not target_versions:
724         return GRAMMARS
725     elif all(not version.is_python2() for version in target_versions):
726         # Python 2-compatible code, so don't try Python 3 grammar.
727         return [
728             pygram.python_grammar_no_print_statement_no_exec_statement,
729             pygram.python_grammar_no_print_statement,
730         ]
731     else:
732         return [pygram.python_grammar]
733
734
735 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
736     """Given a string with source, return the lib2to3 Node."""
737     if src_txt[-1:] != "\n":
738         src_txt += "\n"
739
740     for grammar in get_grammars(set(target_versions)):
741         drv = driver.Driver(grammar, pytree.convert)
742         try:
743             result = drv.parse_string(src_txt, True)
744             break
745
746         except ParseError as pe:
747             lineno, column = pe.context[1]
748             lines = src_txt.splitlines()
749             try:
750                 faulty_line = lines[lineno - 1]
751             except IndexError:
752                 faulty_line = "<line number missing in source>"
753             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
754     else:
755         raise exc from None
756
757     if isinstance(result, Leaf):
758         result = Node(syms.file_input, [result])
759     return result
760
761
762 def lib2to3_unparse(node: Node) -> str:
763     """Given a lib2to3 node, return its string representation."""
764     code = str(node)
765     return code
766
767
768 T = TypeVar("T")
769
770
771 class Visitor(Generic[T]):
772     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
773
774     def visit(self, node: LN) -> Iterator[T]:
775         """Main method to visit `node` and its children.
776
777         It tries to find a `visit_*()` method for the given `node.type`, like
778         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
779         If no dedicated `visit_*()` method is found, chooses `visit_default()`
780         instead.
781
782         Then yields objects of type `T` from the selected visitor.
783         """
784         if node.type < 256:
785             name = token.tok_name[node.type]
786         else:
787             name = type_repr(node.type)
788         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
789
790     def visit_default(self, node: LN) -> Iterator[T]:
791         """Default `visit_*()` implementation. Recurses to children of `node`."""
792         if isinstance(node, Node):
793             for child in node.children:
794                 yield from self.visit(child)
795
796
797 @dataclass
798 class DebugVisitor(Visitor[T]):
799     tree_depth: int = 0
800
801     def visit_default(self, node: LN) -> Iterator[T]:
802         indent = " " * (2 * self.tree_depth)
803         if isinstance(node, Node):
804             _type = type_repr(node.type)
805             out(f"{indent}{_type}", fg="yellow")
806             self.tree_depth += 1
807             for child in node.children:
808                 yield from self.visit(child)
809
810             self.tree_depth -= 1
811             out(f"{indent}/{_type}", fg="yellow", bold=False)
812         else:
813             _type = token.tok_name.get(node.type, str(node.type))
814             out(f"{indent}{_type}", fg="blue", nl=False)
815             if node.prefix:
816                 # We don't have to handle prefixes for `Node` objects since
817                 # that delegates to the first child anyway.
818                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
819             out(f" {node.value!r}", fg="blue", bold=False)
820
821     @classmethod
822     def show(cls, code: Union[str, Leaf, Node]) -> None:
823         """Pretty-print the lib2to3 AST of a given string of `code`.
824
825         Convenience method for debugging.
826         """
827         v: DebugVisitor[None] = DebugVisitor()
828         if isinstance(code, str):
829             code = lib2to3_parse(code)
830         list(v.visit(code))
831
832
833 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
834 STATEMENT = {
835     syms.if_stmt,
836     syms.while_stmt,
837     syms.for_stmt,
838     syms.try_stmt,
839     syms.except_clause,
840     syms.with_stmt,
841     syms.funcdef,
842     syms.classdef,
843 }
844 STANDALONE_COMMENT = 153
845 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
846 LOGIC_OPERATORS = {"and", "or"}
847 COMPARATORS = {
848     token.LESS,
849     token.GREATER,
850     token.EQEQUAL,
851     token.NOTEQUAL,
852     token.LESSEQUAL,
853     token.GREATEREQUAL,
854 }
855 MATH_OPERATORS = {
856     token.VBAR,
857     token.CIRCUMFLEX,
858     token.AMPER,
859     token.LEFTSHIFT,
860     token.RIGHTSHIFT,
861     token.PLUS,
862     token.MINUS,
863     token.STAR,
864     token.SLASH,
865     token.DOUBLESLASH,
866     token.PERCENT,
867     token.AT,
868     token.TILDE,
869     token.DOUBLESTAR,
870 }
871 STARS = {token.STAR, token.DOUBLESTAR}
872 VARARGS_PARENTS = {
873     syms.arglist,
874     syms.argument,  # double star in arglist
875     syms.trailer,  # single argument to call
876     syms.typedargslist,
877     syms.varargslist,  # lambdas
878 }
879 UNPACKING_PARENTS = {
880     syms.atom,  # single element of a list or set literal
881     syms.dictsetmaker,
882     syms.listmaker,
883     syms.testlist_gexp,
884     syms.testlist_star_expr,
885 }
886 TEST_DESCENDANTS = {
887     syms.test,
888     syms.lambdef,
889     syms.or_test,
890     syms.and_test,
891     syms.not_test,
892     syms.comparison,
893     syms.star_expr,
894     syms.expr,
895     syms.xor_expr,
896     syms.and_expr,
897     syms.shift_expr,
898     syms.arith_expr,
899     syms.trailer,
900     syms.term,
901     syms.power,
902 }
903 ASSIGNMENTS = {
904     "=",
905     "+=",
906     "-=",
907     "*=",
908     "@=",
909     "/=",
910     "%=",
911     "&=",
912     "|=",
913     "^=",
914     "<<=",
915     ">>=",
916     "**=",
917     "//=",
918 }
919 COMPREHENSION_PRIORITY = 20
920 COMMA_PRIORITY = 18
921 TERNARY_PRIORITY = 16
922 LOGIC_PRIORITY = 14
923 STRING_PRIORITY = 12
924 COMPARATOR_PRIORITY = 10
925 MATH_PRIORITIES = {
926     token.VBAR: 9,
927     token.CIRCUMFLEX: 8,
928     token.AMPER: 7,
929     token.LEFTSHIFT: 6,
930     token.RIGHTSHIFT: 6,
931     token.PLUS: 5,
932     token.MINUS: 5,
933     token.STAR: 4,
934     token.SLASH: 4,
935     token.DOUBLESLASH: 4,
936     token.PERCENT: 4,
937     token.AT: 4,
938     token.TILDE: 3,
939     token.DOUBLESTAR: 2,
940 }
941 DOT_PRIORITY = 1
942
943
944 @dataclass
945 class BracketTracker:
946     """Keeps track of brackets on a line."""
947
948     depth: int = 0
949     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
950     delimiters: Dict[LeafID, Priority] = Factory(dict)
951     previous: Optional[Leaf] = None
952     _for_loop_depths: List[int] = Factory(list)
953     _lambda_argument_depths: List[int] = Factory(list)
954
955     def mark(self, leaf: Leaf) -> None:
956         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
957
958         All leaves receive an int `bracket_depth` field that stores how deep
959         within brackets a given leaf is. 0 means there are no enclosing brackets
960         that started on this line.
961
962         If a leaf is itself a closing bracket, it receives an `opening_bracket`
963         field that it forms a pair with. This is a one-directional link to
964         avoid reference cycles.
965
966         If a leaf is a delimiter (a token on which Black can split the line if
967         needed) and it's on depth 0, its `id()` is stored in the tracker's
968         `delimiters` field.
969         """
970         if leaf.type == token.COMMENT:
971             return
972
973         self.maybe_decrement_after_for_loop_variable(leaf)
974         self.maybe_decrement_after_lambda_arguments(leaf)
975         if leaf.type in CLOSING_BRACKETS:
976             self.depth -= 1
977             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
978             leaf.opening_bracket = opening_bracket
979         leaf.bracket_depth = self.depth
980         if self.depth == 0:
981             delim = is_split_before_delimiter(leaf, self.previous)
982             if delim and self.previous is not None:
983                 self.delimiters[id(self.previous)] = delim
984             else:
985                 delim = is_split_after_delimiter(leaf, self.previous)
986                 if delim:
987                     self.delimiters[id(leaf)] = delim
988         if leaf.type in OPENING_BRACKETS:
989             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
990             self.depth += 1
991         self.previous = leaf
992         self.maybe_increment_lambda_arguments(leaf)
993         self.maybe_increment_for_loop_variable(leaf)
994
995     def any_open_brackets(self) -> bool:
996         """Return True if there is an yet unmatched open bracket on the line."""
997         return bool(self.bracket_match)
998
999     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1000         """Return the highest priority of a delimiter found on the line.
1001
1002         Values are consistent with what `is_split_*_delimiter()` return.
1003         Raises ValueError on no delimiters.
1004         """
1005         return max(v for k, v in self.delimiters.items() if k not in exclude)
1006
1007     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1008         """Return the number of delimiters with the given `priority`.
1009
1010         If no `priority` is passed, defaults to max priority on the line.
1011         """
1012         if not self.delimiters:
1013             return 0
1014
1015         priority = priority or self.max_delimiter_priority()
1016         return sum(1 for p in self.delimiters.values() if p == priority)
1017
1018     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1019         """In a for loop, or comprehension, the variables are often unpacks.
1020
1021         To avoid splitting on the comma in this situation, increase the depth of
1022         tokens between `for` and `in`.
1023         """
1024         if leaf.type == token.NAME and leaf.value == "for":
1025             self.depth += 1
1026             self._for_loop_depths.append(self.depth)
1027             return True
1028
1029         return False
1030
1031     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1032         """See `maybe_increment_for_loop_variable` above for explanation."""
1033         if (
1034             self._for_loop_depths
1035             and self._for_loop_depths[-1] == self.depth
1036             and leaf.type == token.NAME
1037             and leaf.value == "in"
1038         ):
1039             self.depth -= 1
1040             self._for_loop_depths.pop()
1041             return True
1042
1043         return False
1044
1045     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1046         """In a lambda expression, there might be more than one argument.
1047
1048         To avoid splitting on the comma in this situation, increase the depth of
1049         tokens between `lambda` and `:`.
1050         """
1051         if leaf.type == token.NAME and leaf.value == "lambda":
1052             self.depth += 1
1053             self._lambda_argument_depths.append(self.depth)
1054             return True
1055
1056         return False
1057
1058     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1059         """See `maybe_increment_lambda_arguments` above for explanation."""
1060         if (
1061             self._lambda_argument_depths
1062             and self._lambda_argument_depths[-1] == self.depth
1063             and leaf.type == token.COLON
1064         ):
1065             self.depth -= 1
1066             self._lambda_argument_depths.pop()
1067             return True
1068
1069         return False
1070
1071     def get_open_lsqb(self) -> Optional[Leaf]:
1072         """Return the most recent opening square bracket (if any)."""
1073         return self.bracket_match.get((self.depth - 1, token.RSQB))
1074
1075
1076 @dataclass
1077 class Line:
1078     """Holds leaves and comments. Can be printed with `str(line)`."""
1079
1080     depth: int = 0
1081     leaves: List[Leaf] = Factory(list)
1082     # The LeafID keys of comments must remain ordered by the corresponding leaf's index
1083     # in leaves
1084     comments: Dict[LeafID, List[Leaf]] = Factory(dict)
1085     bracket_tracker: BracketTracker = Factory(BracketTracker)
1086     inside_brackets: bool = False
1087     should_explode: bool = False
1088
1089     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1090         """Add a new `leaf` to the end of the line.
1091
1092         Unless `preformatted` is True, the `leaf` will receive a new consistent
1093         whitespace prefix and metadata applied by :class:`BracketTracker`.
1094         Trailing commas are maybe removed, unpacked for loop variables are
1095         demoted from being delimiters.
1096
1097         Inline comments are put aside.
1098         """
1099         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1100         if not has_value:
1101             return
1102
1103         if token.COLON == leaf.type and self.is_class_paren_empty:
1104             del self.leaves[-2:]
1105         if self.leaves and not preformatted:
1106             # Note: at this point leaf.prefix should be empty except for
1107             # imports, for which we only preserve newlines.
1108             leaf.prefix += whitespace(
1109                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1110             )
1111         if self.inside_brackets or not preformatted:
1112             self.bracket_tracker.mark(leaf)
1113             self.maybe_remove_trailing_comma(leaf)
1114         if not self.append_comment(leaf):
1115             self.leaves.append(leaf)
1116
1117     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1118         """Like :func:`append()` but disallow invalid standalone comment structure.
1119
1120         Raises ValueError when any `leaf` is appended after a standalone comment
1121         or when a standalone comment is not the first leaf on the line.
1122         """
1123         if self.bracket_tracker.depth == 0:
1124             if self.is_comment:
1125                 raise ValueError("cannot append to standalone comments")
1126
1127             if self.leaves and leaf.type == STANDALONE_COMMENT:
1128                 raise ValueError(
1129                     "cannot append standalone comments to a populated line"
1130                 )
1131
1132         self.append(leaf, preformatted=preformatted)
1133
1134     @property
1135     def is_comment(self) -> bool:
1136         """Is this line a standalone comment?"""
1137         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1138
1139     @property
1140     def is_decorator(self) -> bool:
1141         """Is this line a decorator?"""
1142         return bool(self) and self.leaves[0].type == token.AT
1143
1144     @property
1145     def is_import(self) -> bool:
1146         """Is this an import line?"""
1147         return bool(self) and is_import(self.leaves[0])
1148
1149     @property
1150     def is_class(self) -> bool:
1151         """Is this line a class definition?"""
1152         return (
1153             bool(self)
1154             and self.leaves[0].type == token.NAME
1155             and self.leaves[0].value == "class"
1156         )
1157
1158     @property
1159     def is_stub_class(self) -> bool:
1160         """Is this line a class definition with a body consisting only of "..."?"""
1161         return self.is_class and self.leaves[-3:] == [
1162             Leaf(token.DOT, ".") for _ in range(3)
1163         ]
1164
1165     @property
1166     def is_def(self) -> bool:
1167         """Is this a function definition? (Also returns True for async defs.)"""
1168         try:
1169             first_leaf = self.leaves[0]
1170         except IndexError:
1171             return False
1172
1173         try:
1174             second_leaf: Optional[Leaf] = self.leaves[1]
1175         except IndexError:
1176             second_leaf = None
1177         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1178             first_leaf.type == token.ASYNC
1179             and second_leaf is not None
1180             and second_leaf.type == token.NAME
1181             and second_leaf.value == "def"
1182         )
1183
1184     @property
1185     def is_class_paren_empty(self) -> bool:
1186         """Is this a class with no base classes but using parentheses?
1187
1188         Those are unnecessary and should be removed.
1189         """
1190         return (
1191             bool(self)
1192             and len(self.leaves) == 4
1193             and self.is_class
1194             and self.leaves[2].type == token.LPAR
1195             and self.leaves[2].value == "("
1196             and self.leaves[3].type == token.RPAR
1197             and self.leaves[3].value == ")"
1198         )
1199
1200     @property
1201     def is_triple_quoted_string(self) -> bool:
1202         """Is the line a triple quoted string?"""
1203         return (
1204             bool(self)
1205             and self.leaves[0].type == token.STRING
1206             and self.leaves[0].value.startswith(('"""', "'''"))
1207         )
1208
1209     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1210         """If so, needs to be split before emitting."""
1211         for leaf in self.leaves:
1212             if leaf.type == STANDALONE_COMMENT:
1213                 if leaf.bracket_depth <= depth_limit:
1214                     return True
1215
1216         return False
1217
1218     def contains_multiline_strings(self) -> bool:
1219         for leaf in self.leaves:
1220             if is_multiline_string(leaf):
1221                 return True
1222
1223         return False
1224
1225     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1226         """Remove trailing comma if there is one and it's safe."""
1227         if not (
1228             self.leaves
1229             and self.leaves[-1].type == token.COMMA
1230             and closing.type in CLOSING_BRACKETS
1231         ):
1232             return False
1233
1234         if closing.type == token.RBRACE:
1235             self.remove_trailing_comma()
1236             return True
1237
1238         if closing.type == token.RSQB:
1239             comma = self.leaves[-1]
1240             if comma.parent and comma.parent.type == syms.listmaker:
1241                 self.remove_trailing_comma()
1242                 return True
1243
1244         # For parens let's check if it's safe to remove the comma.
1245         # Imports are always safe.
1246         if self.is_import:
1247             self.remove_trailing_comma()
1248             return True
1249
1250         # Otherwise, if the trailing one is the only one, we might mistakenly
1251         # change a tuple into a different type by removing the comma.
1252         depth = closing.bracket_depth + 1
1253         commas = 0
1254         opening = closing.opening_bracket
1255         for _opening_index, leaf in enumerate(self.leaves):
1256             if leaf is opening:
1257                 break
1258
1259         else:
1260             return False
1261
1262         for leaf in self.leaves[_opening_index + 1 :]:
1263             if leaf is closing:
1264                 break
1265
1266             bracket_depth = leaf.bracket_depth
1267             if bracket_depth == depth and leaf.type == token.COMMA:
1268                 commas += 1
1269                 if leaf.parent and leaf.parent.type == syms.arglist:
1270                     commas += 1
1271                     break
1272
1273         if commas > 1:
1274             self.remove_trailing_comma()
1275             return True
1276
1277         return False
1278
1279     def append_comment(self, comment: Leaf) -> bool:
1280         """Add an inline or standalone comment to the line."""
1281         if (
1282             comment.type == STANDALONE_COMMENT
1283             and self.bracket_tracker.any_open_brackets()
1284         ):
1285             comment.prefix = ""
1286             return False
1287
1288         if comment.type != token.COMMENT:
1289             return False
1290
1291         if not self.leaves:
1292             comment.type = STANDALONE_COMMENT
1293             comment.prefix = ""
1294             return False
1295
1296         else:
1297             leaf_id = id(self.leaves[-1])
1298             if leaf_id not in self.comments:
1299                 self.comments[leaf_id] = [comment]
1300             else:
1301                 self.comments[leaf_id].append(comment)
1302             return True
1303
1304     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1305         """Generate comments that should appear directly after `leaf`."""
1306         return self.comments.get(id(leaf), [])
1307
1308     def remove_trailing_comma(self) -> None:
1309         """Remove the trailing comma and moves the comments attached to it."""
1310         # Remember, the LeafID keys of self.comments are ordered by the
1311         # corresponding leaf's index in self.leaves
1312         # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
1313         # Otherwise, we insert it into self.comments, and it becomes the last entry.
1314         # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
1315         # is maintained
1316         self.comments.setdefault(id(self.leaves[-2]), []).extend(
1317             self.comments.get(id(self.leaves[-1]), [])
1318         )
1319         self.comments.pop(id(self.leaves[-1]), None)
1320         self.leaves.pop()
1321
1322     def is_complex_subscript(self, leaf: Leaf) -> bool:
1323         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1324         open_lsqb = self.bracket_tracker.get_open_lsqb()
1325         if open_lsqb is None:
1326             return False
1327
1328         subscript_start = open_lsqb.next_sibling
1329
1330         if isinstance(subscript_start, Node):
1331             if subscript_start.type == syms.listmaker:
1332                 return False
1333
1334             if subscript_start.type == syms.subscriptlist:
1335                 subscript_start = child_towards(subscript_start, leaf)
1336         return subscript_start is not None and any(
1337             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1338         )
1339
1340     def __str__(self) -> str:
1341         """Render the line."""
1342         if not self:
1343             return "\n"
1344
1345         indent = "    " * self.depth
1346         leaves = iter(self.leaves)
1347         first = next(leaves)
1348         res = f"{first.prefix}{indent}{first.value}"
1349         for leaf in leaves:
1350             res += str(leaf)
1351         for comment in itertools.chain.from_iterable(self.comments.values()):
1352             res += str(comment)
1353         return res + "\n"
1354
1355     def __bool__(self) -> bool:
1356         """Return True if the line has leaves or comments."""
1357         return bool(self.leaves or self.comments)
1358
1359
1360 @dataclass
1361 class EmptyLineTracker:
1362     """Provides a stateful method that returns the number of potential extra
1363     empty lines needed before and after the currently processed line.
1364
1365     Note: this tracker works on lines that haven't been split yet.  It assumes
1366     the prefix of the first leaf consists of optional newlines.  Those newlines
1367     are consumed by `maybe_empty_lines()` and included in the computation.
1368     """
1369
1370     is_pyi: bool = False
1371     previous_line: Optional[Line] = None
1372     previous_after: int = 0
1373     previous_defs: List[int] = Factory(list)
1374
1375     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1376         """Return the number of extra empty lines before and after the `current_line`.
1377
1378         This is for separating `def`, `async def` and `class` with extra empty
1379         lines (two on module-level).
1380         """
1381         before, after = self._maybe_empty_lines(current_line)
1382         before -= self.previous_after
1383         self.previous_after = after
1384         self.previous_line = current_line
1385         return before, after
1386
1387     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1388         max_allowed = 1
1389         if current_line.depth == 0:
1390             max_allowed = 1 if self.is_pyi else 2
1391         if current_line.leaves:
1392             # Consume the first leaf's extra newlines.
1393             first_leaf = current_line.leaves[0]
1394             before = first_leaf.prefix.count("\n")
1395             before = min(before, max_allowed)
1396             first_leaf.prefix = ""
1397         else:
1398             before = 0
1399         depth = current_line.depth
1400         while self.previous_defs and self.previous_defs[-1] >= depth:
1401             self.previous_defs.pop()
1402             if self.is_pyi:
1403                 before = 0 if depth else 1
1404             else:
1405                 before = 1 if depth else 2
1406         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1407             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1408
1409         if (
1410             self.previous_line
1411             and self.previous_line.is_import
1412             and not current_line.is_import
1413             and depth == self.previous_line.depth
1414         ):
1415             return (before or 1), 0
1416
1417         if (
1418             self.previous_line
1419             and self.previous_line.is_class
1420             and current_line.is_triple_quoted_string
1421         ):
1422             return before, 1
1423
1424         return before, 0
1425
1426     def _maybe_empty_lines_for_class_or_def(
1427         self, current_line: Line, before: int
1428     ) -> Tuple[int, int]:
1429         if not current_line.is_decorator:
1430             self.previous_defs.append(current_line.depth)
1431         if self.previous_line is None:
1432             # Don't insert empty lines before the first line in the file.
1433             return 0, 0
1434
1435         if self.previous_line.is_decorator:
1436             return 0, 0
1437
1438         if self.previous_line.depth < current_line.depth and (
1439             self.previous_line.is_class or self.previous_line.is_def
1440         ):
1441             return 0, 0
1442
1443         if (
1444             self.previous_line.is_comment
1445             and self.previous_line.depth == current_line.depth
1446             and before == 0
1447         ):
1448             return 0, 0
1449
1450         if self.is_pyi:
1451             if self.previous_line.depth > current_line.depth:
1452                 newlines = 1
1453             elif current_line.is_class or self.previous_line.is_class:
1454                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1455                     # No blank line between classes with an empty body
1456                     newlines = 0
1457                 else:
1458                     newlines = 1
1459             elif current_line.is_def and not self.previous_line.is_def:
1460                 # Blank line between a block of functions and a block of non-functions
1461                 newlines = 1
1462             else:
1463                 newlines = 0
1464         else:
1465             newlines = 2
1466         if current_line.depth and newlines:
1467             newlines -= 1
1468         return newlines, 0
1469
1470
1471 @dataclass
1472 class LineGenerator(Visitor[Line]):
1473     """Generates reformatted Line objects.  Empty lines are not emitted.
1474
1475     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1476     in ways that will no longer stringify to valid Python code on the tree.
1477     """
1478
1479     is_pyi: bool = False
1480     normalize_strings: bool = True
1481     current_line: Line = Factory(Line)
1482     remove_u_prefix: bool = False
1483
1484     def line(self, indent: int = 0) -> Iterator[Line]:
1485         """Generate a line.
1486
1487         If the line is empty, only emit if it makes sense.
1488         If the line is too long, split it first and then generate.
1489
1490         If any lines were generated, set up a new current_line.
1491         """
1492         if not self.current_line:
1493             self.current_line.depth += indent
1494             return  # Line is empty, don't emit. Creating a new one unnecessary.
1495
1496         complete_line = self.current_line
1497         self.current_line = Line(depth=complete_line.depth + indent)
1498         yield complete_line
1499
1500     def visit_default(self, node: LN) -> Iterator[Line]:
1501         """Default `visit_*()` implementation. Recurses to children of `node`."""
1502         if isinstance(node, Leaf):
1503             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1504             for comment in generate_comments(node):
1505                 if any_open_brackets:
1506                     # any comment within brackets is subject to splitting
1507                     self.current_line.append(comment)
1508                 elif comment.type == token.COMMENT:
1509                     # regular trailing comment
1510                     self.current_line.append(comment)
1511                     yield from self.line()
1512
1513                 else:
1514                     # regular standalone comment
1515                     yield from self.line()
1516
1517                     self.current_line.append(comment)
1518                     yield from self.line()
1519
1520             normalize_prefix(node, inside_brackets=any_open_brackets)
1521             if self.normalize_strings and node.type == token.STRING:
1522                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1523                 normalize_string_quotes(node)
1524             if node.type == token.NUMBER:
1525                 normalize_numeric_literal(node)
1526             if node.type not in WHITESPACE:
1527                 self.current_line.append(node)
1528         yield from super().visit_default(node)
1529
1530     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1531         """Increase indentation level, maybe yield a line."""
1532         # In blib2to3 INDENT never holds comments.
1533         yield from self.line(+1)
1534         yield from self.visit_default(node)
1535
1536     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1537         """Decrease indentation level, maybe yield a line."""
1538         # The current line might still wait for trailing comments.  At DEDENT time
1539         # there won't be any (they would be prefixes on the preceding NEWLINE).
1540         # Emit the line then.
1541         yield from self.line()
1542
1543         # While DEDENT has no value, its prefix may contain standalone comments
1544         # that belong to the current indentation level.  Get 'em.
1545         yield from self.visit_default(node)
1546
1547         # Finally, emit the dedent.
1548         yield from self.line(-1)
1549
1550     def visit_stmt(
1551         self, node: Node, keywords: Set[str], parens: Set[str]
1552     ) -> Iterator[Line]:
1553         """Visit a statement.
1554
1555         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1556         `def`, `with`, `class`, `assert` and assignments.
1557
1558         The relevant Python language `keywords` for a given statement will be
1559         NAME leaves within it. This methods puts those on a separate line.
1560
1561         `parens` holds a set of string leaf values immediately after which
1562         invisible parens should be put.
1563         """
1564         normalize_invisible_parens(node, parens_after=parens)
1565         for child in node.children:
1566             if child.type == token.NAME and child.value in keywords:  # type: ignore
1567                 yield from self.line()
1568
1569             yield from self.visit(child)
1570
1571     def visit_suite(self, node: Node) -> Iterator[Line]:
1572         """Visit a suite."""
1573         if self.is_pyi and is_stub_suite(node):
1574             yield from self.visit(node.children[2])
1575         else:
1576             yield from self.visit_default(node)
1577
1578     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1579         """Visit a statement without nested statements."""
1580         is_suite_like = node.parent and node.parent.type in STATEMENT
1581         if is_suite_like:
1582             if self.is_pyi and is_stub_body(node):
1583                 yield from self.visit_default(node)
1584             else:
1585                 yield from self.line(+1)
1586                 yield from self.visit_default(node)
1587                 yield from self.line(-1)
1588
1589         else:
1590             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1591                 yield from self.line()
1592             yield from self.visit_default(node)
1593
1594     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1595         """Visit `async def`, `async for`, `async with`."""
1596         yield from self.line()
1597
1598         children = iter(node.children)
1599         for child in children:
1600             yield from self.visit(child)
1601
1602             if child.type == token.ASYNC:
1603                 break
1604
1605         internal_stmt = next(children)
1606         for child in internal_stmt.children:
1607             yield from self.visit(child)
1608
1609     def visit_decorators(self, node: Node) -> Iterator[Line]:
1610         """Visit decorators."""
1611         for child in node.children:
1612             yield from self.line()
1613             yield from self.visit(child)
1614
1615     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1616         """Remove a semicolon and put the other statement on a separate line."""
1617         yield from self.line()
1618
1619     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1620         """End of file. Process outstanding comments and end with a newline."""
1621         yield from self.visit_default(leaf)
1622         yield from self.line()
1623
1624     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1625         if not self.current_line.bracket_tracker.any_open_brackets():
1626             yield from self.line()
1627         yield from self.visit_default(leaf)
1628
1629     def __attrs_post_init__(self) -> None:
1630         """You are in a twisty little maze of passages."""
1631         v = self.visit_stmt
1632         Ø: Set[str] = set()
1633         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1634         self.visit_if_stmt = partial(
1635             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1636         )
1637         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1638         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1639         self.visit_try_stmt = partial(
1640             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1641         )
1642         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1643         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1644         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1645         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1646         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1647         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1648         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1649         self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
1650         self.visit_async_funcdef = self.visit_async_stmt
1651         self.visit_decorated = self.visit_decorators
1652
1653
1654 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1655 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1656 OPENING_BRACKETS = set(BRACKET.keys())
1657 CLOSING_BRACKETS = set(BRACKET.values())
1658 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1659 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1660
1661
1662 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1663     """Return whitespace prefix if needed for the given `leaf`.
1664
1665     `complex_subscript` signals whether the given leaf is part of a subscription
1666     which has non-trivial arguments, like arithmetic expressions or function calls.
1667     """
1668     NO = ""
1669     SPACE = " "
1670     DOUBLESPACE = "  "
1671     t = leaf.type
1672     p = leaf.parent
1673     v = leaf.value
1674     if t in ALWAYS_NO_SPACE:
1675         return NO
1676
1677     if t == token.COMMENT:
1678         return DOUBLESPACE
1679
1680     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1681     if t == token.COLON and p.type not in {
1682         syms.subscript,
1683         syms.subscriptlist,
1684         syms.sliceop,
1685     }:
1686         return NO
1687
1688     prev = leaf.prev_sibling
1689     if not prev:
1690         prevp = preceding_leaf(p)
1691         if not prevp or prevp.type in OPENING_BRACKETS:
1692             return NO
1693
1694         if t == token.COLON:
1695             if prevp.type == token.COLON:
1696                 return NO
1697
1698             elif prevp.type != token.COMMA and not complex_subscript:
1699                 return NO
1700
1701             return SPACE
1702
1703         if prevp.type == token.EQUAL:
1704             if prevp.parent:
1705                 if prevp.parent.type in {
1706                     syms.arglist,
1707                     syms.argument,
1708                     syms.parameters,
1709                     syms.varargslist,
1710                 }:
1711                     return NO
1712
1713                 elif prevp.parent.type == syms.typedargslist:
1714                     # A bit hacky: if the equal sign has whitespace, it means we
1715                     # previously found it's a typed argument.  So, we're using
1716                     # that, too.
1717                     return prevp.prefix
1718
1719         elif prevp.type in STARS:
1720             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1721                 return NO
1722
1723         elif prevp.type == token.COLON:
1724             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1725                 return SPACE if complex_subscript else NO
1726
1727         elif (
1728             prevp.parent
1729             and prevp.parent.type == syms.factor
1730             and prevp.type in MATH_OPERATORS
1731         ):
1732             return NO
1733
1734         elif (
1735             prevp.type == token.RIGHTSHIFT
1736             and prevp.parent
1737             and prevp.parent.type == syms.shift_expr
1738             and prevp.prev_sibling
1739             and prevp.prev_sibling.type == token.NAME
1740             and prevp.prev_sibling.value == "print"  # type: ignore
1741         ):
1742             # Python 2 print chevron
1743             return NO
1744
1745     elif prev.type in OPENING_BRACKETS:
1746         return NO
1747
1748     if p.type in {syms.parameters, syms.arglist}:
1749         # untyped function signatures or calls
1750         if not prev or prev.type != token.COMMA:
1751             return NO
1752
1753     elif p.type == syms.varargslist:
1754         # lambdas
1755         if prev and prev.type != token.COMMA:
1756             return NO
1757
1758     elif p.type == syms.typedargslist:
1759         # typed function signatures
1760         if not prev:
1761             return NO
1762
1763         if t == token.EQUAL:
1764             if prev.type != syms.tname:
1765                 return NO
1766
1767         elif prev.type == token.EQUAL:
1768             # A bit hacky: if the equal sign has whitespace, it means we
1769             # previously found it's a typed argument.  So, we're using that, too.
1770             return prev.prefix
1771
1772         elif prev.type != token.COMMA:
1773             return NO
1774
1775     elif p.type == syms.tname:
1776         # type names
1777         if not prev:
1778             prevp = preceding_leaf(p)
1779             if not prevp or prevp.type != token.COMMA:
1780                 return NO
1781
1782     elif p.type == syms.trailer:
1783         # attributes and calls
1784         if t == token.LPAR or t == token.RPAR:
1785             return NO
1786
1787         if not prev:
1788             if t == token.DOT:
1789                 prevp = preceding_leaf(p)
1790                 if not prevp or prevp.type != token.NUMBER:
1791                     return NO
1792
1793             elif t == token.LSQB:
1794                 return NO
1795
1796         elif prev.type != token.COMMA:
1797             return NO
1798
1799     elif p.type == syms.argument:
1800         # single argument
1801         if t == token.EQUAL:
1802             return NO
1803
1804         if not prev:
1805             prevp = preceding_leaf(p)
1806             if not prevp or prevp.type == token.LPAR:
1807                 return NO
1808
1809         elif prev.type in {token.EQUAL} | STARS:
1810             return NO
1811
1812     elif p.type == syms.decorator:
1813         # decorators
1814         return NO
1815
1816     elif p.type == syms.dotted_name:
1817         if prev:
1818             return NO
1819
1820         prevp = preceding_leaf(p)
1821         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1822             return NO
1823
1824     elif p.type == syms.classdef:
1825         if t == token.LPAR:
1826             return NO
1827
1828         if prev and prev.type == token.LPAR:
1829             return NO
1830
1831     elif p.type in {syms.subscript, syms.sliceop}:
1832         # indexing
1833         if not prev:
1834             assert p.parent is not None, "subscripts are always parented"
1835             if p.parent.type == syms.subscriptlist:
1836                 return SPACE
1837
1838             return NO
1839
1840         elif not complex_subscript:
1841             return NO
1842
1843     elif p.type == syms.atom:
1844         if prev and t == token.DOT:
1845             # dots, but not the first one.
1846             return NO
1847
1848     elif p.type == syms.dictsetmaker:
1849         # dict unpacking
1850         if prev and prev.type == token.DOUBLESTAR:
1851             return NO
1852
1853     elif p.type in {syms.factor, syms.star_expr}:
1854         # unary ops
1855         if not prev:
1856             prevp = preceding_leaf(p)
1857             if not prevp or prevp.type in OPENING_BRACKETS:
1858                 return NO
1859
1860             prevp_parent = prevp.parent
1861             assert prevp_parent is not None
1862             if prevp.type == token.COLON and prevp_parent.type in {
1863                 syms.subscript,
1864                 syms.sliceop,
1865             }:
1866                 return NO
1867
1868             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1869                 return NO
1870
1871         elif t in {token.NAME, token.NUMBER, token.STRING}:
1872             return NO
1873
1874     elif p.type == syms.import_from:
1875         if t == token.DOT:
1876             if prev and prev.type == token.DOT:
1877                 return NO
1878
1879         elif t == token.NAME:
1880             if v == "import":
1881                 return SPACE
1882
1883             if prev and prev.type == token.DOT:
1884                 return NO
1885
1886     elif p.type == syms.sliceop:
1887         return NO
1888
1889     return SPACE
1890
1891
1892 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1893     """Return the first leaf that precedes `node`, if any."""
1894     while node:
1895         res = node.prev_sibling
1896         if res:
1897             if isinstance(res, Leaf):
1898                 return res
1899
1900             try:
1901                 return list(res.leaves())[-1]
1902
1903             except IndexError:
1904                 return None
1905
1906         node = node.parent
1907     return None
1908
1909
1910 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1911     """Return the child of `ancestor` that contains `descendant`."""
1912     node: Optional[LN] = descendant
1913     while node and node.parent != ancestor:
1914         node = node.parent
1915     return node
1916
1917
1918 def container_of(leaf: Leaf) -> LN:
1919     """Return `leaf` or one of its ancestors that is the topmost container of it.
1920
1921     By "container" we mean a node where `leaf` is the very first child.
1922     """
1923     same_prefix = leaf.prefix
1924     container: LN = leaf
1925     while container:
1926         parent = container.parent
1927         if parent is None:
1928             break
1929
1930         if parent.children[0].prefix != same_prefix:
1931             break
1932
1933         if parent.type == syms.file_input:
1934             break
1935
1936         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1937             break
1938
1939         container = parent
1940     return container
1941
1942
1943 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1944     """Return the priority of the `leaf` delimiter, given a line break after it.
1945
1946     The delimiter priorities returned here are from those delimiters that would
1947     cause a line break after themselves.
1948
1949     Higher numbers are higher priority.
1950     """
1951     if leaf.type == token.COMMA:
1952         return COMMA_PRIORITY
1953
1954     return 0
1955
1956
1957 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1958     """Return the priority of the `leaf` delimiter, given a line break before it.
1959
1960     The delimiter priorities returned here are from those delimiters that would
1961     cause a line break before themselves.
1962
1963     Higher numbers are higher priority.
1964     """
1965     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1966         # * and ** might also be MATH_OPERATORS but in this case they are not.
1967         # Don't treat them as a delimiter.
1968         return 0
1969
1970     if (
1971         leaf.type == token.DOT
1972         and leaf.parent
1973         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1974         and (previous is None or previous.type in CLOSING_BRACKETS)
1975     ):
1976         return DOT_PRIORITY
1977
1978     if (
1979         leaf.type in MATH_OPERATORS
1980         and leaf.parent
1981         and leaf.parent.type not in {syms.factor, syms.star_expr}
1982     ):
1983         return MATH_PRIORITIES[leaf.type]
1984
1985     if leaf.type in COMPARATORS:
1986         return COMPARATOR_PRIORITY
1987
1988     if (
1989         leaf.type == token.STRING
1990         and previous is not None
1991         and previous.type == token.STRING
1992     ):
1993         return STRING_PRIORITY
1994
1995     if leaf.type not in {token.NAME, token.ASYNC}:
1996         return 0
1997
1998     if (
1999         leaf.value == "for"
2000         and leaf.parent
2001         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2002         or leaf.type == token.ASYNC
2003     ):
2004         if (
2005             not isinstance(leaf.prev_sibling, Leaf)
2006             or leaf.prev_sibling.value != "async"
2007         ):
2008             return COMPREHENSION_PRIORITY
2009
2010     if (
2011         leaf.value == "if"
2012         and leaf.parent
2013         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2014     ):
2015         return COMPREHENSION_PRIORITY
2016
2017     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2018         return TERNARY_PRIORITY
2019
2020     if leaf.value == "is":
2021         return COMPARATOR_PRIORITY
2022
2023     if (
2024         leaf.value == "in"
2025         and leaf.parent
2026         and leaf.parent.type in {syms.comp_op, syms.comparison}
2027         and not (
2028             previous is not None
2029             and previous.type == token.NAME
2030             and previous.value == "not"
2031         )
2032     ):
2033         return COMPARATOR_PRIORITY
2034
2035     if (
2036         leaf.value == "not"
2037         and leaf.parent
2038         and leaf.parent.type == syms.comp_op
2039         and not (
2040             previous is not None
2041             and previous.type == token.NAME
2042             and previous.value == "is"
2043         )
2044     ):
2045         return COMPARATOR_PRIORITY
2046
2047     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2048         return LOGIC_PRIORITY
2049
2050     return 0
2051
2052
2053 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2054 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2055
2056
2057 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2058     """Clean the prefix of the `leaf` and generate comments from it, if any.
2059
2060     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2061     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2062     move because it does away with modifying the grammar to include all the
2063     possible places in which comments can be placed.
2064
2065     The sad consequence for us though is that comments don't "belong" anywhere.
2066     This is why this function generates simple parentless Leaf objects for
2067     comments.  We simply don't know what the correct parent should be.
2068
2069     No matter though, we can live without this.  We really only need to
2070     differentiate between inline and standalone comments.  The latter don't
2071     share the line with any code.
2072
2073     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2074     are emitted with a fake STANDALONE_COMMENT token identifier.
2075     """
2076     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2077         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2078
2079
2080 @dataclass
2081 class ProtoComment:
2082     """Describes a piece of syntax that is a comment.
2083
2084     It's not a :class:`blib2to3.pytree.Leaf` so that:
2085
2086     * it can be cached (`Leaf` objects should not be reused more than once as
2087       they store their lineno, column, prefix, and parent information);
2088     * `newlines` and `consumed` fields are kept separate from the `value`. This
2089       simplifies handling of special marker comments like ``# fmt: off/on``.
2090     """
2091
2092     type: int  # token.COMMENT or STANDALONE_COMMENT
2093     value: str  # content of the comment
2094     newlines: int  # how many newlines before the comment
2095     consumed: int  # how many characters of the original leaf's prefix did we consume
2096
2097
2098 @lru_cache(maxsize=4096)
2099 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2100     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2101     result: List[ProtoComment] = []
2102     if not prefix or "#" not in prefix:
2103         return result
2104
2105     consumed = 0
2106     nlines = 0
2107     for index, line in enumerate(prefix.split("\n")):
2108         consumed += len(line) + 1  # adding the length of the split '\n'
2109         line = line.lstrip()
2110         if not line:
2111             nlines += 1
2112         if not line.startswith("#"):
2113             continue
2114
2115         if index == 0 and not is_endmarker:
2116             comment_type = token.COMMENT  # simple trailing comment
2117         else:
2118             comment_type = STANDALONE_COMMENT
2119         comment = make_comment(line)
2120         result.append(
2121             ProtoComment(
2122                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2123             )
2124         )
2125         nlines = 0
2126     return result
2127
2128
2129 def make_comment(content: str) -> str:
2130     """Return a consistently formatted comment from the given `content` string.
2131
2132     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2133     space between the hash sign and the content.
2134
2135     If `content` didn't start with a hash sign, one is provided.
2136     """
2137     content = content.rstrip()
2138     if not content:
2139         return "#"
2140
2141     if content[0] == "#":
2142         content = content[1:]
2143     if content and content[0] not in " !:#'%":
2144         content = " " + content
2145     return "#" + content
2146
2147
2148 def split_line(
2149     line: Line,
2150     line_length: int,
2151     inner: bool = False,
2152     supports_trailing_commas: bool = False,
2153 ) -> Iterator[Line]:
2154     """Split a `line` into potentially many lines.
2155
2156     They should fit in the allotted `line_length` but might not be able to.
2157     `inner` signifies that there were a pair of brackets somewhere around the
2158     current `line`, possibly transitively. This means we can fallback to splitting
2159     by delimiters if the LHS/RHS don't yield any results.
2160
2161     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2162     """
2163     if line.is_comment:
2164         yield line
2165         return
2166
2167     line_str = str(line).strip("\n")
2168
2169     # we don't want to split special comments like type annotations
2170     # https://github.com/python/typing/issues/186
2171     has_special_comment = False
2172     for leaf in line.leaves:
2173         for comment in line.comments_after(leaf):
2174             if leaf.type == token.COMMA and is_special_comment(comment):
2175                 has_special_comment = True
2176
2177     if (
2178         not has_special_comment
2179         and not line.should_explode
2180         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2181     ):
2182         yield line
2183         return
2184
2185     split_funcs: List[SplitFunc]
2186     if line.is_def:
2187         split_funcs = [left_hand_split]
2188     else:
2189
2190         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2191             for omit in generate_trailers_to_omit(line, line_length):
2192                 lines = list(
2193                     right_hand_split(
2194                         line, line_length, supports_trailing_commas, omit=omit
2195                     )
2196                 )
2197                 if is_line_short_enough(lines[0], line_length=line_length):
2198                     yield from lines
2199                     return
2200
2201             # All splits failed, best effort split with no omits.
2202             # This mostly happens to multiline strings that are by definition
2203             # reported as not fitting a single line.
2204             yield from right_hand_split(line, supports_trailing_commas)
2205
2206         if line.inside_brackets:
2207             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2208         else:
2209             split_funcs = [rhs]
2210     for split_func in split_funcs:
2211         # We are accumulating lines in `result` because we might want to abort
2212         # mission and return the original line in the end, or attempt a different
2213         # split altogether.
2214         result: List[Line] = []
2215         try:
2216             for l in split_func(line, supports_trailing_commas):
2217                 if str(l).strip("\n") == line_str:
2218                     raise CannotSplit("Split function returned an unchanged result")
2219
2220                 result.extend(
2221                     split_line(
2222                         l,
2223                         line_length=line_length,
2224                         inner=True,
2225                         supports_trailing_commas=supports_trailing_commas,
2226                     )
2227                 )
2228         except CannotSplit:
2229             continue
2230
2231         else:
2232             yield from result
2233             break
2234
2235     else:
2236         yield line
2237
2238
2239 def left_hand_split(
2240     line: Line, supports_trailing_commas: bool = False
2241 ) -> Iterator[Line]:
2242     """Split line into many lines, starting with the first matching bracket pair.
2243
2244     Note: this usually looks weird, only use this for function definitions.
2245     Prefer RHS otherwise.  This is why this function is not symmetrical with
2246     :func:`right_hand_split` which also handles optional parentheses.
2247     """
2248     tail_leaves: List[Leaf] = []
2249     body_leaves: List[Leaf] = []
2250     head_leaves: List[Leaf] = []
2251     current_leaves = head_leaves
2252     matching_bracket = None
2253     for leaf in line.leaves:
2254         if (
2255             current_leaves is body_leaves
2256             and leaf.type in CLOSING_BRACKETS
2257             and leaf.opening_bracket is matching_bracket
2258         ):
2259             current_leaves = tail_leaves if body_leaves else head_leaves
2260         current_leaves.append(leaf)
2261         if current_leaves is head_leaves:
2262             if leaf.type in OPENING_BRACKETS:
2263                 matching_bracket = leaf
2264                 current_leaves = body_leaves
2265     if not matching_bracket:
2266         raise CannotSplit("No brackets found")
2267
2268     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2269     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2270     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2271     bracket_split_succeeded_or_raise(head, body, tail)
2272     for result in (head, body, tail):
2273         if result:
2274             yield result
2275
2276
2277 def right_hand_split(
2278     line: Line,
2279     line_length: int,
2280     supports_trailing_commas: bool = False,
2281     omit: Collection[LeafID] = (),
2282 ) -> Iterator[Line]:
2283     """Split line into many lines, starting with the last matching bracket pair.
2284
2285     If the split was by optional parentheses, attempt splitting without them, too.
2286     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2287     this split.
2288
2289     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2290     """
2291     tail_leaves: List[Leaf] = []
2292     body_leaves: List[Leaf] = []
2293     head_leaves: List[Leaf] = []
2294     current_leaves = tail_leaves
2295     opening_bracket = None
2296     closing_bracket = None
2297     for leaf in reversed(line.leaves):
2298         if current_leaves is body_leaves:
2299             if leaf is opening_bracket:
2300                 current_leaves = head_leaves if body_leaves else tail_leaves
2301         current_leaves.append(leaf)
2302         if current_leaves is tail_leaves:
2303             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2304                 opening_bracket = leaf.opening_bracket
2305                 closing_bracket = leaf
2306                 current_leaves = body_leaves
2307     if not (opening_bracket and closing_bracket and head_leaves):
2308         # If there is no opening or closing_bracket that means the split failed and
2309         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2310         # the matching `opening_bracket` wasn't available on `line` anymore.
2311         raise CannotSplit("No brackets found")
2312
2313     tail_leaves.reverse()
2314     body_leaves.reverse()
2315     head_leaves.reverse()
2316     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2317     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2318     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2319     bracket_split_succeeded_or_raise(head, body, tail)
2320     if (
2321         # the body shouldn't be exploded
2322         not body.should_explode
2323         # the opening bracket is an optional paren
2324         and opening_bracket.type == token.LPAR
2325         and not opening_bracket.value
2326         # the closing bracket is an optional paren
2327         and closing_bracket.type == token.RPAR
2328         and not closing_bracket.value
2329         # it's not an import (optional parens are the only thing we can split on
2330         # in this case; attempting a split without them is a waste of time)
2331         and not line.is_import
2332         # there are no standalone comments in the body
2333         and not body.contains_standalone_comments(0)
2334         # and we can actually remove the parens
2335         and can_omit_invisible_parens(body, line_length)
2336     ):
2337         omit = {id(closing_bracket), *omit}
2338         try:
2339             yield from right_hand_split(
2340                 line,
2341                 line_length,
2342                 supports_trailing_commas=supports_trailing_commas,
2343                 omit=omit,
2344             )
2345             return
2346
2347         except CannotSplit:
2348             if not (
2349                 can_be_split(body)
2350                 or is_line_short_enough(body, line_length=line_length)
2351             ):
2352                 raise CannotSplit(
2353                     "Splitting failed, body is still too long and can't be split."
2354                 )
2355
2356             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2357                 raise CannotSplit(
2358                     "The current optional pair of parentheses is bound to fail to "
2359                     "satisfy the splitting algorithm because the head or the tail "
2360                     "contains multiline strings which by definition never fit one "
2361                     "line."
2362                 )
2363
2364     ensure_visible(opening_bracket)
2365     ensure_visible(closing_bracket)
2366     for result in (head, body, tail):
2367         if result:
2368             yield result
2369
2370
2371 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2372     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2373
2374     Do nothing otherwise.
2375
2376     A left- or right-hand split is based on a pair of brackets. Content before
2377     (and including) the opening bracket is left on one line, content inside the
2378     brackets is put on a separate line, and finally content starting with and
2379     following the closing bracket is put on a separate line.
2380
2381     Those are called `head`, `body`, and `tail`, respectively. If the split
2382     produced the same line (all content in `head`) or ended up with an empty `body`
2383     and the `tail` is just the closing bracket, then it's considered failed.
2384     """
2385     tail_len = len(str(tail).strip())
2386     if not body:
2387         if tail_len == 0:
2388             raise CannotSplit("Splitting brackets produced the same line")
2389
2390         elif tail_len < 3:
2391             raise CannotSplit(
2392                 f"Splitting brackets on an empty body to save "
2393                 f"{tail_len} characters is not worth it"
2394             )
2395
2396
2397 def bracket_split_build_line(
2398     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2399 ) -> Line:
2400     """Return a new line with given `leaves` and respective comments from `original`.
2401
2402     If `is_body` is True, the result line is one-indented inside brackets and as such
2403     has its first leaf's prefix normalized and a trailing comma added when expected.
2404     """
2405     result = Line(depth=original.depth)
2406     if is_body:
2407         result.inside_brackets = True
2408         result.depth += 1
2409         if leaves:
2410             # Since body is a new indent level, remove spurious leading whitespace.
2411             normalize_prefix(leaves[0], inside_brackets=True)
2412             # Ensure a trailing comma when expected.
2413             if original.is_import:
2414                 if leaves[-1].type != token.COMMA:
2415                     leaves.append(Leaf(token.COMMA, ","))
2416     # Populate the line
2417     for leaf in leaves:
2418         result.append(leaf, preformatted=True)
2419         for comment_after in original.comments_after(leaf):
2420             result.append(comment_after, preformatted=True)
2421     if is_body:
2422         result.should_explode = should_explode(result, opening_bracket)
2423     return result
2424
2425
2426 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2427     """Normalize prefix of the first leaf in every line returned by `split_func`.
2428
2429     This is a decorator over relevant split functions.
2430     """
2431
2432     @wraps(split_func)
2433     def split_wrapper(
2434         line: Line, supports_trailing_commas: bool = False
2435     ) -> Iterator[Line]:
2436         for l in split_func(line, supports_trailing_commas):
2437             normalize_prefix(l.leaves[0], inside_brackets=True)
2438             yield l
2439
2440     return split_wrapper
2441
2442
2443 @dont_increase_indentation
2444 def delimiter_split(
2445     line: Line, supports_trailing_commas: bool = False
2446 ) -> Iterator[Line]:
2447     """Split according to delimiters of the highest priority.
2448
2449     If `py36` is True, the split will add trailing commas also in function
2450     signatures that contain `*` and `**`.
2451     """
2452     try:
2453         last_leaf = line.leaves[-1]
2454     except IndexError:
2455         raise CannotSplit("Line empty")
2456
2457     bt = line.bracket_tracker
2458     try:
2459         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2460     except ValueError:
2461         raise CannotSplit("No delimiters found")
2462
2463     if delimiter_priority == DOT_PRIORITY:
2464         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2465             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2466
2467     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2468     lowest_depth = sys.maxsize
2469     trailing_comma_safe = True
2470
2471     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2472         """Append `leaf` to current line or to new line if appending impossible."""
2473         nonlocal current_line
2474         try:
2475             current_line.append_safe(leaf, preformatted=True)
2476         except ValueError:
2477             yield current_line
2478
2479             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2480             current_line.append(leaf)
2481
2482     for leaf in line.leaves:
2483         yield from append_to_line(leaf)
2484
2485         for comment_after in line.comments_after(leaf):
2486             yield from append_to_line(comment_after)
2487
2488         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2489         if leaf.bracket_depth == lowest_depth and is_vararg(
2490             leaf, within=VARARGS_PARENTS
2491         ):
2492             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2493         leaf_priority = bt.delimiters.get(id(leaf))
2494         if leaf_priority == delimiter_priority:
2495             yield current_line
2496
2497             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2498     if current_line:
2499         if (
2500             trailing_comma_safe
2501             and delimiter_priority == COMMA_PRIORITY
2502             and current_line.leaves[-1].type != token.COMMA
2503             and current_line.leaves[-1].type != STANDALONE_COMMENT
2504         ):
2505             current_line.append(Leaf(token.COMMA, ","))
2506         yield current_line
2507
2508
2509 @dont_increase_indentation
2510 def standalone_comment_split(
2511     line: Line, supports_trailing_commas: bool = False
2512 ) -> Iterator[Line]:
2513     """Split standalone comments from the rest of the line."""
2514     if not line.contains_standalone_comments(0):
2515         raise CannotSplit("Line does not have any standalone comments")
2516
2517     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2518
2519     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2520         """Append `leaf` to current line or to new line if appending impossible."""
2521         nonlocal current_line
2522         try:
2523             current_line.append_safe(leaf, preformatted=True)
2524         except ValueError:
2525             yield current_line
2526
2527             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2528             current_line.append(leaf)
2529
2530     for leaf in line.leaves:
2531         yield from append_to_line(leaf)
2532
2533         for comment_after in line.comments_after(leaf):
2534             yield from append_to_line(comment_after)
2535
2536     if current_line:
2537         yield current_line
2538
2539
2540 def is_import(leaf: Leaf) -> bool:
2541     """Return True if the given leaf starts an import statement."""
2542     p = leaf.parent
2543     t = leaf.type
2544     v = leaf.value
2545     return bool(
2546         t == token.NAME
2547         and (
2548             (v == "import" and p and p.type == syms.import_name)
2549             or (v == "from" and p and p.type == syms.import_from)
2550         )
2551     )
2552
2553
2554 def is_special_comment(leaf: Leaf) -> bool:
2555     """Return True if the given leaf is a special comment.
2556     Only returns true for type comments for now."""
2557     t = leaf.type
2558     v = leaf.value
2559     return bool(
2560         (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
2561     )
2562
2563
2564 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2565     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2566     else.
2567
2568     Note: don't use backslashes for formatting or you'll lose your voting rights.
2569     """
2570     if not inside_brackets:
2571         spl = leaf.prefix.split("#")
2572         if "\\" not in spl[0]:
2573             nl_count = spl[-1].count("\n")
2574             if len(spl) > 1:
2575                 nl_count -= 1
2576             leaf.prefix = "\n" * nl_count
2577             return
2578
2579     leaf.prefix = ""
2580
2581
2582 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2583     """Make all string prefixes lowercase.
2584
2585     If remove_u_prefix is given, also removes any u prefix from the string.
2586
2587     Note: Mutates its argument.
2588     """
2589     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2590     assert match is not None, f"failed to match string {leaf.value!r}"
2591     orig_prefix = match.group(1)
2592     new_prefix = orig_prefix.lower()
2593     if remove_u_prefix:
2594         new_prefix = new_prefix.replace("u", "")
2595     leaf.value = f"{new_prefix}{match.group(2)}"
2596
2597
2598 def normalize_string_quotes(leaf: Leaf) -> None:
2599     """Prefer double quotes but only if it doesn't cause more escaping.
2600
2601     Adds or removes backslashes as appropriate. Doesn't parse and fix
2602     strings nested in f-strings (yet).
2603
2604     Note: Mutates its argument.
2605     """
2606     value = leaf.value.lstrip("furbFURB")
2607     if value[:3] == '"""':
2608         return
2609
2610     elif value[:3] == "'''":
2611         orig_quote = "'''"
2612         new_quote = '"""'
2613     elif value[0] == '"':
2614         orig_quote = '"'
2615         new_quote = "'"
2616     else:
2617         orig_quote = "'"
2618         new_quote = '"'
2619     first_quote_pos = leaf.value.find(orig_quote)
2620     if first_quote_pos == -1:
2621         return  # There's an internal error
2622
2623     prefix = leaf.value[:first_quote_pos]
2624     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2625     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2626     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2627     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2628     if "r" in prefix.casefold():
2629         if unescaped_new_quote.search(body):
2630             # There's at least one unescaped new_quote in this raw string
2631             # so converting is impossible
2632             return
2633
2634         # Do not introduce or remove backslashes in raw strings
2635         new_body = body
2636     else:
2637         # remove unnecessary escapes
2638         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2639         if body != new_body:
2640             # Consider the string without unnecessary escapes as the original
2641             body = new_body
2642             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2643         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2644         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2645     if "f" in prefix.casefold():
2646         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2647         for m in matches:
2648             if "\\" in str(m):
2649                 # Do not introduce backslashes in interpolated expressions
2650                 return
2651     if new_quote == '"""' and new_body[-1:] == '"':
2652         # edge case:
2653         new_body = new_body[:-1] + '\\"'
2654     orig_escape_count = body.count("\\")
2655     new_escape_count = new_body.count("\\")
2656     if new_escape_count > orig_escape_count:
2657         return  # Do not introduce more escaping
2658
2659     if new_escape_count == orig_escape_count and orig_quote == '"':
2660         return  # Prefer double quotes
2661
2662     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2663
2664
2665 def normalize_numeric_literal(leaf: Leaf) -> None:
2666     """Normalizes numeric (float, int, and complex) literals.
2667
2668     All letters used in the representation are normalized to lowercase (except
2669     in Python 2 long literals).
2670     """
2671     text = leaf.value.lower()
2672     if text.startswith(("0o", "0b")):
2673         # Leave octal and binary literals alone.
2674         pass
2675     elif text.startswith("0x"):
2676         # Change hex literals to upper case.
2677         before, after = text[:2], text[2:]
2678         text = f"{before}{after.upper()}"
2679     elif "e" in text:
2680         before, after = text.split("e")
2681         sign = ""
2682         if after.startswith("-"):
2683             after = after[1:]
2684             sign = "-"
2685         elif after.startswith("+"):
2686             after = after[1:]
2687         before = format_float_or_int_string(before)
2688         text = f"{before}e{sign}{after}"
2689     elif text.endswith(("j", "l")):
2690         number = text[:-1]
2691         suffix = text[-1]
2692         # Capitalize in "2L" because "l" looks too similar to "1".
2693         if suffix == "l":
2694             suffix = "L"
2695         text = f"{format_float_or_int_string(number)}{suffix}"
2696     else:
2697         text = format_float_or_int_string(text)
2698     leaf.value = text
2699
2700
2701 def format_float_or_int_string(text: str) -> str:
2702     """Formats a float string like "1.0"."""
2703     if "." not in text:
2704         return text
2705
2706     before, after = text.split(".")
2707     return f"{before or 0}.{after or 0}"
2708
2709
2710 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2711     """Make existing optional parentheses invisible or create new ones.
2712
2713     `parens_after` is a set of string leaf values immeditely after which parens
2714     should be put.
2715
2716     Standardizes on visible parentheses for single-element tuples, and keeps
2717     existing visible parentheses for other tuples and generator expressions.
2718     """
2719     for pc in list_comments(node.prefix, is_endmarker=False):
2720         if pc.value in FMT_OFF:
2721             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2722             return
2723
2724     check_lpar = False
2725     for index, child in enumerate(list(node.children)):
2726         if check_lpar:
2727             if child.type == syms.atom:
2728                 if maybe_make_parens_invisible_in_atom(child):
2729                     lpar = Leaf(token.LPAR, "")
2730                     rpar = Leaf(token.RPAR, "")
2731                     index = child.remove() or 0
2732                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2733             elif is_one_tuple(child):
2734                 # wrap child in visible parentheses
2735                 lpar = Leaf(token.LPAR, "(")
2736                 rpar = Leaf(token.RPAR, ")")
2737                 child.remove()
2738                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2739             elif node.type == syms.import_from:
2740                 # "import from" nodes store parentheses directly as part of
2741                 # the statement
2742                 if child.type == token.LPAR:
2743                     # make parentheses invisible
2744                     child.value = ""  # type: ignore
2745                     node.children[-1].value = ""  # type: ignore
2746                 elif child.type != token.STAR:
2747                     # insert invisible parentheses
2748                     node.insert_child(index, Leaf(token.LPAR, ""))
2749                     node.append_child(Leaf(token.RPAR, ""))
2750                 break
2751
2752             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2753                 # wrap child in invisible parentheses
2754                 lpar = Leaf(token.LPAR, "")
2755                 rpar = Leaf(token.RPAR, "")
2756                 index = child.remove() or 0
2757                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2758
2759         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2760
2761
2762 def normalize_fmt_off(node: Node) -> None:
2763     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2764     try_again = True
2765     while try_again:
2766         try_again = convert_one_fmt_off_pair(node)
2767
2768
2769 def convert_one_fmt_off_pair(node: Node) -> bool:
2770     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2771
2772     Returns True if a pair was converted.
2773     """
2774     for leaf in node.leaves():
2775         previous_consumed = 0
2776         for comment in list_comments(leaf.prefix, is_endmarker=False):
2777             if comment.value in FMT_OFF:
2778                 # We only want standalone comments. If there's no previous leaf or
2779                 # the previous leaf is indentation, it's a standalone comment in
2780                 # disguise.
2781                 if comment.type != STANDALONE_COMMENT:
2782                     prev = preceding_leaf(leaf)
2783                     if prev and prev.type not in WHITESPACE:
2784                         continue
2785
2786                 ignored_nodes = list(generate_ignored_nodes(leaf))
2787                 if not ignored_nodes:
2788                     continue
2789
2790                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2791                 parent = first.parent
2792                 prefix = first.prefix
2793                 first.prefix = prefix[comment.consumed :]
2794                 hidden_value = (
2795                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2796                 )
2797                 if hidden_value.endswith("\n"):
2798                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2799                     # leaf (possibly followed by a DEDENT).
2800                     hidden_value = hidden_value[:-1]
2801                 first_idx = None
2802                 for ignored in ignored_nodes:
2803                     index = ignored.remove()
2804                     if first_idx is None:
2805                         first_idx = index
2806                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2807                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2808                 parent.insert_child(
2809                     first_idx,
2810                     Leaf(
2811                         STANDALONE_COMMENT,
2812                         hidden_value,
2813                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2814                     ),
2815                 )
2816                 return True
2817
2818             previous_consumed = comment.consumed
2819
2820     return False
2821
2822
2823 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2824     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2825
2826     Stops at the end of the block.
2827     """
2828     container: Optional[LN] = container_of(leaf)
2829     while container is not None and container.type != token.ENDMARKER:
2830         for comment in list_comments(container.prefix, is_endmarker=False):
2831             if comment.value in FMT_ON:
2832                 return
2833
2834         yield container
2835
2836         container = container.next_sibling
2837
2838
2839 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2840     """If it's safe, make the parens in the atom `node` invisible, recursively.
2841
2842     Returns whether the node should itself be wrapped in invisible parentheses.
2843
2844     """
2845     if (
2846         node.type != syms.atom
2847         or is_empty_tuple(node)
2848         or is_one_tuple(node)
2849         or is_yield(node)
2850         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2851     ):
2852         return False
2853
2854     first = node.children[0]
2855     last = node.children[-1]
2856     if first.type == token.LPAR and last.type == token.RPAR:
2857         # make parentheses invisible
2858         first.value = ""  # type: ignore
2859         last.value = ""  # type: ignore
2860         if len(node.children) > 1:
2861             maybe_make_parens_invisible_in_atom(node.children[1])
2862         return False
2863
2864     return True
2865
2866
2867 def is_empty_tuple(node: LN) -> bool:
2868     """Return True if `node` holds an empty tuple."""
2869     return (
2870         node.type == syms.atom
2871         and len(node.children) == 2
2872         and node.children[0].type == token.LPAR
2873         and node.children[1].type == token.RPAR
2874     )
2875
2876
2877 def is_one_tuple(node: LN) -> bool:
2878     """Return True if `node` holds a tuple with one element, with or without parens."""
2879     if node.type == syms.atom:
2880         if len(node.children) != 3:
2881             return False
2882
2883         lpar, gexp, rpar = node.children
2884         if not (
2885             lpar.type == token.LPAR
2886             and gexp.type == syms.testlist_gexp
2887             and rpar.type == token.RPAR
2888         ):
2889             return False
2890
2891         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2892
2893     return (
2894         node.type in IMPLICIT_TUPLE
2895         and len(node.children) == 2
2896         and node.children[1].type == token.COMMA
2897     )
2898
2899
2900 def is_yield(node: LN) -> bool:
2901     """Return True if `node` holds a `yield` or `yield from` expression."""
2902     if node.type == syms.yield_expr:
2903         return True
2904
2905     if node.type == token.NAME and node.value == "yield":  # type: ignore
2906         return True
2907
2908     if node.type != syms.atom:
2909         return False
2910
2911     if len(node.children) != 3:
2912         return False
2913
2914     lpar, expr, rpar = node.children
2915     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2916         return is_yield(expr)
2917
2918     return False
2919
2920
2921 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2922     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2923
2924     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2925     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2926     extended iterable unpacking (PEP 3132) and additional unpacking
2927     generalizations (PEP 448).
2928     """
2929     if leaf.type not in STARS or not leaf.parent:
2930         return False
2931
2932     p = leaf.parent
2933     if p.type == syms.star_expr:
2934         # Star expressions are also used as assignment targets in extended
2935         # iterable unpacking (PEP 3132).  See what its parent is instead.
2936         if not p.parent:
2937             return False
2938
2939         p = p.parent
2940
2941     return p.type in within
2942
2943
2944 def is_multiline_string(leaf: Leaf) -> bool:
2945     """Return True if `leaf` is a multiline string that actually spans many lines."""
2946     value = leaf.value.lstrip("furbFURB")
2947     return value[:3] in {'"""', "'''"} and "\n" in value
2948
2949
2950 def is_stub_suite(node: Node) -> bool:
2951     """Return True if `node` is a suite with a stub body."""
2952     if (
2953         len(node.children) != 4
2954         or node.children[0].type != token.NEWLINE
2955         or node.children[1].type != token.INDENT
2956         or node.children[3].type != token.DEDENT
2957     ):
2958         return False
2959
2960     return is_stub_body(node.children[2])
2961
2962
2963 def is_stub_body(node: LN) -> bool:
2964     """Return True if `node` is a simple statement containing an ellipsis."""
2965     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2966         return False
2967
2968     if len(node.children) != 2:
2969         return False
2970
2971     child = node.children[0]
2972     return (
2973         child.type == syms.atom
2974         and len(child.children) == 3
2975         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2976     )
2977
2978
2979 def max_delimiter_priority_in_atom(node: LN) -> int:
2980     """Return maximum delimiter priority inside `node`.
2981
2982     This is specific to atoms with contents contained in a pair of parentheses.
2983     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2984     """
2985     if node.type != syms.atom:
2986         return 0
2987
2988     first = node.children[0]
2989     last = node.children[-1]
2990     if not (first.type == token.LPAR and last.type == token.RPAR):
2991         return 0
2992
2993     bt = BracketTracker()
2994     for c in node.children[1:-1]:
2995         if isinstance(c, Leaf):
2996             bt.mark(c)
2997         else:
2998             for leaf in c.leaves():
2999                 bt.mark(leaf)
3000     try:
3001         return bt.max_delimiter_priority()
3002
3003     except ValueError:
3004         return 0
3005
3006
3007 def ensure_visible(leaf: Leaf) -> None:
3008     """Make sure parentheses are visible.
3009
3010     They could be invisible as part of some statements (see
3011     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3012     """
3013     if leaf.type == token.LPAR:
3014         leaf.value = "("
3015     elif leaf.type == token.RPAR:
3016         leaf.value = ")"
3017
3018
3019 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3020     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3021
3022     if not (
3023         opening_bracket.parent
3024         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3025         and opening_bracket.value in "[{("
3026     ):
3027         return False
3028
3029     try:
3030         last_leaf = line.leaves[-1]
3031         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3032         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3033     except (IndexError, ValueError):
3034         return False
3035
3036     return max_priority == COMMA_PRIORITY
3037
3038
3039 def get_features_used(node: Node) -> Set[Feature]:
3040     """Return a set of (relatively) new Python features used in this file.
3041
3042     Currently looking for:
3043     - f-strings;
3044     - underscores in numeric literals; and
3045     - trailing commas after * or ** in function signatures and calls.
3046     """
3047     features: Set[Feature] = set()
3048     for n in node.pre_order():
3049         if n.type == token.STRING:
3050             value_head = n.value[:2]  # type: ignore
3051             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3052                 features.add(Feature.F_STRINGS)
3053
3054         elif n.type == token.NUMBER:
3055             if "_" in n.value:  # type: ignore
3056                 features.add(Feature.NUMERIC_UNDERSCORES)
3057
3058         elif (
3059             n.type in {syms.typedargslist, syms.arglist}
3060             and n.children
3061             and n.children[-1].type == token.COMMA
3062         ):
3063             for ch in n.children:
3064                 if ch.type in STARS:
3065                     features.add(Feature.TRAILING_COMMA)
3066
3067                 if ch.type == syms.argument:
3068                     for argch in ch.children:
3069                         if argch.type in STARS:
3070                             features.add(Feature.TRAILING_COMMA)
3071
3072     return features
3073
3074
3075 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3076     """Detect the version to target based on the nodes used."""
3077     features = get_features_used(node)
3078     return {
3079         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3080     }
3081
3082
3083 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3084     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3085
3086     Brackets can be omitted if the entire trailer up to and including
3087     a preceding closing bracket fits in one line.
3088
3089     Yielded sets are cumulative (contain results of previous yields, too).  First
3090     set is empty.
3091     """
3092
3093     omit: Set[LeafID] = set()
3094     yield omit
3095
3096     length = 4 * line.depth
3097     opening_bracket = None
3098     closing_bracket = None
3099     inner_brackets: Set[LeafID] = set()
3100     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3101         length += leaf_length
3102         if length > line_length:
3103             break
3104
3105         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3106         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3107             break
3108
3109         if opening_bracket:
3110             if leaf is opening_bracket:
3111                 opening_bracket = None
3112             elif leaf.type in CLOSING_BRACKETS:
3113                 inner_brackets.add(id(leaf))
3114         elif leaf.type in CLOSING_BRACKETS:
3115             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3116                 # Empty brackets would fail a split so treat them as "inner"
3117                 # brackets (e.g. only add them to the `omit` set if another
3118                 # pair of brackets was good enough.
3119                 inner_brackets.add(id(leaf))
3120                 continue
3121
3122             if closing_bracket:
3123                 omit.add(id(closing_bracket))
3124                 omit.update(inner_brackets)
3125                 inner_brackets.clear()
3126                 yield omit
3127
3128             if leaf.value:
3129                 opening_bracket = leaf.opening_bracket
3130                 closing_bracket = leaf
3131
3132
3133 def get_future_imports(node: Node) -> Set[str]:
3134     """Return a set of __future__ imports in the file."""
3135     imports: Set[str] = set()
3136
3137     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3138         for child in children:
3139             if isinstance(child, Leaf):
3140                 if child.type == token.NAME:
3141                     yield child.value
3142             elif child.type == syms.import_as_name:
3143                 orig_name = child.children[0]
3144                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3145                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3146                 yield orig_name.value
3147             elif child.type == syms.import_as_names:
3148                 yield from get_imports_from_children(child.children)
3149             else:
3150                 assert False, "Invalid syntax parsing imports"
3151
3152     for child in node.children:
3153         if child.type != syms.simple_stmt:
3154             break
3155         first_child = child.children[0]
3156         if isinstance(first_child, Leaf):
3157             # Continue looking if we see a docstring; otherwise stop.
3158             if (
3159                 len(child.children) == 2
3160                 and first_child.type == token.STRING
3161                 and child.children[1].type == token.NEWLINE
3162             ):
3163                 continue
3164             else:
3165                 break
3166         elif first_child.type == syms.import_from:
3167             module_name = first_child.children[1]
3168             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3169                 break
3170             imports |= set(get_imports_from_children(first_child.children[3:]))
3171         else:
3172             break
3173     return imports
3174
3175
3176 def gen_python_files_in_dir(
3177     path: Path,
3178     root: Path,
3179     include: Pattern[str],
3180     exclude: Pattern[str],
3181     report: "Report",
3182 ) -> Iterator[Path]:
3183     """Generate all files under `path` whose paths are not excluded by the
3184     `exclude` regex, but are included by the `include` regex.
3185
3186     Symbolic links pointing outside of the `root` directory are ignored.
3187
3188     `report` is where output about exclusions goes.
3189     """
3190     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3191     for child in path.iterdir():
3192         try:
3193             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3194         except ValueError:
3195             if child.is_symlink():
3196                 report.path_ignored(
3197                     child, f"is a symbolic link that points outside {root}"
3198                 )
3199                 continue
3200
3201             raise
3202
3203         if child.is_dir():
3204             normalized_path += "/"
3205         exclude_match = exclude.search(normalized_path)
3206         if exclude_match and exclude_match.group(0):
3207             report.path_ignored(child, f"matches the --exclude regular expression")
3208             continue
3209
3210         if child.is_dir():
3211             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3212
3213         elif child.is_file():
3214             include_match = include.search(normalized_path)
3215             if include_match:
3216                 yield child
3217
3218
3219 @lru_cache()
3220 def find_project_root(srcs: Iterable[str]) -> Path:
3221     """Return a directory containing .git, .hg, or pyproject.toml.
3222
3223     That directory can be one of the directories passed in `srcs` or their
3224     common parent.
3225
3226     If no directory in the tree contains a marker that would specify it's the
3227     project root, the root of the file system is returned.
3228     """
3229     if not srcs:
3230         return Path("/").resolve()
3231
3232     common_base = min(Path(src).resolve() for src in srcs)
3233     if common_base.is_dir():
3234         # Append a fake file so `parents` below returns `common_base_dir`, too.
3235         common_base /= "fake-file"
3236     for directory in common_base.parents:
3237         if (directory / ".git").is_dir():
3238             return directory
3239
3240         if (directory / ".hg").is_dir():
3241             return directory
3242
3243         if (directory / "pyproject.toml").is_file():
3244             return directory
3245
3246     return directory
3247
3248
3249 @dataclass
3250 class Report:
3251     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3252
3253     check: bool = False
3254     quiet: bool = False
3255     verbose: bool = False
3256     change_count: int = 0
3257     same_count: int = 0
3258     failure_count: int = 0
3259
3260     def done(self, src: Path, changed: Changed) -> None:
3261         """Increment the counter for successful reformatting. Write out a message."""
3262         if changed is Changed.YES:
3263             reformatted = "would reformat" if self.check else "reformatted"
3264             if self.verbose or not self.quiet:
3265                 out(f"{reformatted} {src}")
3266             self.change_count += 1
3267         else:
3268             if self.verbose:
3269                 if changed is Changed.NO:
3270                     msg = f"{src} already well formatted, good job."
3271                 else:
3272                     msg = f"{src} wasn't modified on disk since last run."
3273                 out(msg, bold=False)
3274             self.same_count += 1
3275
3276     def failed(self, src: Path, message: str) -> None:
3277         """Increment the counter for failed reformatting. Write out a message."""
3278         err(f"error: cannot format {src}: {message}")
3279         self.failure_count += 1
3280
3281     def path_ignored(self, path: Path, message: str) -> None:
3282         if self.verbose:
3283             out(f"{path} ignored: {message}", bold=False)
3284
3285     @property
3286     def return_code(self) -> int:
3287         """Return the exit code that the app should use.
3288
3289         This considers the current state of changed files and failures:
3290         - if there were any failures, return 123;
3291         - if any files were changed and --check is being used, return 1;
3292         - otherwise return 0.
3293         """
3294         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3295         # 126 we have special return codes reserved by the shell.
3296         if self.failure_count:
3297             return 123
3298
3299         elif self.change_count and self.check:
3300             return 1
3301
3302         return 0
3303
3304     def __str__(self) -> str:
3305         """Render a color report of the current state.
3306
3307         Use `click.unstyle` to remove colors.
3308         """
3309         if self.check:
3310             reformatted = "would be reformatted"
3311             unchanged = "would be left unchanged"
3312             failed = "would fail to reformat"
3313         else:
3314             reformatted = "reformatted"
3315             unchanged = "left unchanged"
3316             failed = "failed to reformat"
3317         report = []
3318         if self.change_count:
3319             s = "s" if self.change_count > 1 else ""
3320             report.append(
3321                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3322             )
3323         if self.same_count:
3324             s = "s" if self.same_count > 1 else ""
3325             report.append(f"{self.same_count} file{s} {unchanged}")
3326         if self.failure_count:
3327             s = "s" if self.failure_count > 1 else ""
3328             report.append(
3329                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3330             )
3331         return ", ".join(report) + "."
3332
3333
3334 def assert_equivalent(src: str, dst: str) -> None:
3335     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3336
3337     import ast
3338     import traceback
3339
3340     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3341         """Simple visitor generating strings to compare ASTs by content."""
3342         yield f"{'  ' * depth}{node.__class__.__name__}("
3343
3344         for field in sorted(node._fields):
3345             try:
3346                 value = getattr(node, field)
3347             except AttributeError:
3348                 continue
3349
3350             yield f"{'  ' * (depth+1)}{field}="
3351
3352             if isinstance(value, list):
3353                 for item in value:
3354                     # Ignore nested tuples within del statements, because we may insert
3355                     # parentheses and they change the AST.
3356                     if (
3357                         field == "targets"
3358                         and isinstance(node, ast.Delete)
3359                         and isinstance(item, ast.Tuple)
3360                     ):
3361                         for item in item.elts:
3362                             yield from _v(item, depth + 2)
3363                     elif isinstance(item, ast.AST):
3364                         yield from _v(item, depth + 2)
3365
3366             elif isinstance(value, ast.AST):
3367                 yield from _v(value, depth + 2)
3368
3369             else:
3370                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3371
3372         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3373
3374     try:
3375         src_ast = ast.parse(src)
3376     except Exception as exc:
3377         major, minor = sys.version_info[:2]
3378         raise AssertionError(
3379             f"cannot use --safe with this file; failed to parse source file "
3380             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3381             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3382         )
3383
3384     try:
3385         dst_ast = ast.parse(dst)
3386     except Exception as exc:
3387         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3388         raise AssertionError(
3389             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3390             f"Please report a bug on https://github.com/ambv/black/issues.  "
3391             f"This invalid output might be helpful: {log}"
3392         ) from None
3393
3394     src_ast_str = "\n".join(_v(src_ast))
3395     dst_ast_str = "\n".join(_v(dst_ast))
3396     if src_ast_str != dst_ast_str:
3397         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3398         raise AssertionError(
3399             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3400             f"the source.  "
3401             f"Please report a bug on https://github.com/ambv/black/issues.  "
3402             f"This diff might be helpful: {log}"
3403         ) from None
3404
3405
3406 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3407     """Raise AssertionError if `dst` reformats differently the second time."""
3408     newdst = format_str(dst, mode=mode)
3409     if dst != newdst:
3410         log = dump_to_file(
3411             diff(src, dst, "source", "first pass"),
3412             diff(dst, newdst, "first pass", "second pass"),
3413         )
3414         raise AssertionError(
3415             f"INTERNAL ERROR: Black produced different code on the second pass "
3416             f"of the formatter.  "
3417             f"Please report a bug on https://github.com/ambv/black/issues.  "
3418             f"This diff might be helpful: {log}"
3419         ) from None
3420
3421
3422 def dump_to_file(*output: str) -> str:
3423     """Dump `output` to a temporary file. Return path to the file."""
3424     import tempfile
3425
3426     with tempfile.NamedTemporaryFile(
3427         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3428     ) as f:
3429         for lines in output:
3430             f.write(lines)
3431             if lines and lines[-1] != "\n":
3432                 f.write("\n")
3433     return f.name
3434
3435
3436 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3437     """Return a unified diff string between strings `a` and `b`."""
3438     import difflib
3439
3440     a_lines = [line + "\n" for line in a.split("\n")]
3441     b_lines = [line + "\n" for line in b.split("\n")]
3442     return "".join(
3443         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3444     )
3445
3446
3447 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3448     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3449     err("Aborted!")
3450     for task in tasks:
3451         task.cancel()
3452
3453
3454 def shutdown(loop: BaseEventLoop) -> None:
3455     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3456     try:
3457         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3458         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3459         if not to_cancel:
3460             return
3461
3462         for task in to_cancel:
3463             task.cancel()
3464         loop.run_until_complete(
3465             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3466         )
3467     finally:
3468         # `concurrent.futures.Future` objects cannot be cancelled once they
3469         # are already running. There might be some when the `shutdown()` happened.
3470         # Silence their logger's spew about the event loop being closed.
3471         cf_logger = logging.getLogger("concurrent.futures")
3472         cf_logger.setLevel(logging.CRITICAL)
3473         loop.close()
3474
3475
3476 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3477     """Replace `regex` with `replacement` twice on `original`.
3478
3479     This is used by string normalization to perform replaces on
3480     overlapping matches.
3481     """
3482     return regex.sub(replacement, regex.sub(replacement, original))
3483
3484
3485 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3486     """Compile a regular expression string in `regex`.
3487
3488     If it contains newlines, use verbose mode.
3489     """
3490     if "\n" in regex:
3491         regex = "(?x)" + regex
3492     return re.compile(regex)
3493
3494
3495 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3496     """Like `reversed(enumerate(sequence))` if that were possible."""
3497     index = len(sequence) - 1
3498     for element in reversed(sequence):
3499         yield (index, element)
3500         index -= 1
3501
3502
3503 def enumerate_with_length(
3504     line: Line, reversed: bool = False
3505 ) -> Iterator[Tuple[Index, Leaf, int]]:
3506     """Return an enumeration of leaves with their length.
3507
3508     Stops prematurely on multiline strings and standalone comments.
3509     """
3510     op = cast(
3511         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3512         enumerate_reversed if reversed else enumerate,
3513     )
3514     for index, leaf in op(line.leaves):
3515         length = len(leaf.prefix) + len(leaf.value)
3516         if "\n" in leaf.value:
3517             return  # Multiline strings, we can't continue.
3518
3519         comment: Optional[Leaf]
3520         for comment in line.comments_after(leaf):
3521             length += len(comment.value)
3522
3523         yield index, leaf, length
3524
3525
3526 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3527     """Return True if `line` is no longer than `line_length`.
3528
3529     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3530     """
3531     if not line_str:
3532         line_str = str(line).strip("\n")
3533     return (
3534         len(line_str) <= line_length
3535         and "\n" not in line_str  # multiline strings
3536         and not line.contains_standalone_comments()
3537     )
3538
3539
3540 def can_be_split(line: Line) -> bool:
3541     """Return False if the line cannot be split *for sure*.
3542
3543     This is not an exhaustive search but a cheap heuristic that we can use to
3544     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3545     in unnecessary parentheses).
3546     """
3547     leaves = line.leaves
3548     if len(leaves) < 2:
3549         return False
3550
3551     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3552         call_count = 0
3553         dot_count = 0
3554         next = leaves[-1]
3555         for leaf in leaves[-2::-1]:
3556             if leaf.type in OPENING_BRACKETS:
3557                 if next.type not in CLOSING_BRACKETS:
3558                     return False
3559
3560                 call_count += 1
3561             elif leaf.type == token.DOT:
3562                 dot_count += 1
3563             elif leaf.type == token.NAME:
3564                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3565                     return False
3566
3567             elif leaf.type not in CLOSING_BRACKETS:
3568                 return False
3569
3570             if dot_count > 1 and call_count > 1:
3571                 return False
3572
3573     return True
3574
3575
3576 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3577     """Does `line` have a shape safe to reformat without optional parens around it?
3578
3579     Returns True for only a subset of potentially nice looking formattings but
3580     the point is to not return false positives that end up producing lines that
3581     are too long.
3582     """
3583     bt = line.bracket_tracker
3584     if not bt.delimiters:
3585         # Without delimiters the optional parentheses are useless.
3586         return True
3587
3588     max_priority = bt.max_delimiter_priority()
3589     if bt.delimiter_count_with_priority(max_priority) > 1:
3590         # With more than one delimiter of a kind the optional parentheses read better.
3591         return False
3592
3593     if max_priority == DOT_PRIORITY:
3594         # A single stranded method call doesn't require optional parentheses.
3595         return True
3596
3597     assert len(line.leaves) >= 2, "Stranded delimiter"
3598
3599     first = line.leaves[0]
3600     second = line.leaves[1]
3601     penultimate = line.leaves[-2]
3602     last = line.leaves[-1]
3603
3604     # With a single delimiter, omit if the expression starts or ends with
3605     # a bracket.
3606     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3607         remainder = False
3608         length = 4 * line.depth
3609         for _index, leaf, leaf_length in enumerate_with_length(line):
3610             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3611                 remainder = True
3612             if remainder:
3613                 length += leaf_length
3614                 if length > line_length:
3615                     break
3616
3617                 if leaf.type in OPENING_BRACKETS:
3618                     # There are brackets we can further split on.
3619                     remainder = False
3620
3621         else:
3622             # checked the entire string and line length wasn't exceeded
3623             if len(line.leaves) == _index + 1:
3624                 return True
3625
3626         # Note: we are not returning False here because a line might have *both*
3627         # a leading opening bracket and a trailing closing bracket.  If the
3628         # opening bracket doesn't match our rule, maybe the closing will.
3629
3630     if (
3631         last.type == token.RPAR
3632         or last.type == token.RBRACE
3633         or (
3634             # don't use indexing for omitting optional parentheses;
3635             # it looks weird
3636             last.type == token.RSQB
3637             and last.parent
3638             and last.parent.type != syms.trailer
3639         )
3640     ):
3641         if penultimate.type in OPENING_BRACKETS:
3642             # Empty brackets don't help.
3643             return False
3644
3645         if is_multiline_string(first):
3646             # Additional wrapping of a multiline string in this situation is
3647             # unnecessary.
3648             return True
3649
3650         length = 4 * line.depth
3651         seen_other_brackets = False
3652         for _index, leaf, leaf_length in enumerate_with_length(line):
3653             length += leaf_length
3654             if leaf is last.opening_bracket:
3655                 if seen_other_brackets or length <= line_length:
3656                     return True
3657
3658             elif leaf.type in OPENING_BRACKETS:
3659                 # There are brackets we can further split on.
3660                 seen_other_brackets = True
3661
3662     return False
3663
3664
3665 def get_cache_file(mode: FileMode) -> Path:
3666     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3667
3668
3669 def read_cache(mode: FileMode) -> Cache:
3670     """Read the cache if it exists and is well formed.
3671
3672     If it is not well formed, the call to write_cache later should resolve the issue.
3673     """
3674     cache_file = get_cache_file(mode)
3675     if not cache_file.exists():
3676         return {}
3677
3678     with cache_file.open("rb") as fobj:
3679         try:
3680             cache: Cache = pickle.load(fobj)
3681         except pickle.UnpicklingError:
3682             return {}
3683
3684     return cache
3685
3686
3687 def get_cache_info(path: Path) -> CacheInfo:
3688     """Return the information used to check if a file is already formatted or not."""
3689     stat = path.stat()
3690     return stat.st_mtime, stat.st_size
3691
3692
3693 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3694     """Split an iterable of paths in `sources` into two sets.
3695
3696     The first contains paths of files that modified on disk or are not in the
3697     cache. The other contains paths to non-modified files.
3698     """
3699     todo, done = set(), set()
3700     for src in sources:
3701         src = src.resolve()
3702         if cache.get(src) != get_cache_info(src):
3703             todo.add(src)
3704         else:
3705             done.add(src)
3706     return todo, done
3707
3708
3709 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3710     """Update the cache file."""
3711     cache_file = get_cache_file(mode)
3712     try:
3713         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3714         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3715         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3716             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3717         os.replace(f.name, cache_file)
3718     except OSError:
3719         pass
3720
3721
3722 def patch_click() -> None:
3723     """Make Click not crash.
3724
3725     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3726     default which restricts paths that it can access during the lifetime of the
3727     application.  Click refuses to work in this scenario by raising a RuntimeError.
3728
3729     In case of Black the likelihood that non-ASCII characters are going to be used in
3730     file paths is minimal since it's Python source code.  Moreover, this crash was
3731     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3732     """
3733     try:
3734         from click import core
3735         from click import _unicodefun  # type: ignore
3736     except ModuleNotFoundError:
3737         return
3738
3739     for module in (core, _unicodefun):
3740         if hasattr(module, "_verify_python3_env"):
3741             module._verify_python3_env = lambda: None
3742
3743
3744 def patched_main() -> None:
3745     freeze_support()
3746     patch_click()
3747     main()
3748
3749
3750 if __name__ == "__main__":
3751     patched_main()