]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add PyCharm setup step (#680)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum
6 from functools import lru_cache, partial, wraps
7 import io
8 import itertools
9 import logging
10 from multiprocessing import Manager, freeze_support
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tempfile
18 import tokenize
19 from typing import (
20     Any,
21     Callable,
22     Collection,
23     Dict,
24     Generator,
25     Generic,
26     Iterable,
27     Iterator,
28     List,
29     Optional,
30     Pattern,
31     Sequence,
32     Set,
33     Tuple,
34     TypeVar,
35     Union,
36     cast,
37 )
38
39 from appdirs import user_cache_dir
40 from attr import dataclass, evolve, Factory
41 import click
42 import toml
43
44 # lib2to3 fork
45 from blib2to3.pytree import Node, Leaf, type_repr
46 from blib2to3 import pygram, pytree
47 from blib2to3.pgen2 import driver, token
48 from blib2to3.pgen2.grammar import Grammar
49 from blib2to3.pgen2.parse import ParseError
50
51
52 __version__ = "18.9b0"
53 DEFAULT_LINE_LENGTH = 88
54 DEFAULT_EXCLUDES = (
55     r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
56 )
57 DEFAULT_INCLUDES = r"\.pyi?$"
58 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
59
60
61 # types
62 FileContent = str
63 Encoding = str
64 NewLine = str
65 Depth = int
66 NodeType = int
67 LeafID = int
68 Priority = int
69 Index = int
70 LN = Union[Leaf, Node]
71 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
72 Timestamp = float
73 FileSize = int
74 CacheInfo = Tuple[Timestamp, FileSize]
75 Cache = Dict[Path, CacheInfo]
76 out = partial(click.secho, bold=True, err=True)
77 err = partial(click.secho, fg="red", err=True)
78
79 pygram.initialize(CACHE_DIR)
80 syms = pygram.python_symbols
81
82
83 class NothingChanged(UserWarning):
84     """Raised when reformatted code is the same as source."""
85
86
87 class CannotSplit(Exception):
88     """A readable split that fits the allotted line length is impossible."""
89
90
91 class InvalidInput(ValueError):
92     """Raised when input source code fails all parse attempts."""
93
94
95 class WriteBack(Enum):
96     NO = 0
97     YES = 1
98     DIFF = 2
99     CHECK = 3
100
101     @classmethod
102     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
103         if check and not diff:
104             return cls.CHECK
105
106         return cls.DIFF if diff else cls.YES
107
108
109 class Changed(Enum):
110     NO = 0
111     CACHED = 1
112     YES = 2
113
114
115 class TargetVersion(Enum):
116     PYPY35 = 1
117     CPY27 = 2
118     CPY33 = 3
119     CPY34 = 4
120     CPY35 = 5
121     CPY36 = 6
122     CPY37 = 7
123     CPY38 = 8
124
125     def is_python2(self) -> bool:
126         return self is TargetVersion.CPY27
127
128
129 PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
130
131
132 class Feature(Enum):
133     # All string literals are unicode
134     UNICODE_LITERALS = 1
135     F_STRINGS = 2
136     NUMERIC_UNDERSCORES = 3
137     TRAILING_COMMA = 4
138
139
140 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
141     TargetVersion.CPY27: set(),
142     TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
143     TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
144     TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
145     TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
146     TargetVersion.CPY36: {
147         Feature.UNICODE_LITERALS,
148         Feature.F_STRINGS,
149         Feature.NUMERIC_UNDERSCORES,
150         Feature.TRAILING_COMMA,
151     },
152     TargetVersion.CPY37: {
153         Feature.UNICODE_LITERALS,
154         Feature.F_STRINGS,
155         Feature.NUMERIC_UNDERSCORES,
156         Feature.TRAILING_COMMA,
157     },
158     TargetVersion.CPY38: {
159         Feature.UNICODE_LITERALS,
160         Feature.F_STRINGS,
161         Feature.NUMERIC_UNDERSCORES,
162         Feature.TRAILING_COMMA,
163     },
164 }
165
166
167 @dataclass
168 class FileMode:
169     target_versions: Set[TargetVersion] = Factory(set)
170     line_length: int = DEFAULT_LINE_LENGTH
171     string_normalization: bool = True
172     is_pyi: bool = False
173
174     def get_cache_key(self) -> str:
175         if self.target_versions:
176             version_str = ",".join(
177                 str(version.value)
178                 for version in sorted(self.target_versions, key=lambda v: v.value)
179             )
180         else:
181             version_str = "-"
182         parts = [
183             version_str,
184             str(self.line_length),
185             str(int(self.string_normalization)),
186             str(int(self.is_pyi)),
187         ]
188         return ".".join(parts)
189
190
191 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
192     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
193
194
195 def read_pyproject_toml(
196     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
197 ) -> Optional[str]:
198     """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
199
200     Returns the path to a successfully found and read configuration file, None
201     otherwise.
202     """
203     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
204     if not value:
205         root = find_project_root(ctx.params.get("src", ()))
206         path = root / "pyproject.toml"
207         if path.is_file():
208             value = str(path)
209         else:
210             return None
211
212     try:
213         pyproject_toml = toml.load(value)
214         config = pyproject_toml.get("tool", {}).get("black", {})
215     except (toml.TomlDecodeError, OSError) as e:
216         raise click.FileError(
217             filename=value, hint=f"Error reading configuration file: {e}"
218         )
219
220     if not config:
221         return None
222
223     if ctx.default_map is None:
224         ctx.default_map = {}
225     ctx.default_map.update(  # type: ignore  # bad types in .pyi
226         {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
227     )
228     return value
229
230
231 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
232 @click.option(
233     "-l",
234     "--line-length",
235     type=int,
236     default=DEFAULT_LINE_LENGTH,
237     help="How many characters per line to allow.",
238     show_default=True,
239 )
240 @click.option(
241     "-t",
242     "--target-version",
243     type=click.Choice([v.name.lower() for v in TargetVersion]),
244     callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
245     multiple=True,
246     help=(
247         "Python versions that should be supported by Black's output. [default: "
248         "per-file auto-detection]"
249     ),
250 )
251 @click.option(
252     "--py36",
253     is_flag=True,
254     help=(
255         "Allow using Python 3.6-only syntax on all input files.  This will put "
256         "trailing commas in function signatures and calls also after *args and "
257         "**kwargs.  [default: per-file auto-detection]"
258     ),
259 )
260 @click.option(
261     "--pyi",
262     is_flag=True,
263     help=(
264         "Format all input files like typing stubs regardless of file extension "
265         "(useful when piping source on standard input)."
266     ),
267 )
268 @click.option(
269     "-S",
270     "--skip-string-normalization",
271     is_flag=True,
272     help="Don't normalize string quotes or prefixes.",
273 )
274 @click.option(
275     "--check",
276     is_flag=True,
277     help=(
278         "Don't write the files back, just return the status.  Return code 0 "
279         "means nothing would change.  Return code 1 means some files would be "
280         "reformatted.  Return code 123 means there was an internal error."
281     ),
282 )
283 @click.option(
284     "--diff",
285     is_flag=True,
286     help="Don't write the files back, just output a diff for each file on stdout.",
287 )
288 @click.option(
289     "--fast/--safe",
290     is_flag=True,
291     help="If --fast given, skip temporary sanity checks. [default: --safe]",
292 )
293 @click.option(
294     "--include",
295     type=str,
296     default=DEFAULT_INCLUDES,
297     help=(
298         "A regular expression that matches files and directories that should be "
299         "included on recursive searches.  An empty value means all files are "
300         "included regardless of the name.  Use forward slashes for directories on "
301         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
302         "later."
303     ),
304     show_default=True,
305 )
306 @click.option(
307     "--exclude",
308     type=str,
309     default=DEFAULT_EXCLUDES,
310     help=(
311         "A regular expression that matches files and directories that should be "
312         "excluded on recursive searches.  An empty value means no paths are excluded. "
313         "Use forward slashes for directories on all platforms (Windows, too).  "
314         "Exclusions are calculated first, inclusions later."
315     ),
316     show_default=True,
317 )
318 @click.option(
319     "-q",
320     "--quiet",
321     is_flag=True,
322     help=(
323         "Don't emit non-error messages to stderr. Errors are still emitted, "
324         "silence those with 2>/dev/null."
325     ),
326 )
327 @click.option(
328     "-v",
329     "--verbose",
330     is_flag=True,
331     help=(
332         "Also emit messages to stderr about files that were not changed or were "
333         "ignored due to --exclude=."
334     ),
335 )
336 @click.version_option(version=__version__)
337 @click.argument(
338     "src",
339     nargs=-1,
340     type=click.Path(
341         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
342     ),
343     is_eager=True,
344 )
345 @click.option(
346     "--config",
347     type=click.Path(
348         exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
349     ),
350     is_eager=True,
351     callback=read_pyproject_toml,
352     help="Read configuration from PATH.",
353 )
354 @click.pass_context
355 def main(
356     ctx: click.Context,
357     line_length: int,
358     target_version: List[TargetVersion],
359     check: bool,
360     diff: bool,
361     fast: bool,
362     pyi: bool,
363     py36: bool,
364     skip_string_normalization: bool,
365     quiet: bool,
366     verbose: bool,
367     include: str,
368     exclude: str,
369     src: Tuple[str],
370     config: Optional[str],
371 ) -> None:
372     """The uncompromising code formatter."""
373     write_back = WriteBack.from_configuration(check=check, diff=diff)
374     if target_version:
375         if py36:
376             err(f"Cannot use both --target-version and --py36")
377             ctx.exit(2)
378         else:
379             versions = set(target_version)
380     elif py36:
381         versions = PY36_VERSIONS
382     else:
383         # We'll autodetect later.
384         versions = set()
385     mode = FileMode(
386         target_versions=versions,
387         line_length=line_length,
388         is_pyi=pyi,
389         string_normalization=not skip_string_normalization,
390     )
391     if config and verbose:
392         out(f"Using configuration from {config}.", bold=False, fg="blue")
393     try:
394         include_regex = re_compile_maybe_verbose(include)
395     except re.error:
396         err(f"Invalid regular expression for include given: {include!r}")
397         ctx.exit(2)
398     try:
399         exclude_regex = re_compile_maybe_verbose(exclude)
400     except re.error:
401         err(f"Invalid regular expression for exclude given: {exclude!r}")
402         ctx.exit(2)
403     report = Report(check=check, quiet=quiet, verbose=verbose)
404     root = find_project_root(src)
405     sources: Set[Path] = set()
406     for s in src:
407         p = Path(s)
408         if p.is_dir():
409             sources.update(
410                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
411             )
412         elif p.is_file() or s == "-":
413             # if a file was explicitly given, we don't care about its extension
414             sources.add(p)
415         else:
416             err(f"invalid path: {s}")
417     if len(sources) == 0:
418         if verbose or not quiet:
419             out("No paths given. Nothing to do 😴")
420         ctx.exit(0)
421
422     if len(sources) == 1:
423         reformat_one(
424             src=sources.pop(),
425             fast=fast,
426             write_back=write_back,
427             mode=mode,
428             report=report,
429         )
430     else:
431         loop = asyncio.get_event_loop()
432         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
433         try:
434             loop.run_until_complete(
435                 schedule_formatting(
436                     sources=sources,
437                     fast=fast,
438                     write_back=write_back,
439                     mode=mode,
440                     report=report,
441                     loop=loop,
442                     executor=executor,
443                 )
444             )
445         finally:
446             shutdown(loop)
447     if verbose or not quiet:
448         bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
449         out(f"All done! {bang}")
450         click.secho(str(report), err=True)
451     ctx.exit(report.return_code)
452
453
454 def reformat_one(
455     src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
456 ) -> None:
457     """Reformat a single file under `src` without spawning child processes.
458
459     If `quiet` is True, non-error messages are not output. `line_length`,
460     `write_back`, `fast` and `pyi` options are passed to
461     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
462     """
463     try:
464         changed = Changed.NO
465         if not src.is_file() and str(src) == "-":
466             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
467                 changed = Changed.YES
468         else:
469             cache: Cache = {}
470             if write_back != WriteBack.DIFF:
471                 cache = read_cache(mode)
472                 res_src = src.resolve()
473                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
474                     changed = Changed.CACHED
475             if changed is not Changed.CACHED and format_file_in_place(
476                 src, fast=fast, write_back=write_back, mode=mode
477             ):
478                 changed = Changed.YES
479             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
480                 write_back is WriteBack.CHECK and changed is Changed.NO
481             ):
482                 write_cache(cache, [src], mode)
483         report.done(src, changed)
484     except Exception as exc:
485         report.failed(src, str(exc))
486
487
488 async def schedule_formatting(
489     sources: Set[Path],
490     fast: bool,
491     write_back: WriteBack,
492     mode: FileMode,
493     report: "Report",
494     loop: BaseEventLoop,
495     executor: Executor,
496 ) -> None:
497     """Run formatting of `sources` in parallel using the provided `executor`.
498
499     (Use ProcessPoolExecutors for actual parallelism.)
500
501     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
502     :func:`format_file_in_place`.
503     """
504     cache: Cache = {}
505     if write_back != WriteBack.DIFF:
506         cache = read_cache(mode)
507         sources, cached = filter_cached(cache, sources)
508         for src in sorted(cached):
509             report.done(src, Changed.CACHED)
510     if not sources:
511         return
512
513     cancelled = []
514     sources_to_cache = []
515     lock = None
516     if write_back == WriteBack.DIFF:
517         # For diff output, we need locks to ensure we don't interleave output
518         # from different processes.
519         manager = Manager()
520         lock = manager.Lock()
521     tasks = {
522         loop.run_in_executor(
523             executor, format_file_in_place, src, fast, mode, write_back, lock
524         ): src
525         for src in sorted(sources)
526     }
527     pending: Iterable[asyncio.Task] = tasks.keys()
528     try:
529         loop.add_signal_handler(signal.SIGINT, cancel, pending)
530         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
531     except NotImplementedError:
532         # There are no good alternatives for these on Windows.
533         pass
534     while pending:
535         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
536         for task in done:
537             src = tasks.pop(task)
538             if task.cancelled():
539                 cancelled.append(task)
540             elif task.exception():
541                 report.failed(src, str(task.exception()))
542             else:
543                 changed = Changed.YES if task.result() else Changed.NO
544                 # If the file was written back or was successfully checked as
545                 # well-formatted, store this information in the cache.
546                 if write_back is WriteBack.YES or (
547                     write_back is WriteBack.CHECK and changed is Changed.NO
548                 ):
549                     sources_to_cache.append(src)
550                 report.done(src, changed)
551     if cancelled:
552         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
553     if sources_to_cache:
554         write_cache(cache, sources_to_cache, mode)
555
556
557 def format_file_in_place(
558     src: Path,
559     fast: bool,
560     mode: FileMode,
561     write_back: WriteBack = WriteBack.NO,
562     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
563 ) -> bool:
564     """Format file under `src` path. Return True if changed.
565
566     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
567     code to the file.
568     `line_length` and `fast` options are passed to :func:`format_file_contents`.
569     """
570     if src.suffix == ".pyi":
571         mode = evolve(mode, is_pyi=True)
572
573     then = datetime.utcfromtimestamp(src.stat().st_mtime)
574     with open(src, "rb") as buf:
575         src_contents, encoding, newline = decode_bytes(buf.read())
576     try:
577         dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
578     except NothingChanged:
579         return False
580
581     if write_back == write_back.YES:
582         with open(src, "w", encoding=encoding, newline=newline) as f:
583             f.write(dst_contents)
584     elif write_back == write_back.DIFF:
585         now = datetime.utcnow()
586         src_name = f"{src}\t{then} +0000"
587         dst_name = f"{src}\t{now} +0000"
588         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
589         if lock:
590             lock.acquire()
591         try:
592             f = io.TextIOWrapper(
593                 sys.stdout.buffer,
594                 encoding=encoding,
595                 newline=newline,
596                 write_through=True,
597             )
598             f.write(diff_contents)
599             f.detach()
600         finally:
601             if lock:
602                 lock.release()
603     return True
604
605
606 def format_stdin_to_stdout(
607     fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
608 ) -> bool:
609     """Format file on stdin. Return True if changed.
610
611     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
612     write a diff to stdout. The `mode` argument is passed to
613     :func:`format_file_contents`.
614     """
615     then = datetime.utcnow()
616     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
617     dst = src
618     try:
619         dst = format_file_contents(src, fast=fast, mode=mode)
620         return True
621
622     except NothingChanged:
623         return False
624
625     finally:
626         f = io.TextIOWrapper(
627             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
628         )
629         if write_back == WriteBack.YES:
630             f.write(dst)
631         elif write_back == WriteBack.DIFF:
632             now = datetime.utcnow()
633             src_name = f"STDIN\t{then} +0000"
634             dst_name = f"STDOUT\t{now} +0000"
635             f.write(diff(src, dst, src_name, dst_name))
636         f.detach()
637
638
639 def format_file_contents(
640     src_contents: str, *, fast: bool, mode: FileMode
641 ) -> FileContent:
642     """Reformat contents a file and return new contents.
643
644     If `fast` is False, additionally confirm that the reformatted code is
645     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
646     `line_length` is passed to :func:`format_str`.
647     """
648     if src_contents.strip() == "":
649         raise NothingChanged
650
651     dst_contents = format_str(src_contents, mode=mode)
652     if src_contents == dst_contents:
653         raise NothingChanged
654
655     if not fast:
656         assert_equivalent(src_contents, dst_contents)
657         assert_stable(src_contents, dst_contents, mode=mode)
658     return dst_contents
659
660
661 def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
662     """Reformat a string and return new contents.
663
664     `line_length` determines how many characters per line are allowed.
665     """
666     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
667     dst_contents = ""
668     future_imports = get_future_imports(src_node)
669     if mode.target_versions:
670         versions = mode.target_versions
671     else:
672         versions = detect_target_versions(src_node)
673     normalize_fmt_off(src_node)
674     lines = LineGenerator(
675         remove_u_prefix="unicode_literals" in future_imports
676         or supports_feature(versions, Feature.UNICODE_LITERALS),
677         is_pyi=mode.is_pyi,
678         normalize_strings=mode.string_normalization,
679     )
680     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
681     empty_line = Line()
682     after = 0
683     for current_line in lines.visit(src_node):
684         for _ in range(after):
685             dst_contents += str(empty_line)
686         before, after = elt.maybe_empty_lines(current_line)
687         for _ in range(before):
688             dst_contents += str(empty_line)
689         for line in split_line(
690             current_line,
691             line_length=mode.line_length,
692             supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
693         ):
694             dst_contents += str(line)
695     return dst_contents
696
697
698 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
699     """Return a tuple of (decoded_contents, encoding, newline).
700
701     `newline` is either CRLF or LF but `decoded_contents` is decoded with
702     universal newlines (i.e. only contains LF).
703     """
704     srcbuf = io.BytesIO(src)
705     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
706     if not lines:
707         return "", encoding, "\n"
708
709     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
710     srcbuf.seek(0)
711     with io.TextIOWrapper(srcbuf, encoding) as tiow:
712         return tiow.read(), encoding, newline
713
714
715 GRAMMARS = [
716     pygram.python_grammar_no_print_statement_no_exec_statement,
717     pygram.python_grammar_no_print_statement,
718     pygram.python_grammar,
719 ]
720
721
722 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
723     if not target_versions:
724         return GRAMMARS
725     elif all(not version.is_python2() for version in target_versions):
726         # Python 2-compatible code, so don't try Python 3 grammar.
727         return [
728             pygram.python_grammar_no_print_statement_no_exec_statement,
729             pygram.python_grammar_no_print_statement,
730         ]
731     else:
732         return [pygram.python_grammar]
733
734
735 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
736     """Given a string with source, return the lib2to3 Node."""
737     if src_txt[-1:] != "\n":
738         src_txt += "\n"
739
740     for grammar in get_grammars(set(target_versions)):
741         drv = driver.Driver(grammar, pytree.convert)
742         try:
743             result = drv.parse_string(src_txt, True)
744             break
745
746         except ParseError as pe:
747             lineno, column = pe.context[1]
748             lines = src_txt.splitlines()
749             try:
750                 faulty_line = lines[lineno - 1]
751             except IndexError:
752                 faulty_line = "<line number missing in source>"
753             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
754     else:
755         raise exc from None
756
757     if isinstance(result, Leaf):
758         result = Node(syms.file_input, [result])
759     return result
760
761
762 def lib2to3_unparse(node: Node) -> str:
763     """Given a lib2to3 node, return its string representation."""
764     code = str(node)
765     return code
766
767
768 T = TypeVar("T")
769
770
771 class Visitor(Generic[T]):
772     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
773
774     def visit(self, node: LN) -> Iterator[T]:
775         """Main method to visit `node` and its children.
776
777         It tries to find a `visit_*()` method for the given `node.type`, like
778         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
779         If no dedicated `visit_*()` method is found, chooses `visit_default()`
780         instead.
781
782         Then yields objects of type `T` from the selected visitor.
783         """
784         if node.type < 256:
785             name = token.tok_name[node.type]
786         else:
787             name = type_repr(node.type)
788         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
789
790     def visit_default(self, node: LN) -> Iterator[T]:
791         """Default `visit_*()` implementation. Recurses to children of `node`."""
792         if isinstance(node, Node):
793             for child in node.children:
794                 yield from self.visit(child)
795
796
797 @dataclass
798 class DebugVisitor(Visitor[T]):
799     tree_depth: int = 0
800
801     def visit_default(self, node: LN) -> Iterator[T]:
802         indent = " " * (2 * self.tree_depth)
803         if isinstance(node, Node):
804             _type = type_repr(node.type)
805             out(f"{indent}{_type}", fg="yellow")
806             self.tree_depth += 1
807             for child in node.children:
808                 yield from self.visit(child)
809
810             self.tree_depth -= 1
811             out(f"{indent}/{_type}", fg="yellow", bold=False)
812         else:
813             _type = token.tok_name.get(node.type, str(node.type))
814             out(f"{indent}{_type}", fg="blue", nl=False)
815             if node.prefix:
816                 # We don't have to handle prefixes for `Node` objects since
817                 # that delegates to the first child anyway.
818                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
819             out(f" {node.value!r}", fg="blue", bold=False)
820
821     @classmethod
822     def show(cls, code: Union[str, Leaf, Node]) -> None:
823         """Pretty-print the lib2to3 AST of a given string of `code`.
824
825         Convenience method for debugging.
826         """
827         v: DebugVisitor[None] = DebugVisitor()
828         if isinstance(code, str):
829             code = lib2to3_parse(code)
830         list(v.visit(code))
831
832
833 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
834 STATEMENT = {
835     syms.if_stmt,
836     syms.while_stmt,
837     syms.for_stmt,
838     syms.try_stmt,
839     syms.except_clause,
840     syms.with_stmt,
841     syms.funcdef,
842     syms.classdef,
843 }
844 STANDALONE_COMMENT = 153
845 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
846 LOGIC_OPERATORS = {"and", "or"}
847 COMPARATORS = {
848     token.LESS,
849     token.GREATER,
850     token.EQEQUAL,
851     token.NOTEQUAL,
852     token.LESSEQUAL,
853     token.GREATEREQUAL,
854 }
855 MATH_OPERATORS = {
856     token.VBAR,
857     token.CIRCUMFLEX,
858     token.AMPER,
859     token.LEFTSHIFT,
860     token.RIGHTSHIFT,
861     token.PLUS,
862     token.MINUS,
863     token.STAR,
864     token.SLASH,
865     token.DOUBLESLASH,
866     token.PERCENT,
867     token.AT,
868     token.TILDE,
869     token.DOUBLESTAR,
870 }
871 STARS = {token.STAR, token.DOUBLESTAR}
872 VARARGS_PARENTS = {
873     syms.arglist,
874     syms.argument,  # double star in arglist
875     syms.trailer,  # single argument to call
876     syms.typedargslist,
877     syms.varargslist,  # lambdas
878 }
879 UNPACKING_PARENTS = {
880     syms.atom,  # single element of a list or set literal
881     syms.dictsetmaker,
882     syms.listmaker,
883     syms.testlist_gexp,
884     syms.testlist_star_expr,
885 }
886 TEST_DESCENDANTS = {
887     syms.test,
888     syms.lambdef,
889     syms.or_test,
890     syms.and_test,
891     syms.not_test,
892     syms.comparison,
893     syms.star_expr,
894     syms.expr,
895     syms.xor_expr,
896     syms.and_expr,
897     syms.shift_expr,
898     syms.arith_expr,
899     syms.trailer,
900     syms.term,
901     syms.power,
902 }
903 ASSIGNMENTS = {
904     "=",
905     "+=",
906     "-=",
907     "*=",
908     "@=",
909     "/=",
910     "%=",
911     "&=",
912     "|=",
913     "^=",
914     "<<=",
915     ">>=",
916     "**=",
917     "//=",
918 }
919 COMPREHENSION_PRIORITY = 20
920 COMMA_PRIORITY = 18
921 TERNARY_PRIORITY = 16
922 LOGIC_PRIORITY = 14
923 STRING_PRIORITY = 12
924 COMPARATOR_PRIORITY = 10
925 MATH_PRIORITIES = {
926     token.VBAR: 9,
927     token.CIRCUMFLEX: 8,
928     token.AMPER: 7,
929     token.LEFTSHIFT: 6,
930     token.RIGHTSHIFT: 6,
931     token.PLUS: 5,
932     token.MINUS: 5,
933     token.STAR: 4,
934     token.SLASH: 4,
935     token.DOUBLESLASH: 4,
936     token.PERCENT: 4,
937     token.AT: 4,
938     token.TILDE: 3,
939     token.DOUBLESTAR: 2,
940 }
941 DOT_PRIORITY = 1
942
943
944 @dataclass
945 class BracketTracker:
946     """Keeps track of brackets on a line."""
947
948     depth: int = 0
949     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
950     delimiters: Dict[LeafID, Priority] = Factory(dict)
951     previous: Optional[Leaf] = None
952     _for_loop_depths: List[int] = Factory(list)
953     _lambda_argument_depths: List[int] = Factory(list)
954
955     def mark(self, leaf: Leaf) -> None:
956         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
957
958         All leaves receive an int `bracket_depth` field that stores how deep
959         within brackets a given leaf is. 0 means there are no enclosing brackets
960         that started on this line.
961
962         If a leaf is itself a closing bracket, it receives an `opening_bracket`
963         field that it forms a pair with. This is a one-directional link to
964         avoid reference cycles.
965
966         If a leaf is a delimiter (a token on which Black can split the line if
967         needed) and it's on depth 0, its `id()` is stored in the tracker's
968         `delimiters` field.
969         """
970         if leaf.type == token.COMMENT:
971             return
972
973         self.maybe_decrement_after_for_loop_variable(leaf)
974         self.maybe_decrement_after_lambda_arguments(leaf)
975         if leaf.type in CLOSING_BRACKETS:
976             self.depth -= 1
977             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
978             leaf.opening_bracket = opening_bracket
979         leaf.bracket_depth = self.depth
980         if self.depth == 0:
981             delim = is_split_before_delimiter(leaf, self.previous)
982             if delim and self.previous is not None:
983                 self.delimiters[id(self.previous)] = delim
984             else:
985                 delim = is_split_after_delimiter(leaf, self.previous)
986                 if delim:
987                     self.delimiters[id(leaf)] = delim
988         if leaf.type in OPENING_BRACKETS:
989             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
990             self.depth += 1
991         self.previous = leaf
992         self.maybe_increment_lambda_arguments(leaf)
993         self.maybe_increment_for_loop_variable(leaf)
994
995     def any_open_brackets(self) -> bool:
996         """Return True if there is an yet unmatched open bracket on the line."""
997         return bool(self.bracket_match)
998
999     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
1000         """Return the highest priority of a delimiter found on the line.
1001
1002         Values are consistent with what `is_split_*_delimiter()` return.
1003         Raises ValueError on no delimiters.
1004         """
1005         return max(v for k, v in self.delimiters.items() if k not in exclude)
1006
1007     def delimiter_count_with_priority(self, priority: int = 0) -> int:
1008         """Return the number of delimiters with the given `priority`.
1009
1010         If no `priority` is passed, defaults to max priority on the line.
1011         """
1012         if not self.delimiters:
1013             return 0
1014
1015         priority = priority or self.max_delimiter_priority()
1016         return sum(1 for p in self.delimiters.values() if p == priority)
1017
1018     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1019         """In a for loop, or comprehension, the variables are often unpacks.
1020
1021         To avoid splitting on the comma in this situation, increase the depth of
1022         tokens between `for` and `in`.
1023         """
1024         if leaf.type == token.NAME and leaf.value == "for":
1025             self.depth += 1
1026             self._for_loop_depths.append(self.depth)
1027             return True
1028
1029         return False
1030
1031     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
1032         """See `maybe_increment_for_loop_variable` above for explanation."""
1033         if (
1034             self._for_loop_depths
1035             and self._for_loop_depths[-1] == self.depth
1036             and leaf.type == token.NAME
1037             and leaf.value == "in"
1038         ):
1039             self.depth -= 1
1040             self._for_loop_depths.pop()
1041             return True
1042
1043         return False
1044
1045     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
1046         """In a lambda expression, there might be more than one argument.
1047
1048         To avoid splitting on the comma in this situation, increase the depth of
1049         tokens between `lambda` and `:`.
1050         """
1051         if leaf.type == token.NAME and leaf.value == "lambda":
1052             self.depth += 1
1053             self._lambda_argument_depths.append(self.depth)
1054             return True
1055
1056         return False
1057
1058     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
1059         """See `maybe_increment_lambda_arguments` above for explanation."""
1060         if (
1061             self._lambda_argument_depths
1062             and self._lambda_argument_depths[-1] == self.depth
1063             and leaf.type == token.COLON
1064         ):
1065             self.depth -= 1
1066             self._lambda_argument_depths.pop()
1067             return True
1068
1069         return False
1070
1071     def get_open_lsqb(self) -> Optional[Leaf]:
1072         """Return the most recent opening square bracket (if any)."""
1073         return self.bracket_match.get((self.depth - 1, token.RSQB))
1074
1075
1076 @dataclass
1077 class Line:
1078     """Holds leaves and comments. Can be printed with `str(line)`."""
1079
1080     depth: int = 0
1081     leaves: List[Leaf] = Factory(list)
1082     # The LeafID keys of comments must remain ordered by the corresponding leaf's index
1083     # in leaves
1084     comments: Dict[LeafID, List[Leaf]] = Factory(dict)
1085     bracket_tracker: BracketTracker = Factory(BracketTracker)
1086     inside_brackets: bool = False
1087     should_explode: bool = False
1088
1089     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
1090         """Add a new `leaf` to the end of the line.
1091
1092         Unless `preformatted` is True, the `leaf` will receive a new consistent
1093         whitespace prefix and metadata applied by :class:`BracketTracker`.
1094         Trailing commas are maybe removed, unpacked for loop variables are
1095         demoted from being delimiters.
1096
1097         Inline comments are put aside.
1098         """
1099         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
1100         if not has_value:
1101             return
1102
1103         if token.COLON == leaf.type and self.is_class_paren_empty:
1104             del self.leaves[-2:]
1105         if self.leaves and not preformatted:
1106             # Note: at this point leaf.prefix should be empty except for
1107             # imports, for which we only preserve newlines.
1108             leaf.prefix += whitespace(
1109                 leaf, complex_subscript=self.is_complex_subscript(leaf)
1110             )
1111         if self.inside_brackets or not preformatted:
1112             self.bracket_tracker.mark(leaf)
1113             self.maybe_remove_trailing_comma(leaf)
1114         if not self.append_comment(leaf):
1115             self.leaves.append(leaf)
1116
1117     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
1118         """Like :func:`append()` but disallow invalid standalone comment structure.
1119
1120         Raises ValueError when any `leaf` is appended after a standalone comment
1121         or when a standalone comment is not the first leaf on the line.
1122         """
1123         if self.bracket_tracker.depth == 0:
1124             if self.is_comment:
1125                 raise ValueError("cannot append to standalone comments")
1126
1127             if self.leaves and leaf.type == STANDALONE_COMMENT:
1128                 raise ValueError(
1129                     "cannot append standalone comments to a populated line"
1130                 )
1131
1132         self.append(leaf, preformatted=preformatted)
1133
1134     @property
1135     def is_comment(self) -> bool:
1136         """Is this line a standalone comment?"""
1137         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1138
1139     @property
1140     def is_decorator(self) -> bool:
1141         """Is this line a decorator?"""
1142         return bool(self) and self.leaves[0].type == token.AT
1143
1144     @property
1145     def is_import(self) -> bool:
1146         """Is this an import line?"""
1147         return bool(self) and is_import(self.leaves[0])
1148
1149     @property
1150     def is_class(self) -> bool:
1151         """Is this line a class definition?"""
1152         return (
1153             bool(self)
1154             and self.leaves[0].type == token.NAME
1155             and self.leaves[0].value == "class"
1156         )
1157
1158     @property
1159     def is_stub_class(self) -> bool:
1160         """Is this line a class definition with a body consisting only of "..."?"""
1161         return self.is_class and self.leaves[-3:] == [
1162             Leaf(token.DOT, ".") for _ in range(3)
1163         ]
1164
1165     @property
1166     def is_def(self) -> bool:
1167         """Is this a function definition? (Also returns True for async defs.)"""
1168         try:
1169             first_leaf = self.leaves[0]
1170         except IndexError:
1171             return False
1172
1173         try:
1174             second_leaf: Optional[Leaf] = self.leaves[1]
1175         except IndexError:
1176             second_leaf = None
1177         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1178             first_leaf.type == token.ASYNC
1179             and second_leaf is not None
1180             and second_leaf.type == token.NAME
1181             and second_leaf.value == "def"
1182         )
1183
1184     @property
1185     def is_class_paren_empty(self) -> bool:
1186         """Is this a class with no base classes but using parentheses?
1187
1188         Those are unnecessary and should be removed.
1189         """
1190         return (
1191             bool(self)
1192             and len(self.leaves) == 4
1193             and self.is_class
1194             and self.leaves[2].type == token.LPAR
1195             and self.leaves[2].value == "("
1196             and self.leaves[3].type == token.RPAR
1197             and self.leaves[3].value == ")"
1198         )
1199
1200     @property
1201     def is_triple_quoted_string(self) -> bool:
1202         """Is the line a triple quoted string?"""
1203         return (
1204             bool(self)
1205             and self.leaves[0].type == token.STRING
1206             and self.leaves[0].value.startswith(('"""', "'''"))
1207         )
1208
1209     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1210         """If so, needs to be split before emitting."""
1211         for leaf in self.leaves:
1212             if leaf.type == STANDALONE_COMMENT:
1213                 if leaf.bracket_depth <= depth_limit:
1214                     return True
1215
1216         return False
1217
1218     def contains_multiline_strings(self) -> bool:
1219         for leaf in self.leaves:
1220             if is_multiline_string(leaf):
1221                 return True
1222
1223         return False
1224
1225     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1226         """Remove trailing comma if there is one and it's safe."""
1227         if not (
1228             self.leaves
1229             and self.leaves[-1].type == token.COMMA
1230             and closing.type in CLOSING_BRACKETS
1231         ):
1232             return False
1233
1234         if closing.type == token.RBRACE:
1235             self.remove_trailing_comma()
1236             return True
1237
1238         if closing.type == token.RSQB:
1239             comma = self.leaves[-1]
1240             if comma.parent and comma.parent.type == syms.listmaker:
1241                 self.remove_trailing_comma()
1242                 return True
1243
1244         # For parens let's check if it's safe to remove the comma.
1245         # Imports are always safe.
1246         if self.is_import:
1247             self.remove_trailing_comma()
1248             return True
1249
1250         # Otherwise, if the trailing one is the only one, we might mistakenly
1251         # change a tuple into a different type by removing the comma.
1252         depth = closing.bracket_depth + 1
1253         commas = 0
1254         opening = closing.opening_bracket
1255         for _opening_index, leaf in enumerate(self.leaves):
1256             if leaf is opening:
1257                 break
1258
1259         else:
1260             return False
1261
1262         for leaf in self.leaves[_opening_index + 1 :]:
1263             if leaf is closing:
1264                 break
1265
1266             bracket_depth = leaf.bracket_depth
1267             if bracket_depth == depth and leaf.type == token.COMMA:
1268                 commas += 1
1269                 if leaf.parent and leaf.parent.type == syms.arglist:
1270                     commas += 1
1271                     break
1272
1273         if commas > 1:
1274             self.remove_trailing_comma()
1275             return True
1276
1277         return False
1278
1279     def append_comment(self, comment: Leaf) -> bool:
1280         """Add an inline or standalone comment to the line."""
1281         if (
1282             comment.type == STANDALONE_COMMENT
1283             and self.bracket_tracker.any_open_brackets()
1284         ):
1285             comment.prefix = ""
1286             return False
1287
1288         if comment.type != token.COMMENT:
1289             return False
1290
1291         if not self.leaves:
1292             comment.type = STANDALONE_COMMENT
1293             comment.prefix = ""
1294             return False
1295
1296         else:
1297             leaf_id = id(self.leaves[-1])
1298             if leaf_id not in self.comments:
1299                 self.comments[leaf_id] = [comment]
1300             else:
1301                 self.comments[leaf_id].append(comment)
1302             return True
1303
1304     def comments_after(self, leaf: Leaf) -> List[Leaf]:
1305         """Generate comments that should appear directly after `leaf`."""
1306         return self.comments.get(id(leaf), [])
1307
1308     def remove_trailing_comma(self) -> None:
1309         """Remove the trailing comma and moves the comments attached to it."""
1310         # Remember, the LeafID keys of self.comments are ordered by the
1311         # corresponding leaf's index in self.leaves
1312         # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
1313         # Otherwise, we insert it into self.comments, and it becomes the last entry.
1314         # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
1315         # is maintained
1316         self.comments.setdefault(id(self.leaves[-2]), []).extend(
1317             self.comments.get(id(self.leaves[-1]), [])
1318         )
1319         self.comments.pop(id(self.leaves[-1]), None)
1320         self.leaves.pop()
1321
1322     def is_complex_subscript(self, leaf: Leaf) -> bool:
1323         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1324         open_lsqb = self.bracket_tracker.get_open_lsqb()
1325         if open_lsqb is None:
1326             return False
1327
1328         subscript_start = open_lsqb.next_sibling
1329
1330         if isinstance(subscript_start, Node):
1331             if subscript_start.type == syms.listmaker:
1332                 return False
1333
1334             if subscript_start.type == syms.subscriptlist:
1335                 subscript_start = child_towards(subscript_start, leaf)
1336         return subscript_start is not None and any(
1337             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1338         )
1339
1340     def __str__(self) -> str:
1341         """Render the line."""
1342         if not self:
1343             return "\n"
1344
1345         indent = "    " * self.depth
1346         leaves = iter(self.leaves)
1347         first = next(leaves)
1348         res = f"{first.prefix}{indent}{first.value}"
1349         for leaf in leaves:
1350             res += str(leaf)
1351         for comment in itertools.chain.from_iterable(self.comments.values()):
1352             res += str(comment)
1353         return res + "\n"
1354
1355     def __bool__(self) -> bool:
1356         """Return True if the line has leaves or comments."""
1357         return bool(self.leaves or self.comments)
1358
1359
1360 @dataclass
1361 class EmptyLineTracker:
1362     """Provides a stateful method that returns the number of potential extra
1363     empty lines needed before and after the currently processed line.
1364
1365     Note: this tracker works on lines that haven't been split yet.  It assumes
1366     the prefix of the first leaf consists of optional newlines.  Those newlines
1367     are consumed by `maybe_empty_lines()` and included in the computation.
1368     """
1369
1370     is_pyi: bool = False
1371     previous_line: Optional[Line] = None
1372     previous_after: int = 0
1373     previous_defs: List[int] = Factory(list)
1374
1375     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1376         """Return the number of extra empty lines before and after the `current_line`.
1377
1378         This is for separating `def`, `async def` and `class` with extra empty
1379         lines (two on module-level).
1380         """
1381         before, after = self._maybe_empty_lines(current_line)
1382         before -= self.previous_after
1383         self.previous_after = after
1384         self.previous_line = current_line
1385         return before, after
1386
1387     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1388         max_allowed = 1
1389         if current_line.depth == 0:
1390             max_allowed = 1 if self.is_pyi else 2
1391         if current_line.leaves:
1392             # Consume the first leaf's extra newlines.
1393             first_leaf = current_line.leaves[0]
1394             before = first_leaf.prefix.count("\n")
1395             before = min(before, max_allowed)
1396             first_leaf.prefix = ""
1397         else:
1398             before = 0
1399         depth = current_line.depth
1400         while self.previous_defs and self.previous_defs[-1] >= depth:
1401             self.previous_defs.pop()
1402             if self.is_pyi:
1403                 before = 0 if depth else 1
1404             else:
1405                 before = 1 if depth else 2
1406         if current_line.is_decorator or current_line.is_def or current_line.is_class:
1407             return self._maybe_empty_lines_for_class_or_def(current_line, before)
1408
1409         if (
1410             self.previous_line
1411             and self.previous_line.is_import
1412             and not current_line.is_import
1413             and depth == self.previous_line.depth
1414         ):
1415             return (before or 1), 0
1416
1417         if (
1418             self.previous_line
1419             and self.previous_line.is_class
1420             and current_line.is_triple_quoted_string
1421         ):
1422             return before, 1
1423
1424         return before, 0
1425
1426     def _maybe_empty_lines_for_class_or_def(
1427         self, current_line: Line, before: int
1428     ) -> Tuple[int, int]:
1429         if not current_line.is_decorator:
1430             self.previous_defs.append(current_line.depth)
1431         if self.previous_line is None:
1432             # Don't insert empty lines before the first line in the file.
1433             return 0, 0
1434
1435         if self.previous_line.is_decorator:
1436             return 0, 0
1437
1438         if self.previous_line.depth < current_line.depth and (
1439             self.previous_line.is_class or self.previous_line.is_def
1440         ):
1441             return 0, 0
1442
1443         if (
1444             self.previous_line.is_comment
1445             and self.previous_line.depth == current_line.depth
1446             and before == 0
1447         ):
1448             return 0, 0
1449
1450         if self.is_pyi:
1451             if self.previous_line.depth > current_line.depth:
1452                 newlines = 1
1453             elif current_line.is_class or self.previous_line.is_class:
1454                 if current_line.is_stub_class and self.previous_line.is_stub_class:
1455                     # No blank line between classes with an empty body
1456                     newlines = 0
1457                 else:
1458                     newlines = 1
1459             elif current_line.is_def and not self.previous_line.is_def:
1460                 # Blank line between a block of functions and a block of non-functions
1461                 newlines = 1
1462             else:
1463                 newlines = 0
1464         else:
1465             newlines = 2
1466         if current_line.depth and newlines:
1467             newlines -= 1
1468         return newlines, 0
1469
1470
1471 @dataclass
1472 class LineGenerator(Visitor[Line]):
1473     """Generates reformatted Line objects.  Empty lines are not emitted.
1474
1475     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1476     in ways that will no longer stringify to valid Python code on the tree.
1477     """
1478
1479     is_pyi: bool = False
1480     normalize_strings: bool = True
1481     current_line: Line = Factory(Line)
1482     remove_u_prefix: bool = False
1483
1484     def line(self, indent: int = 0) -> Iterator[Line]:
1485         """Generate a line.
1486
1487         If the line is empty, only emit if it makes sense.
1488         If the line is too long, split it first and then generate.
1489
1490         If any lines were generated, set up a new current_line.
1491         """
1492         if not self.current_line:
1493             self.current_line.depth += indent
1494             return  # Line is empty, don't emit. Creating a new one unnecessary.
1495
1496         complete_line = self.current_line
1497         self.current_line = Line(depth=complete_line.depth + indent)
1498         yield complete_line
1499
1500     def visit_default(self, node: LN) -> Iterator[Line]:
1501         """Default `visit_*()` implementation. Recurses to children of `node`."""
1502         if isinstance(node, Leaf):
1503             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1504             for comment in generate_comments(node):
1505                 if any_open_brackets:
1506                     # any comment within brackets is subject to splitting
1507                     self.current_line.append(comment)
1508                 elif comment.type == token.COMMENT:
1509                     # regular trailing comment
1510                     self.current_line.append(comment)
1511                     yield from self.line()
1512
1513                 else:
1514                     # regular standalone comment
1515                     yield from self.line()
1516
1517                     self.current_line.append(comment)
1518                     yield from self.line()
1519
1520             normalize_prefix(node, inside_brackets=any_open_brackets)
1521             if self.normalize_strings and node.type == token.STRING:
1522                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1523                 normalize_string_quotes(node)
1524             if node.type == token.NUMBER:
1525                 normalize_numeric_literal(node)
1526             if node.type not in WHITESPACE:
1527                 self.current_line.append(node)
1528         yield from super().visit_default(node)
1529
1530     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1531         """Increase indentation level, maybe yield a line."""
1532         # In blib2to3 INDENT never holds comments.
1533         yield from self.line(+1)
1534         yield from self.visit_default(node)
1535
1536     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1537         """Decrease indentation level, maybe yield a line."""
1538         # The current line might still wait for trailing comments.  At DEDENT time
1539         # there won't be any (they would be prefixes on the preceding NEWLINE).
1540         # Emit the line then.
1541         yield from self.line()
1542
1543         # While DEDENT has no value, its prefix may contain standalone comments
1544         # that belong to the current indentation level.  Get 'em.
1545         yield from self.visit_default(node)
1546
1547         # Finally, emit the dedent.
1548         yield from self.line(-1)
1549
1550     def visit_stmt(
1551         self, node: Node, keywords: Set[str], parens: Set[str]
1552     ) -> Iterator[Line]:
1553         """Visit a statement.
1554
1555         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1556         `def`, `with`, `class`, `assert` and assignments.
1557
1558         The relevant Python language `keywords` for a given statement will be
1559         NAME leaves within it. This methods puts those on a separate line.
1560
1561         `parens` holds a set of string leaf values immediately after which
1562         invisible parens should be put.
1563         """
1564         normalize_invisible_parens(node, parens_after=parens)
1565         for child in node.children:
1566             if child.type == token.NAME and child.value in keywords:  # type: ignore
1567                 yield from self.line()
1568
1569             yield from self.visit(child)
1570
1571     def visit_suite(self, node: Node) -> Iterator[Line]:
1572         """Visit a suite."""
1573         if self.is_pyi and is_stub_suite(node):
1574             yield from self.visit(node.children[2])
1575         else:
1576             yield from self.visit_default(node)
1577
1578     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1579         """Visit a statement without nested statements."""
1580         is_suite_like = node.parent and node.parent.type in STATEMENT
1581         if is_suite_like:
1582             if self.is_pyi and is_stub_body(node):
1583                 yield from self.visit_default(node)
1584             else:
1585                 yield from self.line(+1)
1586                 yield from self.visit_default(node)
1587                 yield from self.line(-1)
1588
1589         else:
1590             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1591                 yield from self.line()
1592             yield from self.visit_default(node)
1593
1594     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1595         """Visit `async def`, `async for`, `async with`."""
1596         yield from self.line()
1597
1598         children = iter(node.children)
1599         for child in children:
1600             yield from self.visit(child)
1601
1602             if child.type == token.ASYNC:
1603                 break
1604
1605         internal_stmt = next(children)
1606         for child in internal_stmt.children:
1607             yield from self.visit(child)
1608
1609     def visit_decorators(self, node: Node) -> Iterator[Line]:
1610         """Visit decorators."""
1611         for child in node.children:
1612             yield from self.line()
1613             yield from self.visit(child)
1614
1615     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1616         """Remove a semicolon and put the other statement on a separate line."""
1617         yield from self.line()
1618
1619     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1620         """End of file. Process outstanding comments and end with a newline."""
1621         yield from self.visit_default(leaf)
1622         yield from self.line()
1623
1624     def visit_STANDALONE_COMMENT(self, leaf: Leaf) -> Iterator[Line]:
1625         if not self.current_line.bracket_tracker.any_open_brackets():
1626             yield from self.line()
1627         yield from self.visit_default(leaf)
1628
1629     def __attrs_post_init__(self) -> None:
1630         """You are in a twisty little maze of passages."""
1631         v = self.visit_stmt
1632         Ø: Set[str] = set()
1633         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1634         self.visit_if_stmt = partial(
1635             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1636         )
1637         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1638         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1639         self.visit_try_stmt = partial(
1640             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1641         )
1642         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1643         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1644         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1645         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1646         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1647         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1648         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1649         self.visit_async_funcdef = self.visit_async_stmt
1650         self.visit_decorated = self.visit_decorators
1651
1652
1653 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1654 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1655 OPENING_BRACKETS = set(BRACKET.keys())
1656 CLOSING_BRACKETS = set(BRACKET.values())
1657 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1658 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1659
1660
1661 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
1662     """Return whitespace prefix if needed for the given `leaf`.
1663
1664     `complex_subscript` signals whether the given leaf is part of a subscription
1665     which has non-trivial arguments, like arithmetic expressions or function calls.
1666     """
1667     NO = ""
1668     SPACE = " "
1669     DOUBLESPACE = "  "
1670     t = leaf.type
1671     p = leaf.parent
1672     v = leaf.value
1673     if t in ALWAYS_NO_SPACE:
1674         return NO
1675
1676     if t == token.COMMENT:
1677         return DOUBLESPACE
1678
1679     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1680     if t == token.COLON and p.type not in {
1681         syms.subscript,
1682         syms.subscriptlist,
1683         syms.sliceop,
1684     }:
1685         return NO
1686
1687     prev = leaf.prev_sibling
1688     if not prev:
1689         prevp = preceding_leaf(p)
1690         if not prevp or prevp.type in OPENING_BRACKETS:
1691             return NO
1692
1693         if t == token.COLON:
1694             if prevp.type == token.COLON:
1695                 return NO
1696
1697             elif prevp.type != token.COMMA and not complex_subscript:
1698                 return NO
1699
1700             return SPACE
1701
1702         if prevp.type == token.EQUAL:
1703             if prevp.parent:
1704                 if prevp.parent.type in {
1705                     syms.arglist,
1706                     syms.argument,
1707                     syms.parameters,
1708                     syms.varargslist,
1709                 }:
1710                     return NO
1711
1712                 elif prevp.parent.type == syms.typedargslist:
1713                     # A bit hacky: if the equal sign has whitespace, it means we
1714                     # previously found it's a typed argument.  So, we're using
1715                     # that, too.
1716                     return prevp.prefix
1717
1718         elif prevp.type in STARS:
1719             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1720                 return NO
1721
1722         elif prevp.type == token.COLON:
1723             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1724                 return SPACE if complex_subscript else NO
1725
1726         elif (
1727             prevp.parent
1728             and prevp.parent.type == syms.factor
1729             and prevp.type in MATH_OPERATORS
1730         ):
1731             return NO
1732
1733         elif (
1734             prevp.type == token.RIGHTSHIFT
1735             and prevp.parent
1736             and prevp.parent.type == syms.shift_expr
1737             and prevp.prev_sibling
1738             and prevp.prev_sibling.type == token.NAME
1739             and prevp.prev_sibling.value == "print"  # type: ignore
1740         ):
1741             # Python 2 print chevron
1742             return NO
1743
1744     elif prev.type in OPENING_BRACKETS:
1745         return NO
1746
1747     if p.type in {syms.parameters, syms.arglist}:
1748         # untyped function signatures or calls
1749         if not prev or prev.type != token.COMMA:
1750             return NO
1751
1752     elif p.type == syms.varargslist:
1753         # lambdas
1754         if prev and prev.type != token.COMMA:
1755             return NO
1756
1757     elif p.type == syms.typedargslist:
1758         # typed function signatures
1759         if not prev:
1760             return NO
1761
1762         if t == token.EQUAL:
1763             if prev.type != syms.tname:
1764                 return NO
1765
1766         elif prev.type == token.EQUAL:
1767             # A bit hacky: if the equal sign has whitespace, it means we
1768             # previously found it's a typed argument.  So, we're using that, too.
1769             return prev.prefix
1770
1771         elif prev.type != token.COMMA:
1772             return NO
1773
1774     elif p.type == syms.tname:
1775         # type names
1776         if not prev:
1777             prevp = preceding_leaf(p)
1778             if not prevp or prevp.type != token.COMMA:
1779                 return NO
1780
1781     elif p.type == syms.trailer:
1782         # attributes and calls
1783         if t == token.LPAR or t == token.RPAR:
1784             return NO
1785
1786         if not prev:
1787             if t == token.DOT:
1788                 prevp = preceding_leaf(p)
1789                 if not prevp or prevp.type != token.NUMBER:
1790                     return NO
1791
1792             elif t == token.LSQB:
1793                 return NO
1794
1795         elif prev.type != token.COMMA:
1796             return NO
1797
1798     elif p.type == syms.argument:
1799         # single argument
1800         if t == token.EQUAL:
1801             return NO
1802
1803         if not prev:
1804             prevp = preceding_leaf(p)
1805             if not prevp or prevp.type == token.LPAR:
1806                 return NO
1807
1808         elif prev.type in {token.EQUAL} | STARS:
1809             return NO
1810
1811     elif p.type == syms.decorator:
1812         # decorators
1813         return NO
1814
1815     elif p.type == syms.dotted_name:
1816         if prev:
1817             return NO
1818
1819         prevp = preceding_leaf(p)
1820         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1821             return NO
1822
1823     elif p.type == syms.classdef:
1824         if t == token.LPAR:
1825             return NO
1826
1827         if prev and prev.type == token.LPAR:
1828             return NO
1829
1830     elif p.type in {syms.subscript, syms.sliceop}:
1831         # indexing
1832         if not prev:
1833             assert p.parent is not None, "subscripts are always parented"
1834             if p.parent.type == syms.subscriptlist:
1835                 return SPACE
1836
1837             return NO
1838
1839         elif not complex_subscript:
1840             return NO
1841
1842     elif p.type == syms.atom:
1843         if prev and t == token.DOT:
1844             # dots, but not the first one.
1845             return NO
1846
1847     elif p.type == syms.dictsetmaker:
1848         # dict unpacking
1849         if prev and prev.type == token.DOUBLESTAR:
1850             return NO
1851
1852     elif p.type in {syms.factor, syms.star_expr}:
1853         # unary ops
1854         if not prev:
1855             prevp = preceding_leaf(p)
1856             if not prevp or prevp.type in OPENING_BRACKETS:
1857                 return NO
1858
1859             prevp_parent = prevp.parent
1860             assert prevp_parent is not None
1861             if prevp.type == token.COLON and prevp_parent.type in {
1862                 syms.subscript,
1863                 syms.sliceop,
1864             }:
1865                 return NO
1866
1867             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1868                 return NO
1869
1870         elif t in {token.NAME, token.NUMBER, token.STRING}:
1871             return NO
1872
1873     elif p.type == syms.import_from:
1874         if t == token.DOT:
1875             if prev and prev.type == token.DOT:
1876                 return NO
1877
1878         elif t == token.NAME:
1879             if v == "import":
1880                 return SPACE
1881
1882             if prev and prev.type == token.DOT:
1883                 return NO
1884
1885     elif p.type == syms.sliceop:
1886         return NO
1887
1888     return SPACE
1889
1890
1891 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1892     """Return the first leaf that precedes `node`, if any."""
1893     while node:
1894         res = node.prev_sibling
1895         if res:
1896             if isinstance(res, Leaf):
1897                 return res
1898
1899             try:
1900                 return list(res.leaves())[-1]
1901
1902             except IndexError:
1903                 return None
1904
1905         node = node.parent
1906     return None
1907
1908
1909 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1910     """Return the child of `ancestor` that contains `descendant`."""
1911     node: Optional[LN] = descendant
1912     while node and node.parent != ancestor:
1913         node = node.parent
1914     return node
1915
1916
1917 def container_of(leaf: Leaf) -> LN:
1918     """Return `leaf` or one of its ancestors that is the topmost container of it.
1919
1920     By "container" we mean a node where `leaf` is the very first child.
1921     """
1922     same_prefix = leaf.prefix
1923     container: LN = leaf
1924     while container:
1925         parent = container.parent
1926         if parent is None:
1927             break
1928
1929         if parent.children[0].prefix != same_prefix:
1930             break
1931
1932         if parent.type == syms.file_input:
1933             break
1934
1935         if parent.prev_sibling is not None and parent.prev_sibling.type in BRACKETS:
1936             break
1937
1938         container = parent
1939     return container
1940
1941
1942 def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1943     """Return the priority of the `leaf` delimiter, given a line break after it.
1944
1945     The delimiter priorities returned here are from those delimiters that would
1946     cause a line break after themselves.
1947
1948     Higher numbers are higher priority.
1949     """
1950     if leaf.type == token.COMMA:
1951         return COMMA_PRIORITY
1952
1953     return 0
1954
1955
1956 def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
1957     """Return the priority of the `leaf` delimiter, given a line break before it.
1958
1959     The delimiter priorities returned here are from those delimiters that would
1960     cause a line break before themselves.
1961
1962     Higher numbers are higher priority.
1963     """
1964     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1965         # * and ** might also be MATH_OPERATORS but in this case they are not.
1966         # Don't treat them as a delimiter.
1967         return 0
1968
1969     if (
1970         leaf.type == token.DOT
1971         and leaf.parent
1972         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1973         and (previous is None or previous.type in CLOSING_BRACKETS)
1974     ):
1975         return DOT_PRIORITY
1976
1977     if (
1978         leaf.type in MATH_OPERATORS
1979         and leaf.parent
1980         and leaf.parent.type not in {syms.factor, syms.star_expr}
1981     ):
1982         return MATH_PRIORITIES[leaf.type]
1983
1984     if leaf.type in COMPARATORS:
1985         return COMPARATOR_PRIORITY
1986
1987     if (
1988         leaf.type == token.STRING
1989         and previous is not None
1990         and previous.type == token.STRING
1991     ):
1992         return STRING_PRIORITY
1993
1994     if leaf.type not in {token.NAME, token.ASYNC}:
1995         return 0
1996
1997     if (
1998         leaf.value == "for"
1999         and leaf.parent
2000         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
2001         or leaf.type == token.ASYNC
2002     ):
2003         if (
2004             not isinstance(leaf.prev_sibling, Leaf)
2005             or leaf.prev_sibling.value != "async"
2006         ):
2007             return COMPREHENSION_PRIORITY
2008
2009     if (
2010         leaf.value == "if"
2011         and leaf.parent
2012         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
2013     ):
2014         return COMPREHENSION_PRIORITY
2015
2016     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
2017         return TERNARY_PRIORITY
2018
2019     if leaf.value == "is":
2020         return COMPARATOR_PRIORITY
2021
2022     if (
2023         leaf.value == "in"
2024         and leaf.parent
2025         and leaf.parent.type in {syms.comp_op, syms.comparison}
2026         and not (
2027             previous is not None
2028             and previous.type == token.NAME
2029             and previous.value == "not"
2030         )
2031     ):
2032         return COMPARATOR_PRIORITY
2033
2034     if (
2035         leaf.value == "not"
2036         and leaf.parent
2037         and leaf.parent.type == syms.comp_op
2038         and not (
2039             previous is not None
2040             and previous.type == token.NAME
2041             and previous.value == "is"
2042         )
2043     ):
2044         return COMPARATOR_PRIORITY
2045
2046     if leaf.value in LOGIC_OPERATORS and leaf.parent:
2047         return LOGIC_PRIORITY
2048
2049     return 0
2050
2051
2052 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
2053 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
2054
2055
2056 def generate_comments(leaf: LN) -> Iterator[Leaf]:
2057     """Clean the prefix of the `leaf` and generate comments from it, if any.
2058
2059     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
2060     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
2061     move because it does away with modifying the grammar to include all the
2062     possible places in which comments can be placed.
2063
2064     The sad consequence for us though is that comments don't "belong" anywhere.
2065     This is why this function generates simple parentless Leaf objects for
2066     comments.  We simply don't know what the correct parent should be.
2067
2068     No matter though, we can live without this.  We really only need to
2069     differentiate between inline and standalone comments.  The latter don't
2070     share the line with any code.
2071
2072     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
2073     are emitted with a fake STANDALONE_COMMENT token identifier.
2074     """
2075     for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
2076         yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
2077
2078
2079 @dataclass
2080 class ProtoComment:
2081     """Describes a piece of syntax that is a comment.
2082
2083     It's not a :class:`blib2to3.pytree.Leaf` so that:
2084
2085     * it can be cached (`Leaf` objects should not be reused more than once as
2086       they store their lineno, column, prefix, and parent information);
2087     * `newlines` and `consumed` fields are kept separate from the `value`. This
2088       simplifies handling of special marker comments like ``# fmt: off/on``.
2089     """
2090
2091     type: int  # token.COMMENT or STANDALONE_COMMENT
2092     value: str  # content of the comment
2093     newlines: int  # how many newlines before the comment
2094     consumed: int  # how many characters of the original leaf's prefix did we consume
2095
2096
2097 @lru_cache(maxsize=4096)
2098 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
2099     """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
2100     result: List[ProtoComment] = []
2101     if not prefix or "#" not in prefix:
2102         return result
2103
2104     consumed = 0
2105     nlines = 0
2106     for index, line in enumerate(prefix.split("\n")):
2107         consumed += len(line) + 1  # adding the length of the split '\n'
2108         line = line.lstrip()
2109         if not line:
2110             nlines += 1
2111         if not line.startswith("#"):
2112             continue
2113
2114         if index == 0 and not is_endmarker:
2115             comment_type = token.COMMENT  # simple trailing comment
2116         else:
2117             comment_type = STANDALONE_COMMENT
2118         comment = make_comment(line)
2119         result.append(
2120             ProtoComment(
2121                 type=comment_type, value=comment, newlines=nlines, consumed=consumed
2122             )
2123         )
2124         nlines = 0
2125     return result
2126
2127
2128 def make_comment(content: str) -> str:
2129     """Return a consistently formatted comment from the given `content` string.
2130
2131     All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
2132     space between the hash sign and the content.
2133
2134     If `content` didn't start with a hash sign, one is provided.
2135     """
2136     content = content.rstrip()
2137     if not content:
2138         return "#"
2139
2140     if content[0] == "#":
2141         content = content[1:]
2142     if content and content[0] not in " !:#'%":
2143         content = " " + content
2144     return "#" + content
2145
2146
2147 def split_line(
2148     line: Line,
2149     line_length: int,
2150     inner: bool = False,
2151     supports_trailing_commas: bool = False,
2152 ) -> Iterator[Line]:
2153     """Split a `line` into potentially many lines.
2154
2155     They should fit in the allotted `line_length` but might not be able to.
2156     `inner` signifies that there were a pair of brackets somewhere around the
2157     current `line`, possibly transitively. This means we can fallback to splitting
2158     by delimiters if the LHS/RHS don't yield any results.
2159
2160     If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
2161     """
2162     if line.is_comment:
2163         yield line
2164         return
2165
2166     line_str = str(line).strip("\n")
2167
2168     # we don't want to split special comments like type annotations
2169     # https://github.com/python/typing/issues/186
2170     has_special_comment = False
2171     for leaf in line.leaves:
2172         for comment in line.comments_after(leaf):
2173             if leaf.type == token.COMMA and is_special_comment(comment):
2174                 has_special_comment = True
2175
2176     if (
2177         not has_special_comment
2178         and not line.should_explode
2179         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
2180     ):
2181         yield line
2182         return
2183
2184     split_funcs: List[SplitFunc]
2185     if line.is_def:
2186         split_funcs = [left_hand_split]
2187     else:
2188
2189         def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
2190             for omit in generate_trailers_to_omit(line, line_length):
2191                 lines = list(
2192                     right_hand_split(
2193                         line, line_length, supports_trailing_commas, omit=omit
2194                     )
2195                 )
2196                 if is_line_short_enough(lines[0], line_length=line_length):
2197                     yield from lines
2198                     return
2199
2200             # All splits failed, best effort split with no omits.
2201             # This mostly happens to multiline strings that are by definition
2202             # reported as not fitting a single line.
2203             yield from right_hand_split(line, supports_trailing_commas)
2204
2205         if line.inside_brackets:
2206             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2207         else:
2208             split_funcs = [rhs]
2209     for split_func in split_funcs:
2210         # We are accumulating lines in `result` because we might want to abort
2211         # mission and return the original line in the end, or attempt a different
2212         # split altogether.
2213         result: List[Line] = []
2214         try:
2215             for l in split_func(line, supports_trailing_commas):
2216                 if str(l).strip("\n") == line_str:
2217                     raise CannotSplit("Split function returned an unchanged result")
2218
2219                 result.extend(
2220                     split_line(
2221                         l,
2222                         line_length=line_length,
2223                         inner=True,
2224                         supports_trailing_commas=supports_trailing_commas,
2225                     )
2226                 )
2227         except CannotSplit:
2228             continue
2229
2230         else:
2231             yield from result
2232             break
2233
2234     else:
2235         yield line
2236
2237
2238 def left_hand_split(
2239     line: Line, supports_trailing_commas: bool = False
2240 ) -> Iterator[Line]:
2241     """Split line into many lines, starting with the first matching bracket pair.
2242
2243     Note: this usually looks weird, only use this for function definitions.
2244     Prefer RHS otherwise.  This is why this function is not symmetrical with
2245     :func:`right_hand_split` which also handles optional parentheses.
2246     """
2247     tail_leaves: List[Leaf] = []
2248     body_leaves: List[Leaf] = []
2249     head_leaves: List[Leaf] = []
2250     current_leaves = head_leaves
2251     matching_bracket = None
2252     for leaf in line.leaves:
2253         if (
2254             current_leaves is body_leaves
2255             and leaf.type in CLOSING_BRACKETS
2256             and leaf.opening_bracket is matching_bracket
2257         ):
2258             current_leaves = tail_leaves if body_leaves else head_leaves
2259         current_leaves.append(leaf)
2260         if current_leaves is head_leaves:
2261             if leaf.type in OPENING_BRACKETS:
2262                 matching_bracket = leaf
2263                 current_leaves = body_leaves
2264     if not matching_bracket:
2265         raise CannotSplit("No brackets found")
2266
2267     head = bracket_split_build_line(head_leaves, line, matching_bracket)
2268     body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
2269     tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
2270     bracket_split_succeeded_or_raise(head, body, tail)
2271     for result in (head, body, tail):
2272         if result:
2273             yield result
2274
2275
2276 def right_hand_split(
2277     line: Line,
2278     line_length: int,
2279     supports_trailing_commas: bool = False,
2280     omit: Collection[LeafID] = (),
2281 ) -> Iterator[Line]:
2282     """Split line into many lines, starting with the last matching bracket pair.
2283
2284     If the split was by optional parentheses, attempt splitting without them, too.
2285     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2286     this split.
2287
2288     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2289     """
2290     tail_leaves: List[Leaf] = []
2291     body_leaves: List[Leaf] = []
2292     head_leaves: List[Leaf] = []
2293     current_leaves = tail_leaves
2294     opening_bracket = None
2295     closing_bracket = None
2296     for leaf in reversed(line.leaves):
2297         if current_leaves is body_leaves:
2298             if leaf is opening_bracket:
2299                 current_leaves = head_leaves if body_leaves else tail_leaves
2300         current_leaves.append(leaf)
2301         if current_leaves is tail_leaves:
2302             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2303                 opening_bracket = leaf.opening_bracket
2304                 closing_bracket = leaf
2305                 current_leaves = body_leaves
2306     if not (opening_bracket and closing_bracket and head_leaves):
2307         # If there is no opening or closing_bracket that means the split failed and
2308         # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
2309         # the matching `opening_bracket` wasn't available on `line` anymore.
2310         raise CannotSplit("No brackets found")
2311
2312     tail_leaves.reverse()
2313     body_leaves.reverse()
2314     head_leaves.reverse()
2315     head = bracket_split_build_line(head_leaves, line, opening_bracket)
2316     body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
2317     tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
2318     bracket_split_succeeded_or_raise(head, body, tail)
2319     if (
2320         # the body shouldn't be exploded
2321         not body.should_explode
2322         # the opening bracket is an optional paren
2323         and opening_bracket.type == token.LPAR
2324         and not opening_bracket.value
2325         # the closing bracket is an optional paren
2326         and closing_bracket.type == token.RPAR
2327         and not closing_bracket.value
2328         # it's not an import (optional parens are the only thing we can split on
2329         # in this case; attempting a split without them is a waste of time)
2330         and not line.is_import
2331         # there are no standalone comments in the body
2332         and not body.contains_standalone_comments(0)
2333         # and we can actually remove the parens
2334         and can_omit_invisible_parens(body, line_length)
2335     ):
2336         omit = {id(closing_bracket), *omit}
2337         try:
2338             yield from right_hand_split(
2339                 line,
2340                 line_length,
2341                 supports_trailing_commas=supports_trailing_commas,
2342                 omit=omit,
2343             )
2344             return
2345
2346         except CannotSplit:
2347             if not (
2348                 can_be_split(body)
2349                 or is_line_short_enough(body, line_length=line_length)
2350             ):
2351                 raise CannotSplit(
2352                     "Splitting failed, body is still too long and can't be split."
2353                 )
2354
2355             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
2356                 raise CannotSplit(
2357                     "The current optional pair of parentheses is bound to fail to "
2358                     "satisfy the splitting algorithm because the head or the tail "
2359                     "contains multiline strings which by definition never fit one "
2360                     "line."
2361                 )
2362
2363     ensure_visible(opening_bracket)
2364     ensure_visible(closing_bracket)
2365     for result in (head, body, tail):
2366         if result:
2367             yield result
2368
2369
2370 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2371     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2372
2373     Do nothing otherwise.
2374
2375     A left- or right-hand split is based on a pair of brackets. Content before
2376     (and including) the opening bracket is left on one line, content inside the
2377     brackets is put on a separate line, and finally content starting with and
2378     following the closing bracket is put on a separate line.
2379
2380     Those are called `head`, `body`, and `tail`, respectively. If the split
2381     produced the same line (all content in `head`) or ended up with an empty `body`
2382     and the `tail` is just the closing bracket, then it's considered failed.
2383     """
2384     tail_len = len(str(tail).strip())
2385     if not body:
2386         if tail_len == 0:
2387             raise CannotSplit("Splitting brackets produced the same line")
2388
2389         elif tail_len < 3:
2390             raise CannotSplit(
2391                 f"Splitting brackets on an empty body to save "
2392                 f"{tail_len} characters is not worth it"
2393             )
2394
2395
2396 def bracket_split_build_line(
2397     leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
2398 ) -> Line:
2399     """Return a new line with given `leaves` and respective comments from `original`.
2400
2401     If `is_body` is True, the result line is one-indented inside brackets and as such
2402     has its first leaf's prefix normalized and a trailing comma added when expected.
2403     """
2404     result = Line(depth=original.depth)
2405     if is_body:
2406         result.inside_brackets = True
2407         result.depth += 1
2408         if leaves:
2409             # Since body is a new indent level, remove spurious leading whitespace.
2410             normalize_prefix(leaves[0], inside_brackets=True)
2411             # Ensure a trailing comma when expected.
2412             if original.is_import:
2413                 if leaves[-1].type != token.COMMA:
2414                     leaves.append(Leaf(token.COMMA, ","))
2415     # Populate the line
2416     for leaf in leaves:
2417         result.append(leaf, preformatted=True)
2418         for comment_after in original.comments_after(leaf):
2419             result.append(comment_after, preformatted=True)
2420     if is_body:
2421         result.should_explode = should_explode(result, opening_bracket)
2422     return result
2423
2424
2425 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2426     """Normalize prefix of the first leaf in every line returned by `split_func`.
2427
2428     This is a decorator over relevant split functions.
2429     """
2430
2431     @wraps(split_func)
2432     def split_wrapper(
2433         line: Line, supports_trailing_commas: bool = False
2434     ) -> Iterator[Line]:
2435         for l in split_func(line, supports_trailing_commas):
2436             normalize_prefix(l.leaves[0], inside_brackets=True)
2437             yield l
2438
2439     return split_wrapper
2440
2441
2442 @dont_increase_indentation
2443 def delimiter_split(
2444     line: Line, supports_trailing_commas: bool = False
2445 ) -> Iterator[Line]:
2446     """Split according to delimiters of the highest priority.
2447
2448     If `py36` is True, the split will add trailing commas also in function
2449     signatures that contain `*` and `**`.
2450     """
2451     try:
2452         last_leaf = line.leaves[-1]
2453     except IndexError:
2454         raise CannotSplit("Line empty")
2455
2456     bt = line.bracket_tracker
2457     try:
2458         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2459     except ValueError:
2460         raise CannotSplit("No delimiters found")
2461
2462     if delimiter_priority == DOT_PRIORITY:
2463         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2464             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2465
2466     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2467     lowest_depth = sys.maxsize
2468     trailing_comma_safe = True
2469
2470     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2471         """Append `leaf` to current line or to new line if appending impossible."""
2472         nonlocal current_line
2473         try:
2474             current_line.append_safe(leaf, preformatted=True)
2475         except ValueError:
2476             yield current_line
2477
2478             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2479             current_line.append(leaf)
2480
2481     for leaf in line.leaves:
2482         yield from append_to_line(leaf)
2483
2484         for comment_after in line.comments_after(leaf):
2485             yield from append_to_line(comment_after)
2486
2487         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2488         if leaf.bracket_depth == lowest_depth and is_vararg(
2489             leaf, within=VARARGS_PARENTS
2490         ):
2491             trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
2492         leaf_priority = bt.delimiters.get(id(leaf))
2493         if leaf_priority == delimiter_priority:
2494             yield current_line
2495
2496             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2497     if current_line:
2498         if (
2499             trailing_comma_safe
2500             and delimiter_priority == COMMA_PRIORITY
2501             and current_line.leaves[-1].type != token.COMMA
2502             and current_line.leaves[-1].type != STANDALONE_COMMENT
2503         ):
2504             current_line.append(Leaf(token.COMMA, ","))
2505         yield current_line
2506
2507
2508 @dont_increase_indentation
2509 def standalone_comment_split(
2510     line: Line, supports_trailing_commas: bool = False
2511 ) -> Iterator[Line]:
2512     """Split standalone comments from the rest of the line."""
2513     if not line.contains_standalone_comments(0):
2514         raise CannotSplit("Line does not have any standalone comments")
2515
2516     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2517
2518     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2519         """Append `leaf` to current line or to new line if appending impossible."""
2520         nonlocal current_line
2521         try:
2522             current_line.append_safe(leaf, preformatted=True)
2523         except ValueError:
2524             yield current_line
2525
2526             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2527             current_line.append(leaf)
2528
2529     for leaf in line.leaves:
2530         yield from append_to_line(leaf)
2531
2532         for comment_after in line.comments_after(leaf):
2533             yield from append_to_line(comment_after)
2534
2535     if current_line:
2536         yield current_line
2537
2538
2539 def is_import(leaf: Leaf) -> bool:
2540     """Return True if the given leaf starts an import statement."""
2541     p = leaf.parent
2542     t = leaf.type
2543     v = leaf.value
2544     return bool(
2545         t == token.NAME
2546         and (
2547             (v == "import" and p and p.type == syms.import_name)
2548             or (v == "from" and p and p.type == syms.import_from)
2549         )
2550     )
2551
2552
2553 def is_special_comment(leaf: Leaf) -> bool:
2554     """Return True if the given leaf is a special comment.
2555     Only returns true for type comments for now."""
2556     t = leaf.type
2557     v = leaf.value
2558     return bool(
2559         (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
2560     )
2561
2562
2563 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2564     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2565     else.
2566
2567     Note: don't use backslashes for formatting or you'll lose your voting rights.
2568     """
2569     if not inside_brackets:
2570         spl = leaf.prefix.split("#")
2571         if "\\" not in spl[0]:
2572             nl_count = spl[-1].count("\n")
2573             if len(spl) > 1:
2574                 nl_count -= 1
2575             leaf.prefix = "\n" * nl_count
2576             return
2577
2578     leaf.prefix = ""
2579
2580
2581 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2582     """Make all string prefixes lowercase.
2583
2584     If remove_u_prefix is given, also removes any u prefix from the string.
2585
2586     Note: Mutates its argument.
2587     """
2588     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2589     assert match is not None, f"failed to match string {leaf.value!r}"
2590     orig_prefix = match.group(1)
2591     new_prefix = orig_prefix.lower()
2592     if remove_u_prefix:
2593         new_prefix = new_prefix.replace("u", "")
2594     leaf.value = f"{new_prefix}{match.group(2)}"
2595
2596
2597 def normalize_string_quotes(leaf: Leaf) -> None:
2598     """Prefer double quotes but only if it doesn't cause more escaping.
2599
2600     Adds or removes backslashes as appropriate. Doesn't parse and fix
2601     strings nested in f-strings (yet).
2602
2603     Note: Mutates its argument.
2604     """
2605     value = leaf.value.lstrip("furbFURB")
2606     if value[:3] == '"""':
2607         return
2608
2609     elif value[:3] == "'''":
2610         orig_quote = "'''"
2611         new_quote = '"""'
2612     elif value[0] == '"':
2613         orig_quote = '"'
2614         new_quote = "'"
2615     else:
2616         orig_quote = "'"
2617         new_quote = '"'
2618     first_quote_pos = leaf.value.find(orig_quote)
2619     if first_quote_pos == -1:
2620         return  # There's an internal error
2621
2622     prefix = leaf.value[:first_quote_pos]
2623     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2624     escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
2625     escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
2626     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2627     if "r" in prefix.casefold():
2628         if unescaped_new_quote.search(body):
2629             # There's at least one unescaped new_quote in this raw string
2630             # so converting is impossible
2631             return
2632
2633         # Do not introduce or remove backslashes in raw strings
2634         new_body = body
2635     else:
2636         # remove unnecessary escapes
2637         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2638         if body != new_body:
2639             # Consider the string without unnecessary escapes as the original
2640             body = new_body
2641             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2642         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2643         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2644     if "f" in prefix.casefold():
2645         matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
2646         for m in matches:
2647             if "\\" in str(m):
2648                 # Do not introduce backslashes in interpolated expressions
2649                 return
2650     if new_quote == '"""' and new_body[-1:] == '"':
2651         # edge case:
2652         new_body = new_body[:-1] + '\\"'
2653     orig_escape_count = body.count("\\")
2654     new_escape_count = new_body.count("\\")
2655     if new_escape_count > orig_escape_count:
2656         return  # Do not introduce more escaping
2657
2658     if new_escape_count == orig_escape_count and orig_quote == '"':
2659         return  # Prefer double quotes
2660
2661     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2662
2663
2664 def normalize_numeric_literal(leaf: Leaf) -> None:
2665     """Normalizes numeric (float, int, and complex) literals.
2666
2667     All letters used in the representation are normalized to lowercase (except
2668     in Python 2 long literals).
2669     """
2670     text = leaf.value.lower()
2671     if text.startswith(("0o", "0b")):
2672         # Leave octal and binary literals alone.
2673         pass
2674     elif text.startswith("0x"):
2675         # Change hex literals to upper case.
2676         before, after = text[:2], text[2:]
2677         text = f"{before}{after.upper()}"
2678     elif "e" in text:
2679         before, after = text.split("e")
2680         sign = ""
2681         if after.startswith("-"):
2682             after = after[1:]
2683             sign = "-"
2684         elif after.startswith("+"):
2685             after = after[1:]
2686         before = format_float_or_int_string(before)
2687         text = f"{before}e{sign}{after}"
2688     elif text.endswith(("j", "l")):
2689         number = text[:-1]
2690         suffix = text[-1]
2691         # Capitalize in "2L" because "l" looks too similar to "1".
2692         if suffix == "l":
2693             suffix = "L"
2694         text = f"{format_float_or_int_string(number)}{suffix}"
2695     else:
2696         text = format_float_or_int_string(text)
2697     leaf.value = text
2698
2699
2700 def format_float_or_int_string(text: str) -> str:
2701     """Formats a float string like "1.0"."""
2702     if "." not in text:
2703         return text
2704
2705     before, after = text.split(".")
2706     return f"{before or 0}.{after or 0}"
2707
2708
2709 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2710     """Make existing optional parentheses invisible or create new ones.
2711
2712     `parens_after` is a set of string leaf values immeditely after which parens
2713     should be put.
2714
2715     Standardizes on visible parentheses for single-element tuples, and keeps
2716     existing visible parentheses for other tuples and generator expressions.
2717     """
2718     for pc in list_comments(node.prefix, is_endmarker=False):
2719         if pc.value in FMT_OFF:
2720             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2721             return
2722
2723     check_lpar = False
2724     for index, child in enumerate(list(node.children)):
2725         if check_lpar:
2726             if child.type == syms.atom:
2727                 if maybe_make_parens_invisible_in_atom(child):
2728                     lpar = Leaf(token.LPAR, "")
2729                     rpar = Leaf(token.RPAR, "")
2730                     index = child.remove() or 0
2731                     node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2732             elif is_one_tuple(child):
2733                 # wrap child in visible parentheses
2734                 lpar = Leaf(token.LPAR, "(")
2735                 rpar = Leaf(token.RPAR, ")")
2736                 child.remove()
2737                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2738             elif node.type == syms.import_from:
2739                 # "import from" nodes store parentheses directly as part of
2740                 # the statement
2741                 if child.type == token.LPAR:
2742                     # make parentheses invisible
2743                     child.value = ""  # type: ignore
2744                     node.children[-1].value = ""  # type: ignore
2745                 elif child.type != token.STAR:
2746                     # insert invisible parentheses
2747                     node.insert_child(index, Leaf(token.LPAR, ""))
2748                     node.append_child(Leaf(token.RPAR, ""))
2749                 break
2750
2751             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2752                 # wrap child in invisible parentheses
2753                 lpar = Leaf(token.LPAR, "")
2754                 rpar = Leaf(token.RPAR, "")
2755                 index = child.remove() or 0
2756                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2757
2758         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2759
2760
2761 def normalize_fmt_off(node: Node) -> None:
2762     """Convert content between `# fmt: off`/`# fmt: on` into standalone comments."""
2763     try_again = True
2764     while try_again:
2765         try_again = convert_one_fmt_off_pair(node)
2766
2767
2768 def convert_one_fmt_off_pair(node: Node) -> bool:
2769     """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment.
2770
2771     Returns True if a pair was converted.
2772     """
2773     for leaf in node.leaves():
2774         previous_consumed = 0
2775         for comment in list_comments(leaf.prefix, is_endmarker=False):
2776             if comment.value in FMT_OFF:
2777                 # We only want standalone comments. If there's no previous leaf or
2778                 # the previous leaf is indentation, it's a standalone comment in
2779                 # disguise.
2780                 if comment.type != STANDALONE_COMMENT:
2781                     prev = preceding_leaf(leaf)
2782                     if prev and prev.type not in WHITESPACE:
2783                         continue
2784
2785                 ignored_nodes = list(generate_ignored_nodes(leaf))
2786                 if not ignored_nodes:
2787                     continue
2788
2789                 first = ignored_nodes[0]  # Can be a container node with the `leaf`.
2790                 parent = first.parent
2791                 prefix = first.prefix
2792                 first.prefix = prefix[comment.consumed :]
2793                 hidden_value = (
2794                     comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
2795                 )
2796                 if hidden_value.endswith("\n"):
2797                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
2798                     # leaf (possibly followed by a DEDENT).
2799                     hidden_value = hidden_value[:-1]
2800                 first_idx = None
2801                 for ignored in ignored_nodes:
2802                     index = ignored.remove()
2803                     if first_idx is None:
2804                         first_idx = index
2805                 assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
2806                 assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
2807                 parent.insert_child(
2808                     first_idx,
2809                     Leaf(
2810                         STANDALONE_COMMENT,
2811                         hidden_value,
2812                         prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
2813                     ),
2814                 )
2815                 return True
2816
2817             previous_consumed = comment.consumed
2818
2819     return False
2820
2821
2822 def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
2823     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
2824
2825     Stops at the end of the block.
2826     """
2827     container: Optional[LN] = container_of(leaf)
2828     while container is not None and container.type != token.ENDMARKER:
2829         for comment in list_comments(container.prefix, is_endmarker=False):
2830             if comment.value in FMT_ON:
2831                 return
2832
2833         yield container
2834
2835         container = container.next_sibling
2836
2837
2838 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2839     """If it's safe, make the parens in the atom `node` invisible, recursively.
2840
2841     Returns whether the node should itself be wrapped in invisible parentheses.
2842
2843     """
2844     if (
2845         node.type != syms.atom
2846         or is_empty_tuple(node)
2847         or is_one_tuple(node)
2848         or is_yield(node)
2849         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2850     ):
2851         return False
2852
2853     first = node.children[0]
2854     last = node.children[-1]
2855     if first.type == token.LPAR and last.type == token.RPAR:
2856         # make parentheses invisible
2857         first.value = ""  # type: ignore
2858         last.value = ""  # type: ignore
2859         if len(node.children) > 1:
2860             maybe_make_parens_invisible_in_atom(node.children[1])
2861         return False
2862
2863     return True
2864
2865
2866 def is_empty_tuple(node: LN) -> bool:
2867     """Return True if `node` holds an empty tuple."""
2868     return (
2869         node.type == syms.atom
2870         and len(node.children) == 2
2871         and node.children[0].type == token.LPAR
2872         and node.children[1].type == token.RPAR
2873     )
2874
2875
2876 def is_one_tuple(node: LN) -> bool:
2877     """Return True if `node` holds a tuple with one element, with or without parens."""
2878     if node.type == syms.atom:
2879         if len(node.children) != 3:
2880             return False
2881
2882         lpar, gexp, rpar = node.children
2883         if not (
2884             lpar.type == token.LPAR
2885             and gexp.type == syms.testlist_gexp
2886             and rpar.type == token.RPAR
2887         ):
2888             return False
2889
2890         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2891
2892     return (
2893         node.type in IMPLICIT_TUPLE
2894         and len(node.children) == 2
2895         and node.children[1].type == token.COMMA
2896     )
2897
2898
2899 def is_yield(node: LN) -> bool:
2900     """Return True if `node` holds a `yield` or `yield from` expression."""
2901     if node.type == syms.yield_expr:
2902         return True
2903
2904     if node.type == token.NAME and node.value == "yield":  # type: ignore
2905         return True
2906
2907     if node.type != syms.atom:
2908         return False
2909
2910     if len(node.children) != 3:
2911         return False
2912
2913     lpar, expr, rpar = node.children
2914     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2915         return is_yield(expr)
2916
2917     return False
2918
2919
2920 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2921     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2922
2923     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2924     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2925     extended iterable unpacking (PEP 3132) and additional unpacking
2926     generalizations (PEP 448).
2927     """
2928     if leaf.type not in STARS or not leaf.parent:
2929         return False
2930
2931     p = leaf.parent
2932     if p.type == syms.star_expr:
2933         # Star expressions are also used as assignment targets in extended
2934         # iterable unpacking (PEP 3132).  See what its parent is instead.
2935         if not p.parent:
2936             return False
2937
2938         p = p.parent
2939
2940     return p.type in within
2941
2942
2943 def is_multiline_string(leaf: Leaf) -> bool:
2944     """Return True if `leaf` is a multiline string that actually spans many lines."""
2945     value = leaf.value.lstrip("furbFURB")
2946     return value[:3] in {'"""', "'''"} and "\n" in value
2947
2948
2949 def is_stub_suite(node: Node) -> bool:
2950     """Return True if `node` is a suite with a stub body."""
2951     if (
2952         len(node.children) != 4
2953         or node.children[0].type != token.NEWLINE
2954         or node.children[1].type != token.INDENT
2955         or node.children[3].type != token.DEDENT
2956     ):
2957         return False
2958
2959     return is_stub_body(node.children[2])
2960
2961
2962 def is_stub_body(node: LN) -> bool:
2963     """Return True if `node` is a simple statement containing an ellipsis."""
2964     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2965         return False
2966
2967     if len(node.children) != 2:
2968         return False
2969
2970     child = node.children[0]
2971     return (
2972         child.type == syms.atom
2973         and len(child.children) == 3
2974         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2975     )
2976
2977
2978 def max_delimiter_priority_in_atom(node: LN) -> int:
2979     """Return maximum delimiter priority inside `node`.
2980
2981     This is specific to atoms with contents contained in a pair of parentheses.
2982     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2983     """
2984     if node.type != syms.atom:
2985         return 0
2986
2987     first = node.children[0]
2988     last = node.children[-1]
2989     if not (first.type == token.LPAR and last.type == token.RPAR):
2990         return 0
2991
2992     bt = BracketTracker()
2993     for c in node.children[1:-1]:
2994         if isinstance(c, Leaf):
2995             bt.mark(c)
2996         else:
2997             for leaf in c.leaves():
2998                 bt.mark(leaf)
2999     try:
3000         return bt.max_delimiter_priority()
3001
3002     except ValueError:
3003         return 0
3004
3005
3006 def ensure_visible(leaf: Leaf) -> None:
3007     """Make sure parentheses are visible.
3008
3009     They could be invisible as part of some statements (see
3010     :func:`normalize_invible_parens` and :func:`visit_import_from`).
3011     """
3012     if leaf.type == token.LPAR:
3013         leaf.value = "("
3014     elif leaf.type == token.RPAR:
3015         leaf.value = ")"
3016
3017
3018 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
3019     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
3020
3021     if not (
3022         opening_bracket.parent
3023         and opening_bracket.parent.type in {syms.atom, syms.import_from}
3024         and opening_bracket.value in "[{("
3025     ):
3026         return False
3027
3028     try:
3029         last_leaf = line.leaves[-1]
3030         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
3031         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
3032     except (IndexError, ValueError):
3033         return False
3034
3035     return max_priority == COMMA_PRIORITY
3036
3037
3038 def get_features_used(node: Node) -> Set[Feature]:
3039     """Return a set of (relatively) new Python features used in this file.
3040
3041     Currently looking for:
3042     - f-strings;
3043     - underscores in numeric literals; and
3044     - trailing commas after * or ** in function signatures and calls.
3045     """
3046     features: Set[Feature] = set()
3047     for n in node.pre_order():
3048         if n.type == token.STRING:
3049             value_head = n.value[:2]  # type: ignore
3050             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
3051                 features.add(Feature.F_STRINGS)
3052
3053         elif n.type == token.NUMBER:
3054             if "_" in n.value:  # type: ignore
3055                 features.add(Feature.NUMERIC_UNDERSCORES)
3056
3057         elif (
3058             n.type in {syms.typedargslist, syms.arglist}
3059             and n.children
3060             and n.children[-1].type == token.COMMA
3061         ):
3062             for ch in n.children:
3063                 if ch.type in STARS:
3064                     features.add(Feature.TRAILING_COMMA)
3065
3066                 if ch.type == syms.argument:
3067                     for argch in ch.children:
3068                         if argch.type in STARS:
3069                             features.add(Feature.TRAILING_COMMA)
3070
3071     return features
3072
3073
3074 def detect_target_versions(node: Node) -> Set[TargetVersion]:
3075     """Detect the version to target based on the nodes used."""
3076     features = get_features_used(node)
3077     return {
3078         version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
3079     }
3080
3081
3082 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
3083     """Generate sets of closing bracket IDs that should be omitted in a RHS.
3084
3085     Brackets can be omitted if the entire trailer up to and including
3086     a preceding closing bracket fits in one line.
3087
3088     Yielded sets are cumulative (contain results of previous yields, too).  First
3089     set is empty.
3090     """
3091
3092     omit: Set[LeafID] = set()
3093     yield omit
3094
3095     length = 4 * line.depth
3096     opening_bracket = None
3097     closing_bracket = None
3098     inner_brackets: Set[LeafID] = set()
3099     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
3100         length += leaf_length
3101         if length > line_length:
3102             break
3103
3104         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
3105         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
3106             break
3107
3108         if opening_bracket:
3109             if leaf is opening_bracket:
3110                 opening_bracket = None
3111             elif leaf.type in CLOSING_BRACKETS:
3112                 inner_brackets.add(id(leaf))
3113         elif leaf.type in CLOSING_BRACKETS:
3114             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
3115                 # Empty brackets would fail a split so treat them as "inner"
3116                 # brackets (e.g. only add them to the `omit` set if another
3117                 # pair of brackets was good enough.
3118                 inner_brackets.add(id(leaf))
3119                 continue
3120
3121             if closing_bracket:
3122                 omit.add(id(closing_bracket))
3123                 omit.update(inner_brackets)
3124                 inner_brackets.clear()
3125                 yield omit
3126
3127             if leaf.value:
3128                 opening_bracket = leaf.opening_bracket
3129                 closing_bracket = leaf
3130
3131
3132 def get_future_imports(node: Node) -> Set[str]:
3133     """Return a set of __future__ imports in the file."""
3134     imports: Set[str] = set()
3135
3136     def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
3137         for child in children:
3138             if isinstance(child, Leaf):
3139                 if child.type == token.NAME:
3140                     yield child.value
3141             elif child.type == syms.import_as_name:
3142                 orig_name = child.children[0]
3143                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
3144                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
3145                 yield orig_name.value
3146             elif child.type == syms.import_as_names:
3147                 yield from get_imports_from_children(child.children)
3148             else:
3149                 assert False, "Invalid syntax parsing imports"
3150
3151     for child in node.children:
3152         if child.type != syms.simple_stmt:
3153             break
3154         first_child = child.children[0]
3155         if isinstance(first_child, Leaf):
3156             # Continue looking if we see a docstring; otherwise stop.
3157             if (
3158                 len(child.children) == 2
3159                 and first_child.type == token.STRING
3160                 and child.children[1].type == token.NEWLINE
3161             ):
3162                 continue
3163             else:
3164                 break
3165         elif first_child.type == syms.import_from:
3166             module_name = first_child.children[1]
3167             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
3168                 break
3169             imports |= set(get_imports_from_children(first_child.children[3:]))
3170         else:
3171             break
3172     return imports
3173
3174
3175 def gen_python_files_in_dir(
3176     path: Path,
3177     root: Path,
3178     include: Pattern[str],
3179     exclude: Pattern[str],
3180     report: "Report",
3181 ) -> Iterator[Path]:
3182     """Generate all files under `path` whose paths are not excluded by the
3183     `exclude` regex, but are included by the `include` regex.
3184
3185     Symbolic links pointing outside of the `root` directory are ignored.
3186
3187     `report` is where output about exclusions goes.
3188     """
3189     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
3190     for child in path.iterdir():
3191         try:
3192             normalized_path = "/" + child.resolve().relative_to(root).as_posix()
3193         except ValueError:
3194             if child.is_symlink():
3195                 report.path_ignored(
3196                     child, f"is a symbolic link that points outside {root}"
3197                 )
3198                 continue
3199
3200             raise
3201
3202         if child.is_dir():
3203             normalized_path += "/"
3204         exclude_match = exclude.search(normalized_path)
3205         if exclude_match and exclude_match.group(0):
3206             report.path_ignored(child, f"matches the --exclude regular expression")
3207             continue
3208
3209         if child.is_dir():
3210             yield from gen_python_files_in_dir(child, root, include, exclude, report)
3211
3212         elif child.is_file():
3213             include_match = include.search(normalized_path)
3214             if include_match:
3215                 yield child
3216
3217
3218 @lru_cache()
3219 def find_project_root(srcs: Iterable[str]) -> Path:
3220     """Return a directory containing .git, .hg, or pyproject.toml.
3221
3222     That directory can be one of the directories passed in `srcs` or their
3223     common parent.
3224
3225     If no directory in the tree contains a marker that would specify it's the
3226     project root, the root of the file system is returned.
3227     """
3228     if not srcs:
3229         return Path("/").resolve()
3230
3231     common_base = min(Path(src).resolve() for src in srcs)
3232     if common_base.is_dir():
3233         # Append a fake file so `parents` below returns `common_base_dir`, too.
3234         common_base /= "fake-file"
3235     for directory in common_base.parents:
3236         if (directory / ".git").is_dir():
3237             return directory
3238
3239         if (directory / ".hg").is_dir():
3240             return directory
3241
3242         if (directory / "pyproject.toml").is_file():
3243             return directory
3244
3245     return directory
3246
3247
3248 @dataclass
3249 class Report:
3250     """Provides a reformatting counter. Can be rendered with `str(report)`."""
3251
3252     check: bool = False
3253     quiet: bool = False
3254     verbose: bool = False
3255     change_count: int = 0
3256     same_count: int = 0
3257     failure_count: int = 0
3258
3259     def done(self, src: Path, changed: Changed) -> None:
3260         """Increment the counter for successful reformatting. Write out a message."""
3261         if changed is Changed.YES:
3262             reformatted = "would reformat" if self.check else "reformatted"
3263             if self.verbose or not self.quiet:
3264                 out(f"{reformatted} {src}")
3265             self.change_count += 1
3266         else:
3267             if self.verbose:
3268                 if changed is Changed.NO:
3269                     msg = f"{src} already well formatted, good job."
3270                 else:
3271                     msg = f"{src} wasn't modified on disk since last run."
3272                 out(msg, bold=False)
3273             self.same_count += 1
3274
3275     def failed(self, src: Path, message: str) -> None:
3276         """Increment the counter for failed reformatting. Write out a message."""
3277         err(f"error: cannot format {src}: {message}")
3278         self.failure_count += 1
3279
3280     def path_ignored(self, path: Path, message: str) -> None:
3281         if self.verbose:
3282             out(f"{path} ignored: {message}", bold=False)
3283
3284     @property
3285     def return_code(self) -> int:
3286         """Return the exit code that the app should use.
3287
3288         This considers the current state of changed files and failures:
3289         - if there were any failures, return 123;
3290         - if any files were changed and --check is being used, return 1;
3291         - otherwise return 0.
3292         """
3293         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
3294         # 126 we have special return codes reserved by the shell.
3295         if self.failure_count:
3296             return 123
3297
3298         elif self.change_count and self.check:
3299             return 1
3300
3301         return 0
3302
3303     def __str__(self) -> str:
3304         """Render a color report of the current state.
3305
3306         Use `click.unstyle` to remove colors.
3307         """
3308         if self.check:
3309             reformatted = "would be reformatted"
3310             unchanged = "would be left unchanged"
3311             failed = "would fail to reformat"
3312         else:
3313             reformatted = "reformatted"
3314             unchanged = "left unchanged"
3315             failed = "failed to reformat"
3316         report = []
3317         if self.change_count:
3318             s = "s" if self.change_count > 1 else ""
3319             report.append(
3320                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
3321             )
3322         if self.same_count:
3323             s = "s" if self.same_count > 1 else ""
3324             report.append(f"{self.same_count} file{s} {unchanged}")
3325         if self.failure_count:
3326             s = "s" if self.failure_count > 1 else ""
3327             report.append(
3328                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
3329             )
3330         return ", ".join(report) + "."
3331
3332
3333 def assert_equivalent(src: str, dst: str) -> None:
3334     """Raise AssertionError if `src` and `dst` aren't equivalent."""
3335
3336     import ast
3337     import traceback
3338
3339     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3340         """Simple visitor generating strings to compare ASTs by content."""
3341         yield f"{'  ' * depth}{node.__class__.__name__}("
3342
3343         for field in sorted(node._fields):
3344             try:
3345                 value = getattr(node, field)
3346             except AttributeError:
3347                 continue
3348
3349             yield f"{'  ' * (depth+1)}{field}="
3350
3351             if isinstance(value, list):
3352                 for item in value:
3353                     if isinstance(item, ast.AST):
3354                         yield from _v(item, depth + 2)
3355
3356             elif isinstance(value, ast.AST):
3357                 yield from _v(value, depth + 2)
3358
3359             else:
3360                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3361
3362         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3363
3364     try:
3365         src_ast = ast.parse(src)
3366     except Exception as exc:
3367         major, minor = sys.version_info[:2]
3368         raise AssertionError(
3369             f"cannot use --safe with this file; failed to parse source file "
3370             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3371             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3372         )
3373
3374     try:
3375         dst_ast = ast.parse(dst)
3376     except Exception as exc:
3377         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3378         raise AssertionError(
3379             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3380             f"Please report a bug on https://github.com/ambv/black/issues.  "
3381             f"This invalid output might be helpful: {log}"
3382         ) from None
3383
3384     src_ast_str = "\n".join(_v(src_ast))
3385     dst_ast_str = "\n".join(_v(dst_ast))
3386     if src_ast_str != dst_ast_str:
3387         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3388         raise AssertionError(
3389             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3390             f"the source.  "
3391             f"Please report a bug on https://github.com/ambv/black/issues.  "
3392             f"This diff might be helpful: {log}"
3393         ) from None
3394
3395
3396 def assert_stable(src: str, dst: str, mode: FileMode) -> None:
3397     """Raise AssertionError if `dst` reformats differently the second time."""
3398     newdst = format_str(dst, mode=mode)
3399     if dst != newdst:
3400         log = dump_to_file(
3401             diff(src, dst, "source", "first pass"),
3402             diff(dst, newdst, "first pass", "second pass"),
3403         )
3404         raise AssertionError(
3405             f"INTERNAL ERROR: Black produced different code on the second pass "
3406             f"of the formatter.  "
3407             f"Please report a bug on https://github.com/ambv/black/issues.  "
3408             f"This diff might be helpful: {log}"
3409         ) from None
3410
3411
3412 def dump_to_file(*output: str) -> str:
3413     """Dump `output` to a temporary file. Return path to the file."""
3414     import tempfile
3415
3416     with tempfile.NamedTemporaryFile(
3417         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3418     ) as f:
3419         for lines in output:
3420             f.write(lines)
3421             if lines and lines[-1] != "\n":
3422                 f.write("\n")
3423     return f.name
3424
3425
3426 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3427     """Return a unified diff string between strings `a` and `b`."""
3428     import difflib
3429
3430     a_lines = [line + "\n" for line in a.split("\n")]
3431     b_lines = [line + "\n" for line in b.split("\n")]
3432     return "".join(
3433         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3434     )
3435
3436
3437 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3438     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3439     err("Aborted!")
3440     for task in tasks:
3441         task.cancel()
3442
3443
3444 def shutdown(loop: BaseEventLoop) -> None:
3445     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3446     try:
3447         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3448         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3449         if not to_cancel:
3450             return
3451
3452         for task in to_cancel:
3453             task.cancel()
3454         loop.run_until_complete(
3455             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3456         )
3457     finally:
3458         # `concurrent.futures.Future` objects cannot be cancelled once they
3459         # are already running. There might be some when the `shutdown()` happened.
3460         # Silence their logger's spew about the event loop being closed.
3461         cf_logger = logging.getLogger("concurrent.futures")
3462         cf_logger.setLevel(logging.CRITICAL)
3463         loop.close()
3464
3465
3466 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3467     """Replace `regex` with `replacement` twice on `original`.
3468
3469     This is used by string normalization to perform replaces on
3470     overlapping matches.
3471     """
3472     return regex.sub(replacement, regex.sub(replacement, original))
3473
3474
3475 def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
3476     """Compile a regular expression string in `regex`.
3477
3478     If it contains newlines, use verbose mode.
3479     """
3480     if "\n" in regex:
3481         regex = "(?x)" + regex
3482     return re.compile(regex)
3483
3484
3485 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3486     """Like `reversed(enumerate(sequence))` if that were possible."""
3487     index = len(sequence) - 1
3488     for element in reversed(sequence):
3489         yield (index, element)
3490         index -= 1
3491
3492
3493 def enumerate_with_length(
3494     line: Line, reversed: bool = False
3495 ) -> Iterator[Tuple[Index, Leaf, int]]:
3496     """Return an enumeration of leaves with their length.
3497
3498     Stops prematurely on multiline strings and standalone comments.
3499     """
3500     op = cast(
3501         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3502         enumerate_reversed if reversed else enumerate,
3503     )
3504     for index, leaf in op(line.leaves):
3505         length = len(leaf.prefix) + len(leaf.value)
3506         if "\n" in leaf.value:
3507             return  # Multiline strings, we can't continue.
3508
3509         comment: Optional[Leaf]
3510         for comment in line.comments_after(leaf):
3511             length += len(comment.value)
3512
3513         yield index, leaf, length
3514
3515
3516 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3517     """Return True if `line` is no longer than `line_length`.
3518
3519     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3520     """
3521     if not line_str:
3522         line_str = str(line).strip("\n")
3523     return (
3524         len(line_str) <= line_length
3525         and "\n" not in line_str  # multiline strings
3526         and not line.contains_standalone_comments()
3527     )
3528
3529
3530 def can_be_split(line: Line) -> bool:
3531     """Return False if the line cannot be split *for sure*.
3532
3533     This is not an exhaustive search but a cheap heuristic that we can use to
3534     avoid some unfortunate formattings (mostly around wrapping unsplittable code
3535     in unnecessary parentheses).
3536     """
3537     leaves = line.leaves
3538     if len(leaves) < 2:
3539         return False
3540
3541     if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
3542         call_count = 0
3543         dot_count = 0
3544         next = leaves[-1]
3545         for leaf in leaves[-2::-1]:
3546             if leaf.type in OPENING_BRACKETS:
3547                 if next.type not in CLOSING_BRACKETS:
3548                     return False
3549
3550                 call_count += 1
3551             elif leaf.type == token.DOT:
3552                 dot_count += 1
3553             elif leaf.type == token.NAME:
3554                 if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
3555                     return False
3556
3557             elif leaf.type not in CLOSING_BRACKETS:
3558                 return False
3559
3560             if dot_count > 1 and call_count > 1:
3561                 return False
3562
3563     return True
3564
3565
3566 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3567     """Does `line` have a shape safe to reformat without optional parens around it?
3568
3569     Returns True for only a subset of potentially nice looking formattings but
3570     the point is to not return false positives that end up producing lines that
3571     are too long.
3572     """
3573     bt = line.bracket_tracker
3574     if not bt.delimiters:
3575         # Without delimiters the optional parentheses are useless.
3576         return True
3577
3578     max_priority = bt.max_delimiter_priority()
3579     if bt.delimiter_count_with_priority(max_priority) > 1:
3580         # With more than one delimiter of a kind the optional parentheses read better.
3581         return False
3582
3583     if max_priority == DOT_PRIORITY:
3584         # A single stranded method call doesn't require optional parentheses.
3585         return True
3586
3587     assert len(line.leaves) >= 2, "Stranded delimiter"
3588
3589     first = line.leaves[0]
3590     second = line.leaves[1]
3591     penultimate = line.leaves[-2]
3592     last = line.leaves[-1]
3593
3594     # With a single delimiter, omit if the expression starts or ends with
3595     # a bracket.
3596     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3597         remainder = False
3598         length = 4 * line.depth
3599         for _index, leaf, leaf_length in enumerate_with_length(line):
3600             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3601                 remainder = True
3602             if remainder:
3603                 length += leaf_length
3604                 if length > line_length:
3605                     break
3606
3607                 if leaf.type in OPENING_BRACKETS:
3608                     # There are brackets we can further split on.
3609                     remainder = False
3610
3611         else:
3612             # checked the entire string and line length wasn't exceeded
3613             if len(line.leaves) == _index + 1:
3614                 return True
3615
3616         # Note: we are not returning False here because a line might have *both*
3617         # a leading opening bracket and a trailing closing bracket.  If the
3618         # opening bracket doesn't match our rule, maybe the closing will.
3619
3620     if (
3621         last.type == token.RPAR
3622         or last.type == token.RBRACE
3623         or (
3624             # don't use indexing for omitting optional parentheses;
3625             # it looks weird
3626             last.type == token.RSQB
3627             and last.parent
3628             and last.parent.type != syms.trailer
3629         )
3630     ):
3631         if penultimate.type in OPENING_BRACKETS:
3632             # Empty brackets don't help.
3633             return False
3634
3635         if is_multiline_string(first):
3636             # Additional wrapping of a multiline string in this situation is
3637             # unnecessary.
3638             return True
3639
3640         length = 4 * line.depth
3641         seen_other_brackets = False
3642         for _index, leaf, leaf_length in enumerate_with_length(line):
3643             length += leaf_length
3644             if leaf is last.opening_bracket:
3645                 if seen_other_brackets or length <= line_length:
3646                     return True
3647
3648             elif leaf.type in OPENING_BRACKETS:
3649                 # There are brackets we can further split on.
3650                 seen_other_brackets = True
3651
3652     return False
3653
3654
3655 def get_cache_file(mode: FileMode) -> Path:
3656     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
3657
3658
3659 def read_cache(mode: FileMode) -> Cache:
3660     """Read the cache if it exists and is well formed.
3661
3662     If it is not well formed, the call to write_cache later should resolve the issue.
3663     """
3664     cache_file = get_cache_file(mode)
3665     if not cache_file.exists():
3666         return {}
3667
3668     with cache_file.open("rb") as fobj:
3669         try:
3670             cache: Cache = pickle.load(fobj)
3671         except pickle.UnpicklingError:
3672             return {}
3673
3674     return cache
3675
3676
3677 def get_cache_info(path: Path) -> CacheInfo:
3678     """Return the information used to check if a file is already formatted or not."""
3679     stat = path.stat()
3680     return stat.st_mtime, stat.st_size
3681
3682
3683 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3684     """Split an iterable of paths in `sources` into two sets.
3685
3686     The first contains paths of files that modified on disk or are not in the
3687     cache. The other contains paths to non-modified files.
3688     """
3689     todo, done = set(), set()
3690     for src in sources:
3691         src = src.resolve()
3692         if cache.get(src) != get_cache_info(src):
3693             todo.add(src)
3694         else:
3695             done.add(src)
3696     return todo, done
3697
3698
3699 def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
3700     """Update the cache file."""
3701     cache_file = get_cache_file(mode)
3702     try:
3703         CACHE_DIR.mkdir(parents=True, exist_ok=True)
3704         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3705         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
3706             pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
3707         os.replace(f.name, cache_file)
3708     except OSError:
3709         pass
3710
3711
3712 def patch_click() -> None:
3713     """Make Click not crash.
3714
3715     On certain misconfigured environments, Python 3 selects the ASCII encoding as the
3716     default which restricts paths that it can access during the lifetime of the
3717     application.  Click refuses to work in this scenario by raising a RuntimeError.
3718
3719     In case of Black the likelihood that non-ASCII characters are going to be used in
3720     file paths is minimal since it's Python source code.  Moreover, this crash was
3721     spurious on Python 3.7 thanks to PEP 538 and PEP 540.
3722     """
3723     try:
3724         from click import core
3725         from click import _unicodefun  # type: ignore
3726     except ModuleNotFoundError:
3727         return
3728
3729     for module in (core, _unicodefun):
3730         if hasattr(module, "_verify_python3_env"):
3731             module._verify_python3_env = lambda: None
3732
3733
3734 def patched_main() -> None:
3735     freeze_support()
3736     patch_click()
3737     main()
3738
3739
3740 if __name__ == "__main__":
3741     patched_main()