]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

19a023cff52db2e92d5cc0950125e20c96d24dff
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import re
14 import tokenize
15 import signal
16 import sys
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Pattern,
28     Sequence,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34     cast,
35 )
36
37 from appdirs import user_cache_dir
38 from attr import dataclass, Factory
39 import click
40
41 # lib2to3 fork
42 from blib2to3.pytree import Node, Leaf, type_repr
43 from blib2to3 import pygram, pytree
44 from blib2to3.pgen2 import driver, token
45 from blib2to3.pgen2.parse import ParseError
46
47
48 __version__ = "18.5b1"
49 DEFAULT_LINE_LENGTH = 88
50 DEFAULT_EXCLUDES = (
51     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
52 )
53 DEFAULT_INCLUDES = r"\.pyi?$"
54 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
55
56
57 # types
58 FileContent = str
59 Encoding = str
60 Depth = int
61 NodeType = int
62 LeafID = int
63 Priority = int
64 Index = int
65 LN = Union[Leaf, Node]
66 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
67 Timestamp = float
68 FileSize = int
69 CacheInfo = Tuple[Timestamp, FileSize]
70 Cache = Dict[Path, CacheInfo]
71 out = partial(click.secho, bold=True, err=True)
72 err = partial(click.secho, fg="red", err=True)
73
74 pygram.initialize(CACHE_DIR)
75 syms = pygram.python_symbols
76
77
78 class NothingChanged(UserWarning):
79     """Raised by :func:`format_file` when reformatted code is the same as source."""
80
81
82 class CannotSplit(Exception):
83     """A readable split that fits the allotted line length is impossible.
84
85     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
86     :func:`delimiter_split`.
87     """
88
89
90 class FormatError(Exception):
91     """Base exception for `# fmt: on` and `# fmt: off` handling.
92
93     It holds the number of bytes of the prefix consumed before the format
94     control comment appeared.
95     """
96
97     def __init__(self, consumed: int) -> None:
98         super().__init__(consumed)
99         self.consumed = consumed
100
101     def trim_prefix(self, leaf: Leaf) -> None:
102         leaf.prefix = leaf.prefix[self.consumed :]
103
104     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
105         """Returns a new Leaf from the consumed part of the prefix."""
106         unformatted_prefix = leaf.prefix[: self.consumed]
107         return Leaf(token.NEWLINE, unformatted_prefix)
108
109
110 class FormatOn(FormatError):
111     """Found a comment like `# fmt: on` in the file."""
112
113
114 class FormatOff(FormatError):
115     """Found a comment like `# fmt: off` in the file."""
116
117
118 class WriteBack(Enum):
119     NO = 0
120     YES = 1
121     DIFF = 2
122
123     @classmethod
124     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
125         if check and not diff:
126             return cls.NO
127
128         return cls.DIFF if diff else cls.YES
129
130
131 class Changed(Enum):
132     NO = 0
133     CACHED = 1
134     YES = 2
135
136
137 class FileMode(Flag):
138     AUTO_DETECT = 0
139     PYTHON36 = 1
140     PYI = 2
141     NO_STRING_NORMALIZATION = 4
142
143     @classmethod
144     def from_configuration(
145         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
146     ) -> "FileMode":
147         mode = cls.AUTO_DETECT
148         if py36:
149             mode |= cls.PYTHON36
150         if pyi:
151             mode |= cls.PYI
152         if skip_string_normalization:
153             mode |= cls.NO_STRING_NORMALIZATION
154         return mode
155
156
157 @click.command()
158 @click.option(
159     "-l",
160     "--line-length",
161     type=int,
162     default=DEFAULT_LINE_LENGTH,
163     help="How many character per line to allow.",
164     show_default=True,
165 )
166 @click.option(
167     "--py36",
168     is_flag=True,
169     help=(
170         "Allow using Python 3.6-only syntax on all input files.  This will put "
171         "trailing commas in function signatures and calls also after *args and "
172         "**kwargs.  [default: per-file auto-detection]"
173     ),
174 )
175 @click.option(
176     "--pyi",
177     is_flag=True,
178     help=(
179         "Format all input files like typing stubs regardless of file extension "
180         "(useful when piping source on standard input)."
181     ),
182 )
183 @click.option(
184     "-S",
185     "--skip-string-normalization",
186     is_flag=True,
187     help="Don't normalize string quotes or prefixes.",
188 )
189 @click.option(
190     "--check",
191     is_flag=True,
192     help=(
193         "Don't write the files back, just return the status.  Return code 0 "
194         "means nothing would change.  Return code 1 means some files would be "
195         "reformatted.  Return code 123 means there was an internal error."
196     ),
197 )
198 @click.option(
199     "--diff",
200     is_flag=True,
201     help="Don't write the files back, just output a diff for each file on stdout.",
202 )
203 @click.option(
204     "--fast/--safe",
205     is_flag=True,
206     help="If --fast given, skip temporary sanity checks. [default: --safe]",
207 )
208 @click.option(
209     "--include",
210     type=str,
211     default=DEFAULT_INCLUDES,
212     help=(
213         "A regular expression that matches files and directories that should be "
214         "included on recursive searches.  An empty value means all files are "
215         "included regardless of the name.  Use forward slashes for directories on "
216         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
217         "later."
218     ),
219     show_default=True,
220 )
221 @click.option(
222     "--exclude",
223     type=str,
224     default=DEFAULT_EXCLUDES,
225     help=(
226         "A regular expression that matches files and directories that should be "
227         "excluded on recursive searches.  An empty value means no paths are excluded. "
228         "Use forward slashes for directories on all platforms (Windows, too).  "
229         "Exclusions are calculated first, inclusions later."
230     ),
231     show_default=True,
232 )
233 @click.option(
234     "-q",
235     "--quiet",
236     is_flag=True,
237     help=(
238         "Don't emit non-error messages to stderr. Errors are still emitted, "
239         "silence those with 2>/dev/null."
240     ),
241 )
242 @click.option(
243     "-v",
244     "--verbose",
245     is_flag=True,
246     help=(
247         "Also emit messages to stderr about files that were not changed or were "
248         "ignored due to --exclude=."
249     ),
250 )
251 @click.version_option(version=__version__)
252 @click.argument(
253     "src",
254     nargs=-1,
255     type=click.Path(
256         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
257     ),
258 )
259 @click.pass_context
260 def main(
261     ctx: click.Context,
262     line_length: int,
263     check: bool,
264     diff: bool,
265     fast: bool,
266     pyi: bool,
267     py36: bool,
268     skip_string_normalization: bool,
269     quiet: bool,
270     verbose: bool,
271     include: str,
272     exclude: str,
273     src: List[str],
274 ) -> None:
275     """The uncompromising code formatter."""
276     write_back = WriteBack.from_configuration(check=check, diff=diff)
277     mode = FileMode.from_configuration(
278         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
279     )
280     report = Report(check=check, quiet=quiet, verbose=verbose)
281     sources: Set[Path] = set()
282     try:
283         include_regex = re.compile(include)
284     except re.error:
285         err(f"Invalid regular expression for include given: {include!r}")
286         ctx.exit(2)
287     try:
288         exclude_regex = re.compile(exclude)
289     except re.error:
290         err(f"Invalid regular expression for exclude given: {exclude!r}")
291         ctx.exit(2)
292     root = find_project_root(src)
293     for s in src:
294         p = Path(s)
295         if p.is_dir():
296             sources.update(
297                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
298             )
299         elif p.is_file() or s == "-":
300             # if a file was explicitly given, we don't care about its extension
301             sources.add(p)
302         else:
303             err(f"invalid path: {s}")
304     if len(sources) == 0:
305         if verbose or not quiet:
306             out("No paths given. Nothing to do 😴")
307         ctx.exit(0)
308         return
309
310     elif len(sources) == 1:
311         reformat_one(
312             src=sources.pop(),
313             line_length=line_length,
314             fast=fast,
315             write_back=write_back,
316             mode=mode,
317             report=report,
318         )
319     else:
320         loop = asyncio.get_event_loop()
321         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
322         try:
323             loop.run_until_complete(
324                 schedule_formatting(
325                     sources=sources,
326                     line_length=line_length,
327                     fast=fast,
328                     write_back=write_back,
329                     mode=mode,
330                     report=report,
331                     loop=loop,
332                     executor=executor,
333                 )
334             )
335         finally:
336             shutdown(loop)
337     if verbose or not quiet:
338         out("All done! ✨ 🍰 ✨")
339         click.echo(str(report))
340     ctx.exit(report.return_code)
341
342
343 def reformat_one(
344     src: Path,
345     line_length: int,
346     fast: bool,
347     write_back: WriteBack,
348     mode: FileMode,
349     report: "Report",
350 ) -> None:
351     """Reformat a single file under `src` without spawning child processes.
352
353     If `quiet` is True, non-error messages are not output. `line_length`,
354     `write_back`, `fast` and `pyi` options are passed to
355     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
356     """
357     try:
358         changed = Changed.NO
359         if not src.is_file() and str(src) == "-":
360             if format_stdin_to_stdout(
361                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
362             ):
363                 changed = Changed.YES
364         else:
365             cache: Cache = {}
366             if write_back != WriteBack.DIFF:
367                 cache = read_cache(line_length, mode)
368                 res_src = src.resolve()
369                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
370                     changed = Changed.CACHED
371             if changed is not Changed.CACHED and format_file_in_place(
372                 src,
373                 line_length=line_length,
374                 fast=fast,
375                 write_back=write_back,
376                 mode=mode,
377             ):
378                 changed = Changed.YES
379             if write_back == WriteBack.YES and changed is not Changed.NO:
380                 write_cache(cache, [src], line_length, mode)
381         report.done(src, changed)
382     except Exception as exc:
383         report.failed(src, str(exc))
384
385
386 async def schedule_formatting(
387     sources: Set[Path],
388     line_length: int,
389     fast: bool,
390     write_back: WriteBack,
391     mode: FileMode,
392     report: "Report",
393     loop: BaseEventLoop,
394     executor: Executor,
395 ) -> None:
396     """Run formatting of `sources` in parallel using the provided `executor`.
397
398     (Use ProcessPoolExecutors for actual parallelism.)
399
400     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
401     :func:`format_file_in_place`.
402     """
403     cache: Cache = {}
404     if write_back != WriteBack.DIFF:
405         cache = read_cache(line_length, mode)
406         sources, cached = filter_cached(cache, sources)
407         for src in sorted(cached):
408             report.done(src, Changed.CACHED)
409     cancelled = []
410     formatted = []
411     if sources:
412         lock = None
413         if write_back == WriteBack.DIFF:
414             # For diff output, we need locks to ensure we don't interleave output
415             # from different processes.
416             manager = Manager()
417             lock = manager.Lock()
418         tasks = {
419             loop.run_in_executor(
420                 executor,
421                 format_file_in_place,
422                 src,
423                 line_length,
424                 fast,
425                 write_back,
426                 mode,
427                 lock,
428             ): src
429             for src in sorted(sources)
430         }
431         pending: Iterable[asyncio.Task] = tasks.keys()
432         try:
433             loop.add_signal_handler(signal.SIGINT, cancel, pending)
434             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
435         except NotImplementedError:
436             # There are no good alternatives for these on Windows
437             pass
438         while pending:
439             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
440             for task in done:
441                 src = tasks.pop(task)
442                 if task.cancelled():
443                     cancelled.append(task)
444                 elif task.exception():
445                     report.failed(src, str(task.exception()))
446                 else:
447                     formatted.append(src)
448                     report.done(src, Changed.YES if task.result() else Changed.NO)
449     if cancelled:
450         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
451     if write_back == WriteBack.YES and formatted:
452         write_cache(cache, formatted, line_length, mode)
453
454
455 def format_file_in_place(
456     src: Path,
457     line_length: int,
458     fast: bool,
459     write_back: WriteBack = WriteBack.NO,
460     mode: FileMode = FileMode.AUTO_DETECT,
461     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
462 ) -> bool:
463     """Format file under `src` path. Return True if changed.
464
465     If `write_back` is True, write reformatted code back to stdout.
466     `line_length` and `fast` options are passed to :func:`format_file_contents`.
467     """
468     if src.suffix == ".pyi":
469         mode |= FileMode.PYI
470
471     with open(src, "rb") as buf:
472         newline, encoding, src_contents = prepare_input(buf.read())
473     try:
474         dst_contents = format_file_contents(
475             src_contents, line_length=line_length, fast=fast, mode=mode
476         )
477     except NothingChanged:
478         return False
479
480     if write_back == write_back.YES:
481         with open(src, "w", encoding=encoding, newline=newline) as f:
482             f.write(dst_contents)
483     elif write_back == write_back.DIFF:
484         src_name = f"{src}  (original)"
485         dst_name = f"{src}  (formatted)"
486         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
487         if lock:
488             lock.acquire()
489         try:
490             f = io.TextIOWrapper(
491                 sys.stdout.buffer,
492                 encoding=encoding,
493                 newline=newline,
494                 write_through=True,
495             )
496             f.write(diff_contents)
497             f.detach()
498         finally:
499             if lock:
500                 lock.release()
501     return True
502
503
504 def format_stdin_to_stdout(
505     line_length: int,
506     fast: bool,
507     write_back: WriteBack = WriteBack.NO,
508     mode: FileMode = FileMode.AUTO_DETECT,
509 ) -> bool:
510     """Format file on stdin. Return True if changed.
511
512     If `write_back` is True, write reformatted code back to stdout.
513     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
514     :func:`format_file_contents`.
515     """
516     newline, encoding, src = prepare_input(sys.stdin.buffer.read())
517     dst = src
518     try:
519         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
520         return True
521
522     except NothingChanged:
523         return False
524
525     finally:
526         if write_back == WriteBack.YES:
527             f = io.TextIOWrapper(
528                 sys.stdout.buffer,
529                 encoding=encoding,
530                 newline=newline,
531                 write_through=True,
532             )
533             f.write(dst)
534             f.detach()
535         elif write_back == WriteBack.DIFF:
536             src_name = "<stdin>  (original)"
537             dst_name = "<stdin>  (formatted)"
538             f = io.TextIOWrapper(
539                 sys.stdout.buffer,
540                 encoding=encoding,
541                 newline=newline,
542                 write_through=True,
543             )
544             f.write(diff(src, dst, src_name, dst_name))
545             f.detach()
546
547
548 def format_file_contents(
549     src_contents: str,
550     *,
551     line_length: int,
552     fast: bool,
553     mode: FileMode = FileMode.AUTO_DETECT,
554 ) -> FileContent:
555     """Reformat contents a file and return new contents.
556
557     If `fast` is False, additionally confirm that the reformatted code is
558     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
559     `line_length` is passed to :func:`format_str`.
560     """
561     if src_contents.strip() == "":
562         raise NothingChanged
563
564     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
565     if src_contents == dst_contents:
566         raise NothingChanged
567
568     if not fast:
569         assert_equivalent(src_contents, dst_contents)
570         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
571     return dst_contents
572
573
574 def format_str(
575     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
576 ) -> FileContent:
577     """Reformat a string and return new contents.
578
579     `line_length` determines how many characters per line are allowed.
580     """
581     src_node = lib2to3_parse(src_contents)
582     dst_contents = ""
583     future_imports = get_future_imports(src_node)
584     is_pyi = bool(mode & FileMode.PYI)
585     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
586     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
587     lines = LineGenerator(
588         remove_u_prefix=py36 or "unicode_literals" in future_imports,
589         is_pyi=is_pyi,
590         normalize_strings=normalize_strings,
591     )
592     elt = EmptyLineTracker(is_pyi=is_pyi)
593     empty_line = Line()
594     after = 0
595     for current_line in lines.visit(src_node):
596         for _ in range(after):
597             dst_contents += str(empty_line)
598         before, after = elt.maybe_empty_lines(current_line)
599         for _ in range(before):
600             dst_contents += str(empty_line)
601         for line in split_line(current_line, line_length=line_length, py36=py36):
602             dst_contents += str(line)
603     return dst_contents
604
605
606 def prepare_input(src: bytes) -> Tuple[str, str, str]:
607     """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
608
609     Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
610     universal newlines (i.e. only LF).
611     """
612     srcbuf = io.BytesIO(src)
613     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
614     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
615     srcbuf.seek(0)
616     return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
617
618
619 GRAMMARS = [
620     pygram.python_grammar_no_print_statement_no_exec_statement,
621     pygram.python_grammar_no_print_statement,
622     pygram.python_grammar,
623 ]
624
625
626 def lib2to3_parse(src_txt: str) -> Node:
627     """Given a string with source, return the lib2to3 Node."""
628     grammar = pygram.python_grammar_no_print_statement
629     if src_txt[-1] != "\n":
630         src_txt += "\n"
631     for grammar in GRAMMARS:
632         drv = driver.Driver(grammar, pytree.convert)
633         try:
634             result = drv.parse_string(src_txt, True)
635             break
636
637         except ParseError as pe:
638             lineno, column = pe.context[1]
639             lines = src_txt.splitlines()
640             try:
641                 faulty_line = lines[lineno - 1]
642             except IndexError:
643                 faulty_line = "<line number missing in source>"
644             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
645     else:
646         raise exc from None
647
648     if isinstance(result, Leaf):
649         result = Node(syms.file_input, [result])
650     return result
651
652
653 def lib2to3_unparse(node: Node) -> str:
654     """Given a lib2to3 node, return its string representation."""
655     code = str(node)
656     return code
657
658
659 T = TypeVar("T")
660
661
662 class Visitor(Generic[T]):
663     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
664
665     def visit(self, node: LN) -> Iterator[T]:
666         """Main method to visit `node` and its children.
667
668         It tries to find a `visit_*()` method for the given `node.type`, like
669         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
670         If no dedicated `visit_*()` method is found, chooses `visit_default()`
671         instead.
672
673         Then yields objects of type `T` from the selected visitor.
674         """
675         if node.type < 256:
676             name = token.tok_name[node.type]
677         else:
678             name = type_repr(node.type)
679         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
680
681     def visit_default(self, node: LN) -> Iterator[T]:
682         """Default `visit_*()` implementation. Recurses to children of `node`."""
683         if isinstance(node, Node):
684             for child in node.children:
685                 yield from self.visit(child)
686
687
688 @dataclass
689 class DebugVisitor(Visitor[T]):
690     tree_depth: int = 0
691
692     def visit_default(self, node: LN) -> Iterator[T]:
693         indent = " " * (2 * self.tree_depth)
694         if isinstance(node, Node):
695             _type = type_repr(node.type)
696             out(f"{indent}{_type}", fg="yellow")
697             self.tree_depth += 1
698             for child in node.children:
699                 yield from self.visit(child)
700
701             self.tree_depth -= 1
702             out(f"{indent}/{_type}", fg="yellow", bold=False)
703         else:
704             _type = token.tok_name.get(node.type, str(node.type))
705             out(f"{indent}{_type}", fg="blue", nl=False)
706             if node.prefix:
707                 # We don't have to handle prefixes for `Node` objects since
708                 # that delegates to the first child anyway.
709                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
710             out(f" {node.value!r}", fg="blue", bold=False)
711
712     @classmethod
713     def show(cls, code: str) -> None:
714         """Pretty-print the lib2to3 AST of a given string of `code`.
715
716         Convenience method for debugging.
717         """
718         v: DebugVisitor[None] = DebugVisitor()
719         list(v.visit(lib2to3_parse(code)))
720
721
722 KEYWORDS = set(keyword.kwlist)
723 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
724 FLOW_CONTROL = {"return", "raise", "break", "continue"}
725 STATEMENT = {
726     syms.if_stmt,
727     syms.while_stmt,
728     syms.for_stmt,
729     syms.try_stmt,
730     syms.except_clause,
731     syms.with_stmt,
732     syms.funcdef,
733     syms.classdef,
734 }
735 STANDALONE_COMMENT = 153
736 LOGIC_OPERATORS = {"and", "or"}
737 COMPARATORS = {
738     token.LESS,
739     token.GREATER,
740     token.EQEQUAL,
741     token.NOTEQUAL,
742     token.LESSEQUAL,
743     token.GREATEREQUAL,
744 }
745 MATH_OPERATORS = {
746     token.VBAR,
747     token.CIRCUMFLEX,
748     token.AMPER,
749     token.LEFTSHIFT,
750     token.RIGHTSHIFT,
751     token.PLUS,
752     token.MINUS,
753     token.STAR,
754     token.SLASH,
755     token.DOUBLESLASH,
756     token.PERCENT,
757     token.AT,
758     token.TILDE,
759     token.DOUBLESTAR,
760 }
761 STARS = {token.STAR, token.DOUBLESTAR}
762 VARARGS_PARENTS = {
763     syms.arglist,
764     syms.argument,  # double star in arglist
765     syms.trailer,  # single argument to call
766     syms.typedargslist,
767     syms.varargslist,  # lambdas
768 }
769 UNPACKING_PARENTS = {
770     syms.atom,  # single element of a list or set literal
771     syms.dictsetmaker,
772     syms.listmaker,
773     syms.testlist_gexp,
774 }
775 TEST_DESCENDANTS = {
776     syms.test,
777     syms.lambdef,
778     syms.or_test,
779     syms.and_test,
780     syms.not_test,
781     syms.comparison,
782     syms.star_expr,
783     syms.expr,
784     syms.xor_expr,
785     syms.and_expr,
786     syms.shift_expr,
787     syms.arith_expr,
788     syms.trailer,
789     syms.term,
790     syms.power,
791 }
792 ASSIGNMENTS = {
793     "=",
794     "+=",
795     "-=",
796     "*=",
797     "@=",
798     "/=",
799     "%=",
800     "&=",
801     "|=",
802     "^=",
803     "<<=",
804     ">>=",
805     "**=",
806     "//=",
807 }
808 COMPREHENSION_PRIORITY = 20
809 COMMA_PRIORITY = 18
810 TERNARY_PRIORITY = 16
811 LOGIC_PRIORITY = 14
812 STRING_PRIORITY = 12
813 COMPARATOR_PRIORITY = 10
814 MATH_PRIORITIES = {
815     token.VBAR: 9,
816     token.CIRCUMFLEX: 8,
817     token.AMPER: 7,
818     token.LEFTSHIFT: 6,
819     token.RIGHTSHIFT: 6,
820     token.PLUS: 5,
821     token.MINUS: 5,
822     token.STAR: 4,
823     token.SLASH: 4,
824     token.DOUBLESLASH: 4,
825     token.PERCENT: 4,
826     token.AT: 4,
827     token.TILDE: 3,
828     token.DOUBLESTAR: 2,
829 }
830 DOT_PRIORITY = 1
831
832
833 @dataclass
834 class BracketTracker:
835     """Keeps track of brackets on a line."""
836
837     depth: int = 0
838     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
839     delimiters: Dict[LeafID, Priority] = Factory(dict)
840     previous: Optional[Leaf] = None
841     _for_loop_variable: int = 0
842     _lambda_arguments: int = 0
843
844     def mark(self, leaf: Leaf) -> None:
845         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
846
847         All leaves receive an int `bracket_depth` field that stores how deep
848         within brackets a given leaf is. 0 means there are no enclosing brackets
849         that started on this line.
850
851         If a leaf is itself a closing bracket, it receives an `opening_bracket`
852         field that it forms a pair with. This is a one-directional link to
853         avoid reference cycles.
854
855         If a leaf is a delimiter (a token on which Black can split the line if
856         needed) and it's on depth 0, its `id()` is stored in the tracker's
857         `delimiters` field.
858         """
859         if leaf.type == token.COMMENT:
860             return
861
862         self.maybe_decrement_after_for_loop_variable(leaf)
863         self.maybe_decrement_after_lambda_arguments(leaf)
864         if leaf.type in CLOSING_BRACKETS:
865             self.depth -= 1
866             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
867             leaf.opening_bracket = opening_bracket
868         leaf.bracket_depth = self.depth
869         if self.depth == 0:
870             delim = is_split_before_delimiter(leaf, self.previous)
871             if delim and self.previous is not None:
872                 self.delimiters[id(self.previous)] = delim
873             else:
874                 delim = is_split_after_delimiter(leaf, self.previous)
875                 if delim:
876                     self.delimiters[id(leaf)] = delim
877         if leaf.type in OPENING_BRACKETS:
878             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
879             self.depth += 1
880         self.previous = leaf
881         self.maybe_increment_lambda_arguments(leaf)
882         self.maybe_increment_for_loop_variable(leaf)
883
884     def any_open_brackets(self) -> bool:
885         """Return True if there is an yet unmatched open bracket on the line."""
886         return bool(self.bracket_match)
887
888     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
889         """Return the highest priority of a delimiter found on the line.
890
891         Values are consistent with what `is_split_*_delimiter()` return.
892         Raises ValueError on no delimiters.
893         """
894         return max(v for k, v in self.delimiters.items() if k not in exclude)
895
896     def delimiter_count_with_priority(self, priority: int = 0) -> int:
897         """Return the number of delimiters with the given `priority`.
898
899         If no `priority` is passed, defaults to max priority on the line.
900         """
901         if not self.delimiters:
902             return 0
903
904         priority = priority or self.max_delimiter_priority()
905         return sum(1 for p in self.delimiters.values() if p == priority)
906
907     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
908         """In a for loop, or comprehension, the variables are often unpacks.
909
910         To avoid splitting on the comma in this situation, increase the depth of
911         tokens between `for` and `in`.
912         """
913         if leaf.type == token.NAME and leaf.value == "for":
914             self.depth += 1
915             self._for_loop_variable += 1
916             return True
917
918         return False
919
920     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
921         """See `maybe_increment_for_loop_variable` above for explanation."""
922         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
923             self.depth -= 1
924             self._for_loop_variable -= 1
925             return True
926
927         return False
928
929     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
930         """In a lambda expression, there might be more than one argument.
931
932         To avoid splitting on the comma in this situation, increase the depth of
933         tokens between `lambda` and `:`.
934         """
935         if leaf.type == token.NAME and leaf.value == "lambda":
936             self.depth += 1
937             self._lambda_arguments += 1
938             return True
939
940         return False
941
942     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
943         """See `maybe_increment_lambda_arguments` above for explanation."""
944         if self._lambda_arguments and leaf.type == token.COLON:
945             self.depth -= 1
946             self._lambda_arguments -= 1
947             return True
948
949         return False
950
951     def get_open_lsqb(self) -> Optional[Leaf]:
952         """Return the most recent opening square bracket (if any)."""
953         return self.bracket_match.get((self.depth - 1, token.RSQB))
954
955
956 @dataclass
957 class Line:
958     """Holds leaves and comments. Can be printed with `str(line)`."""
959
960     depth: int = 0
961     leaves: List[Leaf] = Factory(list)
962     comments: List[Tuple[Index, Leaf]] = Factory(list)
963     bracket_tracker: BracketTracker = Factory(BracketTracker)
964     inside_brackets: bool = False
965     should_explode: bool = False
966
967     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
968         """Add a new `leaf` to the end of the line.
969
970         Unless `preformatted` is True, the `leaf` will receive a new consistent
971         whitespace prefix and metadata applied by :class:`BracketTracker`.
972         Trailing commas are maybe removed, unpacked for loop variables are
973         demoted from being delimiters.
974
975         Inline comments are put aside.
976         """
977         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
978         if not has_value:
979             return
980
981         if token.COLON == leaf.type and self.is_class_paren_empty:
982             del self.leaves[-2:]
983         if self.leaves and not preformatted:
984             # Note: at this point leaf.prefix should be empty except for
985             # imports, for which we only preserve newlines.
986             leaf.prefix += whitespace(
987                 leaf, complex_subscript=self.is_complex_subscript(leaf)
988             )
989         if self.inside_brackets or not preformatted:
990             self.bracket_tracker.mark(leaf)
991             self.maybe_remove_trailing_comma(leaf)
992         if not self.append_comment(leaf):
993             self.leaves.append(leaf)
994
995     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
996         """Like :func:`append()` but disallow invalid standalone comment structure.
997
998         Raises ValueError when any `leaf` is appended after a standalone comment
999         or when a standalone comment is not the first leaf on the line.
1000         """
1001         if self.bracket_tracker.depth == 0:
1002             if self.is_comment:
1003                 raise ValueError("cannot append to standalone comments")
1004
1005             if self.leaves and leaf.type == STANDALONE_COMMENT:
1006                 raise ValueError(
1007                     "cannot append standalone comments to a populated line"
1008                 )
1009
1010         self.append(leaf, preformatted=preformatted)
1011
1012     @property
1013     def is_comment(self) -> bool:
1014         """Is this line a standalone comment?"""
1015         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1016
1017     @property
1018     def is_decorator(self) -> bool:
1019         """Is this line a decorator?"""
1020         return bool(self) and self.leaves[0].type == token.AT
1021
1022     @property
1023     def is_import(self) -> bool:
1024         """Is this an import line?"""
1025         return bool(self) and is_import(self.leaves[0])
1026
1027     @property
1028     def is_class(self) -> bool:
1029         """Is this line a class definition?"""
1030         return (
1031             bool(self)
1032             and self.leaves[0].type == token.NAME
1033             and self.leaves[0].value == "class"
1034         )
1035
1036     @property
1037     def is_stub_class(self) -> bool:
1038         """Is this line a class definition with a body consisting only of "..."?"""
1039         return self.is_class and self.leaves[-3:] == [
1040             Leaf(token.DOT, ".") for _ in range(3)
1041         ]
1042
1043     @property
1044     def is_def(self) -> bool:
1045         """Is this a function definition? (Also returns True for async defs.)"""
1046         try:
1047             first_leaf = self.leaves[0]
1048         except IndexError:
1049             return False
1050
1051         try:
1052             second_leaf: Optional[Leaf] = self.leaves[1]
1053         except IndexError:
1054             second_leaf = None
1055         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1056             first_leaf.type == token.ASYNC
1057             and second_leaf is not None
1058             and second_leaf.type == token.NAME
1059             and second_leaf.value == "def"
1060         )
1061
1062     @property
1063     def is_class_paren_empty(self) -> bool:
1064         """Is this a class with no base classes but using parentheses?
1065
1066         Those are unnecessary and should be removed.
1067         """
1068         return (
1069             bool(self)
1070             and len(self.leaves) == 4
1071             and self.is_class
1072             and self.leaves[2].type == token.LPAR
1073             and self.leaves[2].value == "("
1074             and self.leaves[3].type == token.RPAR
1075             and self.leaves[3].value == ")"
1076         )
1077
1078     @property
1079     def is_triple_quoted_string(self) -> bool:
1080         """Is the line a triple quoted string?"""
1081         return (
1082             bool(self)
1083             and self.leaves[0].type == token.STRING
1084             and self.leaves[0].value.startswith(('"""', "'''"))
1085         )
1086
1087     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1088         """If so, needs to be split before emitting."""
1089         for leaf in self.leaves:
1090             if leaf.type == STANDALONE_COMMENT:
1091                 if leaf.bracket_depth <= depth_limit:
1092                     return True
1093
1094         return False
1095
1096     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1097         """Remove trailing comma if there is one and it's safe."""
1098         if not (
1099             self.leaves
1100             and self.leaves[-1].type == token.COMMA
1101             and closing.type in CLOSING_BRACKETS
1102         ):
1103             return False
1104
1105         if closing.type == token.RBRACE:
1106             self.remove_trailing_comma()
1107             return True
1108
1109         if closing.type == token.RSQB:
1110             comma = self.leaves[-1]
1111             if comma.parent and comma.parent.type == syms.listmaker:
1112                 self.remove_trailing_comma()
1113                 return True
1114
1115         # For parens let's check if it's safe to remove the comma.
1116         # Imports are always safe.
1117         if self.is_import:
1118             self.remove_trailing_comma()
1119             return True
1120
1121         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1122         # change a tuple into a different type by removing the comma.
1123         depth = closing.bracket_depth + 1
1124         commas = 0
1125         opening = closing.opening_bracket
1126         for _opening_index, leaf in enumerate(self.leaves):
1127             if leaf is opening:
1128                 break
1129
1130         else:
1131             return False
1132
1133         for leaf in self.leaves[_opening_index + 1 :]:
1134             if leaf is closing:
1135                 break
1136
1137             bracket_depth = leaf.bracket_depth
1138             if bracket_depth == depth and leaf.type == token.COMMA:
1139                 commas += 1
1140                 if leaf.parent and leaf.parent.type == syms.arglist:
1141                     commas += 1
1142                     break
1143
1144         if commas > 1:
1145             self.remove_trailing_comma()
1146             return True
1147
1148         return False
1149
1150     def append_comment(self, comment: Leaf) -> bool:
1151         """Add an inline or standalone comment to the line."""
1152         if (
1153             comment.type == STANDALONE_COMMENT
1154             and self.bracket_tracker.any_open_brackets()
1155         ):
1156             comment.prefix = ""
1157             return False
1158
1159         if comment.type != token.COMMENT:
1160             return False
1161
1162         after = len(self.leaves) - 1
1163         if after == -1:
1164             comment.type = STANDALONE_COMMENT
1165             comment.prefix = ""
1166             return False
1167
1168         else:
1169             self.comments.append((after, comment))
1170             return True
1171
1172     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1173         """Generate comments that should appear directly after `leaf`.
1174
1175         Provide a non-negative leaf `_index` to speed up the function.
1176         """
1177         if _index == -1:
1178             for _index, _leaf in enumerate(self.leaves):
1179                 if leaf is _leaf:
1180                     break
1181
1182             else:
1183                 return
1184
1185         for index, comment_after in self.comments:
1186             if _index == index:
1187                 yield comment_after
1188
1189     def remove_trailing_comma(self) -> None:
1190         """Remove the trailing comma and moves the comments attached to it."""
1191         comma_index = len(self.leaves) - 1
1192         for i in range(len(self.comments)):
1193             comment_index, comment = self.comments[i]
1194             if comment_index == comma_index:
1195                 self.comments[i] = (comma_index - 1, comment)
1196         self.leaves.pop()
1197
1198     def is_complex_subscript(self, leaf: Leaf) -> bool:
1199         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1200         open_lsqb = (
1201             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1202         )
1203         if open_lsqb is None:
1204             return False
1205
1206         subscript_start = open_lsqb.next_sibling
1207         if (
1208             isinstance(subscript_start, Node)
1209             and subscript_start.type == syms.subscriptlist
1210         ):
1211             subscript_start = child_towards(subscript_start, leaf)
1212         return subscript_start is not None and any(
1213             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1214         )
1215
1216     def __str__(self) -> str:
1217         """Render the line."""
1218         if not self:
1219             return "\n"
1220
1221         indent = "    " * self.depth
1222         leaves = iter(self.leaves)
1223         first = next(leaves)
1224         res = f"{first.prefix}{indent}{first.value}"
1225         for leaf in leaves:
1226             res += str(leaf)
1227         for _, comment in self.comments:
1228             res += str(comment)
1229         return res + "\n"
1230
1231     def __bool__(self) -> bool:
1232         """Return True if the line has leaves or comments."""
1233         return bool(self.leaves or self.comments)
1234
1235
1236 class UnformattedLines(Line):
1237     """Just like :class:`Line` but stores lines which aren't reformatted."""
1238
1239     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1240         """Just add a new `leaf` to the end of the lines.
1241
1242         The `preformatted` argument is ignored.
1243
1244         Keeps track of indentation `depth`, which is useful when the user
1245         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1246         """
1247         try:
1248             list(generate_comments(leaf))
1249         except FormatOn as f_on:
1250             self.leaves.append(f_on.leaf_from_consumed(leaf))
1251             raise
1252
1253         self.leaves.append(leaf)
1254         if leaf.type == token.INDENT:
1255             self.depth += 1
1256         elif leaf.type == token.DEDENT:
1257             self.depth -= 1
1258
1259     def __str__(self) -> str:
1260         """Render unformatted lines from leaves which were added with `append()`.
1261
1262         `depth` is not used for indentation in this case.
1263         """
1264         if not self:
1265             return "\n"
1266
1267         res = ""
1268         for leaf in self.leaves:
1269             res += str(leaf)
1270         return res
1271
1272     def append_comment(self, comment: Leaf) -> bool:
1273         """Not implemented in this class. Raises `NotImplementedError`."""
1274         raise NotImplementedError("Unformatted lines don't store comments separately.")
1275
1276     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1277         """Does nothing and returns False."""
1278         return False
1279
1280     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1281         """Does nothing and returns False."""
1282         return False
1283
1284
1285 @dataclass
1286 class EmptyLineTracker:
1287     """Provides a stateful method that returns the number of potential extra
1288     empty lines needed before and after the currently processed line.
1289
1290     Note: this tracker works on lines that haven't been split yet.  It assumes
1291     the prefix of the first leaf consists of optional newlines.  Those newlines
1292     are consumed by `maybe_empty_lines()` and included in the computation.
1293     """
1294
1295     is_pyi: bool = False
1296     previous_line: Optional[Line] = None
1297     previous_after: int = 0
1298     previous_defs: List[int] = Factory(list)
1299
1300     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1301         """Return the number of extra empty lines before and after the `current_line`.
1302
1303         This is for separating `def`, `async def` and `class` with extra empty
1304         lines (two on module-level).
1305         """
1306         if isinstance(current_line, UnformattedLines):
1307             return 0, 0
1308
1309         before, after = self._maybe_empty_lines(current_line)
1310         before -= self.previous_after
1311         self.previous_after = after
1312         self.previous_line = current_line
1313         return before, after
1314
1315     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1316         max_allowed = 1
1317         if current_line.depth == 0:
1318             max_allowed = 1 if self.is_pyi else 2
1319         if current_line.leaves:
1320             # Consume the first leaf's extra newlines.
1321             first_leaf = current_line.leaves[0]
1322             before = first_leaf.prefix.count("\n")
1323             before = min(before, max_allowed)
1324             first_leaf.prefix = ""
1325         else:
1326             before = 0
1327         depth = current_line.depth
1328         while self.previous_defs and self.previous_defs[-1] >= depth:
1329             self.previous_defs.pop()
1330             if self.is_pyi:
1331                 before = 0 if depth else 1
1332             else:
1333                 before = 1 if depth else 2
1334         is_decorator = current_line.is_decorator
1335         if is_decorator or current_line.is_def or current_line.is_class:
1336             if not is_decorator:
1337                 self.previous_defs.append(depth)
1338             if self.previous_line is None:
1339                 # Don't insert empty lines before the first line in the file.
1340                 return 0, 0
1341
1342             if self.previous_line.is_decorator:
1343                 return 0, 0
1344
1345             if self.previous_line.depth < current_line.depth and (
1346                 self.previous_line.is_class or self.previous_line.is_def
1347             ):
1348                 return 0, 0
1349
1350             if (
1351                 self.previous_line.is_comment
1352                 and self.previous_line.depth == current_line.depth
1353                 and before == 0
1354             ):
1355                 return 0, 0
1356
1357             if self.is_pyi:
1358                 if self.previous_line.depth > current_line.depth:
1359                     newlines = 1
1360                 elif current_line.is_class or self.previous_line.is_class:
1361                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1362                         newlines = 0
1363                     else:
1364                         newlines = 1
1365                 else:
1366                     newlines = 0
1367             else:
1368                 newlines = 2
1369             if current_line.depth and newlines:
1370                 newlines -= 1
1371             return newlines, 0
1372
1373         if (
1374             self.previous_line
1375             and self.previous_line.is_import
1376             and not current_line.is_import
1377             and depth == self.previous_line.depth
1378         ):
1379             return (before or 1), 0
1380
1381         if (
1382             self.previous_line
1383             and self.previous_line.is_class
1384             and current_line.is_triple_quoted_string
1385         ):
1386             return before, 1
1387
1388         return before, 0
1389
1390
1391 @dataclass
1392 class LineGenerator(Visitor[Line]):
1393     """Generates reformatted Line objects.  Empty lines are not emitted.
1394
1395     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1396     in ways that will no longer stringify to valid Python code on the tree.
1397     """
1398
1399     is_pyi: bool = False
1400     normalize_strings: bool = True
1401     current_line: Line = Factory(Line)
1402     remove_u_prefix: bool = False
1403
1404     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1405         """Generate a line.
1406
1407         If the line is empty, only emit if it makes sense.
1408         If the line is too long, split it first and then generate.
1409
1410         If any lines were generated, set up a new current_line.
1411         """
1412         if not self.current_line:
1413             if self.current_line.__class__ == type:
1414                 self.current_line.depth += indent
1415             else:
1416                 self.current_line = type(depth=self.current_line.depth + indent)
1417             return  # Line is empty, don't emit. Creating a new one unnecessary.
1418
1419         complete_line = self.current_line
1420         self.current_line = type(depth=complete_line.depth + indent)
1421         yield complete_line
1422
1423     def visit(self, node: LN) -> Iterator[Line]:
1424         """Main method to visit `node` and its children.
1425
1426         Yields :class:`Line` objects.
1427         """
1428         if isinstance(self.current_line, UnformattedLines):
1429             # File contained `# fmt: off`
1430             yield from self.visit_unformatted(node)
1431
1432         else:
1433             yield from super().visit(node)
1434
1435     def visit_default(self, node: LN) -> Iterator[Line]:
1436         """Default `visit_*()` implementation. Recurses to children of `node`."""
1437         if isinstance(node, Leaf):
1438             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1439             try:
1440                 for comment in generate_comments(node):
1441                     if any_open_brackets:
1442                         # any comment within brackets is subject to splitting
1443                         self.current_line.append(comment)
1444                     elif comment.type == token.COMMENT:
1445                         # regular trailing comment
1446                         self.current_line.append(comment)
1447                         yield from self.line()
1448
1449                     else:
1450                         # regular standalone comment
1451                         yield from self.line()
1452
1453                         self.current_line.append(comment)
1454                         yield from self.line()
1455
1456             except FormatOff as f_off:
1457                 f_off.trim_prefix(node)
1458                 yield from self.line(type=UnformattedLines)
1459                 yield from self.visit(node)
1460
1461             except FormatOn as f_on:
1462                 # This only happens here if somebody says "fmt: on" multiple
1463                 # times in a row.
1464                 f_on.trim_prefix(node)
1465                 yield from self.visit_default(node)
1466
1467             else:
1468                 normalize_prefix(node, inside_brackets=any_open_brackets)
1469                 if self.normalize_strings and node.type == token.STRING:
1470                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1471                     normalize_string_quotes(node)
1472                 if node.type not in WHITESPACE:
1473                     self.current_line.append(node)
1474         yield from super().visit_default(node)
1475
1476     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1477         """Increase indentation level, maybe yield a line."""
1478         # In blib2to3 INDENT never holds comments.
1479         yield from self.line(+1)
1480         yield from self.visit_default(node)
1481
1482     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1483         """Decrease indentation level, maybe yield a line."""
1484         # The current line might still wait for trailing comments.  At DEDENT time
1485         # there won't be any (they would be prefixes on the preceding NEWLINE).
1486         # Emit the line then.
1487         yield from self.line()
1488
1489         # While DEDENT has no value, its prefix may contain standalone comments
1490         # that belong to the current indentation level.  Get 'em.
1491         yield from self.visit_default(node)
1492
1493         # Finally, emit the dedent.
1494         yield from self.line(-1)
1495
1496     def visit_stmt(
1497         self, node: Node, keywords: Set[str], parens: Set[str]
1498     ) -> Iterator[Line]:
1499         """Visit a statement.
1500
1501         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1502         `def`, `with`, `class`, `assert` and assignments.
1503
1504         The relevant Python language `keywords` for a given statement will be
1505         NAME leaves within it. This methods puts those on a separate line.
1506
1507         `parens` holds a set of string leaf values immediately after which
1508         invisible parens should be put.
1509         """
1510         normalize_invisible_parens(node, parens_after=parens)
1511         for child in node.children:
1512             if child.type == token.NAME and child.value in keywords:  # type: ignore
1513                 yield from self.line()
1514
1515             yield from self.visit(child)
1516
1517     def visit_suite(self, node: Node) -> Iterator[Line]:
1518         """Visit a suite."""
1519         if self.is_pyi and is_stub_suite(node):
1520             yield from self.visit(node.children[2])
1521         else:
1522             yield from self.visit_default(node)
1523
1524     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1525         """Visit a statement without nested statements."""
1526         is_suite_like = node.parent and node.parent.type in STATEMENT
1527         if is_suite_like:
1528             if self.is_pyi and is_stub_body(node):
1529                 yield from self.visit_default(node)
1530             else:
1531                 yield from self.line(+1)
1532                 yield from self.visit_default(node)
1533                 yield from self.line(-1)
1534
1535         else:
1536             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1537                 yield from self.line()
1538             yield from self.visit_default(node)
1539
1540     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1541         """Visit `async def`, `async for`, `async with`."""
1542         yield from self.line()
1543
1544         children = iter(node.children)
1545         for child in children:
1546             yield from self.visit(child)
1547
1548             if child.type == token.ASYNC:
1549                 break
1550
1551         internal_stmt = next(children)
1552         for child in internal_stmt.children:
1553             yield from self.visit(child)
1554
1555     def visit_decorators(self, node: Node) -> Iterator[Line]:
1556         """Visit decorators."""
1557         for child in node.children:
1558             yield from self.line()
1559             yield from self.visit(child)
1560
1561     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1562         """Remove a semicolon and put the other statement on a separate line."""
1563         yield from self.line()
1564
1565     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1566         """End of file. Process outstanding comments and end with a newline."""
1567         yield from self.visit_default(leaf)
1568         yield from self.line()
1569
1570     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1571         """Used when file contained a `# fmt: off`."""
1572         if isinstance(node, Node):
1573             for child in node.children:
1574                 yield from self.visit(child)
1575
1576         else:
1577             try:
1578                 self.current_line.append(node)
1579             except FormatOn as f_on:
1580                 f_on.trim_prefix(node)
1581                 yield from self.line()
1582                 yield from self.visit(node)
1583
1584             if node.type == token.ENDMARKER:
1585                 # somebody decided not to put a final `# fmt: on`
1586                 yield from self.line()
1587
1588     def __attrs_post_init__(self) -> None:
1589         """You are in a twisty little maze of passages."""
1590         v = self.visit_stmt
1591         Ø: Set[str] = set()
1592         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1593         self.visit_if_stmt = partial(
1594             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1595         )
1596         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1597         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1598         self.visit_try_stmt = partial(
1599             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1600         )
1601         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1602         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1603         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1604         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1605         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1606         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1607         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1608         self.visit_async_funcdef = self.visit_async_stmt
1609         self.visit_decorated = self.visit_decorators
1610
1611
1612 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1613 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1614 OPENING_BRACKETS = set(BRACKET.keys())
1615 CLOSING_BRACKETS = set(BRACKET.values())
1616 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1617 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1618
1619
1620 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1621     """Return whitespace prefix if needed for the given `leaf`.
1622
1623     `complex_subscript` signals whether the given leaf is part of a subscription
1624     which has non-trivial arguments, like arithmetic expressions or function calls.
1625     """
1626     NO = ""
1627     SPACE = " "
1628     DOUBLESPACE = "  "
1629     t = leaf.type
1630     p = leaf.parent
1631     v = leaf.value
1632     if t in ALWAYS_NO_SPACE:
1633         return NO
1634
1635     if t == token.COMMENT:
1636         return DOUBLESPACE
1637
1638     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1639     if t == token.COLON and p.type not in {
1640         syms.subscript,
1641         syms.subscriptlist,
1642         syms.sliceop,
1643     }:
1644         return NO
1645
1646     prev = leaf.prev_sibling
1647     if not prev:
1648         prevp = preceding_leaf(p)
1649         if not prevp or prevp.type in OPENING_BRACKETS:
1650             return NO
1651
1652         if t == token.COLON:
1653             if prevp.type == token.COLON:
1654                 return NO
1655
1656             elif prevp.type != token.COMMA and not complex_subscript:
1657                 return NO
1658
1659             return SPACE
1660
1661         if prevp.type == token.EQUAL:
1662             if prevp.parent:
1663                 if prevp.parent.type in {
1664                     syms.arglist,
1665                     syms.argument,
1666                     syms.parameters,
1667                     syms.varargslist,
1668                 }:
1669                     return NO
1670
1671                 elif prevp.parent.type == syms.typedargslist:
1672                     # A bit hacky: if the equal sign has whitespace, it means we
1673                     # previously found it's a typed argument.  So, we're using
1674                     # that, too.
1675                     return prevp.prefix
1676
1677         elif prevp.type in STARS:
1678             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1679                 return NO
1680
1681         elif prevp.type == token.COLON:
1682             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1683                 return SPACE if complex_subscript else NO
1684
1685         elif (
1686             prevp.parent
1687             and prevp.parent.type == syms.factor
1688             and prevp.type in MATH_OPERATORS
1689         ):
1690             return NO
1691
1692         elif (
1693             prevp.type == token.RIGHTSHIFT
1694             and prevp.parent
1695             and prevp.parent.type == syms.shift_expr
1696             and prevp.prev_sibling
1697             and prevp.prev_sibling.type == token.NAME
1698             and prevp.prev_sibling.value == "print"  # type: ignore
1699         ):
1700             # Python 2 print chevron
1701             return NO
1702
1703     elif prev.type in OPENING_BRACKETS:
1704         return NO
1705
1706     if p.type in {syms.parameters, syms.arglist}:
1707         # untyped function signatures or calls
1708         if not prev or prev.type != token.COMMA:
1709             return NO
1710
1711     elif p.type == syms.varargslist:
1712         # lambdas
1713         if prev and prev.type != token.COMMA:
1714             return NO
1715
1716     elif p.type == syms.typedargslist:
1717         # typed function signatures
1718         if not prev:
1719             return NO
1720
1721         if t == token.EQUAL:
1722             if prev.type != syms.tname:
1723                 return NO
1724
1725         elif prev.type == token.EQUAL:
1726             # A bit hacky: if the equal sign has whitespace, it means we
1727             # previously found it's a typed argument.  So, we're using that, too.
1728             return prev.prefix
1729
1730         elif prev.type != token.COMMA:
1731             return NO
1732
1733     elif p.type == syms.tname:
1734         # type names
1735         if not prev:
1736             prevp = preceding_leaf(p)
1737             if not prevp or prevp.type != token.COMMA:
1738                 return NO
1739
1740     elif p.type == syms.trailer:
1741         # attributes and calls
1742         if t == token.LPAR or t == token.RPAR:
1743             return NO
1744
1745         if not prev:
1746             if t == token.DOT:
1747                 prevp = preceding_leaf(p)
1748                 if not prevp or prevp.type != token.NUMBER:
1749                     return NO
1750
1751             elif t == token.LSQB:
1752                 return NO
1753
1754         elif prev.type != token.COMMA:
1755             return NO
1756
1757     elif p.type == syms.argument:
1758         # single argument
1759         if t == token.EQUAL:
1760             return NO
1761
1762         if not prev:
1763             prevp = preceding_leaf(p)
1764             if not prevp or prevp.type == token.LPAR:
1765                 return NO
1766
1767         elif prev.type in {token.EQUAL} | STARS:
1768             return NO
1769
1770     elif p.type == syms.decorator:
1771         # decorators
1772         return NO
1773
1774     elif p.type == syms.dotted_name:
1775         if prev:
1776             return NO
1777
1778         prevp = preceding_leaf(p)
1779         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1780             return NO
1781
1782     elif p.type == syms.classdef:
1783         if t == token.LPAR:
1784             return NO
1785
1786         if prev and prev.type == token.LPAR:
1787             return NO
1788
1789     elif p.type in {syms.subscript, syms.sliceop}:
1790         # indexing
1791         if not prev:
1792             assert p.parent is not None, "subscripts are always parented"
1793             if p.parent.type == syms.subscriptlist:
1794                 return SPACE
1795
1796             return NO
1797
1798         elif not complex_subscript:
1799             return NO
1800
1801     elif p.type == syms.atom:
1802         if prev and t == token.DOT:
1803             # dots, but not the first one.
1804             return NO
1805
1806     elif p.type == syms.dictsetmaker:
1807         # dict unpacking
1808         if prev and prev.type == token.DOUBLESTAR:
1809             return NO
1810
1811     elif p.type in {syms.factor, syms.star_expr}:
1812         # unary ops
1813         if not prev:
1814             prevp = preceding_leaf(p)
1815             if not prevp or prevp.type in OPENING_BRACKETS:
1816                 return NO
1817
1818             prevp_parent = prevp.parent
1819             assert prevp_parent is not None
1820             if prevp.type == token.COLON and prevp_parent.type in {
1821                 syms.subscript,
1822                 syms.sliceop,
1823             }:
1824                 return NO
1825
1826             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1827                 return NO
1828
1829         elif t == token.NAME or t == token.NUMBER:
1830             return NO
1831
1832     elif p.type == syms.import_from:
1833         if t == token.DOT:
1834             if prev and prev.type == token.DOT:
1835                 return NO
1836
1837         elif t == token.NAME:
1838             if v == "import":
1839                 return SPACE
1840
1841             if prev and prev.type == token.DOT:
1842                 return NO
1843
1844     elif p.type == syms.sliceop:
1845         return NO
1846
1847     return SPACE
1848
1849
1850 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1851     """Return the first leaf that precedes `node`, if any."""
1852     while node:
1853         res = node.prev_sibling
1854         if res:
1855             if isinstance(res, Leaf):
1856                 return res
1857
1858             try:
1859                 return list(res.leaves())[-1]
1860
1861             except IndexError:
1862                 return None
1863
1864         node = node.parent
1865     return None
1866
1867
1868 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1869     """Return the child of `ancestor` that contains `descendant`."""
1870     node: Optional[LN] = descendant
1871     while node and node.parent != ancestor:
1872         node = node.parent
1873     return node
1874
1875
1876 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1877     """Return the priority of the `leaf` delimiter, given a line break after it.
1878
1879     The delimiter priorities returned here are from those delimiters that would
1880     cause a line break after themselves.
1881
1882     Higher numbers are higher priority.
1883     """
1884     if leaf.type == token.COMMA:
1885         return COMMA_PRIORITY
1886
1887     return 0
1888
1889
1890 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1891     """Return the priority of the `leaf` delimiter, given a line before after it.
1892
1893     The delimiter priorities returned here are from those delimiters that would
1894     cause a line break before themselves.
1895
1896     Higher numbers are higher priority.
1897     """
1898     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1899         # * and ** might also be MATH_OPERATORS but in this case they are not.
1900         # Don't treat them as a delimiter.
1901         return 0
1902
1903     if (
1904         leaf.type == token.DOT
1905         and leaf.parent
1906         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1907         and (previous is None or previous.type in CLOSING_BRACKETS)
1908     ):
1909         return DOT_PRIORITY
1910
1911     if (
1912         leaf.type in MATH_OPERATORS
1913         and leaf.parent
1914         and leaf.parent.type not in {syms.factor, syms.star_expr}
1915     ):
1916         return MATH_PRIORITIES[leaf.type]
1917
1918     if leaf.type in COMPARATORS:
1919         return COMPARATOR_PRIORITY
1920
1921     if (
1922         leaf.type == token.STRING
1923         and previous is not None
1924         and previous.type == token.STRING
1925     ):
1926         return STRING_PRIORITY
1927
1928     if leaf.type != token.NAME:
1929         return 0
1930
1931     if (
1932         leaf.value == "for"
1933         and leaf.parent
1934         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1935     ):
1936         return COMPREHENSION_PRIORITY
1937
1938     if (
1939         leaf.value == "if"
1940         and leaf.parent
1941         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1942     ):
1943         return COMPREHENSION_PRIORITY
1944
1945     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1946         return TERNARY_PRIORITY
1947
1948     if leaf.value == "is":
1949         return COMPARATOR_PRIORITY
1950
1951     if (
1952         leaf.value == "in"
1953         and leaf.parent
1954         and leaf.parent.type in {syms.comp_op, syms.comparison}
1955         and not (
1956             previous is not None
1957             and previous.type == token.NAME
1958             and previous.value == "not"
1959         )
1960     ):
1961         return COMPARATOR_PRIORITY
1962
1963     if (
1964         leaf.value == "not"
1965         and leaf.parent
1966         and leaf.parent.type == syms.comp_op
1967         and not (
1968             previous is not None
1969             and previous.type == token.NAME
1970             and previous.value == "is"
1971         )
1972     ):
1973         return COMPARATOR_PRIORITY
1974
1975     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1976         return LOGIC_PRIORITY
1977
1978     return 0
1979
1980
1981 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1982     """Clean the prefix of the `leaf` and generate comments from it, if any.
1983
1984     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1985     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1986     move because it does away with modifying the grammar to include all the
1987     possible places in which comments can be placed.
1988
1989     The sad consequence for us though is that comments don't "belong" anywhere.
1990     This is why this function generates simple parentless Leaf objects for
1991     comments.  We simply don't know what the correct parent should be.
1992
1993     No matter though, we can live without this.  We really only need to
1994     differentiate between inline and standalone comments.  The latter don't
1995     share the line with any code.
1996
1997     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1998     are emitted with a fake STANDALONE_COMMENT token identifier.
1999     """
2000     p = leaf.prefix
2001     if not p:
2002         return
2003
2004     if "#" not in p:
2005         return
2006
2007     consumed = 0
2008     nlines = 0
2009     for index, line in enumerate(p.split("\n")):
2010         consumed += len(line) + 1  # adding the length of the split '\n'
2011         line = line.lstrip()
2012         if not line:
2013             nlines += 1
2014         if not line.startswith("#"):
2015             continue
2016
2017         if index == 0 and leaf.type != token.ENDMARKER:
2018             comment_type = token.COMMENT  # simple trailing comment
2019         else:
2020             comment_type = STANDALONE_COMMENT
2021         comment = make_comment(line)
2022         yield Leaf(comment_type, comment, prefix="\n" * nlines)
2023
2024         if comment in {"# fmt: on", "# yapf: enable"}:
2025             raise FormatOn(consumed)
2026
2027         if comment in {"# fmt: off", "# yapf: disable"}:
2028             if comment_type == STANDALONE_COMMENT:
2029                 raise FormatOff(consumed)
2030
2031             prev = preceding_leaf(leaf)
2032             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
2033                 raise FormatOff(consumed)
2034
2035         nlines = 0
2036
2037
2038 def make_comment(content: str) -> str:
2039     """Return a consistently formatted comment from the given `content` string.
2040
2041     All comments (except for "##", "#!", "#:") should have a single space between
2042     the hash sign and the content.
2043
2044     If `content` didn't start with a hash sign, one is provided.
2045     """
2046     content = content.rstrip()
2047     if not content:
2048         return "#"
2049
2050     if content[0] == "#":
2051         content = content[1:]
2052     if content and content[0] not in " !:#":
2053         content = " " + content
2054     return "#" + content
2055
2056
2057 def split_line(
2058     line: Line, line_length: int, inner: bool = False, py36: bool = False
2059 ) -> Iterator[Line]:
2060     """Split a `line` into potentially many lines.
2061
2062     They should fit in the allotted `line_length` but might not be able to.
2063     `inner` signifies that there were a pair of brackets somewhere around the
2064     current `line`, possibly transitively. This means we can fallback to splitting
2065     by delimiters if the LHS/RHS don't yield any results.
2066
2067     If `py36` is True, splitting may generate syntax that is only compatible
2068     with Python 3.6 and later.
2069     """
2070     if isinstance(line, UnformattedLines) or line.is_comment:
2071         yield line
2072         return
2073
2074     line_str = str(line).strip("\n")
2075     if not line.should_explode and is_line_short_enough(
2076         line, line_length=line_length, line_str=line_str
2077     ):
2078         yield line
2079         return
2080
2081     split_funcs: List[SplitFunc]
2082     if line.is_def:
2083         split_funcs = [left_hand_split]
2084     else:
2085
2086         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2087             for omit in generate_trailers_to_omit(line, line_length):
2088                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2089                 if is_line_short_enough(lines[0], line_length=line_length):
2090                     yield from lines
2091                     return
2092
2093             # All splits failed, best effort split with no omits.
2094             # This mostly happens to multiline strings that are by definition
2095             # reported as not fitting a single line.
2096             yield from right_hand_split(line, py36)
2097
2098         if line.inside_brackets:
2099             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2100         else:
2101             split_funcs = [rhs]
2102     for split_func in split_funcs:
2103         # We are accumulating lines in `result` because we might want to abort
2104         # mission and return the original line in the end, or attempt a different
2105         # split altogether.
2106         result: List[Line] = []
2107         try:
2108             for l in split_func(line, py36):
2109                 if str(l).strip("\n") == line_str:
2110                     raise CannotSplit("Split function returned an unchanged result")
2111
2112                 result.extend(
2113                     split_line(l, line_length=line_length, inner=True, py36=py36)
2114                 )
2115         except CannotSplit as cs:
2116             continue
2117
2118         else:
2119             yield from result
2120             break
2121
2122     else:
2123         yield line
2124
2125
2126 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2127     """Split line into many lines, starting with the first matching bracket pair.
2128
2129     Note: this usually looks weird, only use this for function definitions.
2130     Prefer RHS otherwise.  This is why this function is not symmetrical with
2131     :func:`right_hand_split` which also handles optional parentheses.
2132     """
2133     head = Line(depth=line.depth)
2134     body = Line(depth=line.depth + 1, inside_brackets=True)
2135     tail = Line(depth=line.depth)
2136     tail_leaves: List[Leaf] = []
2137     body_leaves: List[Leaf] = []
2138     head_leaves: List[Leaf] = []
2139     current_leaves = head_leaves
2140     matching_bracket = None
2141     for leaf in line.leaves:
2142         if (
2143             current_leaves is body_leaves
2144             and leaf.type in CLOSING_BRACKETS
2145             and leaf.opening_bracket is matching_bracket
2146         ):
2147             current_leaves = tail_leaves if body_leaves else head_leaves
2148         current_leaves.append(leaf)
2149         if current_leaves is head_leaves:
2150             if leaf.type in OPENING_BRACKETS:
2151                 matching_bracket = leaf
2152                 current_leaves = body_leaves
2153     # Since body is a new indent level, remove spurious leading whitespace.
2154     if body_leaves:
2155         normalize_prefix(body_leaves[0], inside_brackets=True)
2156     # Build the new lines.
2157     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2158         for leaf in leaves:
2159             result.append(leaf, preformatted=True)
2160             for comment_after in line.comments_after(leaf):
2161                 result.append(comment_after, preformatted=True)
2162     bracket_split_succeeded_or_raise(head, body, tail)
2163     for result in (head, body, tail):
2164         if result:
2165             yield result
2166
2167
2168 def right_hand_split(
2169     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2170 ) -> Iterator[Line]:
2171     """Split line into many lines, starting with the last matching bracket pair.
2172
2173     If the split was by optional parentheses, attempt splitting without them, too.
2174     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2175     this split.
2176
2177     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2178     """
2179     head = Line(depth=line.depth)
2180     body = Line(depth=line.depth + 1, inside_brackets=True)
2181     tail = Line(depth=line.depth)
2182     tail_leaves: List[Leaf] = []
2183     body_leaves: List[Leaf] = []
2184     head_leaves: List[Leaf] = []
2185     current_leaves = tail_leaves
2186     opening_bracket = None
2187     closing_bracket = None
2188     for leaf in reversed(line.leaves):
2189         if current_leaves is body_leaves:
2190             if leaf is opening_bracket:
2191                 current_leaves = head_leaves if body_leaves else tail_leaves
2192         current_leaves.append(leaf)
2193         if current_leaves is tail_leaves:
2194             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2195                 opening_bracket = leaf.opening_bracket
2196                 closing_bracket = leaf
2197                 current_leaves = body_leaves
2198     tail_leaves.reverse()
2199     body_leaves.reverse()
2200     head_leaves.reverse()
2201     # Since body is a new indent level, remove spurious leading whitespace.
2202     if body_leaves:
2203         normalize_prefix(body_leaves[0], inside_brackets=True)
2204     if not head_leaves:
2205         # No `head` means the split failed. Either `tail` has all content or
2206         # the matching `opening_bracket` wasn't available on `line` anymore.
2207         raise CannotSplit("No brackets found")
2208
2209     # Build the new lines.
2210     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2211         for leaf in leaves:
2212             result.append(leaf, preformatted=True)
2213             for comment_after in line.comments_after(leaf):
2214                 result.append(comment_after, preformatted=True)
2215     bracket_split_succeeded_or_raise(head, body, tail)
2216     assert opening_bracket and closing_bracket
2217     if (
2218         # the opening bracket is an optional paren
2219         opening_bracket.type == token.LPAR
2220         and not opening_bracket.value
2221         # the closing bracket is an optional paren
2222         and closing_bracket.type == token.RPAR
2223         and not closing_bracket.value
2224         # there are no standalone comments in the body
2225         and not line.contains_standalone_comments(0)
2226         # and it's not an import (optional parens are the only thing we can split
2227         # on in this case; attempting a split without them is a waste of time)
2228         and not line.is_import
2229     ):
2230         omit = {id(closing_bracket), *omit}
2231         if can_omit_invisible_parens(body, line_length):
2232             try:
2233                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2234                 return
2235             except CannotSplit:
2236                 pass
2237
2238     ensure_visible(opening_bracket)
2239     ensure_visible(closing_bracket)
2240     body.should_explode = should_explode(body, opening_bracket)
2241     for result in (head, body, tail):
2242         if result:
2243             yield result
2244
2245
2246 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2247     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2248
2249     Do nothing otherwise.
2250
2251     A left- or right-hand split is based on a pair of brackets. Content before
2252     (and including) the opening bracket is left on one line, content inside the
2253     brackets is put on a separate line, and finally content starting with and
2254     following the closing bracket is put on a separate line.
2255
2256     Those are called `head`, `body`, and `tail`, respectively. If the split
2257     produced the same line (all content in `head`) or ended up with an empty `body`
2258     and the `tail` is just the closing bracket, then it's considered failed.
2259     """
2260     tail_len = len(str(tail).strip())
2261     if not body:
2262         if tail_len == 0:
2263             raise CannotSplit("Splitting brackets produced the same line")
2264
2265         elif tail_len < 3:
2266             raise CannotSplit(
2267                 f"Splitting brackets on an empty body to save "
2268                 f"{tail_len} characters is not worth it"
2269             )
2270
2271
2272 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2273     """Normalize prefix of the first leaf in every line returned by `split_func`.
2274
2275     This is a decorator over relevant split functions.
2276     """
2277
2278     @wraps(split_func)
2279     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2280         for l in split_func(line, py36):
2281             normalize_prefix(l.leaves[0], inside_brackets=True)
2282             yield l
2283
2284     return split_wrapper
2285
2286
2287 @dont_increase_indentation
2288 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2289     """Split according to delimiters of the highest priority.
2290
2291     If `py36` is True, the split will add trailing commas also in function
2292     signatures that contain `*` and `**`.
2293     """
2294     try:
2295         last_leaf = line.leaves[-1]
2296     except IndexError:
2297         raise CannotSplit("Line empty")
2298
2299     bt = line.bracket_tracker
2300     try:
2301         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2302     except ValueError:
2303         raise CannotSplit("No delimiters found")
2304
2305     if delimiter_priority == DOT_PRIORITY:
2306         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2307             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2308
2309     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2310     lowest_depth = sys.maxsize
2311     trailing_comma_safe = True
2312
2313     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2314         """Append `leaf` to current line or to new line if appending impossible."""
2315         nonlocal current_line
2316         try:
2317             current_line.append_safe(leaf, preformatted=True)
2318         except ValueError as ve:
2319             yield current_line
2320
2321             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2322             current_line.append(leaf)
2323
2324     for index, leaf in enumerate(line.leaves):
2325         yield from append_to_line(leaf)
2326
2327         for comment_after in line.comments_after(leaf, index):
2328             yield from append_to_line(comment_after)
2329
2330         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2331         if leaf.bracket_depth == lowest_depth and is_vararg(
2332             leaf, within=VARARGS_PARENTS
2333         ):
2334             trailing_comma_safe = trailing_comma_safe and py36
2335         leaf_priority = bt.delimiters.get(id(leaf))
2336         if leaf_priority == delimiter_priority:
2337             yield current_line
2338
2339             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2340     if current_line:
2341         if (
2342             trailing_comma_safe
2343             and delimiter_priority == COMMA_PRIORITY
2344             and current_line.leaves[-1].type != token.COMMA
2345             and current_line.leaves[-1].type != STANDALONE_COMMENT
2346         ):
2347             current_line.append(Leaf(token.COMMA, ","))
2348         yield current_line
2349
2350
2351 @dont_increase_indentation
2352 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2353     """Split standalone comments from the rest of the line."""
2354     if not line.contains_standalone_comments(0):
2355         raise CannotSplit("Line does not have any standalone comments")
2356
2357     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2358
2359     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2360         """Append `leaf` to current line or to new line if appending impossible."""
2361         nonlocal current_line
2362         try:
2363             current_line.append_safe(leaf, preformatted=True)
2364         except ValueError as ve:
2365             yield current_line
2366
2367             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2368             current_line.append(leaf)
2369
2370     for index, leaf in enumerate(line.leaves):
2371         yield from append_to_line(leaf)
2372
2373         for comment_after in line.comments_after(leaf, index):
2374             yield from append_to_line(comment_after)
2375
2376     if current_line:
2377         yield current_line
2378
2379
2380 def is_import(leaf: Leaf) -> bool:
2381     """Return True if the given leaf starts an import statement."""
2382     p = leaf.parent
2383     t = leaf.type
2384     v = leaf.value
2385     return bool(
2386         t == token.NAME
2387         and (
2388             (v == "import" and p and p.type == syms.import_name)
2389             or (v == "from" and p and p.type == syms.import_from)
2390         )
2391     )
2392
2393
2394 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2395     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2396     else.
2397
2398     Note: don't use backslashes for formatting or you'll lose your voting rights.
2399     """
2400     if not inside_brackets:
2401         spl = leaf.prefix.split("#")
2402         if "\\" not in spl[0]:
2403             nl_count = spl[-1].count("\n")
2404             if len(spl) > 1:
2405                 nl_count -= 1
2406             leaf.prefix = "\n" * nl_count
2407             return
2408
2409     leaf.prefix = ""
2410
2411
2412 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2413     """Make all string prefixes lowercase.
2414
2415     If remove_u_prefix is given, also removes any u prefix from the string.
2416
2417     Note: Mutates its argument.
2418     """
2419     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2420     assert match is not None, f"failed to match string {leaf.value!r}"
2421     orig_prefix = match.group(1)
2422     new_prefix = orig_prefix.lower()
2423     if remove_u_prefix:
2424         new_prefix = new_prefix.replace("u", "")
2425     leaf.value = f"{new_prefix}{match.group(2)}"
2426
2427
2428 def normalize_string_quotes(leaf: Leaf) -> None:
2429     """Prefer double quotes but only if it doesn't cause more escaping.
2430
2431     Adds or removes backslashes as appropriate. Doesn't parse and fix
2432     strings nested in f-strings (yet).
2433
2434     Note: Mutates its argument.
2435     """
2436     value = leaf.value.lstrip("furbFURB")
2437     if value[:3] == '"""':
2438         return
2439
2440     elif value[:3] == "'''":
2441         orig_quote = "'''"
2442         new_quote = '"""'
2443     elif value[0] == '"':
2444         orig_quote = '"'
2445         new_quote = "'"
2446     else:
2447         orig_quote = "'"
2448         new_quote = '"'
2449     first_quote_pos = leaf.value.find(orig_quote)
2450     if first_quote_pos == -1:
2451         return  # There's an internal error
2452
2453     prefix = leaf.value[:first_quote_pos]
2454     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2455     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2456     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2457     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2458     if "r" in prefix.casefold():
2459         if unescaped_new_quote.search(body):
2460             # There's at least one unescaped new_quote in this raw string
2461             # so converting is impossible
2462             return
2463
2464         # Do not introduce or remove backslashes in raw strings
2465         new_body = body
2466     else:
2467         # remove unnecessary quotes
2468         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2469         if body != new_body:
2470             # Consider the string without unnecessary quotes as the original
2471             body = new_body
2472             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2473         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2474         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2475     if new_quote == '"""' and new_body[-1] == '"':
2476         # edge case:
2477         new_body = new_body[:-1] + '\\"'
2478     orig_escape_count = body.count("\\")
2479     new_escape_count = new_body.count("\\")
2480     if new_escape_count > orig_escape_count:
2481         return  # Do not introduce more escaping
2482
2483     if new_escape_count == orig_escape_count and orig_quote == '"':
2484         return  # Prefer double quotes
2485
2486     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2487
2488
2489 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2490     """Make existing optional parentheses invisible or create new ones.
2491
2492     `parens_after` is a set of string leaf values immeditely after which parens
2493     should be put.
2494
2495     Standardizes on visible parentheses for single-element tuples, and keeps
2496     existing visible parentheses for other tuples and generator expressions.
2497     """
2498     try:
2499         list(generate_comments(node))
2500     except FormatOff:
2501         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2502
2503     check_lpar = False
2504     for index, child in enumerate(list(node.children)):
2505         if check_lpar:
2506             if child.type == syms.atom:
2507                 maybe_make_parens_invisible_in_atom(child)
2508             elif is_one_tuple(child):
2509                 # wrap child in visible parentheses
2510                 lpar = Leaf(token.LPAR, "(")
2511                 rpar = Leaf(token.RPAR, ")")
2512                 child.remove()
2513                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2514             elif node.type == syms.import_from:
2515                 # "import from" nodes store parentheses directly as part of
2516                 # the statement
2517                 if child.type == token.LPAR:
2518                     # make parentheses invisible
2519                     child.value = ""  # type: ignore
2520                     node.children[-1].value = ""  # type: ignore
2521                 elif child.type != token.STAR:
2522                     # insert invisible parentheses
2523                     node.insert_child(index, Leaf(token.LPAR, ""))
2524                     node.append_child(Leaf(token.RPAR, ""))
2525                 break
2526
2527             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2528                 # wrap child in invisible parentheses
2529                 lpar = Leaf(token.LPAR, "")
2530                 rpar = Leaf(token.RPAR, "")
2531                 index = child.remove() or 0
2532                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2533
2534         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2535
2536
2537 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2538     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2539     if (
2540         node.type != syms.atom
2541         or is_empty_tuple(node)
2542         or is_one_tuple(node)
2543         or is_yield(node)
2544         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2545     ):
2546         return False
2547
2548     first = node.children[0]
2549     last = node.children[-1]
2550     if first.type == token.LPAR and last.type == token.RPAR:
2551         # make parentheses invisible
2552         first.value = ""  # type: ignore
2553         last.value = ""  # type: ignore
2554         if len(node.children) > 1:
2555             maybe_make_parens_invisible_in_atom(node.children[1])
2556         return True
2557
2558     return False
2559
2560
2561 def is_empty_tuple(node: LN) -> bool:
2562     """Return True if `node` holds an empty tuple."""
2563     return (
2564         node.type == syms.atom
2565         and len(node.children) == 2
2566         and node.children[0].type == token.LPAR
2567         and node.children[1].type == token.RPAR
2568     )
2569
2570
2571 def is_one_tuple(node: LN) -> bool:
2572     """Return True if `node` holds a tuple with one element, with or without parens."""
2573     if node.type == syms.atom:
2574         if len(node.children) != 3:
2575             return False
2576
2577         lpar, gexp, rpar = node.children
2578         if not (
2579             lpar.type == token.LPAR
2580             and gexp.type == syms.testlist_gexp
2581             and rpar.type == token.RPAR
2582         ):
2583             return False
2584
2585         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2586
2587     return (
2588         node.type in IMPLICIT_TUPLE
2589         and len(node.children) == 2
2590         and node.children[1].type == token.COMMA
2591     )
2592
2593
2594 def is_yield(node: LN) -> bool:
2595     """Return True if `node` holds a `yield` or `yield from` expression."""
2596     if node.type == syms.yield_expr:
2597         return True
2598
2599     if node.type == token.NAME and node.value == "yield":  # type: ignore
2600         return True
2601
2602     if node.type != syms.atom:
2603         return False
2604
2605     if len(node.children) != 3:
2606         return False
2607
2608     lpar, expr, rpar = node.children
2609     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2610         return is_yield(expr)
2611
2612     return False
2613
2614
2615 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2616     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2617
2618     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2619     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2620     extended iterable unpacking (PEP 3132) and additional unpacking
2621     generalizations (PEP 448).
2622     """
2623     if leaf.type not in STARS or not leaf.parent:
2624         return False
2625
2626     p = leaf.parent
2627     if p.type == syms.star_expr:
2628         # Star expressions are also used as assignment targets in extended
2629         # iterable unpacking (PEP 3132).  See what its parent is instead.
2630         if not p.parent:
2631             return False
2632
2633         p = p.parent
2634
2635     return p.type in within
2636
2637
2638 def is_multiline_string(leaf: Leaf) -> bool:
2639     """Return True if `leaf` is a multiline string that actually spans many lines."""
2640     value = leaf.value.lstrip("furbFURB")
2641     return value[:3] in {'"""', "'''"} and "\n" in value
2642
2643
2644 def is_stub_suite(node: Node) -> bool:
2645     """Return True if `node` is a suite with a stub body."""
2646     if (
2647         len(node.children) != 4
2648         or node.children[0].type != token.NEWLINE
2649         or node.children[1].type != token.INDENT
2650         or node.children[3].type != token.DEDENT
2651     ):
2652         return False
2653
2654     return is_stub_body(node.children[2])
2655
2656
2657 def is_stub_body(node: LN) -> bool:
2658     """Return True if `node` is a simple statement containing an ellipsis."""
2659     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2660         return False
2661
2662     if len(node.children) != 2:
2663         return False
2664
2665     child = node.children[0]
2666     return (
2667         child.type == syms.atom
2668         and len(child.children) == 3
2669         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2670     )
2671
2672
2673 def max_delimiter_priority_in_atom(node: LN) -> int:
2674     """Return maximum delimiter priority inside `node`.
2675
2676     This is specific to atoms with contents contained in a pair of parentheses.
2677     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2678     """
2679     if node.type != syms.atom:
2680         return 0
2681
2682     first = node.children[0]
2683     last = node.children[-1]
2684     if not (first.type == token.LPAR and last.type == token.RPAR):
2685         return 0
2686
2687     bt = BracketTracker()
2688     for c in node.children[1:-1]:
2689         if isinstance(c, Leaf):
2690             bt.mark(c)
2691         else:
2692             for leaf in c.leaves():
2693                 bt.mark(leaf)
2694     try:
2695         return bt.max_delimiter_priority()
2696
2697     except ValueError:
2698         return 0
2699
2700
2701 def ensure_visible(leaf: Leaf) -> None:
2702     """Make sure parentheses are visible.
2703
2704     They could be invisible as part of some statements (see
2705     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2706     """
2707     if leaf.type == token.LPAR:
2708         leaf.value = "("
2709     elif leaf.type == token.RPAR:
2710         leaf.value = ")"
2711
2712
2713 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2714     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2715     if not (
2716         opening_bracket.parent
2717         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2718         and opening_bracket.value in "[{("
2719     ):
2720         return False
2721
2722     try:
2723         last_leaf = line.leaves[-1]
2724         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2725         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2726     except (IndexError, ValueError):
2727         return False
2728
2729     return max_priority == COMMA_PRIORITY
2730
2731
2732 def is_python36(node: Node) -> bool:
2733     """Return True if the current file is using Python 3.6+ features.
2734
2735     Currently looking for:
2736     - f-strings; and
2737     - trailing commas after * or ** in function signatures and calls.
2738     """
2739     for n in node.pre_order():
2740         if n.type == token.STRING:
2741             value_head = n.value[:2]  # type: ignore
2742             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2743                 return True
2744
2745         elif (
2746             n.type in {syms.typedargslist, syms.arglist}
2747             and n.children
2748             and n.children[-1].type == token.COMMA
2749         ):
2750             for ch in n.children:
2751                 if ch.type in STARS:
2752                     return True
2753
2754                 if ch.type == syms.argument:
2755                     for argch in ch.children:
2756                         if argch.type in STARS:
2757                             return True
2758
2759     return False
2760
2761
2762 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2763     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2764
2765     Brackets can be omitted if the entire trailer up to and including
2766     a preceding closing bracket fits in one line.
2767
2768     Yielded sets are cumulative (contain results of previous yields, too).  First
2769     set is empty.
2770     """
2771
2772     omit: Set[LeafID] = set()
2773     yield omit
2774
2775     length = 4 * line.depth
2776     opening_bracket = None
2777     closing_bracket = None
2778     optional_brackets: Set[LeafID] = set()
2779     inner_brackets: Set[LeafID] = set()
2780     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2781         length += leaf_length
2782         if length > line_length:
2783             break
2784
2785         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2786         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2787             break
2788
2789         optional_brackets.discard(id(leaf))
2790         if opening_bracket:
2791             if leaf is opening_bracket:
2792                 opening_bracket = None
2793             elif leaf.type in CLOSING_BRACKETS:
2794                 inner_brackets.add(id(leaf))
2795         elif leaf.type in CLOSING_BRACKETS:
2796             if not leaf.value:
2797                 optional_brackets.add(id(opening_bracket))
2798                 continue
2799
2800             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2801                 # Empty brackets would fail a split so treat them as "inner"
2802                 # brackets (e.g. only add them to the `omit` set if another
2803                 # pair of brackets was good enough.
2804                 inner_brackets.add(id(leaf))
2805                 continue
2806
2807             opening_bracket = leaf.opening_bracket
2808             if closing_bracket:
2809                 omit.add(id(closing_bracket))
2810                 omit.update(inner_brackets)
2811                 inner_brackets.clear()
2812                 yield omit
2813             closing_bracket = leaf
2814
2815
2816 def get_future_imports(node: Node) -> Set[str]:
2817     """Return a set of __future__ imports in the file."""
2818     imports = set()
2819     for child in node.children:
2820         if child.type != syms.simple_stmt:
2821             break
2822         first_child = child.children[0]
2823         if isinstance(first_child, Leaf):
2824             # Continue looking if we see a docstring; otherwise stop.
2825             if (
2826                 len(child.children) == 2
2827                 and first_child.type == token.STRING
2828                 and child.children[1].type == token.NEWLINE
2829             ):
2830                 continue
2831             else:
2832                 break
2833         elif first_child.type == syms.import_from:
2834             module_name = first_child.children[1]
2835             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2836                 break
2837             for import_from_child in first_child.children[3:]:
2838                 if isinstance(import_from_child, Leaf):
2839                     if import_from_child.type == token.NAME:
2840                         imports.add(import_from_child.value)
2841                 else:
2842                     assert import_from_child.type == syms.import_as_names
2843                     for leaf in import_from_child.children:
2844                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2845                             imports.add(leaf.value)
2846         else:
2847             break
2848     return imports
2849
2850
2851 def gen_python_files_in_dir(
2852     path: Path,
2853     root: Path,
2854     include: Pattern[str],
2855     exclude: Pattern[str],
2856     report: "Report",
2857 ) -> Iterator[Path]:
2858     """Generate all files under `path` whose paths are not excluded by the
2859     `exclude` regex, but are included by the `include` regex.
2860
2861     `report` is where output about exclusions goes.
2862     """
2863     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2864     for child in path.iterdir():
2865         normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2866         if child.is_dir():
2867             normalized_path += "/"
2868         exclude_match = exclude.search(normalized_path)
2869         if exclude_match and exclude_match.group(0):
2870             report.path_ignored(child, f"matches --exclude={exclude.pattern}")
2871             continue
2872
2873         if child.is_dir():
2874             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2875
2876         elif child.is_file():
2877             include_match = include.search(normalized_path)
2878             if include_match:
2879                 yield child
2880
2881
2882 def find_project_root(srcs: List[str]) -> Path:
2883     """Return a directory containing .git, .hg, or pyproject.toml.
2884
2885     That directory can be one of the directories passed in `srcs` or their
2886     common parent.
2887
2888     If no directory in the tree contains a marker that would specify it's the
2889     project root, the root of the file system is returned.
2890     """
2891     if not srcs:
2892         return Path("/").resolve()
2893
2894     common_base = min(Path(src).resolve() for src in srcs)
2895     if common_base.is_dir():
2896         # Append a fake file so `parents` below returns `common_base_dir`, too.
2897         common_base /= "fake-file"
2898     for directory in common_base.parents:
2899         if (directory / ".git").is_dir():
2900             return directory
2901
2902         if (directory / ".hg").is_dir():
2903             return directory
2904
2905         if (directory / "pyproject.toml").is_file():
2906             return directory
2907
2908     return directory
2909
2910
2911 @dataclass
2912 class Report:
2913     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2914
2915     check: bool = False
2916     quiet: bool = False
2917     verbose: bool = False
2918     change_count: int = 0
2919     same_count: int = 0
2920     failure_count: int = 0
2921
2922     def done(self, src: Path, changed: Changed) -> None:
2923         """Increment the counter for successful reformatting. Write out a message."""
2924         if changed is Changed.YES:
2925             reformatted = "would reformat" if self.check else "reformatted"
2926             if self.verbose or not self.quiet:
2927                 out(f"{reformatted} {src}")
2928             self.change_count += 1
2929         else:
2930             if self.verbose:
2931                 if changed is Changed.NO:
2932                     msg = f"{src} already well formatted, good job."
2933                 else:
2934                     msg = f"{src} wasn't modified on disk since last run."
2935                 out(msg, bold=False)
2936             self.same_count += 1
2937
2938     def failed(self, src: Path, message: str) -> None:
2939         """Increment the counter for failed reformatting. Write out a message."""
2940         err(f"error: cannot format {src}: {message}")
2941         self.failure_count += 1
2942
2943     def path_ignored(self, path: Path, message: str) -> None:
2944         if self.verbose:
2945             out(f"{path} ignored: {message}", bold=False)
2946
2947     @property
2948     def return_code(self) -> int:
2949         """Return the exit code that the app should use.
2950
2951         This considers the current state of changed files and failures:
2952         - if there were any failures, return 123;
2953         - if any files were changed and --check is being used, return 1;
2954         - otherwise return 0.
2955         """
2956         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2957         # 126 we have special returncodes reserved by the shell.
2958         if self.failure_count:
2959             return 123
2960
2961         elif self.change_count and self.check:
2962             return 1
2963
2964         return 0
2965
2966     def __str__(self) -> str:
2967         """Render a color report of the current state.
2968
2969         Use `click.unstyle` to remove colors.
2970         """
2971         if self.check:
2972             reformatted = "would be reformatted"
2973             unchanged = "would be left unchanged"
2974             failed = "would fail to reformat"
2975         else:
2976             reformatted = "reformatted"
2977             unchanged = "left unchanged"
2978             failed = "failed to reformat"
2979         report = []
2980         if self.change_count:
2981             s = "s" if self.change_count > 1 else ""
2982             report.append(
2983                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2984             )
2985         if self.same_count:
2986             s = "s" if self.same_count > 1 else ""
2987             report.append(f"{self.same_count} file{s} {unchanged}")
2988         if self.failure_count:
2989             s = "s" if self.failure_count > 1 else ""
2990             report.append(
2991                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2992             )
2993         return ", ".join(report) + "."
2994
2995
2996 def assert_equivalent(src: str, dst: str) -> None:
2997     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2998
2999     import ast
3000     import traceback
3001
3002     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3003         """Simple visitor generating strings to compare ASTs by content."""
3004         yield f"{'  ' * depth}{node.__class__.__name__}("
3005
3006         for field in sorted(node._fields):
3007             try:
3008                 value = getattr(node, field)
3009             except AttributeError:
3010                 continue
3011
3012             yield f"{'  ' * (depth+1)}{field}="
3013
3014             if isinstance(value, list):
3015                 for item in value:
3016                     if isinstance(item, ast.AST):
3017                         yield from _v(item, depth + 2)
3018
3019             elif isinstance(value, ast.AST):
3020                 yield from _v(value, depth + 2)
3021
3022             else:
3023                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3024
3025         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3026
3027     try:
3028         src_ast = ast.parse(src)
3029     except Exception as exc:
3030         major, minor = sys.version_info[:2]
3031         raise AssertionError(
3032             f"cannot use --safe with this file; failed to parse source file "
3033             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3034             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3035         )
3036
3037     try:
3038         dst_ast = ast.parse(dst)
3039     except Exception as exc:
3040         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3041         raise AssertionError(
3042             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3043             f"Please report a bug on https://github.com/ambv/black/issues.  "
3044             f"This invalid output might be helpful: {log}"
3045         ) from None
3046
3047     src_ast_str = "\n".join(_v(src_ast))
3048     dst_ast_str = "\n".join(_v(dst_ast))
3049     if src_ast_str != dst_ast_str:
3050         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3051         raise AssertionError(
3052             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3053             f"the source.  "
3054             f"Please report a bug on https://github.com/ambv/black/issues.  "
3055             f"This diff might be helpful: {log}"
3056         ) from None
3057
3058
3059 def assert_stable(
3060     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3061 ) -> None:
3062     """Raise AssertionError if `dst` reformats differently the second time."""
3063     newdst = format_str(dst, line_length=line_length, mode=mode)
3064     if dst != newdst:
3065         log = dump_to_file(
3066             diff(src, dst, "source", "first pass"),
3067             diff(dst, newdst, "first pass", "second pass"),
3068         )
3069         raise AssertionError(
3070             f"INTERNAL ERROR: Black produced different code on the second pass "
3071             f"of the formatter.  "
3072             f"Please report a bug on https://github.com/ambv/black/issues.  "
3073             f"This diff might be helpful: {log}"
3074         ) from None
3075
3076
3077 def dump_to_file(*output: str) -> str:
3078     """Dump `output` to a temporary file. Return path to the file."""
3079     import tempfile
3080
3081     with tempfile.NamedTemporaryFile(
3082         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3083     ) as f:
3084         for lines in output:
3085             f.write(lines)
3086             if lines and lines[-1] != "\n":
3087                 f.write("\n")
3088     return f.name
3089
3090
3091 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3092     """Return a unified diff string between strings `a` and `b`."""
3093     import difflib
3094
3095     a_lines = [line + "\n" for line in a.split("\n")]
3096     b_lines = [line + "\n" for line in b.split("\n")]
3097     return "".join(
3098         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3099     )
3100
3101
3102 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3103     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3104     err("Aborted!")
3105     for task in tasks:
3106         task.cancel()
3107
3108
3109 def shutdown(loop: BaseEventLoop) -> None:
3110     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3111     try:
3112         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3113         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3114         if not to_cancel:
3115             return
3116
3117         for task in to_cancel:
3118             task.cancel()
3119         loop.run_until_complete(
3120             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3121         )
3122     finally:
3123         # `concurrent.futures.Future` objects cannot be cancelled once they
3124         # are already running. There might be some when the `shutdown()` happened.
3125         # Silence their logger's spew about the event loop being closed.
3126         cf_logger = logging.getLogger("concurrent.futures")
3127         cf_logger.setLevel(logging.CRITICAL)
3128         loop.close()
3129
3130
3131 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3132     """Replace `regex` with `replacement` twice on `original`.
3133
3134     This is used by string normalization to perform replaces on
3135     overlapping matches.
3136     """
3137     return regex.sub(replacement, regex.sub(replacement, original))
3138
3139
3140 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3141     """Like `reversed(enumerate(sequence))` if that were possible."""
3142     index = len(sequence) - 1
3143     for element in reversed(sequence):
3144         yield (index, element)
3145         index -= 1
3146
3147
3148 def enumerate_with_length(
3149     line: Line, reversed: bool = False
3150 ) -> Iterator[Tuple[Index, Leaf, int]]:
3151     """Return an enumeration of leaves with their length.
3152
3153     Stops prematurely on multiline strings and standalone comments.
3154     """
3155     op = cast(
3156         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3157         enumerate_reversed if reversed else enumerate,
3158     )
3159     for index, leaf in op(line.leaves):
3160         length = len(leaf.prefix) + len(leaf.value)
3161         if "\n" in leaf.value:
3162             return  # Multiline strings, we can't continue.
3163
3164         comment: Optional[Leaf]
3165         for comment in line.comments_after(leaf, index):
3166             length += len(comment.value)
3167
3168         yield index, leaf, length
3169
3170
3171 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3172     """Return True if `line` is no longer than `line_length`.
3173
3174     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3175     """
3176     if not line_str:
3177         line_str = str(line).strip("\n")
3178     return (
3179         len(line_str) <= line_length
3180         and "\n" not in line_str  # multiline strings
3181         and not line.contains_standalone_comments()
3182     )
3183
3184
3185 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3186     """Does `line` have a shape safe to reformat without optional parens around it?
3187
3188     Returns True for only a subset of potentially nice looking formattings but
3189     the point is to not return false positives that end up producing lines that
3190     are too long.
3191     """
3192     bt = line.bracket_tracker
3193     if not bt.delimiters:
3194         # Without delimiters the optional parentheses are useless.
3195         return True
3196
3197     max_priority = bt.max_delimiter_priority()
3198     if bt.delimiter_count_with_priority(max_priority) > 1:
3199         # With more than one delimiter of a kind the optional parentheses read better.
3200         return False
3201
3202     if max_priority == DOT_PRIORITY:
3203         # A single stranded method call doesn't require optional parentheses.
3204         return True
3205
3206     assert len(line.leaves) >= 2, "Stranded delimiter"
3207
3208     first = line.leaves[0]
3209     second = line.leaves[1]
3210     penultimate = line.leaves[-2]
3211     last = line.leaves[-1]
3212
3213     # With a single delimiter, omit if the expression starts or ends with
3214     # a bracket.
3215     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3216         remainder = False
3217         length = 4 * line.depth
3218         for _index, leaf, leaf_length in enumerate_with_length(line):
3219             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3220                 remainder = True
3221             if remainder:
3222                 length += leaf_length
3223                 if length > line_length:
3224                     break
3225
3226                 if leaf.type in OPENING_BRACKETS:
3227                     # There are brackets we can further split on.
3228                     remainder = False
3229
3230         else:
3231             # checked the entire string and line length wasn't exceeded
3232             if len(line.leaves) == _index + 1:
3233                 return True
3234
3235         # Note: we are not returning False here because a line might have *both*
3236         # a leading opening bracket and a trailing closing bracket.  If the
3237         # opening bracket doesn't match our rule, maybe the closing will.
3238
3239     if (
3240         last.type == token.RPAR
3241         or last.type == token.RBRACE
3242         or (
3243             # don't use indexing for omitting optional parentheses;
3244             # it looks weird
3245             last.type == token.RSQB
3246             and last.parent
3247             and last.parent.type != syms.trailer
3248         )
3249     ):
3250         if penultimate.type in OPENING_BRACKETS:
3251             # Empty brackets don't help.
3252             return False
3253
3254         if is_multiline_string(first):
3255             # Additional wrapping of a multiline string in this situation is
3256             # unnecessary.
3257             return True
3258
3259         length = 4 * line.depth
3260         seen_other_brackets = False
3261         for _index, leaf, leaf_length in enumerate_with_length(line):
3262             length += leaf_length
3263             if leaf is last.opening_bracket:
3264                 if seen_other_brackets or length <= line_length:
3265                     return True
3266
3267             elif leaf.type in OPENING_BRACKETS:
3268                 # There are brackets we can further split on.
3269                 seen_other_brackets = True
3270
3271     return False
3272
3273
3274 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3275     pyi = bool(mode & FileMode.PYI)
3276     py36 = bool(mode & FileMode.PYTHON36)
3277     return (
3278         CACHE_DIR
3279         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3280     )
3281
3282
3283 def read_cache(line_length: int, mode: FileMode) -> Cache:
3284     """Read the cache if it exists and is well formed.
3285
3286     If it is not well formed, the call to write_cache later should resolve the issue.
3287     """
3288     cache_file = get_cache_file(line_length, mode)
3289     if not cache_file.exists():
3290         return {}
3291
3292     with cache_file.open("rb") as fobj:
3293         try:
3294             cache: Cache = pickle.load(fobj)
3295         except pickle.UnpicklingError:
3296             return {}
3297
3298     return cache
3299
3300
3301 def get_cache_info(path: Path) -> CacheInfo:
3302     """Return the information used to check if a file is already formatted or not."""
3303     stat = path.stat()
3304     return stat.st_mtime, stat.st_size
3305
3306
3307 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3308     """Split an iterable of paths in `sources` into two sets.
3309
3310     The first contains paths of files that modified on disk or are not in the
3311     cache. The other contains paths to non-modified files.
3312     """
3313     todo, done = set(), set()
3314     for src in sources:
3315         src = src.resolve()
3316         if cache.get(src) != get_cache_info(src):
3317             todo.add(src)
3318         else:
3319             done.add(src)
3320     return todo, done
3321
3322
3323 def write_cache(
3324     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3325 ) -> None:
3326     """Update the cache file."""
3327     cache_file = get_cache_file(line_length, mode)
3328     try:
3329         if not CACHE_DIR.exists():
3330             CACHE_DIR.mkdir(parents=True)
3331         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3332         with cache_file.open("wb") as fobj:
3333             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3334     except OSError:
3335         pass
3336
3337
3338 if __name__ == "__main__":
3339     main()