]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

36e49b026a53492ba635621353c2e37ffa2b9520
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import re
14 import tokenize
15 import signal
16 import sys
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Pattern,
28     Sequence,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34     cast,
35 )
36
37 from appdirs import user_cache_dir
38 from attr import dataclass, Factory
39 import click
40
41 # lib2to3 fork
42 from blib2to3.pytree import Node, Leaf, type_repr
43 from blib2to3 import pygram, pytree
44 from blib2to3.pgen2 import driver, token
45 from blib2to3.pgen2.parse import ParseError
46
47
48 __version__ = "18.5b1"
49 DEFAULT_LINE_LENGTH = 88
50 DEFAULT_EXCLUDES = (
51     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
52 )
53 DEFAULT_INCLUDES = r"\.pyi?$"
54 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
55
56
57 # types
58 FileContent = str
59 Encoding = str
60 Depth = int
61 NodeType = int
62 LeafID = int
63 Priority = int
64 Index = int
65 LN = Union[Leaf, Node]
66 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
67 Timestamp = float
68 FileSize = int
69 CacheInfo = Tuple[Timestamp, FileSize]
70 Cache = Dict[Path, CacheInfo]
71 out = partial(click.secho, bold=True, err=True)
72 err = partial(click.secho, fg="red", err=True)
73
74 pygram.initialize(CACHE_DIR)
75 syms = pygram.python_symbols
76
77
78 class NothingChanged(UserWarning):
79     """Raised by :func:`format_file` when reformatted code is the same as source."""
80
81
82 class CannotSplit(Exception):
83     """A readable split that fits the allotted line length is impossible.
84
85     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
86     :func:`delimiter_split`.
87     """
88
89
90 class FormatError(Exception):
91     """Base exception for `# fmt: on` and `# fmt: off` handling.
92
93     It holds the number of bytes of the prefix consumed before the format
94     control comment appeared.
95     """
96
97     def __init__(self, consumed: int) -> None:
98         super().__init__(consumed)
99         self.consumed = consumed
100
101     def trim_prefix(self, leaf: Leaf) -> None:
102         leaf.prefix = leaf.prefix[self.consumed :]
103
104     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
105         """Returns a new Leaf from the consumed part of the prefix."""
106         unformatted_prefix = leaf.prefix[: self.consumed]
107         return Leaf(token.NEWLINE, unformatted_prefix)
108
109
110 class FormatOn(FormatError):
111     """Found a comment like `# fmt: on` in the file."""
112
113
114 class FormatOff(FormatError):
115     """Found a comment like `# fmt: off` in the file."""
116
117
118 class WriteBack(Enum):
119     NO = 0
120     YES = 1
121     DIFF = 2
122
123     @classmethod
124     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
125         if check and not diff:
126             return cls.NO
127
128         return cls.DIFF if diff else cls.YES
129
130
131 class Changed(Enum):
132     NO = 0
133     CACHED = 1
134     YES = 2
135
136
137 class FileMode(Flag):
138     AUTO_DETECT = 0
139     PYTHON36 = 1
140     PYI = 2
141     NO_STRING_NORMALIZATION = 4
142
143     @classmethod
144     def from_configuration(
145         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
146     ) -> "FileMode":
147         mode = cls.AUTO_DETECT
148         if py36:
149             mode |= cls.PYTHON36
150         if pyi:
151             mode |= cls.PYI
152         if skip_string_normalization:
153             mode |= cls.NO_STRING_NORMALIZATION
154         return mode
155
156
157 @click.command()
158 @click.option(
159     "-l",
160     "--line-length",
161     type=int,
162     default=DEFAULT_LINE_LENGTH,
163     help="How many character per line to allow.",
164     show_default=True,
165 )
166 @click.option(
167     "--py36",
168     is_flag=True,
169     help=(
170         "Allow using Python 3.6-only syntax on all input files.  This will put "
171         "trailing commas in function signatures and calls also after *args and "
172         "**kwargs.  [default: per-file auto-detection]"
173     ),
174 )
175 @click.option(
176     "--pyi",
177     is_flag=True,
178     help=(
179         "Format all input files like typing stubs regardless of file extension "
180         "(useful when piping source on standard input)."
181     ),
182 )
183 @click.option(
184     "-S",
185     "--skip-string-normalization",
186     is_flag=True,
187     help="Don't normalize string quotes or prefixes.",
188 )
189 @click.option(
190     "--check",
191     is_flag=True,
192     help=(
193         "Don't write the files back, just return the status.  Return code 0 "
194         "means nothing would change.  Return code 1 means some files would be "
195         "reformatted.  Return code 123 means there was an internal error."
196     ),
197 )
198 @click.option(
199     "--diff",
200     is_flag=True,
201     help="Don't write the files back, just output a diff for each file on stdout.",
202 )
203 @click.option(
204     "--fast/--safe",
205     is_flag=True,
206     help="If --fast given, skip temporary sanity checks. [default: --safe]",
207 )
208 @click.option(
209     "--include",
210     type=str,
211     default=DEFAULT_INCLUDES,
212     help=(
213         "A regular expression that matches files and directories that should be "
214         "included on recursive searches.  An empty value means all files are "
215         "included regardless of the name.  Use forward slashes for directories on "
216         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
217         "later."
218     ),
219     show_default=True,
220 )
221 @click.option(
222     "--exclude",
223     type=str,
224     default=DEFAULT_EXCLUDES,
225     help=(
226         "A regular expression that matches files and directories that should be "
227         "excluded on recursive searches.  An empty value means no paths are excluded. "
228         "Use forward slashes for directories on all platforms (Windows, too).  "
229         "Exclusions are calculated first, inclusions later."
230     ),
231     show_default=True,
232 )
233 @click.option(
234     "-q",
235     "--quiet",
236     is_flag=True,
237     help=(
238         "Don't emit non-error messages to stderr. Errors are still emitted, "
239         "silence those with 2>/dev/null."
240     ),
241 )
242 @click.option(
243     "-v",
244     "--verbose",
245     is_flag=True,
246     help=(
247         "Also emit messages to stderr about files that were not changed or were "
248         "ignored due to --exclude=."
249     ),
250 )
251 @click.version_option(version=__version__)
252 @click.argument(
253     "src",
254     nargs=-1,
255     type=click.Path(
256         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
257     ),
258 )
259 @click.pass_context
260 def main(
261     ctx: click.Context,
262     line_length: int,
263     check: bool,
264     diff: bool,
265     fast: bool,
266     pyi: bool,
267     py36: bool,
268     skip_string_normalization: bool,
269     quiet: bool,
270     verbose: bool,
271     include: str,
272     exclude: str,
273     src: List[str],
274 ) -> None:
275     """The uncompromising code formatter."""
276     write_back = WriteBack.from_configuration(check=check, diff=diff)
277     mode = FileMode.from_configuration(
278         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
279     )
280     report = Report(check=check, quiet=quiet, verbose=verbose)
281     sources: List[Path] = []
282     try:
283         include_regex = re.compile(include)
284     except re.error:
285         err(f"Invalid regular expression for include given: {include!r}")
286         ctx.exit(2)
287     try:
288         exclude_regex = re.compile(exclude)
289     except re.error:
290         err(f"Invalid regular expression for exclude given: {exclude!r}")
291         ctx.exit(2)
292     root = find_project_root(src)
293     for s in src:
294         p = Path(s)
295         if p.is_dir():
296             sources.extend(
297                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
298             )
299         elif p.is_file() or s == "-":
300             # if a file was explicitly given, we don't care about its extension
301             sources.append(p)
302         else:
303             err(f"invalid path: {s}")
304     if len(sources) == 0:
305         out("No paths given. Nothing to do 😴")
306         ctx.exit(0)
307         return
308
309     elif len(sources) == 1:
310         reformat_one(
311             src=sources[0],
312             line_length=line_length,
313             fast=fast,
314             write_back=write_back,
315             mode=mode,
316             report=report,
317         )
318     else:
319         loop = asyncio.get_event_loop()
320         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
321         try:
322             loop.run_until_complete(
323                 schedule_formatting(
324                     sources=sources,
325                     line_length=line_length,
326                     fast=fast,
327                     write_back=write_back,
328                     mode=mode,
329                     report=report,
330                     loop=loop,
331                     executor=executor,
332                 )
333             )
334         finally:
335             shutdown(loop)
336         if not quiet:
337             out("All done! ✨ 🍰 ✨")
338             click.echo(str(report))
339     ctx.exit(report.return_code)
340
341
342 def reformat_one(
343     src: Path,
344     line_length: int,
345     fast: bool,
346     write_back: WriteBack,
347     mode: FileMode,
348     report: "Report",
349 ) -> None:
350     """Reformat a single file under `src` without spawning child processes.
351
352     If `quiet` is True, non-error messages are not output. `line_length`,
353     `write_back`, `fast` and `pyi` options are passed to
354     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
355     """
356     try:
357         changed = Changed.NO
358         if not src.is_file() and str(src) == "-":
359             if format_stdin_to_stdout(
360                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
361             ):
362                 changed = Changed.YES
363         else:
364             cache: Cache = {}
365             if write_back != WriteBack.DIFF:
366                 cache = read_cache(line_length, mode)
367                 res_src = src.resolve()
368                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
369                     changed = Changed.CACHED
370             if changed is not Changed.CACHED and format_file_in_place(
371                 src,
372                 line_length=line_length,
373                 fast=fast,
374                 write_back=write_back,
375                 mode=mode,
376             ):
377                 changed = Changed.YES
378             if write_back == WriteBack.YES and changed is not Changed.NO:
379                 write_cache(cache, [src], line_length, mode)
380         report.done(src, changed)
381     except Exception as exc:
382         report.failed(src, str(exc))
383
384
385 async def schedule_formatting(
386     sources: List[Path],
387     line_length: int,
388     fast: bool,
389     write_back: WriteBack,
390     mode: FileMode,
391     report: "Report",
392     loop: BaseEventLoop,
393     executor: Executor,
394 ) -> None:
395     """Run formatting of `sources` in parallel using the provided `executor`.
396
397     (Use ProcessPoolExecutors for actual parallelism.)
398
399     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
400     :func:`format_file_in_place`.
401     """
402     cache: Cache = {}
403     if write_back != WriteBack.DIFF:
404         cache = read_cache(line_length, mode)
405         sources, cached = filter_cached(cache, sources)
406         for src in cached:
407             report.done(src, Changed.CACHED)
408     cancelled = []
409     formatted = []
410     if sources:
411         lock = None
412         if write_back == WriteBack.DIFF:
413             # For diff output, we need locks to ensure we don't interleave output
414             # from different processes.
415             manager = Manager()
416             lock = manager.Lock()
417         tasks = {
418             loop.run_in_executor(
419                 executor,
420                 format_file_in_place,
421                 src,
422                 line_length,
423                 fast,
424                 write_back,
425                 mode,
426                 lock,
427             ): src
428             for src in sorted(sources)
429         }
430         pending: Iterable[asyncio.Task] = tasks.keys()
431         try:
432             loop.add_signal_handler(signal.SIGINT, cancel, pending)
433             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
434         except NotImplementedError:
435             # There are no good alternatives for these on Windows
436             pass
437         while pending:
438             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
439             for task in done:
440                 src = tasks.pop(task)
441                 if task.cancelled():
442                     cancelled.append(task)
443                 elif task.exception():
444                     report.failed(src, str(task.exception()))
445                 else:
446                     formatted.append(src)
447                     report.done(src, Changed.YES if task.result() else Changed.NO)
448     if cancelled:
449         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
450     if write_back == WriteBack.YES and formatted:
451         write_cache(cache, formatted, line_length, mode)
452
453
454 def format_file_in_place(
455     src: Path,
456     line_length: int,
457     fast: bool,
458     write_back: WriteBack = WriteBack.NO,
459     mode: FileMode = FileMode.AUTO_DETECT,
460     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
461 ) -> bool:
462     """Format file under `src` path. Return True if changed.
463
464     If `write_back` is True, write reformatted code back to stdout.
465     `line_length` and `fast` options are passed to :func:`format_file_contents`.
466     """
467     if src.suffix == ".pyi":
468         mode |= FileMode.PYI
469
470     with open(src, "rb") as buf:
471         newline, encoding, src_contents = prepare_input(buf.read())
472     try:
473         dst_contents = format_file_contents(
474             src_contents, line_length=line_length, fast=fast, mode=mode
475         )
476     except NothingChanged:
477         return False
478
479     if write_back == write_back.YES:
480         with open(src, "w", encoding=encoding, newline=newline) as f:
481             f.write(dst_contents)
482     elif write_back == write_back.DIFF:
483         src_name = f"{src}  (original)"
484         dst_name = f"{src}  (formatted)"
485         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
486         if lock:
487             lock.acquire()
488         try:
489             f = io.TextIOWrapper(
490                 sys.stdout.buffer,
491                 encoding=encoding,
492                 newline=newline,
493                 write_through=True,
494             )
495             f.write(diff_contents)
496             f.detach()
497         finally:
498             if lock:
499                 lock.release()
500     return True
501
502
503 def format_stdin_to_stdout(
504     line_length: int,
505     fast: bool,
506     write_back: WriteBack = WriteBack.NO,
507     mode: FileMode = FileMode.AUTO_DETECT,
508 ) -> bool:
509     """Format file on stdin. Return True if changed.
510
511     If `write_back` is True, write reformatted code back to stdout.
512     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
513     :func:`format_file_contents`.
514     """
515     newline, encoding, src = prepare_input(sys.stdin.buffer.read())
516     dst = src
517     try:
518         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
519         return True
520
521     except NothingChanged:
522         return False
523
524     finally:
525         if write_back == WriteBack.YES:
526             f = io.TextIOWrapper(
527                 sys.stdout.buffer,
528                 encoding=encoding,
529                 newline=newline,
530                 write_through=True,
531             )
532             f.write(dst)
533             f.detach()
534         elif write_back == WriteBack.DIFF:
535             src_name = "<stdin>  (original)"
536             dst_name = "<stdin>  (formatted)"
537             f = io.TextIOWrapper(
538                 sys.stdout.buffer,
539                 encoding=encoding,
540                 newline=newline,
541                 write_through=True,
542             )
543             f.write(diff(src, dst, src_name, dst_name))
544             f.detach()
545
546
547 def format_file_contents(
548     src_contents: str,
549     *,
550     line_length: int,
551     fast: bool,
552     mode: FileMode = FileMode.AUTO_DETECT,
553 ) -> FileContent:
554     """Reformat contents a file and return new contents.
555
556     If `fast` is False, additionally confirm that the reformatted code is
557     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
558     `line_length` is passed to :func:`format_str`.
559     """
560     if src_contents.strip() == "":
561         raise NothingChanged
562
563     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
564     if src_contents == dst_contents:
565         raise NothingChanged
566
567     if not fast:
568         assert_equivalent(src_contents, dst_contents)
569         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
570     return dst_contents
571
572
573 def format_str(
574     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
575 ) -> FileContent:
576     """Reformat a string and return new contents.
577
578     `line_length` determines how many characters per line are allowed.
579     """
580     src_node = lib2to3_parse(src_contents)
581     dst_contents = ""
582     future_imports = get_future_imports(src_node)
583     is_pyi = bool(mode & FileMode.PYI)
584     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
585     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
586     lines = LineGenerator(
587         remove_u_prefix=py36 or "unicode_literals" in future_imports,
588         is_pyi=is_pyi,
589         normalize_strings=normalize_strings,
590     )
591     elt = EmptyLineTracker(is_pyi=is_pyi)
592     empty_line = Line()
593     after = 0
594     for current_line in lines.visit(src_node):
595         for _ in range(after):
596             dst_contents += str(empty_line)
597         before, after = elt.maybe_empty_lines(current_line)
598         for _ in range(before):
599             dst_contents += str(empty_line)
600         for line in split_line(current_line, line_length=line_length, py36=py36):
601             dst_contents += str(line)
602     return dst_contents
603
604
605 def prepare_input(src: bytes) -> Tuple[str, str, str]:
606     """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
607
608     Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
609     universal newlines (i.e. only LF).
610     """
611     srcbuf = io.BytesIO(src)
612     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
613     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
614     srcbuf.seek(0)
615     return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
616
617
618 GRAMMARS = [
619     pygram.python_grammar_no_print_statement_no_exec_statement,
620     pygram.python_grammar_no_print_statement,
621     pygram.python_grammar,
622 ]
623
624
625 def lib2to3_parse(src_txt: str) -> Node:
626     """Given a string with source, return the lib2to3 Node."""
627     grammar = pygram.python_grammar_no_print_statement
628     if src_txt[-1] != "\n":
629         src_txt += "\n"
630     for grammar in GRAMMARS:
631         drv = driver.Driver(grammar, pytree.convert)
632         try:
633             result = drv.parse_string(src_txt, True)
634             break
635
636         except ParseError as pe:
637             lineno, column = pe.context[1]
638             lines = src_txt.splitlines()
639             try:
640                 faulty_line = lines[lineno - 1]
641             except IndexError:
642                 faulty_line = "<line number missing in source>"
643             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
644     else:
645         raise exc from None
646
647     if isinstance(result, Leaf):
648         result = Node(syms.file_input, [result])
649     return result
650
651
652 def lib2to3_unparse(node: Node) -> str:
653     """Given a lib2to3 node, return its string representation."""
654     code = str(node)
655     return code
656
657
658 T = TypeVar("T")
659
660
661 class Visitor(Generic[T]):
662     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
663
664     def visit(self, node: LN) -> Iterator[T]:
665         """Main method to visit `node` and its children.
666
667         It tries to find a `visit_*()` method for the given `node.type`, like
668         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
669         If no dedicated `visit_*()` method is found, chooses `visit_default()`
670         instead.
671
672         Then yields objects of type `T` from the selected visitor.
673         """
674         if node.type < 256:
675             name = token.tok_name[node.type]
676         else:
677             name = type_repr(node.type)
678         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
679
680     def visit_default(self, node: LN) -> Iterator[T]:
681         """Default `visit_*()` implementation. Recurses to children of `node`."""
682         if isinstance(node, Node):
683             for child in node.children:
684                 yield from self.visit(child)
685
686
687 @dataclass
688 class DebugVisitor(Visitor[T]):
689     tree_depth: int = 0
690
691     def visit_default(self, node: LN) -> Iterator[T]:
692         indent = " " * (2 * self.tree_depth)
693         if isinstance(node, Node):
694             _type = type_repr(node.type)
695             out(f"{indent}{_type}", fg="yellow")
696             self.tree_depth += 1
697             for child in node.children:
698                 yield from self.visit(child)
699
700             self.tree_depth -= 1
701             out(f"{indent}/{_type}", fg="yellow", bold=False)
702         else:
703             _type = token.tok_name.get(node.type, str(node.type))
704             out(f"{indent}{_type}", fg="blue", nl=False)
705             if node.prefix:
706                 # We don't have to handle prefixes for `Node` objects since
707                 # that delegates to the first child anyway.
708                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
709             out(f" {node.value!r}", fg="blue", bold=False)
710
711     @classmethod
712     def show(cls, code: str) -> None:
713         """Pretty-print the lib2to3 AST of a given string of `code`.
714
715         Convenience method for debugging.
716         """
717         v: DebugVisitor[None] = DebugVisitor()
718         list(v.visit(lib2to3_parse(code)))
719
720
721 KEYWORDS = set(keyword.kwlist)
722 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
723 FLOW_CONTROL = {"return", "raise", "break", "continue"}
724 STATEMENT = {
725     syms.if_stmt,
726     syms.while_stmt,
727     syms.for_stmt,
728     syms.try_stmt,
729     syms.except_clause,
730     syms.with_stmt,
731     syms.funcdef,
732     syms.classdef,
733 }
734 STANDALONE_COMMENT = 153
735 LOGIC_OPERATORS = {"and", "or"}
736 COMPARATORS = {
737     token.LESS,
738     token.GREATER,
739     token.EQEQUAL,
740     token.NOTEQUAL,
741     token.LESSEQUAL,
742     token.GREATEREQUAL,
743 }
744 MATH_OPERATORS = {
745     token.VBAR,
746     token.CIRCUMFLEX,
747     token.AMPER,
748     token.LEFTSHIFT,
749     token.RIGHTSHIFT,
750     token.PLUS,
751     token.MINUS,
752     token.STAR,
753     token.SLASH,
754     token.DOUBLESLASH,
755     token.PERCENT,
756     token.AT,
757     token.TILDE,
758     token.DOUBLESTAR,
759 }
760 STARS = {token.STAR, token.DOUBLESTAR}
761 VARARGS_PARENTS = {
762     syms.arglist,
763     syms.argument,  # double star in arglist
764     syms.trailer,  # single argument to call
765     syms.typedargslist,
766     syms.varargslist,  # lambdas
767 }
768 UNPACKING_PARENTS = {
769     syms.atom,  # single element of a list or set literal
770     syms.dictsetmaker,
771     syms.listmaker,
772     syms.testlist_gexp,
773 }
774 TEST_DESCENDANTS = {
775     syms.test,
776     syms.lambdef,
777     syms.or_test,
778     syms.and_test,
779     syms.not_test,
780     syms.comparison,
781     syms.star_expr,
782     syms.expr,
783     syms.xor_expr,
784     syms.and_expr,
785     syms.shift_expr,
786     syms.arith_expr,
787     syms.trailer,
788     syms.term,
789     syms.power,
790 }
791 ASSIGNMENTS = {
792     "=",
793     "+=",
794     "-=",
795     "*=",
796     "@=",
797     "/=",
798     "%=",
799     "&=",
800     "|=",
801     "^=",
802     "<<=",
803     ">>=",
804     "**=",
805     "//=",
806 }
807 COMPREHENSION_PRIORITY = 20
808 COMMA_PRIORITY = 18
809 TERNARY_PRIORITY = 16
810 LOGIC_PRIORITY = 14
811 STRING_PRIORITY = 12
812 COMPARATOR_PRIORITY = 10
813 MATH_PRIORITIES = {
814     token.VBAR: 9,
815     token.CIRCUMFLEX: 8,
816     token.AMPER: 7,
817     token.LEFTSHIFT: 6,
818     token.RIGHTSHIFT: 6,
819     token.PLUS: 5,
820     token.MINUS: 5,
821     token.STAR: 4,
822     token.SLASH: 4,
823     token.DOUBLESLASH: 4,
824     token.PERCENT: 4,
825     token.AT: 4,
826     token.TILDE: 3,
827     token.DOUBLESTAR: 2,
828 }
829 DOT_PRIORITY = 1
830
831
832 @dataclass
833 class BracketTracker:
834     """Keeps track of brackets on a line."""
835
836     depth: int = 0
837     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
838     delimiters: Dict[LeafID, Priority] = Factory(dict)
839     previous: Optional[Leaf] = None
840     _for_loop_variable: int = 0
841     _lambda_arguments: int = 0
842
843     def mark(self, leaf: Leaf) -> None:
844         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
845
846         All leaves receive an int `bracket_depth` field that stores how deep
847         within brackets a given leaf is. 0 means there are no enclosing brackets
848         that started on this line.
849
850         If a leaf is itself a closing bracket, it receives an `opening_bracket`
851         field that it forms a pair with. This is a one-directional link to
852         avoid reference cycles.
853
854         If a leaf is a delimiter (a token on which Black can split the line if
855         needed) and it's on depth 0, its `id()` is stored in the tracker's
856         `delimiters` field.
857         """
858         if leaf.type == token.COMMENT:
859             return
860
861         self.maybe_decrement_after_for_loop_variable(leaf)
862         self.maybe_decrement_after_lambda_arguments(leaf)
863         if leaf.type in CLOSING_BRACKETS:
864             self.depth -= 1
865             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
866             leaf.opening_bracket = opening_bracket
867         leaf.bracket_depth = self.depth
868         if self.depth == 0:
869             delim = is_split_before_delimiter(leaf, self.previous)
870             if delim and self.previous is not None:
871                 self.delimiters[id(self.previous)] = delim
872             else:
873                 delim = is_split_after_delimiter(leaf, self.previous)
874                 if delim:
875                     self.delimiters[id(leaf)] = delim
876         if leaf.type in OPENING_BRACKETS:
877             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
878             self.depth += 1
879         self.previous = leaf
880         self.maybe_increment_lambda_arguments(leaf)
881         self.maybe_increment_for_loop_variable(leaf)
882
883     def any_open_brackets(self) -> bool:
884         """Return True if there is an yet unmatched open bracket on the line."""
885         return bool(self.bracket_match)
886
887     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
888         """Return the highest priority of a delimiter found on the line.
889
890         Values are consistent with what `is_split_*_delimiter()` return.
891         Raises ValueError on no delimiters.
892         """
893         return max(v for k, v in self.delimiters.items() if k not in exclude)
894
895     def delimiter_count_with_priority(self, priority: int = 0) -> int:
896         """Return the number of delimiters with the given `priority`.
897
898         If no `priority` is passed, defaults to max priority on the line.
899         """
900         if not self.delimiters:
901             return 0
902
903         priority = priority or self.max_delimiter_priority()
904         return sum(1 for p in self.delimiters.values() if p == priority)
905
906     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
907         """In a for loop, or comprehension, the variables are often unpacks.
908
909         To avoid splitting on the comma in this situation, increase the depth of
910         tokens between `for` and `in`.
911         """
912         if leaf.type == token.NAME and leaf.value == "for":
913             self.depth += 1
914             self._for_loop_variable += 1
915             return True
916
917         return False
918
919     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
920         """See `maybe_increment_for_loop_variable` above for explanation."""
921         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
922             self.depth -= 1
923             self._for_loop_variable -= 1
924             return True
925
926         return False
927
928     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
929         """In a lambda expression, there might be more than one argument.
930
931         To avoid splitting on the comma in this situation, increase the depth of
932         tokens between `lambda` and `:`.
933         """
934         if leaf.type == token.NAME and leaf.value == "lambda":
935             self.depth += 1
936             self._lambda_arguments += 1
937             return True
938
939         return False
940
941     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
942         """See `maybe_increment_lambda_arguments` above for explanation."""
943         if self._lambda_arguments and leaf.type == token.COLON:
944             self.depth -= 1
945             self._lambda_arguments -= 1
946             return True
947
948         return False
949
950     def get_open_lsqb(self) -> Optional[Leaf]:
951         """Return the most recent opening square bracket (if any)."""
952         return self.bracket_match.get((self.depth - 1, token.RSQB))
953
954
955 @dataclass
956 class Line:
957     """Holds leaves and comments. Can be printed with `str(line)`."""
958
959     depth: int = 0
960     leaves: List[Leaf] = Factory(list)
961     comments: List[Tuple[Index, Leaf]] = Factory(list)
962     bracket_tracker: BracketTracker = Factory(BracketTracker)
963     inside_brackets: bool = False
964     should_explode: bool = False
965
966     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
967         """Add a new `leaf` to the end of the line.
968
969         Unless `preformatted` is True, the `leaf` will receive a new consistent
970         whitespace prefix and metadata applied by :class:`BracketTracker`.
971         Trailing commas are maybe removed, unpacked for loop variables are
972         demoted from being delimiters.
973
974         Inline comments are put aside.
975         """
976         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
977         if not has_value:
978             return
979
980         if token.COLON == leaf.type and self.is_class_paren_empty:
981             del self.leaves[-2:]
982         if self.leaves and not preformatted:
983             # Note: at this point leaf.prefix should be empty except for
984             # imports, for which we only preserve newlines.
985             leaf.prefix += whitespace(
986                 leaf, complex_subscript=self.is_complex_subscript(leaf)
987             )
988         if self.inside_brackets or not preformatted:
989             self.bracket_tracker.mark(leaf)
990             self.maybe_remove_trailing_comma(leaf)
991         if not self.append_comment(leaf):
992             self.leaves.append(leaf)
993
994     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
995         """Like :func:`append()` but disallow invalid standalone comment structure.
996
997         Raises ValueError when any `leaf` is appended after a standalone comment
998         or when a standalone comment is not the first leaf on the line.
999         """
1000         if self.bracket_tracker.depth == 0:
1001             if self.is_comment:
1002                 raise ValueError("cannot append to standalone comments")
1003
1004             if self.leaves and leaf.type == STANDALONE_COMMENT:
1005                 raise ValueError(
1006                     "cannot append standalone comments to a populated line"
1007                 )
1008
1009         self.append(leaf, preformatted=preformatted)
1010
1011     @property
1012     def is_comment(self) -> bool:
1013         """Is this line a standalone comment?"""
1014         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1015
1016     @property
1017     def is_decorator(self) -> bool:
1018         """Is this line a decorator?"""
1019         return bool(self) and self.leaves[0].type == token.AT
1020
1021     @property
1022     def is_import(self) -> bool:
1023         """Is this an import line?"""
1024         return bool(self) and is_import(self.leaves[0])
1025
1026     @property
1027     def is_class(self) -> bool:
1028         """Is this line a class definition?"""
1029         return (
1030             bool(self)
1031             and self.leaves[0].type == token.NAME
1032             and self.leaves[0].value == "class"
1033         )
1034
1035     @property
1036     def is_stub_class(self) -> bool:
1037         """Is this line a class definition with a body consisting only of "..."?"""
1038         return self.is_class and self.leaves[-3:] == [
1039             Leaf(token.DOT, ".") for _ in range(3)
1040         ]
1041
1042     @property
1043     def is_def(self) -> bool:
1044         """Is this a function definition? (Also returns True for async defs.)"""
1045         try:
1046             first_leaf = self.leaves[0]
1047         except IndexError:
1048             return False
1049
1050         try:
1051             second_leaf: Optional[Leaf] = self.leaves[1]
1052         except IndexError:
1053             second_leaf = None
1054         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1055             first_leaf.type == token.ASYNC
1056             and second_leaf is not None
1057             and second_leaf.type == token.NAME
1058             and second_leaf.value == "def"
1059         )
1060
1061     @property
1062     def is_class_paren_empty(self) -> bool:
1063         """Is this a class with no base classes but using parentheses?
1064
1065         Those are unnecessary and should be removed.
1066         """
1067         return (
1068             bool(self)
1069             and len(self.leaves) == 4
1070             and self.is_class
1071             and self.leaves[2].type == token.LPAR
1072             and self.leaves[2].value == "("
1073             and self.leaves[3].type == token.RPAR
1074             and self.leaves[3].value == ")"
1075         )
1076
1077     @property
1078     def is_triple_quoted_string(self) -> bool:
1079         """Is the line a triple quoted string?"""
1080         return (
1081             bool(self)
1082             and self.leaves[0].type == token.STRING
1083             and self.leaves[0].value.startswith(('"""', "'''"))
1084         )
1085
1086     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1087         """If so, needs to be split before emitting."""
1088         for leaf in self.leaves:
1089             if leaf.type == STANDALONE_COMMENT:
1090                 if leaf.bracket_depth <= depth_limit:
1091                     return True
1092
1093         return False
1094
1095     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1096         """Remove trailing comma if there is one and it's safe."""
1097         if not (
1098             self.leaves
1099             and self.leaves[-1].type == token.COMMA
1100             and closing.type in CLOSING_BRACKETS
1101         ):
1102             return False
1103
1104         if closing.type == token.RBRACE:
1105             self.remove_trailing_comma()
1106             return True
1107
1108         if closing.type == token.RSQB:
1109             comma = self.leaves[-1]
1110             if comma.parent and comma.parent.type == syms.listmaker:
1111                 self.remove_trailing_comma()
1112                 return True
1113
1114         # For parens let's check if it's safe to remove the comma.
1115         # Imports are always safe.
1116         if self.is_import:
1117             self.remove_trailing_comma()
1118             return True
1119
1120         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1121         # change a tuple into a different type by removing the comma.
1122         depth = closing.bracket_depth + 1
1123         commas = 0
1124         opening = closing.opening_bracket
1125         for _opening_index, leaf in enumerate(self.leaves):
1126             if leaf is opening:
1127                 break
1128
1129         else:
1130             return False
1131
1132         for leaf in self.leaves[_opening_index + 1 :]:
1133             if leaf is closing:
1134                 break
1135
1136             bracket_depth = leaf.bracket_depth
1137             if bracket_depth == depth and leaf.type == token.COMMA:
1138                 commas += 1
1139                 if leaf.parent and leaf.parent.type == syms.arglist:
1140                     commas += 1
1141                     break
1142
1143         if commas > 1:
1144             self.remove_trailing_comma()
1145             return True
1146
1147         return False
1148
1149     def append_comment(self, comment: Leaf) -> bool:
1150         """Add an inline or standalone comment to the line."""
1151         if (
1152             comment.type == STANDALONE_COMMENT
1153             and self.bracket_tracker.any_open_brackets()
1154         ):
1155             comment.prefix = ""
1156             return False
1157
1158         if comment.type != token.COMMENT:
1159             return False
1160
1161         after = len(self.leaves) - 1
1162         if after == -1:
1163             comment.type = STANDALONE_COMMENT
1164             comment.prefix = ""
1165             return False
1166
1167         else:
1168             self.comments.append((after, comment))
1169             return True
1170
1171     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1172         """Generate comments that should appear directly after `leaf`.
1173
1174         Provide a non-negative leaf `_index` to speed up the function.
1175         """
1176         if _index == -1:
1177             for _index, _leaf in enumerate(self.leaves):
1178                 if leaf is _leaf:
1179                     break
1180
1181             else:
1182                 return
1183
1184         for index, comment_after in self.comments:
1185             if _index == index:
1186                 yield comment_after
1187
1188     def remove_trailing_comma(self) -> None:
1189         """Remove the trailing comma and moves the comments attached to it."""
1190         comma_index = len(self.leaves) - 1
1191         for i in range(len(self.comments)):
1192             comment_index, comment = self.comments[i]
1193             if comment_index == comma_index:
1194                 self.comments[i] = (comma_index - 1, comment)
1195         self.leaves.pop()
1196
1197     def is_complex_subscript(self, leaf: Leaf) -> bool:
1198         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1199         open_lsqb = (
1200             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1201         )
1202         if open_lsqb is None:
1203             return False
1204
1205         subscript_start = open_lsqb.next_sibling
1206         if (
1207             isinstance(subscript_start, Node)
1208             and subscript_start.type == syms.subscriptlist
1209         ):
1210             subscript_start = child_towards(subscript_start, leaf)
1211         return subscript_start is not None and any(
1212             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1213         )
1214
1215     def __str__(self) -> str:
1216         """Render the line."""
1217         if not self:
1218             return "\n"
1219
1220         indent = "    " * self.depth
1221         leaves = iter(self.leaves)
1222         first = next(leaves)
1223         res = f"{first.prefix}{indent}{first.value}"
1224         for leaf in leaves:
1225             res += str(leaf)
1226         for _, comment in self.comments:
1227             res += str(comment)
1228         return res + "\n"
1229
1230     def __bool__(self) -> bool:
1231         """Return True if the line has leaves or comments."""
1232         return bool(self.leaves or self.comments)
1233
1234
1235 class UnformattedLines(Line):
1236     """Just like :class:`Line` but stores lines which aren't reformatted."""
1237
1238     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1239         """Just add a new `leaf` to the end of the lines.
1240
1241         The `preformatted` argument is ignored.
1242
1243         Keeps track of indentation `depth`, which is useful when the user
1244         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1245         """
1246         try:
1247             list(generate_comments(leaf))
1248         except FormatOn as f_on:
1249             self.leaves.append(f_on.leaf_from_consumed(leaf))
1250             raise
1251
1252         self.leaves.append(leaf)
1253         if leaf.type == token.INDENT:
1254             self.depth += 1
1255         elif leaf.type == token.DEDENT:
1256             self.depth -= 1
1257
1258     def __str__(self) -> str:
1259         """Render unformatted lines from leaves which were added with `append()`.
1260
1261         `depth` is not used for indentation in this case.
1262         """
1263         if not self:
1264             return "\n"
1265
1266         res = ""
1267         for leaf in self.leaves:
1268             res += str(leaf)
1269         return res
1270
1271     def append_comment(self, comment: Leaf) -> bool:
1272         """Not implemented in this class. Raises `NotImplementedError`."""
1273         raise NotImplementedError("Unformatted lines don't store comments separately.")
1274
1275     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1276         """Does nothing and returns False."""
1277         return False
1278
1279     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1280         """Does nothing and returns False."""
1281         return False
1282
1283
1284 @dataclass
1285 class EmptyLineTracker:
1286     """Provides a stateful method that returns the number of potential extra
1287     empty lines needed before and after the currently processed line.
1288
1289     Note: this tracker works on lines that haven't been split yet.  It assumes
1290     the prefix of the first leaf consists of optional newlines.  Those newlines
1291     are consumed by `maybe_empty_lines()` and included in the computation.
1292     """
1293
1294     is_pyi: bool = False
1295     previous_line: Optional[Line] = None
1296     previous_after: int = 0
1297     previous_defs: List[int] = Factory(list)
1298
1299     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1300         """Return the number of extra empty lines before and after the `current_line`.
1301
1302         This is for separating `def`, `async def` and `class` with extra empty
1303         lines (two on module-level).
1304         """
1305         if isinstance(current_line, UnformattedLines):
1306             return 0, 0
1307
1308         before, after = self._maybe_empty_lines(current_line)
1309         before -= self.previous_after
1310         self.previous_after = after
1311         self.previous_line = current_line
1312         return before, after
1313
1314     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1315         max_allowed = 1
1316         if current_line.depth == 0:
1317             max_allowed = 1 if self.is_pyi else 2
1318         if current_line.leaves:
1319             # Consume the first leaf's extra newlines.
1320             first_leaf = current_line.leaves[0]
1321             before = first_leaf.prefix.count("\n")
1322             before = min(before, max_allowed)
1323             first_leaf.prefix = ""
1324         else:
1325             before = 0
1326         depth = current_line.depth
1327         while self.previous_defs and self.previous_defs[-1] >= depth:
1328             self.previous_defs.pop()
1329             if self.is_pyi:
1330                 before = 0 if depth else 1
1331             else:
1332                 before = 1 if depth else 2
1333         is_decorator = current_line.is_decorator
1334         if is_decorator or current_line.is_def or current_line.is_class:
1335             if not is_decorator:
1336                 self.previous_defs.append(depth)
1337             if self.previous_line is None:
1338                 # Don't insert empty lines before the first line in the file.
1339                 return 0, 0
1340
1341             if self.previous_line.is_decorator:
1342                 return 0, 0
1343
1344             if self.previous_line.depth < current_line.depth and (
1345                 self.previous_line.is_class or self.previous_line.is_def
1346             ):
1347                 return 0, 0
1348
1349             if (
1350                 self.previous_line.is_comment
1351                 and self.previous_line.depth == current_line.depth
1352                 and before == 0
1353             ):
1354                 return 0, 0
1355
1356             if self.is_pyi:
1357                 if self.previous_line.depth > current_line.depth:
1358                     newlines = 1
1359                 elif current_line.is_class or self.previous_line.is_class:
1360                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1361                         newlines = 0
1362                     else:
1363                         newlines = 1
1364                 else:
1365                     newlines = 0
1366             else:
1367                 newlines = 2
1368             if current_line.depth and newlines:
1369                 newlines -= 1
1370             return newlines, 0
1371
1372         if (
1373             self.previous_line
1374             and self.previous_line.is_import
1375             and not current_line.is_import
1376             and depth == self.previous_line.depth
1377         ):
1378             return (before or 1), 0
1379
1380         if (
1381             self.previous_line
1382             and self.previous_line.is_class
1383             and current_line.is_triple_quoted_string
1384         ):
1385             return before, 1
1386
1387         return before, 0
1388
1389
1390 @dataclass
1391 class LineGenerator(Visitor[Line]):
1392     """Generates reformatted Line objects.  Empty lines are not emitted.
1393
1394     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1395     in ways that will no longer stringify to valid Python code on the tree.
1396     """
1397
1398     is_pyi: bool = False
1399     normalize_strings: bool = True
1400     current_line: Line = Factory(Line)
1401     remove_u_prefix: bool = False
1402
1403     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1404         """Generate a line.
1405
1406         If the line is empty, only emit if it makes sense.
1407         If the line is too long, split it first and then generate.
1408
1409         If any lines were generated, set up a new current_line.
1410         """
1411         if not self.current_line:
1412             if self.current_line.__class__ == type:
1413                 self.current_line.depth += indent
1414             else:
1415                 self.current_line = type(depth=self.current_line.depth + indent)
1416             return  # Line is empty, don't emit. Creating a new one unnecessary.
1417
1418         complete_line = self.current_line
1419         self.current_line = type(depth=complete_line.depth + indent)
1420         yield complete_line
1421
1422     def visit(self, node: LN) -> Iterator[Line]:
1423         """Main method to visit `node` and its children.
1424
1425         Yields :class:`Line` objects.
1426         """
1427         if isinstance(self.current_line, UnformattedLines):
1428             # File contained `# fmt: off`
1429             yield from self.visit_unformatted(node)
1430
1431         else:
1432             yield from super().visit(node)
1433
1434     def visit_default(self, node: LN) -> Iterator[Line]:
1435         """Default `visit_*()` implementation. Recurses to children of `node`."""
1436         if isinstance(node, Leaf):
1437             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1438             try:
1439                 for comment in generate_comments(node):
1440                     if any_open_brackets:
1441                         # any comment within brackets is subject to splitting
1442                         self.current_line.append(comment)
1443                     elif comment.type == token.COMMENT:
1444                         # regular trailing comment
1445                         self.current_line.append(comment)
1446                         yield from self.line()
1447
1448                     else:
1449                         # regular standalone comment
1450                         yield from self.line()
1451
1452                         self.current_line.append(comment)
1453                         yield from self.line()
1454
1455             except FormatOff as f_off:
1456                 f_off.trim_prefix(node)
1457                 yield from self.line(type=UnformattedLines)
1458                 yield from self.visit(node)
1459
1460             except FormatOn as f_on:
1461                 # This only happens here if somebody says "fmt: on" multiple
1462                 # times in a row.
1463                 f_on.trim_prefix(node)
1464                 yield from self.visit_default(node)
1465
1466             else:
1467                 normalize_prefix(node, inside_brackets=any_open_brackets)
1468                 if self.normalize_strings and node.type == token.STRING:
1469                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1470                     normalize_string_quotes(node)
1471                 if node.type not in WHITESPACE:
1472                     self.current_line.append(node)
1473         yield from super().visit_default(node)
1474
1475     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1476         """Increase indentation level, maybe yield a line."""
1477         # In blib2to3 INDENT never holds comments.
1478         yield from self.line(+1)
1479         yield from self.visit_default(node)
1480
1481     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1482         """Decrease indentation level, maybe yield a line."""
1483         # The current line might still wait for trailing comments.  At DEDENT time
1484         # there won't be any (they would be prefixes on the preceding NEWLINE).
1485         # Emit the line then.
1486         yield from self.line()
1487
1488         # While DEDENT has no value, its prefix may contain standalone comments
1489         # that belong to the current indentation level.  Get 'em.
1490         yield from self.visit_default(node)
1491
1492         # Finally, emit the dedent.
1493         yield from self.line(-1)
1494
1495     def visit_stmt(
1496         self, node: Node, keywords: Set[str], parens: Set[str]
1497     ) -> Iterator[Line]:
1498         """Visit a statement.
1499
1500         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1501         `def`, `with`, `class`, `assert` and assignments.
1502
1503         The relevant Python language `keywords` for a given statement will be
1504         NAME leaves within it. This methods puts those on a separate line.
1505
1506         `parens` holds a set of string leaf values immediately after which
1507         invisible parens should be put.
1508         """
1509         normalize_invisible_parens(node, parens_after=parens)
1510         for child in node.children:
1511             if child.type == token.NAME and child.value in keywords:  # type: ignore
1512                 yield from self.line()
1513
1514             yield from self.visit(child)
1515
1516     def visit_suite(self, node: Node) -> Iterator[Line]:
1517         """Visit a suite."""
1518         if self.is_pyi and is_stub_suite(node):
1519             yield from self.visit(node.children[2])
1520         else:
1521             yield from self.visit_default(node)
1522
1523     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1524         """Visit a statement without nested statements."""
1525         is_suite_like = node.parent and node.parent.type in STATEMENT
1526         if is_suite_like:
1527             if self.is_pyi and is_stub_body(node):
1528                 yield from self.visit_default(node)
1529             else:
1530                 yield from self.line(+1)
1531                 yield from self.visit_default(node)
1532                 yield from self.line(-1)
1533
1534         else:
1535             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1536                 yield from self.line()
1537             yield from self.visit_default(node)
1538
1539     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1540         """Visit `async def`, `async for`, `async with`."""
1541         yield from self.line()
1542
1543         children = iter(node.children)
1544         for child in children:
1545             yield from self.visit(child)
1546
1547             if child.type == token.ASYNC:
1548                 break
1549
1550         internal_stmt = next(children)
1551         for child in internal_stmt.children:
1552             yield from self.visit(child)
1553
1554     def visit_decorators(self, node: Node) -> Iterator[Line]:
1555         """Visit decorators."""
1556         for child in node.children:
1557             yield from self.line()
1558             yield from self.visit(child)
1559
1560     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1561         """Remove a semicolon and put the other statement on a separate line."""
1562         yield from self.line()
1563
1564     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1565         """End of file. Process outstanding comments and end with a newline."""
1566         yield from self.visit_default(leaf)
1567         yield from self.line()
1568
1569     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1570         """Used when file contained a `# fmt: off`."""
1571         if isinstance(node, Node):
1572             for child in node.children:
1573                 yield from self.visit(child)
1574
1575         else:
1576             try:
1577                 self.current_line.append(node)
1578             except FormatOn as f_on:
1579                 f_on.trim_prefix(node)
1580                 yield from self.line()
1581                 yield from self.visit(node)
1582
1583             if node.type == token.ENDMARKER:
1584                 # somebody decided not to put a final `# fmt: on`
1585                 yield from self.line()
1586
1587     def __attrs_post_init__(self) -> None:
1588         """You are in a twisty little maze of passages."""
1589         v = self.visit_stmt
1590         Ø: Set[str] = set()
1591         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1592         self.visit_if_stmt = partial(
1593             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1594         )
1595         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1596         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1597         self.visit_try_stmt = partial(
1598             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1599         )
1600         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1601         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1602         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1603         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1604         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1605         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1606         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1607         self.visit_async_funcdef = self.visit_async_stmt
1608         self.visit_decorated = self.visit_decorators
1609
1610
1611 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1612 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1613 OPENING_BRACKETS = set(BRACKET.keys())
1614 CLOSING_BRACKETS = set(BRACKET.values())
1615 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1616 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1617
1618
1619 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1620     """Return whitespace prefix if needed for the given `leaf`.
1621
1622     `complex_subscript` signals whether the given leaf is part of a subscription
1623     which has non-trivial arguments, like arithmetic expressions or function calls.
1624     """
1625     NO = ""
1626     SPACE = " "
1627     DOUBLESPACE = "  "
1628     t = leaf.type
1629     p = leaf.parent
1630     v = leaf.value
1631     if t in ALWAYS_NO_SPACE:
1632         return NO
1633
1634     if t == token.COMMENT:
1635         return DOUBLESPACE
1636
1637     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1638     if t == token.COLON and p.type not in {
1639         syms.subscript,
1640         syms.subscriptlist,
1641         syms.sliceop,
1642     }:
1643         return NO
1644
1645     prev = leaf.prev_sibling
1646     if not prev:
1647         prevp = preceding_leaf(p)
1648         if not prevp or prevp.type in OPENING_BRACKETS:
1649             return NO
1650
1651         if t == token.COLON:
1652             if prevp.type == token.COLON:
1653                 return NO
1654
1655             elif prevp.type != token.COMMA and not complex_subscript:
1656                 return NO
1657
1658             return SPACE
1659
1660         if prevp.type == token.EQUAL:
1661             if prevp.parent:
1662                 if prevp.parent.type in {
1663                     syms.arglist,
1664                     syms.argument,
1665                     syms.parameters,
1666                     syms.varargslist,
1667                 }:
1668                     return NO
1669
1670                 elif prevp.parent.type == syms.typedargslist:
1671                     # A bit hacky: if the equal sign has whitespace, it means we
1672                     # previously found it's a typed argument.  So, we're using
1673                     # that, too.
1674                     return prevp.prefix
1675
1676         elif prevp.type in STARS:
1677             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1678                 return NO
1679
1680         elif prevp.type == token.COLON:
1681             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1682                 return SPACE if complex_subscript else NO
1683
1684         elif (
1685             prevp.parent
1686             and prevp.parent.type == syms.factor
1687             and prevp.type in MATH_OPERATORS
1688         ):
1689             return NO
1690
1691         elif (
1692             prevp.type == token.RIGHTSHIFT
1693             and prevp.parent
1694             and prevp.parent.type == syms.shift_expr
1695             and prevp.prev_sibling
1696             and prevp.prev_sibling.type == token.NAME
1697             and prevp.prev_sibling.value == "print"  # type: ignore
1698         ):
1699             # Python 2 print chevron
1700             return NO
1701
1702     elif prev.type in OPENING_BRACKETS:
1703         return NO
1704
1705     if p.type in {syms.parameters, syms.arglist}:
1706         # untyped function signatures or calls
1707         if not prev or prev.type != token.COMMA:
1708             return NO
1709
1710     elif p.type == syms.varargslist:
1711         # lambdas
1712         if prev and prev.type != token.COMMA:
1713             return NO
1714
1715     elif p.type == syms.typedargslist:
1716         # typed function signatures
1717         if not prev:
1718             return NO
1719
1720         if t == token.EQUAL:
1721             if prev.type != syms.tname:
1722                 return NO
1723
1724         elif prev.type == token.EQUAL:
1725             # A bit hacky: if the equal sign has whitespace, it means we
1726             # previously found it's a typed argument.  So, we're using that, too.
1727             return prev.prefix
1728
1729         elif prev.type != token.COMMA:
1730             return NO
1731
1732     elif p.type == syms.tname:
1733         # type names
1734         if not prev:
1735             prevp = preceding_leaf(p)
1736             if not prevp or prevp.type != token.COMMA:
1737                 return NO
1738
1739     elif p.type == syms.trailer:
1740         # attributes and calls
1741         if t == token.LPAR or t == token.RPAR:
1742             return NO
1743
1744         if not prev:
1745             if t == token.DOT:
1746                 prevp = preceding_leaf(p)
1747                 if not prevp or prevp.type != token.NUMBER:
1748                     return NO
1749
1750             elif t == token.LSQB:
1751                 return NO
1752
1753         elif prev.type != token.COMMA:
1754             return NO
1755
1756     elif p.type == syms.argument:
1757         # single argument
1758         if t == token.EQUAL:
1759             return NO
1760
1761         if not prev:
1762             prevp = preceding_leaf(p)
1763             if not prevp or prevp.type == token.LPAR:
1764                 return NO
1765
1766         elif prev.type in {token.EQUAL} | STARS:
1767             return NO
1768
1769     elif p.type == syms.decorator:
1770         # decorators
1771         return NO
1772
1773     elif p.type == syms.dotted_name:
1774         if prev:
1775             return NO
1776
1777         prevp = preceding_leaf(p)
1778         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1779             return NO
1780
1781     elif p.type == syms.classdef:
1782         if t == token.LPAR:
1783             return NO
1784
1785         if prev and prev.type == token.LPAR:
1786             return NO
1787
1788     elif p.type in {syms.subscript, syms.sliceop}:
1789         # indexing
1790         if not prev:
1791             assert p.parent is not None, "subscripts are always parented"
1792             if p.parent.type == syms.subscriptlist:
1793                 return SPACE
1794
1795             return NO
1796
1797         elif not complex_subscript:
1798             return NO
1799
1800     elif p.type == syms.atom:
1801         if prev and t == token.DOT:
1802             # dots, but not the first one.
1803             return NO
1804
1805     elif p.type == syms.dictsetmaker:
1806         # dict unpacking
1807         if prev and prev.type == token.DOUBLESTAR:
1808             return NO
1809
1810     elif p.type in {syms.factor, syms.star_expr}:
1811         # unary ops
1812         if not prev:
1813             prevp = preceding_leaf(p)
1814             if not prevp or prevp.type in OPENING_BRACKETS:
1815                 return NO
1816
1817             prevp_parent = prevp.parent
1818             assert prevp_parent is not None
1819             if prevp.type == token.COLON and prevp_parent.type in {
1820                 syms.subscript,
1821                 syms.sliceop,
1822             }:
1823                 return NO
1824
1825             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1826                 return NO
1827
1828         elif t == token.NAME or t == token.NUMBER:
1829             return NO
1830
1831     elif p.type == syms.import_from:
1832         if t == token.DOT:
1833             if prev and prev.type == token.DOT:
1834                 return NO
1835
1836         elif t == token.NAME:
1837             if v == "import":
1838                 return SPACE
1839
1840             if prev and prev.type == token.DOT:
1841                 return NO
1842
1843     elif p.type == syms.sliceop:
1844         return NO
1845
1846     return SPACE
1847
1848
1849 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1850     """Return the first leaf that precedes `node`, if any."""
1851     while node:
1852         res = node.prev_sibling
1853         if res:
1854             if isinstance(res, Leaf):
1855                 return res
1856
1857             try:
1858                 return list(res.leaves())[-1]
1859
1860             except IndexError:
1861                 return None
1862
1863         node = node.parent
1864     return None
1865
1866
1867 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1868     """Return the child of `ancestor` that contains `descendant`."""
1869     node: Optional[LN] = descendant
1870     while node and node.parent != ancestor:
1871         node = node.parent
1872     return node
1873
1874
1875 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1876     """Return the priority of the `leaf` delimiter, given a line break after it.
1877
1878     The delimiter priorities returned here are from those delimiters that would
1879     cause a line break after themselves.
1880
1881     Higher numbers are higher priority.
1882     """
1883     if leaf.type == token.COMMA:
1884         return COMMA_PRIORITY
1885
1886     return 0
1887
1888
1889 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1890     """Return the priority of the `leaf` delimiter, given a line before after it.
1891
1892     The delimiter priorities returned here are from those delimiters that would
1893     cause a line break before themselves.
1894
1895     Higher numbers are higher priority.
1896     """
1897     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1898         # * and ** might also be MATH_OPERATORS but in this case they are not.
1899         # Don't treat them as a delimiter.
1900         return 0
1901
1902     if (
1903         leaf.type == token.DOT
1904         and leaf.parent
1905         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1906         and (previous is None or previous.type in CLOSING_BRACKETS)
1907     ):
1908         return DOT_PRIORITY
1909
1910     if (
1911         leaf.type in MATH_OPERATORS
1912         and leaf.parent
1913         and leaf.parent.type not in {syms.factor, syms.star_expr}
1914     ):
1915         return MATH_PRIORITIES[leaf.type]
1916
1917     if leaf.type in COMPARATORS:
1918         return COMPARATOR_PRIORITY
1919
1920     if (
1921         leaf.type == token.STRING
1922         and previous is not None
1923         and previous.type == token.STRING
1924     ):
1925         return STRING_PRIORITY
1926
1927     if leaf.type != token.NAME:
1928         return 0
1929
1930     if (
1931         leaf.value == "for"
1932         and leaf.parent
1933         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1934     ):
1935         return COMPREHENSION_PRIORITY
1936
1937     if (
1938         leaf.value == "if"
1939         and leaf.parent
1940         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1941     ):
1942         return COMPREHENSION_PRIORITY
1943
1944     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1945         return TERNARY_PRIORITY
1946
1947     if leaf.value == "is":
1948         return COMPARATOR_PRIORITY
1949
1950     if (
1951         leaf.value == "in"
1952         and leaf.parent
1953         and leaf.parent.type in {syms.comp_op, syms.comparison}
1954         and not (
1955             previous is not None
1956             and previous.type == token.NAME
1957             and previous.value == "not"
1958         )
1959     ):
1960         return COMPARATOR_PRIORITY
1961
1962     if (
1963         leaf.value == "not"
1964         and leaf.parent
1965         and leaf.parent.type == syms.comp_op
1966         and not (
1967             previous is not None
1968             and previous.type == token.NAME
1969             and previous.value == "is"
1970         )
1971     ):
1972         return COMPARATOR_PRIORITY
1973
1974     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1975         return LOGIC_PRIORITY
1976
1977     return 0
1978
1979
1980 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1981     """Clean the prefix of the `leaf` and generate comments from it, if any.
1982
1983     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1984     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1985     move because it does away with modifying the grammar to include all the
1986     possible places in which comments can be placed.
1987
1988     The sad consequence for us though is that comments don't "belong" anywhere.
1989     This is why this function generates simple parentless Leaf objects for
1990     comments.  We simply don't know what the correct parent should be.
1991
1992     No matter though, we can live without this.  We really only need to
1993     differentiate between inline and standalone comments.  The latter don't
1994     share the line with any code.
1995
1996     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1997     are emitted with a fake STANDALONE_COMMENT token identifier.
1998     """
1999     p = leaf.prefix
2000     if not p:
2001         return
2002
2003     if "#" not in p:
2004         return
2005
2006     consumed = 0
2007     nlines = 0
2008     for index, line in enumerate(p.split("\n")):
2009         consumed += len(line) + 1  # adding the length of the split '\n'
2010         line = line.lstrip()
2011         if not line:
2012             nlines += 1
2013         if not line.startswith("#"):
2014             continue
2015
2016         if index == 0 and leaf.type != token.ENDMARKER:
2017             comment_type = token.COMMENT  # simple trailing comment
2018         else:
2019             comment_type = STANDALONE_COMMENT
2020         comment = make_comment(line)
2021         yield Leaf(comment_type, comment, prefix="\n" * nlines)
2022
2023         if comment in {"# fmt: on", "# yapf: enable"}:
2024             raise FormatOn(consumed)
2025
2026         if comment in {"# fmt: off", "# yapf: disable"}:
2027             if comment_type == STANDALONE_COMMENT:
2028                 raise FormatOff(consumed)
2029
2030             prev = preceding_leaf(leaf)
2031             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
2032                 raise FormatOff(consumed)
2033
2034         nlines = 0
2035
2036
2037 def make_comment(content: str) -> str:
2038     """Return a consistently formatted comment from the given `content` string.
2039
2040     All comments (except for "##", "#!", "#:") should have a single space between
2041     the hash sign and the content.
2042
2043     If `content` didn't start with a hash sign, one is provided.
2044     """
2045     content = content.rstrip()
2046     if not content:
2047         return "#"
2048
2049     if content[0] == "#":
2050         content = content[1:]
2051     if content and content[0] not in " !:#":
2052         content = " " + content
2053     return "#" + content
2054
2055
2056 def split_line(
2057     line: Line, line_length: int, inner: bool = False, py36: bool = False
2058 ) -> Iterator[Line]:
2059     """Split a `line` into potentially many lines.
2060
2061     They should fit in the allotted `line_length` but might not be able to.
2062     `inner` signifies that there were a pair of brackets somewhere around the
2063     current `line`, possibly transitively. This means we can fallback to splitting
2064     by delimiters if the LHS/RHS don't yield any results.
2065
2066     If `py36` is True, splitting may generate syntax that is only compatible
2067     with Python 3.6 and later.
2068     """
2069     if isinstance(line, UnformattedLines) or line.is_comment:
2070         yield line
2071         return
2072
2073     line_str = str(line).strip("\n")
2074     if not line.should_explode and is_line_short_enough(
2075         line, line_length=line_length, line_str=line_str
2076     ):
2077         yield line
2078         return
2079
2080     split_funcs: List[SplitFunc]
2081     if line.is_def:
2082         split_funcs = [left_hand_split]
2083     else:
2084
2085         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2086             for omit in generate_trailers_to_omit(line, line_length):
2087                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2088                 if is_line_short_enough(lines[0], line_length=line_length):
2089                     yield from lines
2090                     return
2091
2092             # All splits failed, best effort split with no omits.
2093             # This mostly happens to multiline strings that are by definition
2094             # reported as not fitting a single line.
2095             yield from right_hand_split(line, py36)
2096
2097         if line.inside_brackets:
2098             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2099         else:
2100             split_funcs = [rhs]
2101     for split_func in split_funcs:
2102         # We are accumulating lines in `result` because we might want to abort
2103         # mission and return the original line in the end, or attempt a different
2104         # split altogether.
2105         result: List[Line] = []
2106         try:
2107             for l in split_func(line, py36):
2108                 if str(l).strip("\n") == line_str:
2109                     raise CannotSplit("Split function returned an unchanged result")
2110
2111                 result.extend(
2112                     split_line(l, line_length=line_length, inner=True, py36=py36)
2113                 )
2114         except CannotSplit as cs:
2115             continue
2116
2117         else:
2118             yield from result
2119             break
2120
2121     else:
2122         yield line
2123
2124
2125 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2126     """Split line into many lines, starting with the first matching bracket pair.
2127
2128     Note: this usually looks weird, only use this for function definitions.
2129     Prefer RHS otherwise.  This is why this function is not symmetrical with
2130     :func:`right_hand_split` which also handles optional parentheses.
2131     """
2132     head = Line(depth=line.depth)
2133     body = Line(depth=line.depth + 1, inside_brackets=True)
2134     tail = Line(depth=line.depth)
2135     tail_leaves: List[Leaf] = []
2136     body_leaves: List[Leaf] = []
2137     head_leaves: List[Leaf] = []
2138     current_leaves = head_leaves
2139     matching_bracket = None
2140     for leaf in line.leaves:
2141         if (
2142             current_leaves is body_leaves
2143             and leaf.type in CLOSING_BRACKETS
2144             and leaf.opening_bracket is matching_bracket
2145         ):
2146             current_leaves = tail_leaves if body_leaves else head_leaves
2147         current_leaves.append(leaf)
2148         if current_leaves is head_leaves:
2149             if leaf.type in OPENING_BRACKETS:
2150                 matching_bracket = leaf
2151                 current_leaves = body_leaves
2152     # Since body is a new indent level, remove spurious leading whitespace.
2153     if body_leaves:
2154         normalize_prefix(body_leaves[0], inside_brackets=True)
2155     # Build the new lines.
2156     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2157         for leaf in leaves:
2158             result.append(leaf, preformatted=True)
2159             for comment_after in line.comments_after(leaf):
2160                 result.append(comment_after, preformatted=True)
2161     bracket_split_succeeded_or_raise(head, body, tail)
2162     for result in (head, body, tail):
2163         if result:
2164             yield result
2165
2166
2167 def right_hand_split(
2168     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2169 ) -> Iterator[Line]:
2170     """Split line into many lines, starting with the last matching bracket pair.
2171
2172     If the split was by optional parentheses, attempt splitting without them, too.
2173     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2174     this split.
2175
2176     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2177     """
2178     head = Line(depth=line.depth)
2179     body = Line(depth=line.depth + 1, inside_brackets=True)
2180     tail = Line(depth=line.depth)
2181     tail_leaves: List[Leaf] = []
2182     body_leaves: List[Leaf] = []
2183     head_leaves: List[Leaf] = []
2184     current_leaves = tail_leaves
2185     opening_bracket = None
2186     closing_bracket = None
2187     for leaf in reversed(line.leaves):
2188         if current_leaves is body_leaves:
2189             if leaf is opening_bracket:
2190                 current_leaves = head_leaves if body_leaves else tail_leaves
2191         current_leaves.append(leaf)
2192         if current_leaves is tail_leaves:
2193             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2194                 opening_bracket = leaf.opening_bracket
2195                 closing_bracket = leaf
2196                 current_leaves = body_leaves
2197     tail_leaves.reverse()
2198     body_leaves.reverse()
2199     head_leaves.reverse()
2200     # Since body is a new indent level, remove spurious leading whitespace.
2201     if body_leaves:
2202         normalize_prefix(body_leaves[0], inside_brackets=True)
2203     if not head_leaves:
2204         # No `head` means the split failed. Either `tail` has all content or
2205         # the matching `opening_bracket` wasn't available on `line` anymore.
2206         raise CannotSplit("No brackets found")
2207
2208     # Build the new lines.
2209     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2210         for leaf in leaves:
2211             result.append(leaf, preformatted=True)
2212             for comment_after in line.comments_after(leaf):
2213                 result.append(comment_after, preformatted=True)
2214     bracket_split_succeeded_or_raise(head, body, tail)
2215     assert opening_bracket and closing_bracket
2216     if (
2217         # the opening bracket is an optional paren
2218         opening_bracket.type == token.LPAR
2219         and not opening_bracket.value
2220         # the closing bracket is an optional paren
2221         and closing_bracket.type == token.RPAR
2222         and not closing_bracket.value
2223         # there are no standalone comments in the body
2224         and not line.contains_standalone_comments(0)
2225         # and it's not an import (optional parens are the only thing we can split
2226         # on in this case; attempting a split without them is a waste of time)
2227         and not line.is_import
2228     ):
2229         omit = {id(closing_bracket), *omit}
2230         if can_omit_invisible_parens(body, line_length):
2231             try:
2232                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2233                 return
2234             except CannotSplit:
2235                 pass
2236
2237     ensure_visible(opening_bracket)
2238     ensure_visible(closing_bracket)
2239     body.should_explode = should_explode(body, opening_bracket)
2240     for result in (head, body, tail):
2241         if result:
2242             yield result
2243
2244
2245 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2246     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2247
2248     Do nothing otherwise.
2249
2250     A left- or right-hand split is based on a pair of brackets. Content before
2251     (and including) the opening bracket is left on one line, content inside the
2252     brackets is put on a separate line, and finally content starting with and
2253     following the closing bracket is put on a separate line.
2254
2255     Those are called `head`, `body`, and `tail`, respectively. If the split
2256     produced the same line (all content in `head`) or ended up with an empty `body`
2257     and the `tail` is just the closing bracket, then it's considered failed.
2258     """
2259     tail_len = len(str(tail).strip())
2260     if not body:
2261         if tail_len == 0:
2262             raise CannotSplit("Splitting brackets produced the same line")
2263
2264         elif tail_len < 3:
2265             raise CannotSplit(
2266                 f"Splitting brackets on an empty body to save "
2267                 f"{tail_len} characters is not worth it"
2268             )
2269
2270
2271 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2272     """Normalize prefix of the first leaf in every line returned by `split_func`.
2273
2274     This is a decorator over relevant split functions.
2275     """
2276
2277     @wraps(split_func)
2278     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2279         for l in split_func(line, py36):
2280             normalize_prefix(l.leaves[0], inside_brackets=True)
2281             yield l
2282
2283     return split_wrapper
2284
2285
2286 @dont_increase_indentation
2287 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2288     """Split according to delimiters of the highest priority.
2289
2290     If `py36` is True, the split will add trailing commas also in function
2291     signatures that contain `*` and `**`.
2292     """
2293     try:
2294         last_leaf = line.leaves[-1]
2295     except IndexError:
2296         raise CannotSplit("Line empty")
2297
2298     bt = line.bracket_tracker
2299     try:
2300         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2301     except ValueError:
2302         raise CannotSplit("No delimiters found")
2303
2304     if delimiter_priority == DOT_PRIORITY:
2305         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2306             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2307
2308     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2309     lowest_depth = sys.maxsize
2310     trailing_comma_safe = True
2311
2312     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2313         """Append `leaf` to current line or to new line if appending impossible."""
2314         nonlocal current_line
2315         try:
2316             current_line.append_safe(leaf, preformatted=True)
2317         except ValueError as ve:
2318             yield current_line
2319
2320             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2321             current_line.append(leaf)
2322
2323     for index, leaf in enumerate(line.leaves):
2324         yield from append_to_line(leaf)
2325
2326         for comment_after in line.comments_after(leaf, index):
2327             yield from append_to_line(comment_after)
2328
2329         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2330         if leaf.bracket_depth == lowest_depth and is_vararg(
2331             leaf, within=VARARGS_PARENTS
2332         ):
2333             trailing_comma_safe = trailing_comma_safe and py36
2334         leaf_priority = bt.delimiters.get(id(leaf))
2335         if leaf_priority == delimiter_priority:
2336             yield current_line
2337
2338             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2339     if current_line:
2340         if (
2341             trailing_comma_safe
2342             and delimiter_priority == COMMA_PRIORITY
2343             and current_line.leaves[-1].type != token.COMMA
2344             and current_line.leaves[-1].type != STANDALONE_COMMENT
2345         ):
2346             current_line.append(Leaf(token.COMMA, ","))
2347         yield current_line
2348
2349
2350 @dont_increase_indentation
2351 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2352     """Split standalone comments from the rest of the line."""
2353     if not line.contains_standalone_comments(0):
2354         raise CannotSplit("Line does not have any standalone comments")
2355
2356     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2357
2358     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2359         """Append `leaf` to current line or to new line if appending impossible."""
2360         nonlocal current_line
2361         try:
2362             current_line.append_safe(leaf, preformatted=True)
2363         except ValueError as ve:
2364             yield current_line
2365
2366             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2367             current_line.append(leaf)
2368
2369     for index, leaf in enumerate(line.leaves):
2370         yield from append_to_line(leaf)
2371
2372         for comment_after in line.comments_after(leaf, index):
2373             yield from append_to_line(comment_after)
2374
2375     if current_line:
2376         yield current_line
2377
2378
2379 def is_import(leaf: Leaf) -> bool:
2380     """Return True if the given leaf starts an import statement."""
2381     p = leaf.parent
2382     t = leaf.type
2383     v = leaf.value
2384     return bool(
2385         t == token.NAME
2386         and (
2387             (v == "import" and p and p.type == syms.import_name)
2388             or (v == "from" and p and p.type == syms.import_from)
2389         )
2390     )
2391
2392
2393 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2394     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2395     else.
2396
2397     Note: don't use backslashes for formatting or you'll lose your voting rights.
2398     """
2399     if not inside_brackets:
2400         spl = leaf.prefix.split("#")
2401         if "\\" not in spl[0]:
2402             nl_count = spl[-1].count("\n")
2403             if len(spl) > 1:
2404                 nl_count -= 1
2405             leaf.prefix = "\n" * nl_count
2406             return
2407
2408     leaf.prefix = ""
2409
2410
2411 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2412     """Make all string prefixes lowercase.
2413
2414     If remove_u_prefix is given, also removes any u prefix from the string.
2415
2416     Note: Mutates its argument.
2417     """
2418     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2419     assert match is not None, f"failed to match string {leaf.value!r}"
2420     orig_prefix = match.group(1)
2421     new_prefix = orig_prefix.lower()
2422     if remove_u_prefix:
2423         new_prefix = new_prefix.replace("u", "")
2424     leaf.value = f"{new_prefix}{match.group(2)}"
2425
2426
2427 def normalize_string_quotes(leaf: Leaf) -> None:
2428     """Prefer double quotes but only if it doesn't cause more escaping.
2429
2430     Adds or removes backslashes as appropriate. Doesn't parse and fix
2431     strings nested in f-strings (yet).
2432
2433     Note: Mutates its argument.
2434     """
2435     value = leaf.value.lstrip("furbFURB")
2436     if value[:3] == '"""':
2437         return
2438
2439     elif value[:3] == "'''":
2440         orig_quote = "'''"
2441         new_quote = '"""'
2442     elif value[0] == '"':
2443         orig_quote = '"'
2444         new_quote = "'"
2445     else:
2446         orig_quote = "'"
2447         new_quote = '"'
2448     first_quote_pos = leaf.value.find(orig_quote)
2449     if first_quote_pos == -1:
2450         return  # There's an internal error
2451
2452     prefix = leaf.value[:first_quote_pos]
2453     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2454     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2455     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2456     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2457     if "r" in prefix.casefold():
2458         if unescaped_new_quote.search(body):
2459             # There's at least one unescaped new_quote in this raw string
2460             # so converting is impossible
2461             return
2462
2463         # Do not introduce or remove backslashes in raw strings
2464         new_body = body
2465     else:
2466         # remove unnecessary quotes
2467         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2468         if body != new_body:
2469             # Consider the string without unnecessary quotes as the original
2470             body = new_body
2471             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2472         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2473         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2474     if new_quote == '"""' and new_body[-1] == '"':
2475         # edge case:
2476         new_body = new_body[:-1] + '\\"'
2477     orig_escape_count = body.count("\\")
2478     new_escape_count = new_body.count("\\")
2479     if new_escape_count > orig_escape_count:
2480         return  # Do not introduce more escaping
2481
2482     if new_escape_count == orig_escape_count and orig_quote == '"':
2483         return  # Prefer double quotes
2484
2485     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2486
2487
2488 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2489     """Make existing optional parentheses invisible or create new ones.
2490
2491     `parens_after` is a set of string leaf values immeditely after which parens
2492     should be put.
2493
2494     Standardizes on visible parentheses for single-element tuples, and keeps
2495     existing visible parentheses for other tuples and generator expressions.
2496     """
2497     try:
2498         list(generate_comments(node))
2499     except FormatOff:
2500         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2501
2502     check_lpar = False
2503     for index, child in enumerate(list(node.children)):
2504         if check_lpar:
2505             if child.type == syms.atom:
2506                 maybe_make_parens_invisible_in_atom(child)
2507             elif is_one_tuple(child):
2508                 # wrap child in visible parentheses
2509                 lpar = Leaf(token.LPAR, "(")
2510                 rpar = Leaf(token.RPAR, ")")
2511                 child.remove()
2512                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2513             elif node.type == syms.import_from:
2514                 # "import from" nodes store parentheses directly as part of
2515                 # the statement
2516                 if child.type == token.LPAR:
2517                     # make parentheses invisible
2518                     child.value = ""  # type: ignore
2519                     node.children[-1].value = ""  # type: ignore
2520                 elif child.type != token.STAR:
2521                     # insert invisible parentheses
2522                     node.insert_child(index, Leaf(token.LPAR, ""))
2523                     node.append_child(Leaf(token.RPAR, ""))
2524                 break
2525
2526             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2527                 # wrap child in invisible parentheses
2528                 lpar = Leaf(token.LPAR, "")
2529                 rpar = Leaf(token.RPAR, "")
2530                 index = child.remove() or 0
2531                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2532
2533         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2534
2535
2536 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2537     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2538     if (
2539         node.type != syms.atom
2540         or is_empty_tuple(node)
2541         or is_one_tuple(node)
2542         or is_yield(node)
2543         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2544     ):
2545         return False
2546
2547     first = node.children[0]
2548     last = node.children[-1]
2549     if first.type == token.LPAR and last.type == token.RPAR:
2550         # make parentheses invisible
2551         first.value = ""  # type: ignore
2552         last.value = ""  # type: ignore
2553         if len(node.children) > 1:
2554             maybe_make_parens_invisible_in_atom(node.children[1])
2555         return True
2556
2557     return False
2558
2559
2560 def is_empty_tuple(node: LN) -> bool:
2561     """Return True if `node` holds an empty tuple."""
2562     return (
2563         node.type == syms.atom
2564         and len(node.children) == 2
2565         and node.children[0].type == token.LPAR
2566         and node.children[1].type == token.RPAR
2567     )
2568
2569
2570 def is_one_tuple(node: LN) -> bool:
2571     """Return True if `node` holds a tuple with one element, with or without parens."""
2572     if node.type == syms.atom:
2573         if len(node.children) != 3:
2574             return False
2575
2576         lpar, gexp, rpar = node.children
2577         if not (
2578             lpar.type == token.LPAR
2579             and gexp.type == syms.testlist_gexp
2580             and rpar.type == token.RPAR
2581         ):
2582             return False
2583
2584         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2585
2586     return (
2587         node.type in IMPLICIT_TUPLE
2588         and len(node.children) == 2
2589         and node.children[1].type == token.COMMA
2590     )
2591
2592
2593 def is_yield(node: LN) -> bool:
2594     """Return True if `node` holds a `yield` or `yield from` expression."""
2595     if node.type == syms.yield_expr:
2596         return True
2597
2598     if node.type == token.NAME and node.value == "yield":  # type: ignore
2599         return True
2600
2601     if node.type != syms.atom:
2602         return False
2603
2604     if len(node.children) != 3:
2605         return False
2606
2607     lpar, expr, rpar = node.children
2608     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2609         return is_yield(expr)
2610
2611     return False
2612
2613
2614 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2615     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2616
2617     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2618     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2619     extended iterable unpacking (PEP 3132) and additional unpacking
2620     generalizations (PEP 448).
2621     """
2622     if leaf.type not in STARS or not leaf.parent:
2623         return False
2624
2625     p = leaf.parent
2626     if p.type == syms.star_expr:
2627         # Star expressions are also used as assignment targets in extended
2628         # iterable unpacking (PEP 3132).  See what its parent is instead.
2629         if not p.parent:
2630             return False
2631
2632         p = p.parent
2633
2634     return p.type in within
2635
2636
2637 def is_multiline_string(leaf: Leaf) -> bool:
2638     """Return True if `leaf` is a multiline string that actually spans many lines."""
2639     value = leaf.value.lstrip("furbFURB")
2640     return value[:3] in {'"""', "'''"} and "\n" in value
2641
2642
2643 def is_stub_suite(node: Node) -> bool:
2644     """Return True if `node` is a suite with a stub body."""
2645     if (
2646         len(node.children) != 4
2647         or node.children[0].type != token.NEWLINE
2648         or node.children[1].type != token.INDENT
2649         or node.children[3].type != token.DEDENT
2650     ):
2651         return False
2652
2653     return is_stub_body(node.children[2])
2654
2655
2656 def is_stub_body(node: LN) -> bool:
2657     """Return True if `node` is a simple statement containing an ellipsis."""
2658     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2659         return False
2660
2661     if len(node.children) != 2:
2662         return False
2663
2664     child = node.children[0]
2665     return (
2666         child.type == syms.atom
2667         and len(child.children) == 3
2668         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2669     )
2670
2671
2672 def max_delimiter_priority_in_atom(node: LN) -> int:
2673     """Return maximum delimiter priority inside `node`.
2674
2675     This is specific to atoms with contents contained in a pair of parentheses.
2676     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2677     """
2678     if node.type != syms.atom:
2679         return 0
2680
2681     first = node.children[0]
2682     last = node.children[-1]
2683     if not (first.type == token.LPAR and last.type == token.RPAR):
2684         return 0
2685
2686     bt = BracketTracker()
2687     for c in node.children[1:-1]:
2688         if isinstance(c, Leaf):
2689             bt.mark(c)
2690         else:
2691             for leaf in c.leaves():
2692                 bt.mark(leaf)
2693     try:
2694         return bt.max_delimiter_priority()
2695
2696     except ValueError:
2697         return 0
2698
2699
2700 def ensure_visible(leaf: Leaf) -> None:
2701     """Make sure parentheses are visible.
2702
2703     They could be invisible as part of some statements (see
2704     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2705     """
2706     if leaf.type == token.LPAR:
2707         leaf.value = "("
2708     elif leaf.type == token.RPAR:
2709         leaf.value = ")"
2710
2711
2712 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2713     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2714     if not (
2715         opening_bracket.parent
2716         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2717         and opening_bracket.value in "[{("
2718     ):
2719         return False
2720
2721     try:
2722         last_leaf = line.leaves[-1]
2723         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2724         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2725     except (IndexError, ValueError):
2726         return False
2727
2728     return max_priority == COMMA_PRIORITY
2729
2730
2731 def is_python36(node: Node) -> bool:
2732     """Return True if the current file is using Python 3.6+ features.
2733
2734     Currently looking for:
2735     - f-strings; and
2736     - trailing commas after * or ** in function signatures and calls.
2737     """
2738     for n in node.pre_order():
2739         if n.type == token.STRING:
2740             value_head = n.value[:2]  # type: ignore
2741             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2742                 return True
2743
2744         elif (
2745             n.type in {syms.typedargslist, syms.arglist}
2746             and n.children
2747             and n.children[-1].type == token.COMMA
2748         ):
2749             for ch in n.children:
2750                 if ch.type in STARS:
2751                     return True
2752
2753                 if ch.type == syms.argument:
2754                     for argch in ch.children:
2755                         if argch.type in STARS:
2756                             return True
2757
2758     return False
2759
2760
2761 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2762     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2763
2764     Brackets can be omitted if the entire trailer up to and including
2765     a preceding closing bracket fits in one line.
2766
2767     Yielded sets are cumulative (contain results of previous yields, too).  First
2768     set is empty.
2769     """
2770
2771     omit: Set[LeafID] = set()
2772     yield omit
2773
2774     length = 4 * line.depth
2775     opening_bracket = None
2776     closing_bracket = None
2777     optional_brackets: Set[LeafID] = set()
2778     inner_brackets: Set[LeafID] = set()
2779     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2780         length += leaf_length
2781         if length > line_length:
2782             break
2783
2784         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2785         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2786             break
2787
2788         optional_brackets.discard(id(leaf))
2789         if opening_bracket:
2790             if leaf is opening_bracket:
2791                 opening_bracket = None
2792             elif leaf.type in CLOSING_BRACKETS:
2793                 inner_brackets.add(id(leaf))
2794         elif leaf.type in CLOSING_BRACKETS:
2795             if not leaf.value:
2796                 optional_brackets.add(id(opening_bracket))
2797                 continue
2798
2799             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2800                 # Empty brackets would fail a split so treat them as "inner"
2801                 # brackets (e.g. only add them to the `omit` set if another
2802                 # pair of brackets was good enough.
2803                 inner_brackets.add(id(leaf))
2804                 continue
2805
2806             opening_bracket = leaf.opening_bracket
2807             if closing_bracket:
2808                 omit.add(id(closing_bracket))
2809                 omit.update(inner_brackets)
2810                 inner_brackets.clear()
2811                 yield omit
2812             closing_bracket = leaf
2813
2814
2815 def get_future_imports(node: Node) -> Set[str]:
2816     """Return a set of __future__ imports in the file."""
2817     imports = set()
2818     for child in node.children:
2819         if child.type != syms.simple_stmt:
2820             break
2821         first_child = child.children[0]
2822         if isinstance(first_child, Leaf):
2823             # Continue looking if we see a docstring; otherwise stop.
2824             if (
2825                 len(child.children) == 2
2826                 and first_child.type == token.STRING
2827                 and child.children[1].type == token.NEWLINE
2828             ):
2829                 continue
2830             else:
2831                 break
2832         elif first_child.type == syms.import_from:
2833             module_name = first_child.children[1]
2834             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2835                 break
2836             for import_from_child in first_child.children[3:]:
2837                 if isinstance(import_from_child, Leaf):
2838                     if import_from_child.type == token.NAME:
2839                         imports.add(import_from_child.value)
2840                 else:
2841                     assert import_from_child.type == syms.import_as_names
2842                     for leaf in import_from_child.children:
2843                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2844                             imports.add(leaf.value)
2845         else:
2846             break
2847     return imports
2848
2849
2850 def gen_python_files_in_dir(
2851     path: Path,
2852     root: Path,
2853     include: Pattern[str],
2854     exclude: Pattern[str],
2855     report: "Report",
2856 ) -> Iterator[Path]:
2857     """Generate all files under `path` whose paths are not excluded by the
2858     `exclude` regex, but are included by the `include` regex.
2859
2860     `report` is where output about exclusions goes.
2861     """
2862     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2863     for child in path.iterdir():
2864         normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2865         if child.is_dir():
2866             normalized_path += "/"
2867         exclude_match = exclude.search(normalized_path)
2868         if exclude_match and exclude_match.group(0):
2869             report.path_ignored(child, f"matches --exclude={exclude.pattern}")
2870             continue
2871
2872         if child.is_dir():
2873             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2874
2875         elif child.is_file():
2876             include_match = include.search(normalized_path)
2877             if include_match:
2878                 yield child
2879
2880
2881 def find_project_root(srcs: List[str]) -> Path:
2882     """Return a directory containing .git, .hg, or pyproject.toml.
2883
2884     That directory can be one of the directories passed in `srcs` or their
2885     common parent.
2886
2887     If no directory in the tree contains a marker that would specify it's the
2888     project root, the root of the file system is returned.
2889     """
2890     if not srcs:
2891         return Path("/").resolve()
2892
2893     common_base = min(Path(src).resolve() for src in srcs)
2894     if common_base.is_dir():
2895         # Append a fake file so `parents` below returns `common_base_dir`, too.
2896         common_base /= "fake-file"
2897     for directory in common_base.parents:
2898         if (directory / ".git").is_dir():
2899             return directory
2900
2901         if (directory / ".hg").is_dir():
2902             return directory
2903
2904         if (directory / "pyproject.toml").is_file():
2905             return directory
2906
2907     return directory
2908
2909
2910 @dataclass
2911 class Report:
2912     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2913
2914     check: bool = False
2915     quiet: bool = False
2916     verbose: bool = False
2917     change_count: int = 0
2918     same_count: int = 0
2919     failure_count: int = 0
2920
2921     def done(self, src: Path, changed: Changed) -> None:
2922         """Increment the counter for successful reformatting. Write out a message."""
2923         if changed is Changed.YES:
2924             reformatted = "would reformat" if self.check else "reformatted"
2925             if self.verbose or not self.quiet:
2926                 out(f"{reformatted} {src}")
2927             self.change_count += 1
2928         else:
2929             if self.verbose:
2930                 if changed is Changed.NO:
2931                     msg = f"{src} already well formatted, good job."
2932                 else:
2933                     msg = f"{src} wasn't modified on disk since last run."
2934                 out(msg, bold=False)
2935             self.same_count += 1
2936
2937     def failed(self, src: Path, message: str) -> None:
2938         """Increment the counter for failed reformatting. Write out a message."""
2939         err(f"error: cannot format {src}: {message}")
2940         self.failure_count += 1
2941
2942     def path_ignored(self, path: Path, message: str) -> None:
2943         if self.verbose:
2944             out(f"{path} ignored: {message}", bold=False)
2945
2946     @property
2947     def return_code(self) -> int:
2948         """Return the exit code that the app should use.
2949
2950         This considers the current state of changed files and failures:
2951         - if there were any failures, return 123;
2952         - if any files were changed and --check is being used, return 1;
2953         - otherwise return 0.
2954         """
2955         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2956         # 126 we have special returncodes reserved by the shell.
2957         if self.failure_count:
2958             return 123
2959
2960         elif self.change_count and self.check:
2961             return 1
2962
2963         return 0
2964
2965     def __str__(self) -> str:
2966         """Render a color report of the current state.
2967
2968         Use `click.unstyle` to remove colors.
2969         """
2970         if self.check:
2971             reformatted = "would be reformatted"
2972             unchanged = "would be left unchanged"
2973             failed = "would fail to reformat"
2974         else:
2975             reformatted = "reformatted"
2976             unchanged = "left unchanged"
2977             failed = "failed to reformat"
2978         report = []
2979         if self.change_count:
2980             s = "s" if self.change_count > 1 else ""
2981             report.append(
2982                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2983             )
2984         if self.same_count:
2985             s = "s" if self.same_count > 1 else ""
2986             report.append(f"{self.same_count} file{s} {unchanged}")
2987         if self.failure_count:
2988             s = "s" if self.failure_count > 1 else ""
2989             report.append(
2990                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2991             )
2992         return ", ".join(report) + "."
2993
2994
2995 def assert_equivalent(src: str, dst: str) -> None:
2996     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2997
2998     import ast
2999     import traceback
3000
3001     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3002         """Simple visitor generating strings to compare ASTs by content."""
3003         yield f"{'  ' * depth}{node.__class__.__name__}("
3004
3005         for field in sorted(node._fields):
3006             try:
3007                 value = getattr(node, field)
3008             except AttributeError:
3009                 continue
3010
3011             yield f"{'  ' * (depth+1)}{field}="
3012
3013             if isinstance(value, list):
3014                 for item in value:
3015                     if isinstance(item, ast.AST):
3016                         yield from _v(item, depth + 2)
3017
3018             elif isinstance(value, ast.AST):
3019                 yield from _v(value, depth + 2)
3020
3021             else:
3022                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3023
3024         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3025
3026     try:
3027         src_ast = ast.parse(src)
3028     except Exception as exc:
3029         major, minor = sys.version_info[:2]
3030         raise AssertionError(
3031             f"cannot use --safe with this file; failed to parse source file "
3032             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3033             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3034         )
3035
3036     try:
3037         dst_ast = ast.parse(dst)
3038     except Exception as exc:
3039         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3040         raise AssertionError(
3041             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3042             f"Please report a bug on https://github.com/ambv/black/issues.  "
3043             f"This invalid output might be helpful: {log}"
3044         ) from None
3045
3046     src_ast_str = "\n".join(_v(src_ast))
3047     dst_ast_str = "\n".join(_v(dst_ast))
3048     if src_ast_str != dst_ast_str:
3049         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3050         raise AssertionError(
3051             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3052             f"the source.  "
3053             f"Please report a bug on https://github.com/ambv/black/issues.  "
3054             f"This diff might be helpful: {log}"
3055         ) from None
3056
3057
3058 def assert_stable(
3059     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3060 ) -> None:
3061     """Raise AssertionError if `dst` reformats differently the second time."""
3062     newdst = format_str(dst, line_length=line_length, mode=mode)
3063     if dst != newdst:
3064         log = dump_to_file(
3065             diff(src, dst, "source", "first pass"),
3066             diff(dst, newdst, "first pass", "second pass"),
3067         )
3068         raise AssertionError(
3069             f"INTERNAL ERROR: Black produced different code on the second pass "
3070             f"of the formatter.  "
3071             f"Please report a bug on https://github.com/ambv/black/issues.  "
3072             f"This diff might be helpful: {log}"
3073         ) from None
3074
3075
3076 def dump_to_file(*output: str) -> str:
3077     """Dump `output` to a temporary file. Return path to the file."""
3078     import tempfile
3079
3080     with tempfile.NamedTemporaryFile(
3081         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3082     ) as f:
3083         for lines in output:
3084             f.write(lines)
3085             if lines and lines[-1] != "\n":
3086                 f.write("\n")
3087     return f.name
3088
3089
3090 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3091     """Return a unified diff string between strings `a` and `b`."""
3092     import difflib
3093
3094     a_lines = [line + "\n" for line in a.split("\n")]
3095     b_lines = [line + "\n" for line in b.split("\n")]
3096     return "".join(
3097         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3098     )
3099
3100
3101 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3102     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3103     err("Aborted!")
3104     for task in tasks:
3105         task.cancel()
3106
3107
3108 def shutdown(loop: BaseEventLoop) -> None:
3109     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3110     try:
3111         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3112         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3113         if not to_cancel:
3114             return
3115
3116         for task in to_cancel:
3117             task.cancel()
3118         loop.run_until_complete(
3119             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3120         )
3121     finally:
3122         # `concurrent.futures.Future` objects cannot be cancelled once they
3123         # are already running. There might be some when the `shutdown()` happened.
3124         # Silence their logger's spew about the event loop being closed.
3125         cf_logger = logging.getLogger("concurrent.futures")
3126         cf_logger.setLevel(logging.CRITICAL)
3127         loop.close()
3128
3129
3130 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3131     """Replace `regex` with `replacement` twice on `original`.
3132
3133     This is used by string normalization to perform replaces on
3134     overlapping matches.
3135     """
3136     return regex.sub(replacement, regex.sub(replacement, original))
3137
3138
3139 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3140     """Like `reversed(enumerate(sequence))` if that were possible."""
3141     index = len(sequence) - 1
3142     for element in reversed(sequence):
3143         yield (index, element)
3144         index -= 1
3145
3146
3147 def enumerate_with_length(
3148     line: Line, reversed: bool = False
3149 ) -> Iterator[Tuple[Index, Leaf, int]]:
3150     """Return an enumeration of leaves with their length.
3151
3152     Stops prematurely on multiline strings and standalone comments.
3153     """
3154     op = cast(
3155         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3156         enumerate_reversed if reversed else enumerate,
3157     )
3158     for index, leaf in op(line.leaves):
3159         length = len(leaf.prefix) + len(leaf.value)
3160         if "\n" in leaf.value:
3161             return  # Multiline strings, we can't continue.
3162
3163         comment: Optional[Leaf]
3164         for comment in line.comments_after(leaf, index):
3165             length += len(comment.value)
3166
3167         yield index, leaf, length
3168
3169
3170 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3171     """Return True if `line` is no longer than `line_length`.
3172
3173     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3174     """
3175     if not line_str:
3176         line_str = str(line).strip("\n")
3177     return (
3178         len(line_str) <= line_length
3179         and "\n" not in line_str  # multiline strings
3180         and not line.contains_standalone_comments()
3181     )
3182
3183
3184 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3185     """Does `line` have a shape safe to reformat without optional parens around it?
3186
3187     Returns True for only a subset of potentially nice looking formattings but
3188     the point is to not return false positives that end up producing lines that
3189     are too long.
3190     """
3191     bt = line.bracket_tracker
3192     if not bt.delimiters:
3193         # Without delimiters the optional parentheses are useless.
3194         return True
3195
3196     max_priority = bt.max_delimiter_priority()
3197     if bt.delimiter_count_with_priority(max_priority) > 1:
3198         # With more than one delimiter of a kind the optional parentheses read better.
3199         return False
3200
3201     if max_priority == DOT_PRIORITY:
3202         # A single stranded method call doesn't require optional parentheses.
3203         return True
3204
3205     assert len(line.leaves) >= 2, "Stranded delimiter"
3206
3207     first = line.leaves[0]
3208     second = line.leaves[1]
3209     penultimate = line.leaves[-2]
3210     last = line.leaves[-1]
3211
3212     # With a single delimiter, omit if the expression starts or ends with
3213     # a bracket.
3214     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3215         remainder = False
3216         length = 4 * line.depth
3217         for _index, leaf, leaf_length in enumerate_with_length(line):
3218             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3219                 remainder = True
3220             if remainder:
3221                 length += leaf_length
3222                 if length > line_length:
3223                     break
3224
3225                 if leaf.type in OPENING_BRACKETS:
3226                     # There are brackets we can further split on.
3227                     remainder = False
3228
3229         else:
3230             # checked the entire string and line length wasn't exceeded
3231             if len(line.leaves) == _index + 1:
3232                 return True
3233
3234         # Note: we are not returning False here because a line might have *both*
3235         # a leading opening bracket and a trailing closing bracket.  If the
3236         # opening bracket doesn't match our rule, maybe the closing will.
3237
3238     if (
3239         last.type == token.RPAR
3240         or last.type == token.RBRACE
3241         or (
3242             # don't use indexing for omitting optional parentheses;
3243             # it looks weird
3244             last.type == token.RSQB
3245             and last.parent
3246             and last.parent.type != syms.trailer
3247         )
3248     ):
3249         if penultimate.type in OPENING_BRACKETS:
3250             # Empty brackets don't help.
3251             return False
3252
3253         if is_multiline_string(first):
3254             # Additional wrapping of a multiline string in this situation is
3255             # unnecessary.
3256             return True
3257
3258         length = 4 * line.depth
3259         seen_other_brackets = False
3260         for _index, leaf, leaf_length in enumerate_with_length(line):
3261             length += leaf_length
3262             if leaf is last.opening_bracket:
3263                 if seen_other_brackets or length <= line_length:
3264                     return True
3265
3266             elif leaf.type in OPENING_BRACKETS:
3267                 # There are brackets we can further split on.
3268                 seen_other_brackets = True
3269
3270     return False
3271
3272
3273 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3274     pyi = bool(mode & FileMode.PYI)
3275     py36 = bool(mode & FileMode.PYTHON36)
3276     return (
3277         CACHE_DIR
3278         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3279     )
3280
3281
3282 def read_cache(line_length: int, mode: FileMode) -> Cache:
3283     """Read the cache if it exists and is well formed.
3284
3285     If it is not well formed, the call to write_cache later should resolve the issue.
3286     """
3287     cache_file = get_cache_file(line_length, mode)
3288     if not cache_file.exists():
3289         return {}
3290
3291     with cache_file.open("rb") as fobj:
3292         try:
3293             cache: Cache = pickle.load(fobj)
3294         except pickle.UnpicklingError:
3295             return {}
3296
3297     return cache
3298
3299
3300 def get_cache_info(path: Path) -> CacheInfo:
3301     """Return the information used to check if a file is already formatted or not."""
3302     stat = path.stat()
3303     return stat.st_mtime, stat.st_size
3304
3305
3306 def filter_cached(
3307     cache: Cache, sources: Iterable[Path]
3308 ) -> Tuple[List[Path], List[Path]]:
3309     """Split a list of paths into two.
3310
3311     The first list contains paths of files that modified on disk or are not in the
3312     cache. The other list contains paths to non-modified files.
3313     """
3314     todo, done = [], []
3315     for src in sources:
3316         src = src.resolve()
3317         if cache.get(src) != get_cache_info(src):
3318             todo.append(src)
3319         else:
3320             done.append(src)
3321     return todo, done
3322
3323
3324 def write_cache(
3325     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3326 ) -> None:
3327     """Update the cache file."""
3328     cache_file = get_cache_file(line_length, mode)
3329     try:
3330         if not CACHE_DIR.exists():
3331             CACHE_DIR.mkdir(parents=True)
3332         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3333         with cache_file.open("wb") as fobj:
3334             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3335     except OSError:
3336         pass
3337
3338
3339 if __name__ == "__main__":
3340     main()