]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Consider stars in testlist_star_expr unpacking (because they are)
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Sequence,
30     Set,
31     Tuple,
32     Type,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41
42 # lib2to3 fork
43 from blib2to3.pytree import Node, Leaf, type_repr
44 from blib2to3 import pygram, pytree
45 from blib2to3.pgen2 import driver, token
46 from blib2to3.pgen2.parse import ParseError
47
48
49 __version__ = "18.5b1"
50 DEFAULT_LINE_LENGTH = 88
51 DEFAULT_EXCLUDES = (
52     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
53 )
54 DEFAULT_INCLUDES = r"\.pyi?$"
55 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
56
57
58 # types
59 FileContent = str
60 Encoding = str
61 NewLine = str
62 Depth = int
63 NodeType = int
64 LeafID = int
65 Priority = int
66 Index = int
67 LN = Union[Leaf, Node]
68 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
69 Timestamp = float
70 FileSize = int
71 CacheInfo = Tuple[Timestamp, FileSize]
72 Cache = Dict[Path, CacheInfo]
73 out = partial(click.secho, bold=True, err=True)
74 err = partial(click.secho, fg="red", err=True)
75
76 pygram.initialize(CACHE_DIR)
77 syms = pygram.python_symbols
78
79
80 class NothingChanged(UserWarning):
81     """Raised by :func:`format_file` when reformatted code is the same as source."""
82
83
84 class CannotSplit(Exception):
85     """A readable split that fits the allotted line length is impossible.
86
87     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
88     :func:`delimiter_split`.
89     """
90
91
92 class FormatError(Exception):
93     """Base exception for `# fmt: on` and `# fmt: off` handling.
94
95     It holds the number of bytes of the prefix consumed before the format
96     control comment appeared.
97     """
98
99     def __init__(self, consumed: int) -> None:
100         super().__init__(consumed)
101         self.consumed = consumed
102
103     def trim_prefix(self, leaf: Leaf) -> None:
104         leaf.prefix = leaf.prefix[self.consumed :]
105
106     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
107         """Returns a new Leaf from the consumed part of the prefix."""
108         unformatted_prefix = leaf.prefix[: self.consumed]
109         return Leaf(token.NEWLINE, unformatted_prefix)
110
111
112 class FormatOn(FormatError):
113     """Found a comment like `# fmt: on` in the file."""
114
115
116 class FormatOff(FormatError):
117     """Found a comment like `# fmt: off` in the file."""
118
119
120 class WriteBack(Enum):
121     NO = 0
122     YES = 1
123     DIFF = 2
124
125     @classmethod
126     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
127         if check and not diff:
128             return cls.NO
129
130         return cls.DIFF if diff else cls.YES
131
132
133 class Changed(Enum):
134     NO = 0
135     CACHED = 1
136     YES = 2
137
138
139 class FileMode(Flag):
140     AUTO_DETECT = 0
141     PYTHON36 = 1
142     PYI = 2
143     NO_STRING_NORMALIZATION = 4
144
145     @classmethod
146     def from_configuration(
147         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
148     ) -> "FileMode":
149         mode = cls.AUTO_DETECT
150         if py36:
151             mode |= cls.PYTHON36
152         if pyi:
153             mode |= cls.PYI
154         if skip_string_normalization:
155             mode |= cls.NO_STRING_NORMALIZATION
156         return mode
157
158
159 @click.command()
160 @click.option(
161     "-l",
162     "--line-length",
163     type=int,
164     default=DEFAULT_LINE_LENGTH,
165     help="How many character per line to allow.",
166     show_default=True,
167 )
168 @click.option(
169     "--py36",
170     is_flag=True,
171     help=(
172         "Allow using Python 3.6-only syntax on all input files.  This will put "
173         "trailing commas in function signatures and calls also after *args and "
174         "**kwargs.  [default: per-file auto-detection]"
175     ),
176 )
177 @click.option(
178     "--pyi",
179     is_flag=True,
180     help=(
181         "Format all input files like typing stubs regardless of file extension "
182         "(useful when piping source on standard input)."
183     ),
184 )
185 @click.option(
186     "-S",
187     "--skip-string-normalization",
188     is_flag=True,
189     help="Don't normalize string quotes or prefixes.",
190 )
191 @click.option(
192     "--check",
193     is_flag=True,
194     help=(
195         "Don't write the files back, just return the status.  Return code 0 "
196         "means nothing would change.  Return code 1 means some files would be "
197         "reformatted.  Return code 123 means there was an internal error."
198     ),
199 )
200 @click.option(
201     "--diff",
202     is_flag=True,
203     help="Don't write the files back, just output a diff for each file on stdout.",
204 )
205 @click.option(
206     "--fast/--safe",
207     is_flag=True,
208     help="If --fast given, skip temporary sanity checks. [default: --safe]",
209 )
210 @click.option(
211     "--include",
212     type=str,
213     default=DEFAULT_INCLUDES,
214     help=(
215         "A regular expression that matches files and directories that should be "
216         "included on recursive searches.  An empty value means all files are "
217         "included regardless of the name.  Use forward slashes for directories on "
218         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
219         "later."
220     ),
221     show_default=True,
222 )
223 @click.option(
224     "--exclude",
225     type=str,
226     default=DEFAULT_EXCLUDES,
227     help=(
228         "A regular expression that matches files and directories that should be "
229         "excluded on recursive searches.  An empty value means no paths are excluded. "
230         "Use forward slashes for directories on all platforms (Windows, too).  "
231         "Exclusions are calculated first, inclusions later."
232     ),
233     show_default=True,
234 )
235 @click.option(
236     "-q",
237     "--quiet",
238     is_flag=True,
239     help=(
240         "Don't emit non-error messages to stderr. Errors are still emitted, "
241         "silence those with 2>/dev/null."
242     ),
243 )
244 @click.option(
245     "-v",
246     "--verbose",
247     is_flag=True,
248     help=(
249         "Also emit messages to stderr about files that were not changed or were "
250         "ignored due to --exclude=."
251     ),
252 )
253 @click.version_option(version=__version__)
254 @click.argument(
255     "src",
256     nargs=-1,
257     type=click.Path(
258         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
259     ),
260 )
261 @click.pass_context
262 def main(
263     ctx: click.Context,
264     line_length: int,
265     check: bool,
266     diff: bool,
267     fast: bool,
268     pyi: bool,
269     py36: bool,
270     skip_string_normalization: bool,
271     quiet: bool,
272     verbose: bool,
273     include: str,
274     exclude: str,
275     src: List[str],
276 ) -> None:
277     """The uncompromising code formatter."""
278     write_back = WriteBack.from_configuration(check=check, diff=diff)
279     mode = FileMode.from_configuration(
280         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
281     )
282     report = Report(check=check, quiet=quiet, verbose=verbose)
283     sources: Set[Path] = set()
284     try:
285         include_regex = re.compile(include)
286     except re.error:
287         err(f"Invalid regular expression for include given: {include!r}")
288         ctx.exit(2)
289     try:
290         exclude_regex = re.compile(exclude)
291     except re.error:
292         err(f"Invalid regular expression for exclude given: {exclude!r}")
293         ctx.exit(2)
294     root = find_project_root(src)
295     for s in src:
296         p = Path(s)
297         if p.is_dir():
298             sources.update(
299                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
300             )
301         elif p.is_file() or s == "-":
302             # if a file was explicitly given, we don't care about its extension
303             sources.add(p)
304         else:
305             err(f"invalid path: {s}")
306     if len(sources) == 0:
307         if verbose or not quiet:
308             out("No paths given. Nothing to do 😴")
309         ctx.exit(0)
310         return
311
312     elif len(sources) == 1:
313         reformat_one(
314             src=sources.pop(),
315             line_length=line_length,
316             fast=fast,
317             write_back=write_back,
318             mode=mode,
319             report=report,
320         )
321     else:
322         loop = asyncio.get_event_loop()
323         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
324         try:
325             loop.run_until_complete(
326                 schedule_formatting(
327                     sources=sources,
328                     line_length=line_length,
329                     fast=fast,
330                     write_back=write_back,
331                     mode=mode,
332                     report=report,
333                     loop=loop,
334                     executor=executor,
335                 )
336             )
337         finally:
338             shutdown(loop)
339     if verbose or not quiet:
340         out("All done! ✨ 🍰 ✨")
341         click.echo(str(report))
342     ctx.exit(report.return_code)
343
344
345 def reformat_one(
346     src: Path,
347     line_length: int,
348     fast: bool,
349     write_back: WriteBack,
350     mode: FileMode,
351     report: "Report",
352 ) -> None:
353     """Reformat a single file under `src` without spawning child processes.
354
355     If `quiet` is True, non-error messages are not output. `line_length`,
356     `write_back`, `fast` and `pyi` options are passed to
357     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
358     """
359     try:
360         changed = Changed.NO
361         if not src.is_file() and str(src) == "-":
362             if format_stdin_to_stdout(
363                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
364             ):
365                 changed = Changed.YES
366         else:
367             cache: Cache = {}
368             if write_back != WriteBack.DIFF:
369                 cache = read_cache(line_length, mode)
370                 res_src = src.resolve()
371                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
372                     changed = Changed.CACHED
373             if changed is not Changed.CACHED and format_file_in_place(
374                 src,
375                 line_length=line_length,
376                 fast=fast,
377                 write_back=write_back,
378                 mode=mode,
379             ):
380                 changed = Changed.YES
381             if write_back == WriteBack.YES and changed is not Changed.NO:
382                 write_cache(cache, [src], line_length, mode)
383         report.done(src, changed)
384     except Exception as exc:
385         report.failed(src, str(exc))
386
387
388 async def schedule_formatting(
389     sources: Set[Path],
390     line_length: int,
391     fast: bool,
392     write_back: WriteBack,
393     mode: FileMode,
394     report: "Report",
395     loop: BaseEventLoop,
396     executor: Executor,
397 ) -> None:
398     """Run formatting of `sources` in parallel using the provided `executor`.
399
400     (Use ProcessPoolExecutors for actual parallelism.)
401
402     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
403     :func:`format_file_in_place`.
404     """
405     cache: Cache = {}
406     if write_back != WriteBack.DIFF:
407         cache = read_cache(line_length, mode)
408         sources, cached = filter_cached(cache, sources)
409         for src in sorted(cached):
410             report.done(src, Changed.CACHED)
411     cancelled = []
412     formatted = []
413     if sources:
414         lock = None
415         if write_back == WriteBack.DIFF:
416             # For diff output, we need locks to ensure we don't interleave output
417             # from different processes.
418             manager = Manager()
419             lock = manager.Lock()
420         tasks = {
421             loop.run_in_executor(
422                 executor,
423                 format_file_in_place,
424                 src,
425                 line_length,
426                 fast,
427                 write_back,
428                 mode,
429                 lock,
430             ): src
431             for src in sorted(sources)
432         }
433         pending: Iterable[asyncio.Task] = tasks.keys()
434         try:
435             loop.add_signal_handler(signal.SIGINT, cancel, pending)
436             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
437         except NotImplementedError:
438             # There are no good alternatives for these on Windows
439             pass
440         while pending:
441             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
442             for task in done:
443                 src = tasks.pop(task)
444                 if task.cancelled():
445                     cancelled.append(task)
446                 elif task.exception():
447                     report.failed(src, str(task.exception()))
448                 else:
449                     formatted.append(src)
450                     report.done(src, Changed.YES if task.result() else Changed.NO)
451     if cancelled:
452         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
453     if write_back == WriteBack.YES and formatted:
454         write_cache(cache, formatted, line_length, mode)
455
456
457 def format_file_in_place(
458     src: Path,
459     line_length: int,
460     fast: bool,
461     write_back: WriteBack = WriteBack.NO,
462     mode: FileMode = FileMode.AUTO_DETECT,
463     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
464 ) -> bool:
465     """Format file under `src` path. Return True if changed.
466
467     If `write_back` is True, write reformatted code back to stdout.
468     `line_length` and `fast` options are passed to :func:`format_file_contents`.
469     """
470     if src.suffix == ".pyi":
471         mode |= FileMode.PYI
472
473     then = datetime.utcfromtimestamp(src.stat().st_mtime)
474     with open(src, "rb") as buf:
475         src_contents, encoding, newline = decode_bytes(buf.read())
476     try:
477         dst_contents = format_file_contents(
478             src_contents, line_length=line_length, fast=fast, mode=mode
479         )
480     except NothingChanged:
481         return False
482
483     if write_back == write_back.YES:
484         with open(src, "w", encoding=encoding, newline=newline) as f:
485             f.write(dst_contents)
486     elif write_back == write_back.DIFF:
487         now = datetime.utcnow()
488         src_name = f"{src}\t{then} +0000"
489         dst_name = f"{src}\t{now} +0000"
490         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
491         if lock:
492             lock.acquire()
493         try:
494             f = io.TextIOWrapper(
495                 sys.stdout.buffer,
496                 encoding=encoding,
497                 newline=newline,
498                 write_through=True,
499             )
500             f.write(diff_contents)
501             f.detach()
502         finally:
503             if lock:
504                 lock.release()
505     return True
506
507
508 def format_stdin_to_stdout(
509     line_length: int,
510     fast: bool,
511     write_back: WriteBack = WriteBack.NO,
512     mode: FileMode = FileMode.AUTO_DETECT,
513 ) -> bool:
514     """Format file on stdin. Return True if changed.
515
516     If `write_back` is True, write reformatted code back to stdout.
517     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
518     :func:`format_file_contents`.
519     """
520     then = datetime.utcnow()
521     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
522     dst = src
523     try:
524         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
525         return True
526
527     except NothingChanged:
528         return False
529
530     finally:
531         f = io.TextIOWrapper(
532             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
533         )
534         if write_back == WriteBack.YES:
535             f.write(dst)
536         elif write_back == WriteBack.DIFF:
537             now = datetime.utcnow()
538             src_name = f"STDIN\t{then} +0000"
539             dst_name = f"STDOUT\t{now} +0000"
540             f.write(diff(src, dst, src_name, dst_name))
541         f.detach()
542
543
544 def format_file_contents(
545     src_contents: str,
546     *,
547     line_length: int,
548     fast: bool,
549     mode: FileMode = FileMode.AUTO_DETECT,
550 ) -> FileContent:
551     """Reformat contents a file and return new contents.
552
553     If `fast` is False, additionally confirm that the reformatted code is
554     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
555     `line_length` is passed to :func:`format_str`.
556     """
557     if src_contents.strip() == "":
558         raise NothingChanged
559
560     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
561     if src_contents == dst_contents:
562         raise NothingChanged
563
564     if not fast:
565         assert_equivalent(src_contents, dst_contents)
566         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
567     return dst_contents
568
569
570 def format_str(
571     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
572 ) -> FileContent:
573     """Reformat a string and return new contents.
574
575     `line_length` determines how many characters per line are allowed.
576     """
577     src_node = lib2to3_parse(src_contents)
578     dst_contents = ""
579     future_imports = get_future_imports(src_node)
580     is_pyi = bool(mode & FileMode.PYI)
581     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
582     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
583     lines = LineGenerator(
584         remove_u_prefix=py36 or "unicode_literals" in future_imports,
585         is_pyi=is_pyi,
586         normalize_strings=normalize_strings,
587     )
588     elt = EmptyLineTracker(is_pyi=is_pyi)
589     empty_line = Line()
590     after = 0
591     for current_line in lines.visit(src_node):
592         for _ in range(after):
593             dst_contents += str(empty_line)
594         before, after = elt.maybe_empty_lines(current_line)
595         for _ in range(before):
596             dst_contents += str(empty_line)
597         for line in split_line(current_line, line_length=line_length, py36=py36):
598             dst_contents += str(line)
599     return dst_contents
600
601
602 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
603     """Return a tuple of (decoded_contents, encoding, newline).
604
605     `newline` is either CRLF or LF but `decoded_contents` is decoded with
606     universal newlines (i.e. only contains LF).
607     """
608     srcbuf = io.BytesIO(src)
609     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
610     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
611     srcbuf.seek(0)
612     with io.TextIOWrapper(srcbuf, encoding) as tiow:
613         return tiow.read(), encoding, newline
614
615
616 GRAMMARS = [
617     pygram.python_grammar_no_print_statement_no_exec_statement,
618     pygram.python_grammar_no_print_statement,
619     pygram.python_grammar,
620 ]
621
622
623 def lib2to3_parse(src_txt: str) -> Node:
624     """Given a string with source, return the lib2to3 Node."""
625     grammar = pygram.python_grammar_no_print_statement
626     if src_txt[-1] != "\n":
627         src_txt += "\n"
628     for grammar in GRAMMARS:
629         drv = driver.Driver(grammar, pytree.convert)
630         try:
631             result = drv.parse_string(src_txt, True)
632             break
633
634         except ParseError as pe:
635             lineno, column = pe.context[1]
636             lines = src_txt.splitlines()
637             try:
638                 faulty_line = lines[lineno - 1]
639             except IndexError:
640                 faulty_line = "<line number missing in source>"
641             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
642     else:
643         raise exc from None
644
645     if isinstance(result, Leaf):
646         result = Node(syms.file_input, [result])
647     return result
648
649
650 def lib2to3_unparse(node: Node) -> str:
651     """Given a lib2to3 node, return its string representation."""
652     code = str(node)
653     return code
654
655
656 T = TypeVar("T")
657
658
659 class Visitor(Generic[T]):
660     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
661
662     def visit(self, node: LN) -> Iterator[T]:
663         """Main method to visit `node` and its children.
664
665         It tries to find a `visit_*()` method for the given `node.type`, like
666         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
667         If no dedicated `visit_*()` method is found, chooses `visit_default()`
668         instead.
669
670         Then yields objects of type `T` from the selected visitor.
671         """
672         if node.type < 256:
673             name = token.tok_name[node.type]
674         else:
675             name = type_repr(node.type)
676         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
677
678     def visit_default(self, node: LN) -> Iterator[T]:
679         """Default `visit_*()` implementation. Recurses to children of `node`."""
680         if isinstance(node, Node):
681             for child in node.children:
682                 yield from self.visit(child)
683
684
685 @dataclass
686 class DebugVisitor(Visitor[T]):
687     tree_depth: int = 0
688
689     def visit_default(self, node: LN) -> Iterator[T]:
690         indent = " " * (2 * self.tree_depth)
691         if isinstance(node, Node):
692             _type = type_repr(node.type)
693             out(f"{indent}{_type}", fg="yellow")
694             self.tree_depth += 1
695             for child in node.children:
696                 yield from self.visit(child)
697
698             self.tree_depth -= 1
699             out(f"{indent}/{_type}", fg="yellow", bold=False)
700         else:
701             _type = token.tok_name.get(node.type, str(node.type))
702             out(f"{indent}{_type}", fg="blue", nl=False)
703             if node.prefix:
704                 # We don't have to handle prefixes for `Node` objects since
705                 # that delegates to the first child anyway.
706                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
707             out(f" {node.value!r}", fg="blue", bold=False)
708
709     @classmethod
710     def show(cls, code: str) -> None:
711         """Pretty-print the lib2to3 AST of a given string of `code`.
712
713         Convenience method for debugging.
714         """
715         v: DebugVisitor[None] = DebugVisitor()
716         list(v.visit(lib2to3_parse(code)))
717
718
719 KEYWORDS = set(keyword.kwlist)
720 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
721 FLOW_CONTROL = {"return", "raise", "break", "continue"}
722 STATEMENT = {
723     syms.if_stmt,
724     syms.while_stmt,
725     syms.for_stmt,
726     syms.try_stmt,
727     syms.except_clause,
728     syms.with_stmt,
729     syms.funcdef,
730     syms.classdef,
731 }
732 STANDALONE_COMMENT = 153
733 LOGIC_OPERATORS = {"and", "or"}
734 COMPARATORS = {
735     token.LESS,
736     token.GREATER,
737     token.EQEQUAL,
738     token.NOTEQUAL,
739     token.LESSEQUAL,
740     token.GREATEREQUAL,
741 }
742 MATH_OPERATORS = {
743     token.VBAR,
744     token.CIRCUMFLEX,
745     token.AMPER,
746     token.LEFTSHIFT,
747     token.RIGHTSHIFT,
748     token.PLUS,
749     token.MINUS,
750     token.STAR,
751     token.SLASH,
752     token.DOUBLESLASH,
753     token.PERCENT,
754     token.AT,
755     token.TILDE,
756     token.DOUBLESTAR,
757 }
758 STARS = {token.STAR, token.DOUBLESTAR}
759 VARARGS_PARENTS = {
760     syms.arglist,
761     syms.argument,  # double star in arglist
762     syms.trailer,  # single argument to call
763     syms.typedargslist,
764     syms.varargslist,  # lambdas
765 }
766 UNPACKING_PARENTS = {
767     syms.atom,  # single element of a list or set literal
768     syms.dictsetmaker,
769     syms.listmaker,
770     syms.testlist_gexp,
771     syms.testlist_star_expr,
772 }
773 TEST_DESCENDANTS = {
774     syms.test,
775     syms.lambdef,
776     syms.or_test,
777     syms.and_test,
778     syms.not_test,
779     syms.comparison,
780     syms.star_expr,
781     syms.expr,
782     syms.xor_expr,
783     syms.and_expr,
784     syms.shift_expr,
785     syms.arith_expr,
786     syms.trailer,
787     syms.term,
788     syms.power,
789 }
790 ASSIGNMENTS = {
791     "=",
792     "+=",
793     "-=",
794     "*=",
795     "@=",
796     "/=",
797     "%=",
798     "&=",
799     "|=",
800     "^=",
801     "<<=",
802     ">>=",
803     "**=",
804     "//=",
805 }
806 COMPREHENSION_PRIORITY = 20
807 COMMA_PRIORITY = 18
808 TERNARY_PRIORITY = 16
809 LOGIC_PRIORITY = 14
810 STRING_PRIORITY = 12
811 COMPARATOR_PRIORITY = 10
812 MATH_PRIORITIES = {
813     token.VBAR: 9,
814     token.CIRCUMFLEX: 8,
815     token.AMPER: 7,
816     token.LEFTSHIFT: 6,
817     token.RIGHTSHIFT: 6,
818     token.PLUS: 5,
819     token.MINUS: 5,
820     token.STAR: 4,
821     token.SLASH: 4,
822     token.DOUBLESLASH: 4,
823     token.PERCENT: 4,
824     token.AT: 4,
825     token.TILDE: 3,
826     token.DOUBLESTAR: 2,
827 }
828 DOT_PRIORITY = 1
829
830
831 @dataclass
832 class BracketTracker:
833     """Keeps track of brackets on a line."""
834
835     depth: int = 0
836     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
837     delimiters: Dict[LeafID, Priority] = Factory(dict)
838     previous: Optional[Leaf] = None
839     _for_loop_variable: int = 0
840     _lambda_arguments: int = 0
841
842     def mark(self, leaf: Leaf) -> None:
843         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
844
845         All leaves receive an int `bracket_depth` field that stores how deep
846         within brackets a given leaf is. 0 means there are no enclosing brackets
847         that started on this line.
848
849         If a leaf is itself a closing bracket, it receives an `opening_bracket`
850         field that it forms a pair with. This is a one-directional link to
851         avoid reference cycles.
852
853         If a leaf is a delimiter (a token on which Black can split the line if
854         needed) and it's on depth 0, its `id()` is stored in the tracker's
855         `delimiters` field.
856         """
857         if leaf.type == token.COMMENT:
858             return
859
860         self.maybe_decrement_after_for_loop_variable(leaf)
861         self.maybe_decrement_after_lambda_arguments(leaf)
862         if leaf.type in CLOSING_BRACKETS:
863             self.depth -= 1
864             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
865             leaf.opening_bracket = opening_bracket
866         leaf.bracket_depth = self.depth
867         if self.depth == 0:
868             delim = is_split_before_delimiter(leaf, self.previous)
869             if delim and self.previous is not None:
870                 self.delimiters[id(self.previous)] = delim
871             else:
872                 delim = is_split_after_delimiter(leaf, self.previous)
873                 if delim:
874                     self.delimiters[id(leaf)] = delim
875         if leaf.type in OPENING_BRACKETS:
876             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
877             self.depth += 1
878         self.previous = leaf
879         self.maybe_increment_lambda_arguments(leaf)
880         self.maybe_increment_for_loop_variable(leaf)
881
882     def any_open_brackets(self) -> bool:
883         """Return True if there is an yet unmatched open bracket on the line."""
884         return bool(self.bracket_match)
885
886     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
887         """Return the highest priority of a delimiter found on the line.
888
889         Values are consistent with what `is_split_*_delimiter()` return.
890         Raises ValueError on no delimiters.
891         """
892         return max(v for k, v in self.delimiters.items() if k not in exclude)
893
894     def delimiter_count_with_priority(self, priority: int = 0) -> int:
895         """Return the number of delimiters with the given `priority`.
896
897         If no `priority` is passed, defaults to max priority on the line.
898         """
899         if not self.delimiters:
900             return 0
901
902         priority = priority or self.max_delimiter_priority()
903         return sum(1 for p in self.delimiters.values() if p == priority)
904
905     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
906         """In a for loop, or comprehension, the variables are often unpacks.
907
908         To avoid splitting on the comma in this situation, increase the depth of
909         tokens between `for` and `in`.
910         """
911         if leaf.type == token.NAME and leaf.value == "for":
912             self.depth += 1
913             self._for_loop_variable += 1
914             return True
915
916         return False
917
918     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
919         """See `maybe_increment_for_loop_variable` above for explanation."""
920         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
921             self.depth -= 1
922             self._for_loop_variable -= 1
923             return True
924
925         return False
926
927     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
928         """In a lambda expression, there might be more than one argument.
929
930         To avoid splitting on the comma in this situation, increase the depth of
931         tokens between `lambda` and `:`.
932         """
933         if leaf.type == token.NAME and leaf.value == "lambda":
934             self.depth += 1
935             self._lambda_arguments += 1
936             return True
937
938         return False
939
940     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
941         """See `maybe_increment_lambda_arguments` above for explanation."""
942         if self._lambda_arguments and leaf.type == token.COLON:
943             self.depth -= 1
944             self._lambda_arguments -= 1
945             return True
946
947         return False
948
949     def get_open_lsqb(self) -> Optional[Leaf]:
950         """Return the most recent opening square bracket (if any)."""
951         return self.bracket_match.get((self.depth - 1, token.RSQB))
952
953
954 @dataclass
955 class Line:
956     """Holds leaves and comments. Can be printed with `str(line)`."""
957
958     depth: int = 0
959     leaves: List[Leaf] = Factory(list)
960     comments: List[Tuple[Index, Leaf]] = Factory(list)
961     bracket_tracker: BracketTracker = Factory(BracketTracker)
962     inside_brackets: bool = False
963     should_explode: bool = False
964
965     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
966         """Add a new `leaf` to the end of the line.
967
968         Unless `preformatted` is True, the `leaf` will receive a new consistent
969         whitespace prefix and metadata applied by :class:`BracketTracker`.
970         Trailing commas are maybe removed, unpacked for loop variables are
971         demoted from being delimiters.
972
973         Inline comments are put aside.
974         """
975         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
976         if not has_value:
977             return
978
979         if token.COLON == leaf.type and self.is_class_paren_empty:
980             del self.leaves[-2:]
981         if self.leaves and not preformatted:
982             # Note: at this point leaf.prefix should be empty except for
983             # imports, for which we only preserve newlines.
984             leaf.prefix += whitespace(
985                 leaf, complex_subscript=self.is_complex_subscript(leaf)
986             )
987         if self.inside_brackets or not preformatted:
988             self.bracket_tracker.mark(leaf)
989             self.maybe_remove_trailing_comma(leaf)
990         if not self.append_comment(leaf):
991             self.leaves.append(leaf)
992
993     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
994         """Like :func:`append()` but disallow invalid standalone comment structure.
995
996         Raises ValueError when any `leaf` is appended after a standalone comment
997         or when a standalone comment is not the first leaf on the line.
998         """
999         if self.bracket_tracker.depth == 0:
1000             if self.is_comment:
1001                 raise ValueError("cannot append to standalone comments")
1002
1003             if self.leaves and leaf.type == STANDALONE_COMMENT:
1004                 raise ValueError(
1005                     "cannot append standalone comments to a populated line"
1006                 )
1007
1008         self.append(leaf, preformatted=preformatted)
1009
1010     @property
1011     def is_comment(self) -> bool:
1012         """Is this line a standalone comment?"""
1013         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1014
1015     @property
1016     def is_decorator(self) -> bool:
1017         """Is this line a decorator?"""
1018         return bool(self) and self.leaves[0].type == token.AT
1019
1020     @property
1021     def is_import(self) -> bool:
1022         """Is this an import line?"""
1023         return bool(self) and is_import(self.leaves[0])
1024
1025     @property
1026     def is_class(self) -> bool:
1027         """Is this line a class definition?"""
1028         return (
1029             bool(self)
1030             and self.leaves[0].type == token.NAME
1031             and self.leaves[0].value == "class"
1032         )
1033
1034     @property
1035     def is_stub_class(self) -> bool:
1036         """Is this line a class definition with a body consisting only of "..."?"""
1037         return self.is_class and self.leaves[-3:] == [
1038             Leaf(token.DOT, ".") for _ in range(3)
1039         ]
1040
1041     @property
1042     def is_def(self) -> bool:
1043         """Is this a function definition? (Also returns True for async defs.)"""
1044         try:
1045             first_leaf = self.leaves[0]
1046         except IndexError:
1047             return False
1048
1049         try:
1050             second_leaf: Optional[Leaf] = self.leaves[1]
1051         except IndexError:
1052             second_leaf = None
1053         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1054             first_leaf.type == token.ASYNC
1055             and second_leaf is not None
1056             and second_leaf.type == token.NAME
1057             and second_leaf.value == "def"
1058         )
1059
1060     @property
1061     def is_class_paren_empty(self) -> bool:
1062         """Is this a class with no base classes but using parentheses?
1063
1064         Those are unnecessary and should be removed.
1065         """
1066         return (
1067             bool(self)
1068             and len(self.leaves) == 4
1069             and self.is_class
1070             and self.leaves[2].type == token.LPAR
1071             and self.leaves[2].value == "("
1072             and self.leaves[3].type == token.RPAR
1073             and self.leaves[3].value == ")"
1074         )
1075
1076     @property
1077     def is_triple_quoted_string(self) -> bool:
1078         """Is the line a triple quoted string?"""
1079         return (
1080             bool(self)
1081             and self.leaves[0].type == token.STRING
1082             and self.leaves[0].value.startswith(('"""', "'''"))
1083         )
1084
1085     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1086         """If so, needs to be split before emitting."""
1087         for leaf in self.leaves:
1088             if leaf.type == STANDALONE_COMMENT:
1089                 if leaf.bracket_depth <= depth_limit:
1090                     return True
1091
1092         return False
1093
1094     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1095         """Remove trailing comma if there is one and it's safe."""
1096         if not (
1097             self.leaves
1098             and self.leaves[-1].type == token.COMMA
1099             and closing.type in CLOSING_BRACKETS
1100         ):
1101             return False
1102
1103         if closing.type == token.RBRACE:
1104             self.remove_trailing_comma()
1105             return True
1106
1107         if closing.type == token.RSQB:
1108             comma = self.leaves[-1]
1109             if comma.parent and comma.parent.type == syms.listmaker:
1110                 self.remove_trailing_comma()
1111                 return True
1112
1113         # For parens let's check if it's safe to remove the comma.
1114         # Imports are always safe.
1115         if self.is_import:
1116             self.remove_trailing_comma()
1117             return True
1118
1119         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1120         # change a tuple into a different type by removing the comma.
1121         depth = closing.bracket_depth + 1
1122         commas = 0
1123         opening = closing.opening_bracket
1124         for _opening_index, leaf in enumerate(self.leaves):
1125             if leaf is opening:
1126                 break
1127
1128         else:
1129             return False
1130
1131         for leaf in self.leaves[_opening_index + 1 :]:
1132             if leaf is closing:
1133                 break
1134
1135             bracket_depth = leaf.bracket_depth
1136             if bracket_depth == depth and leaf.type == token.COMMA:
1137                 commas += 1
1138                 if leaf.parent and leaf.parent.type == syms.arglist:
1139                     commas += 1
1140                     break
1141
1142         if commas > 1:
1143             self.remove_trailing_comma()
1144             return True
1145
1146         return False
1147
1148     def append_comment(self, comment: Leaf) -> bool:
1149         """Add an inline or standalone comment to the line."""
1150         if (
1151             comment.type == STANDALONE_COMMENT
1152             and self.bracket_tracker.any_open_brackets()
1153         ):
1154             comment.prefix = ""
1155             return False
1156
1157         if comment.type != token.COMMENT:
1158             return False
1159
1160         after = len(self.leaves) - 1
1161         if after == -1:
1162             comment.type = STANDALONE_COMMENT
1163             comment.prefix = ""
1164             return False
1165
1166         else:
1167             self.comments.append((after, comment))
1168             return True
1169
1170     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1171         """Generate comments that should appear directly after `leaf`.
1172
1173         Provide a non-negative leaf `_index` to speed up the function.
1174         """
1175         if _index == -1:
1176             for _index, _leaf in enumerate(self.leaves):
1177                 if leaf is _leaf:
1178                     break
1179
1180             else:
1181                 return
1182
1183         for index, comment_after in self.comments:
1184             if _index == index:
1185                 yield comment_after
1186
1187     def remove_trailing_comma(self) -> None:
1188         """Remove the trailing comma and moves the comments attached to it."""
1189         comma_index = len(self.leaves) - 1
1190         for i in range(len(self.comments)):
1191             comment_index, comment = self.comments[i]
1192             if comment_index == comma_index:
1193                 self.comments[i] = (comma_index - 1, comment)
1194         self.leaves.pop()
1195
1196     def is_complex_subscript(self, leaf: Leaf) -> bool:
1197         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1198         open_lsqb = (
1199             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1200         )
1201         if open_lsqb is None:
1202             return False
1203
1204         subscript_start = open_lsqb.next_sibling
1205         if (
1206             isinstance(subscript_start, Node)
1207             and subscript_start.type == syms.subscriptlist
1208         ):
1209             subscript_start = child_towards(subscript_start, leaf)
1210         return subscript_start is not None and any(
1211             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1212         )
1213
1214     def __str__(self) -> str:
1215         """Render the line."""
1216         if not self:
1217             return "\n"
1218
1219         indent = "    " * self.depth
1220         leaves = iter(self.leaves)
1221         first = next(leaves)
1222         res = f"{first.prefix}{indent}{first.value}"
1223         for leaf in leaves:
1224             res += str(leaf)
1225         for _, comment in self.comments:
1226             res += str(comment)
1227         return res + "\n"
1228
1229     def __bool__(self) -> bool:
1230         """Return True if the line has leaves or comments."""
1231         return bool(self.leaves or self.comments)
1232
1233
1234 class UnformattedLines(Line):
1235     """Just like :class:`Line` but stores lines which aren't reformatted."""
1236
1237     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1238         """Just add a new `leaf` to the end of the lines.
1239
1240         The `preformatted` argument is ignored.
1241
1242         Keeps track of indentation `depth`, which is useful when the user
1243         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1244         """
1245         try:
1246             list(generate_comments(leaf))
1247         except FormatOn as f_on:
1248             self.leaves.append(f_on.leaf_from_consumed(leaf))
1249             raise
1250
1251         self.leaves.append(leaf)
1252         if leaf.type == token.INDENT:
1253             self.depth += 1
1254         elif leaf.type == token.DEDENT:
1255             self.depth -= 1
1256
1257     def __str__(self) -> str:
1258         """Render unformatted lines from leaves which were added with `append()`.
1259
1260         `depth` is not used for indentation in this case.
1261         """
1262         if not self:
1263             return "\n"
1264
1265         res = ""
1266         for leaf in self.leaves:
1267             res += str(leaf)
1268         return res
1269
1270     def append_comment(self, comment: Leaf) -> bool:
1271         """Not implemented in this class. Raises `NotImplementedError`."""
1272         raise NotImplementedError("Unformatted lines don't store comments separately.")
1273
1274     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1275         """Does nothing and returns False."""
1276         return False
1277
1278     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1279         """Does nothing and returns False."""
1280         return False
1281
1282
1283 @dataclass
1284 class EmptyLineTracker:
1285     """Provides a stateful method that returns the number of potential extra
1286     empty lines needed before and after the currently processed line.
1287
1288     Note: this tracker works on lines that haven't been split yet.  It assumes
1289     the prefix of the first leaf consists of optional newlines.  Those newlines
1290     are consumed by `maybe_empty_lines()` and included in the computation.
1291     """
1292
1293     is_pyi: bool = False
1294     previous_line: Optional[Line] = None
1295     previous_after: int = 0
1296     previous_defs: List[int] = Factory(list)
1297
1298     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1299         """Return the number of extra empty lines before and after the `current_line`.
1300
1301         This is for separating `def`, `async def` and `class` with extra empty
1302         lines (two on module-level).
1303         """
1304         if isinstance(current_line, UnformattedLines):
1305             return 0, 0
1306
1307         before, after = self._maybe_empty_lines(current_line)
1308         before -= self.previous_after
1309         self.previous_after = after
1310         self.previous_line = current_line
1311         return before, after
1312
1313     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1314         max_allowed = 1
1315         if current_line.depth == 0:
1316             max_allowed = 1 if self.is_pyi else 2
1317         if current_line.leaves:
1318             # Consume the first leaf's extra newlines.
1319             first_leaf = current_line.leaves[0]
1320             before = first_leaf.prefix.count("\n")
1321             before = min(before, max_allowed)
1322             first_leaf.prefix = ""
1323         else:
1324             before = 0
1325         depth = current_line.depth
1326         while self.previous_defs and self.previous_defs[-1] >= depth:
1327             self.previous_defs.pop()
1328             if self.is_pyi:
1329                 before = 0 if depth else 1
1330             else:
1331                 before = 1 if depth else 2
1332         is_decorator = current_line.is_decorator
1333         if is_decorator or current_line.is_def or current_line.is_class:
1334             if not is_decorator:
1335                 self.previous_defs.append(depth)
1336             if self.previous_line is None:
1337                 # Don't insert empty lines before the first line in the file.
1338                 return 0, 0
1339
1340             if self.previous_line.is_decorator:
1341                 return 0, 0
1342
1343             if self.previous_line.depth < current_line.depth and (
1344                 self.previous_line.is_class or self.previous_line.is_def
1345             ):
1346                 return 0, 0
1347
1348             if (
1349                 self.previous_line.is_comment
1350                 and self.previous_line.depth == current_line.depth
1351                 and before == 0
1352             ):
1353                 return 0, 0
1354
1355             if self.is_pyi:
1356                 if self.previous_line.depth > current_line.depth:
1357                     newlines = 1
1358                 elif current_line.is_class or self.previous_line.is_class:
1359                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1360                         newlines = 0
1361                     else:
1362                         newlines = 1
1363                 else:
1364                     newlines = 0
1365             else:
1366                 newlines = 2
1367             if current_line.depth and newlines:
1368                 newlines -= 1
1369             return newlines, 0
1370
1371         if (
1372             self.previous_line
1373             and self.previous_line.is_import
1374             and not current_line.is_import
1375             and depth == self.previous_line.depth
1376         ):
1377             return (before or 1), 0
1378
1379         if (
1380             self.previous_line
1381             and self.previous_line.is_class
1382             and current_line.is_triple_quoted_string
1383         ):
1384             return before, 1
1385
1386         return before, 0
1387
1388
1389 @dataclass
1390 class LineGenerator(Visitor[Line]):
1391     """Generates reformatted Line objects.  Empty lines are not emitted.
1392
1393     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1394     in ways that will no longer stringify to valid Python code on the tree.
1395     """
1396
1397     is_pyi: bool = False
1398     normalize_strings: bool = True
1399     current_line: Line = Factory(Line)
1400     remove_u_prefix: bool = False
1401
1402     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1403         """Generate a line.
1404
1405         If the line is empty, only emit if it makes sense.
1406         If the line is too long, split it first and then generate.
1407
1408         If any lines were generated, set up a new current_line.
1409         """
1410         if not self.current_line:
1411             if self.current_line.__class__ == type:
1412                 self.current_line.depth += indent
1413             else:
1414                 self.current_line = type(depth=self.current_line.depth + indent)
1415             return  # Line is empty, don't emit. Creating a new one unnecessary.
1416
1417         complete_line = self.current_line
1418         self.current_line = type(depth=complete_line.depth + indent)
1419         yield complete_line
1420
1421     def visit(self, node: LN) -> Iterator[Line]:
1422         """Main method to visit `node` and its children.
1423
1424         Yields :class:`Line` objects.
1425         """
1426         if isinstance(self.current_line, UnformattedLines):
1427             # File contained `# fmt: off`
1428             yield from self.visit_unformatted(node)
1429
1430         else:
1431             yield from super().visit(node)
1432
1433     def visit_default(self, node: LN) -> Iterator[Line]:
1434         """Default `visit_*()` implementation. Recurses to children of `node`."""
1435         if isinstance(node, Leaf):
1436             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1437             try:
1438                 for comment in generate_comments(node):
1439                     if any_open_brackets:
1440                         # any comment within brackets is subject to splitting
1441                         self.current_line.append(comment)
1442                     elif comment.type == token.COMMENT:
1443                         # regular trailing comment
1444                         self.current_line.append(comment)
1445                         yield from self.line()
1446
1447                     else:
1448                         # regular standalone comment
1449                         yield from self.line()
1450
1451                         self.current_line.append(comment)
1452                         yield from self.line()
1453
1454             except FormatOff as f_off:
1455                 f_off.trim_prefix(node)
1456                 yield from self.line(type=UnformattedLines)
1457                 yield from self.visit(node)
1458
1459             except FormatOn as f_on:
1460                 # This only happens here if somebody says "fmt: on" multiple
1461                 # times in a row.
1462                 f_on.trim_prefix(node)
1463                 yield from self.visit_default(node)
1464
1465             else:
1466                 normalize_prefix(node, inside_brackets=any_open_brackets)
1467                 if self.normalize_strings and node.type == token.STRING:
1468                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1469                     normalize_string_quotes(node)
1470                 if node.type not in WHITESPACE:
1471                     self.current_line.append(node)
1472         yield from super().visit_default(node)
1473
1474     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1475         """Increase indentation level, maybe yield a line."""
1476         # In blib2to3 INDENT never holds comments.
1477         yield from self.line(+1)
1478         yield from self.visit_default(node)
1479
1480     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1481         """Decrease indentation level, maybe yield a line."""
1482         # The current line might still wait for trailing comments.  At DEDENT time
1483         # there won't be any (they would be prefixes on the preceding NEWLINE).
1484         # Emit the line then.
1485         yield from self.line()
1486
1487         # While DEDENT has no value, its prefix may contain standalone comments
1488         # that belong to the current indentation level.  Get 'em.
1489         yield from self.visit_default(node)
1490
1491         # Finally, emit the dedent.
1492         yield from self.line(-1)
1493
1494     def visit_stmt(
1495         self, node: Node, keywords: Set[str], parens: Set[str]
1496     ) -> Iterator[Line]:
1497         """Visit a statement.
1498
1499         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1500         `def`, `with`, `class`, `assert` and assignments.
1501
1502         The relevant Python language `keywords` for a given statement will be
1503         NAME leaves within it. This methods puts those on a separate line.
1504
1505         `parens` holds a set of string leaf values immediately after which
1506         invisible parens should be put.
1507         """
1508         normalize_invisible_parens(node, parens_after=parens)
1509         for child in node.children:
1510             if child.type == token.NAME and child.value in keywords:  # type: ignore
1511                 yield from self.line()
1512
1513             yield from self.visit(child)
1514
1515     def visit_suite(self, node: Node) -> Iterator[Line]:
1516         """Visit a suite."""
1517         if self.is_pyi and is_stub_suite(node):
1518             yield from self.visit(node.children[2])
1519         else:
1520             yield from self.visit_default(node)
1521
1522     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1523         """Visit a statement without nested statements."""
1524         is_suite_like = node.parent and node.parent.type in STATEMENT
1525         if is_suite_like:
1526             if self.is_pyi and is_stub_body(node):
1527                 yield from self.visit_default(node)
1528             else:
1529                 yield from self.line(+1)
1530                 yield from self.visit_default(node)
1531                 yield from self.line(-1)
1532
1533         else:
1534             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1535                 yield from self.line()
1536             yield from self.visit_default(node)
1537
1538     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1539         """Visit `async def`, `async for`, `async with`."""
1540         yield from self.line()
1541
1542         children = iter(node.children)
1543         for child in children:
1544             yield from self.visit(child)
1545
1546             if child.type == token.ASYNC:
1547                 break
1548
1549         internal_stmt = next(children)
1550         for child in internal_stmt.children:
1551             yield from self.visit(child)
1552
1553     def visit_decorators(self, node: Node) -> Iterator[Line]:
1554         """Visit decorators."""
1555         for child in node.children:
1556             yield from self.line()
1557             yield from self.visit(child)
1558
1559     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1560         """Remove a semicolon and put the other statement on a separate line."""
1561         yield from self.line()
1562
1563     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1564         """End of file. Process outstanding comments and end with a newline."""
1565         yield from self.visit_default(leaf)
1566         yield from self.line()
1567
1568     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1569         """Used when file contained a `# fmt: off`."""
1570         if isinstance(node, Node):
1571             for child in node.children:
1572                 yield from self.visit(child)
1573
1574         else:
1575             try:
1576                 self.current_line.append(node)
1577             except FormatOn as f_on:
1578                 f_on.trim_prefix(node)
1579                 yield from self.line()
1580                 yield from self.visit(node)
1581
1582             if node.type == token.ENDMARKER:
1583                 # somebody decided not to put a final `# fmt: on`
1584                 yield from self.line()
1585
1586     def __attrs_post_init__(self) -> None:
1587         """You are in a twisty little maze of passages."""
1588         v = self.visit_stmt
1589         Ø: Set[str] = set()
1590         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1591         self.visit_if_stmt = partial(
1592             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1593         )
1594         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1595         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1596         self.visit_try_stmt = partial(
1597             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1598         )
1599         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1600         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1601         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1602         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1603         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1604         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1605         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1606         self.visit_async_funcdef = self.visit_async_stmt
1607         self.visit_decorated = self.visit_decorators
1608
1609
1610 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1611 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1612 OPENING_BRACKETS = set(BRACKET.keys())
1613 CLOSING_BRACKETS = set(BRACKET.values())
1614 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1615 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1616
1617
1618 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1619     """Return whitespace prefix if needed for the given `leaf`.
1620
1621     `complex_subscript` signals whether the given leaf is part of a subscription
1622     which has non-trivial arguments, like arithmetic expressions or function calls.
1623     """
1624     NO = ""
1625     SPACE = " "
1626     DOUBLESPACE = "  "
1627     t = leaf.type
1628     p = leaf.parent
1629     v = leaf.value
1630     if t in ALWAYS_NO_SPACE:
1631         return NO
1632
1633     if t == token.COMMENT:
1634         return DOUBLESPACE
1635
1636     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1637     if t == token.COLON and p.type not in {
1638         syms.subscript,
1639         syms.subscriptlist,
1640         syms.sliceop,
1641     }:
1642         return NO
1643
1644     prev = leaf.prev_sibling
1645     if not prev:
1646         prevp = preceding_leaf(p)
1647         if not prevp or prevp.type in OPENING_BRACKETS:
1648             return NO
1649
1650         if t == token.COLON:
1651             if prevp.type == token.COLON:
1652                 return NO
1653
1654             elif prevp.type != token.COMMA and not complex_subscript:
1655                 return NO
1656
1657             return SPACE
1658
1659         if prevp.type == token.EQUAL:
1660             if prevp.parent:
1661                 if prevp.parent.type in {
1662                     syms.arglist,
1663                     syms.argument,
1664                     syms.parameters,
1665                     syms.varargslist,
1666                 }:
1667                     return NO
1668
1669                 elif prevp.parent.type == syms.typedargslist:
1670                     # A bit hacky: if the equal sign has whitespace, it means we
1671                     # previously found it's a typed argument.  So, we're using
1672                     # that, too.
1673                     return prevp.prefix
1674
1675         elif prevp.type in STARS:
1676             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1677                 return NO
1678
1679         elif prevp.type == token.COLON:
1680             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1681                 return SPACE if complex_subscript else NO
1682
1683         elif (
1684             prevp.parent
1685             and prevp.parent.type == syms.factor
1686             and prevp.type in MATH_OPERATORS
1687         ):
1688             return NO
1689
1690         elif (
1691             prevp.type == token.RIGHTSHIFT
1692             and prevp.parent
1693             and prevp.parent.type == syms.shift_expr
1694             and prevp.prev_sibling
1695             and prevp.prev_sibling.type == token.NAME
1696             and prevp.prev_sibling.value == "print"  # type: ignore
1697         ):
1698             # Python 2 print chevron
1699             return NO
1700
1701     elif prev.type in OPENING_BRACKETS:
1702         return NO
1703
1704     if p.type in {syms.parameters, syms.arglist}:
1705         # untyped function signatures or calls
1706         if not prev or prev.type != token.COMMA:
1707             return NO
1708
1709     elif p.type == syms.varargslist:
1710         # lambdas
1711         if prev and prev.type != token.COMMA:
1712             return NO
1713
1714     elif p.type == syms.typedargslist:
1715         # typed function signatures
1716         if not prev:
1717             return NO
1718
1719         if t == token.EQUAL:
1720             if prev.type != syms.tname:
1721                 return NO
1722
1723         elif prev.type == token.EQUAL:
1724             # A bit hacky: if the equal sign has whitespace, it means we
1725             # previously found it's a typed argument.  So, we're using that, too.
1726             return prev.prefix
1727
1728         elif prev.type != token.COMMA:
1729             return NO
1730
1731     elif p.type == syms.tname:
1732         # type names
1733         if not prev:
1734             prevp = preceding_leaf(p)
1735             if not prevp or prevp.type != token.COMMA:
1736                 return NO
1737
1738     elif p.type == syms.trailer:
1739         # attributes and calls
1740         if t == token.LPAR or t == token.RPAR:
1741             return NO
1742
1743         if not prev:
1744             if t == token.DOT:
1745                 prevp = preceding_leaf(p)
1746                 if not prevp or prevp.type != token.NUMBER:
1747                     return NO
1748
1749             elif t == token.LSQB:
1750                 return NO
1751
1752         elif prev.type != token.COMMA:
1753             return NO
1754
1755     elif p.type == syms.argument:
1756         # single argument
1757         if t == token.EQUAL:
1758             return NO
1759
1760         if not prev:
1761             prevp = preceding_leaf(p)
1762             if not prevp or prevp.type == token.LPAR:
1763                 return NO
1764
1765         elif prev.type in {token.EQUAL} | STARS:
1766             return NO
1767
1768     elif p.type == syms.decorator:
1769         # decorators
1770         return NO
1771
1772     elif p.type == syms.dotted_name:
1773         if prev:
1774             return NO
1775
1776         prevp = preceding_leaf(p)
1777         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1778             return NO
1779
1780     elif p.type == syms.classdef:
1781         if t == token.LPAR:
1782             return NO
1783
1784         if prev and prev.type == token.LPAR:
1785             return NO
1786
1787     elif p.type in {syms.subscript, syms.sliceop}:
1788         # indexing
1789         if not prev:
1790             assert p.parent is not None, "subscripts are always parented"
1791             if p.parent.type == syms.subscriptlist:
1792                 return SPACE
1793
1794             return NO
1795
1796         elif not complex_subscript:
1797             return NO
1798
1799     elif p.type == syms.atom:
1800         if prev and t == token.DOT:
1801             # dots, but not the first one.
1802             return NO
1803
1804     elif p.type == syms.dictsetmaker:
1805         # dict unpacking
1806         if prev and prev.type == token.DOUBLESTAR:
1807             return NO
1808
1809     elif p.type in {syms.factor, syms.star_expr}:
1810         # unary ops
1811         if not prev:
1812             prevp = preceding_leaf(p)
1813             if not prevp or prevp.type in OPENING_BRACKETS:
1814                 return NO
1815
1816             prevp_parent = prevp.parent
1817             assert prevp_parent is not None
1818             if prevp.type == token.COLON and prevp_parent.type in {
1819                 syms.subscript,
1820                 syms.sliceop,
1821             }:
1822                 return NO
1823
1824             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1825                 return NO
1826
1827         elif t == token.NAME or t == token.NUMBER:
1828             return NO
1829
1830     elif p.type == syms.import_from:
1831         if t == token.DOT:
1832             if prev and prev.type == token.DOT:
1833                 return NO
1834
1835         elif t == token.NAME:
1836             if v == "import":
1837                 return SPACE
1838
1839             if prev and prev.type == token.DOT:
1840                 return NO
1841
1842     elif p.type == syms.sliceop:
1843         return NO
1844
1845     return SPACE
1846
1847
1848 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1849     """Return the first leaf that precedes `node`, if any."""
1850     while node:
1851         res = node.prev_sibling
1852         if res:
1853             if isinstance(res, Leaf):
1854                 return res
1855
1856             try:
1857                 return list(res.leaves())[-1]
1858
1859             except IndexError:
1860                 return None
1861
1862         node = node.parent
1863     return None
1864
1865
1866 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1867     """Return the child of `ancestor` that contains `descendant`."""
1868     node: Optional[LN] = descendant
1869     while node and node.parent != ancestor:
1870         node = node.parent
1871     return node
1872
1873
1874 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1875     """Return the priority of the `leaf` delimiter, given a line break after it.
1876
1877     The delimiter priorities returned here are from those delimiters that would
1878     cause a line break after themselves.
1879
1880     Higher numbers are higher priority.
1881     """
1882     if leaf.type == token.COMMA:
1883         return COMMA_PRIORITY
1884
1885     return 0
1886
1887
1888 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1889     """Return the priority of the `leaf` delimiter, given a line before after it.
1890
1891     The delimiter priorities returned here are from those delimiters that would
1892     cause a line break before themselves.
1893
1894     Higher numbers are higher priority.
1895     """
1896     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1897         # * and ** might also be MATH_OPERATORS but in this case they are not.
1898         # Don't treat them as a delimiter.
1899         return 0
1900
1901     if (
1902         leaf.type == token.DOT
1903         and leaf.parent
1904         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1905         and (previous is None or previous.type in CLOSING_BRACKETS)
1906     ):
1907         return DOT_PRIORITY
1908
1909     if (
1910         leaf.type in MATH_OPERATORS
1911         and leaf.parent
1912         and leaf.parent.type not in {syms.factor, syms.star_expr}
1913     ):
1914         return MATH_PRIORITIES[leaf.type]
1915
1916     if leaf.type in COMPARATORS:
1917         return COMPARATOR_PRIORITY
1918
1919     if (
1920         leaf.type == token.STRING
1921         and previous is not None
1922         and previous.type == token.STRING
1923     ):
1924         return STRING_PRIORITY
1925
1926     if leaf.type != token.NAME:
1927         return 0
1928
1929     if (
1930         leaf.value == "for"
1931         and leaf.parent
1932         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1933     ):
1934         return COMPREHENSION_PRIORITY
1935
1936     if (
1937         leaf.value == "if"
1938         and leaf.parent
1939         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1940     ):
1941         return COMPREHENSION_PRIORITY
1942
1943     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1944         return TERNARY_PRIORITY
1945
1946     if leaf.value == "is":
1947         return COMPARATOR_PRIORITY
1948
1949     if (
1950         leaf.value == "in"
1951         and leaf.parent
1952         and leaf.parent.type in {syms.comp_op, syms.comparison}
1953         and not (
1954             previous is not None
1955             and previous.type == token.NAME
1956             and previous.value == "not"
1957         )
1958     ):
1959         return COMPARATOR_PRIORITY
1960
1961     if (
1962         leaf.value == "not"
1963         and leaf.parent
1964         and leaf.parent.type == syms.comp_op
1965         and not (
1966             previous is not None
1967             and previous.type == token.NAME
1968             and previous.value == "is"
1969         )
1970     ):
1971         return COMPARATOR_PRIORITY
1972
1973     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1974         return LOGIC_PRIORITY
1975
1976     return 0
1977
1978
1979 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1980     """Clean the prefix of the `leaf` and generate comments from it, if any.
1981
1982     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1983     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1984     move because it does away with modifying the grammar to include all the
1985     possible places in which comments can be placed.
1986
1987     The sad consequence for us though is that comments don't "belong" anywhere.
1988     This is why this function generates simple parentless Leaf objects for
1989     comments.  We simply don't know what the correct parent should be.
1990
1991     No matter though, we can live without this.  We really only need to
1992     differentiate between inline and standalone comments.  The latter don't
1993     share the line with any code.
1994
1995     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1996     are emitted with a fake STANDALONE_COMMENT token identifier.
1997     """
1998     p = leaf.prefix
1999     if not p:
2000         return
2001
2002     if "#" not in p:
2003         return
2004
2005     consumed = 0
2006     nlines = 0
2007     for index, line in enumerate(p.split("\n")):
2008         consumed += len(line) + 1  # adding the length of the split '\n'
2009         line = line.lstrip()
2010         if not line:
2011             nlines += 1
2012         if not line.startswith("#"):
2013             continue
2014
2015         if index == 0 and leaf.type != token.ENDMARKER:
2016             comment_type = token.COMMENT  # simple trailing comment
2017         else:
2018             comment_type = STANDALONE_COMMENT
2019         comment = make_comment(line)
2020         yield Leaf(comment_type, comment, prefix="\n" * nlines)
2021
2022         if comment in {"# fmt: on", "# yapf: enable"}:
2023             raise FormatOn(consumed)
2024
2025         if comment in {"# fmt: off", "# yapf: disable"}:
2026             if comment_type == STANDALONE_COMMENT:
2027                 raise FormatOff(consumed)
2028
2029             prev = preceding_leaf(leaf)
2030             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
2031                 raise FormatOff(consumed)
2032
2033         nlines = 0
2034
2035
2036 def make_comment(content: str) -> str:
2037     """Return a consistently formatted comment from the given `content` string.
2038
2039     All comments (except for "##", "#!", "#:") should have a single space between
2040     the hash sign and the content.
2041
2042     If `content` didn't start with a hash sign, one is provided.
2043     """
2044     content = content.rstrip()
2045     if not content:
2046         return "#"
2047
2048     if content[0] == "#":
2049         content = content[1:]
2050     if content and content[0] not in " !:#":
2051         content = " " + content
2052     return "#" + content
2053
2054
2055 def split_line(
2056     line: Line, line_length: int, inner: bool = False, py36: bool = False
2057 ) -> Iterator[Line]:
2058     """Split a `line` into potentially many lines.
2059
2060     They should fit in the allotted `line_length` but might not be able to.
2061     `inner` signifies that there were a pair of brackets somewhere around the
2062     current `line`, possibly transitively. This means we can fallback to splitting
2063     by delimiters if the LHS/RHS don't yield any results.
2064
2065     If `py36` is True, splitting may generate syntax that is only compatible
2066     with Python 3.6 and later.
2067     """
2068     if isinstance(line, UnformattedLines) or line.is_comment:
2069         yield line
2070         return
2071
2072     line_str = str(line).strip("\n")
2073     if not line.should_explode and is_line_short_enough(
2074         line, line_length=line_length, line_str=line_str
2075     ):
2076         yield line
2077         return
2078
2079     split_funcs: List[SplitFunc]
2080     if line.is_def:
2081         split_funcs = [left_hand_split]
2082     else:
2083
2084         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2085             for omit in generate_trailers_to_omit(line, line_length):
2086                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2087                 if is_line_short_enough(lines[0], line_length=line_length):
2088                     yield from lines
2089                     return
2090
2091             # All splits failed, best effort split with no omits.
2092             # This mostly happens to multiline strings that are by definition
2093             # reported as not fitting a single line.
2094             yield from right_hand_split(line, py36)
2095
2096         if line.inside_brackets:
2097             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2098         else:
2099             split_funcs = [rhs]
2100     for split_func in split_funcs:
2101         # We are accumulating lines in `result` because we might want to abort
2102         # mission and return the original line in the end, or attempt a different
2103         # split altogether.
2104         result: List[Line] = []
2105         try:
2106             for l in split_func(line, py36):
2107                 if str(l).strip("\n") == line_str:
2108                     raise CannotSplit("Split function returned an unchanged result")
2109
2110                 result.extend(
2111                     split_line(l, line_length=line_length, inner=True, py36=py36)
2112                 )
2113         except CannotSplit as cs:
2114             continue
2115
2116         else:
2117             yield from result
2118             break
2119
2120     else:
2121         yield line
2122
2123
2124 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2125     """Split line into many lines, starting with the first matching bracket pair.
2126
2127     Note: this usually looks weird, only use this for function definitions.
2128     Prefer RHS otherwise.  This is why this function is not symmetrical with
2129     :func:`right_hand_split` which also handles optional parentheses.
2130     """
2131     head = Line(depth=line.depth)
2132     body = Line(depth=line.depth + 1, inside_brackets=True)
2133     tail = Line(depth=line.depth)
2134     tail_leaves: List[Leaf] = []
2135     body_leaves: List[Leaf] = []
2136     head_leaves: List[Leaf] = []
2137     current_leaves = head_leaves
2138     matching_bracket = None
2139     for leaf in line.leaves:
2140         if (
2141             current_leaves is body_leaves
2142             and leaf.type in CLOSING_BRACKETS
2143             and leaf.opening_bracket is matching_bracket
2144         ):
2145             current_leaves = tail_leaves if body_leaves else head_leaves
2146         current_leaves.append(leaf)
2147         if current_leaves is head_leaves:
2148             if leaf.type in OPENING_BRACKETS:
2149                 matching_bracket = leaf
2150                 current_leaves = body_leaves
2151     # Since body is a new indent level, remove spurious leading whitespace.
2152     if body_leaves:
2153         normalize_prefix(body_leaves[0], inside_brackets=True)
2154     # Build the new lines.
2155     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2156         for leaf in leaves:
2157             result.append(leaf, preformatted=True)
2158             for comment_after in line.comments_after(leaf):
2159                 result.append(comment_after, preformatted=True)
2160     bracket_split_succeeded_or_raise(head, body, tail)
2161     for result in (head, body, tail):
2162         if result:
2163             yield result
2164
2165
2166 def right_hand_split(
2167     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2168 ) -> Iterator[Line]:
2169     """Split line into many lines, starting with the last matching bracket pair.
2170
2171     If the split was by optional parentheses, attempt splitting without them, too.
2172     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2173     this split.
2174
2175     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2176     """
2177     head = Line(depth=line.depth)
2178     body = Line(depth=line.depth + 1, inside_brackets=True)
2179     tail = Line(depth=line.depth)
2180     tail_leaves: List[Leaf] = []
2181     body_leaves: List[Leaf] = []
2182     head_leaves: List[Leaf] = []
2183     current_leaves = tail_leaves
2184     opening_bracket = None
2185     closing_bracket = None
2186     for leaf in reversed(line.leaves):
2187         if current_leaves is body_leaves:
2188             if leaf is opening_bracket:
2189                 current_leaves = head_leaves if body_leaves else tail_leaves
2190         current_leaves.append(leaf)
2191         if current_leaves is tail_leaves:
2192             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2193                 opening_bracket = leaf.opening_bracket
2194                 closing_bracket = leaf
2195                 current_leaves = body_leaves
2196     tail_leaves.reverse()
2197     body_leaves.reverse()
2198     head_leaves.reverse()
2199     # Since body is a new indent level, remove spurious leading whitespace.
2200     if body_leaves:
2201         normalize_prefix(body_leaves[0], inside_brackets=True)
2202     if not head_leaves:
2203         # No `head` means the split failed. Either `tail` has all content or
2204         # the matching `opening_bracket` wasn't available on `line` anymore.
2205         raise CannotSplit("No brackets found")
2206
2207     # Build the new lines.
2208     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2209         for leaf in leaves:
2210             result.append(leaf, preformatted=True)
2211             for comment_after in line.comments_after(leaf):
2212                 result.append(comment_after, preformatted=True)
2213     bracket_split_succeeded_or_raise(head, body, tail)
2214     assert opening_bracket and closing_bracket
2215     if (
2216         # the opening bracket is an optional paren
2217         opening_bracket.type == token.LPAR
2218         and not opening_bracket.value
2219         # the closing bracket is an optional paren
2220         and closing_bracket.type == token.RPAR
2221         and not closing_bracket.value
2222         # there are no standalone comments in the body
2223         and not line.contains_standalone_comments(0)
2224         # and it's not an import (optional parens are the only thing we can split
2225         # on in this case; attempting a split without them is a waste of time)
2226         and not line.is_import
2227     ):
2228         omit = {id(closing_bracket), *omit}
2229         if can_omit_invisible_parens(body, line_length):
2230             try:
2231                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2232                 return
2233             except CannotSplit:
2234                 pass
2235
2236     ensure_visible(opening_bracket)
2237     ensure_visible(closing_bracket)
2238     body.should_explode = should_explode(body, opening_bracket)
2239     for result in (head, body, tail):
2240         if result:
2241             yield result
2242
2243
2244 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2245     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2246
2247     Do nothing otherwise.
2248
2249     A left- or right-hand split is based on a pair of brackets. Content before
2250     (and including) the opening bracket is left on one line, content inside the
2251     brackets is put on a separate line, and finally content starting with and
2252     following the closing bracket is put on a separate line.
2253
2254     Those are called `head`, `body`, and `tail`, respectively. If the split
2255     produced the same line (all content in `head`) or ended up with an empty `body`
2256     and the `tail` is just the closing bracket, then it's considered failed.
2257     """
2258     tail_len = len(str(tail).strip())
2259     if not body:
2260         if tail_len == 0:
2261             raise CannotSplit("Splitting brackets produced the same line")
2262
2263         elif tail_len < 3:
2264             raise CannotSplit(
2265                 f"Splitting brackets on an empty body to save "
2266                 f"{tail_len} characters is not worth it"
2267             )
2268
2269
2270 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2271     """Normalize prefix of the first leaf in every line returned by `split_func`.
2272
2273     This is a decorator over relevant split functions.
2274     """
2275
2276     @wraps(split_func)
2277     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2278         for l in split_func(line, py36):
2279             normalize_prefix(l.leaves[0], inside_brackets=True)
2280             yield l
2281
2282     return split_wrapper
2283
2284
2285 @dont_increase_indentation
2286 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2287     """Split according to delimiters of the highest priority.
2288
2289     If `py36` is True, the split will add trailing commas also in function
2290     signatures that contain `*` and `**`.
2291     """
2292     try:
2293         last_leaf = line.leaves[-1]
2294     except IndexError:
2295         raise CannotSplit("Line empty")
2296
2297     bt = line.bracket_tracker
2298     try:
2299         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2300     except ValueError:
2301         raise CannotSplit("No delimiters found")
2302
2303     if delimiter_priority == DOT_PRIORITY:
2304         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2305             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2306
2307     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2308     lowest_depth = sys.maxsize
2309     trailing_comma_safe = True
2310
2311     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2312         """Append `leaf` to current line or to new line if appending impossible."""
2313         nonlocal current_line
2314         try:
2315             current_line.append_safe(leaf, preformatted=True)
2316         except ValueError as ve:
2317             yield current_line
2318
2319             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2320             current_line.append(leaf)
2321
2322     for index, leaf in enumerate(line.leaves):
2323         yield from append_to_line(leaf)
2324
2325         for comment_after in line.comments_after(leaf, index):
2326             yield from append_to_line(comment_after)
2327
2328         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2329         if leaf.bracket_depth == lowest_depth and is_vararg(
2330             leaf, within=VARARGS_PARENTS
2331         ):
2332             trailing_comma_safe = trailing_comma_safe and py36
2333         leaf_priority = bt.delimiters.get(id(leaf))
2334         if leaf_priority == delimiter_priority:
2335             yield current_line
2336
2337             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2338     if current_line:
2339         if (
2340             trailing_comma_safe
2341             and delimiter_priority == COMMA_PRIORITY
2342             and current_line.leaves[-1].type != token.COMMA
2343             and current_line.leaves[-1].type != STANDALONE_COMMENT
2344         ):
2345             current_line.append(Leaf(token.COMMA, ","))
2346         yield current_line
2347
2348
2349 @dont_increase_indentation
2350 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2351     """Split standalone comments from the rest of the line."""
2352     if not line.contains_standalone_comments(0):
2353         raise CannotSplit("Line does not have any standalone comments")
2354
2355     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2356
2357     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2358         """Append `leaf` to current line or to new line if appending impossible."""
2359         nonlocal current_line
2360         try:
2361             current_line.append_safe(leaf, preformatted=True)
2362         except ValueError as ve:
2363             yield current_line
2364
2365             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2366             current_line.append(leaf)
2367
2368     for index, leaf in enumerate(line.leaves):
2369         yield from append_to_line(leaf)
2370
2371         for comment_after in line.comments_after(leaf, index):
2372             yield from append_to_line(comment_after)
2373
2374     if current_line:
2375         yield current_line
2376
2377
2378 def is_import(leaf: Leaf) -> bool:
2379     """Return True if the given leaf starts an import statement."""
2380     p = leaf.parent
2381     t = leaf.type
2382     v = leaf.value
2383     return bool(
2384         t == token.NAME
2385         and (
2386             (v == "import" and p and p.type == syms.import_name)
2387             or (v == "from" and p and p.type == syms.import_from)
2388         )
2389     )
2390
2391
2392 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2393     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2394     else.
2395
2396     Note: don't use backslashes for formatting or you'll lose your voting rights.
2397     """
2398     if not inside_brackets:
2399         spl = leaf.prefix.split("#")
2400         if "\\" not in spl[0]:
2401             nl_count = spl[-1].count("\n")
2402             if len(spl) > 1:
2403                 nl_count -= 1
2404             leaf.prefix = "\n" * nl_count
2405             return
2406
2407     leaf.prefix = ""
2408
2409
2410 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2411     """Make all string prefixes lowercase.
2412
2413     If remove_u_prefix is given, also removes any u prefix from the string.
2414
2415     Note: Mutates its argument.
2416     """
2417     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2418     assert match is not None, f"failed to match string {leaf.value!r}"
2419     orig_prefix = match.group(1)
2420     new_prefix = orig_prefix.lower()
2421     if remove_u_prefix:
2422         new_prefix = new_prefix.replace("u", "")
2423     leaf.value = f"{new_prefix}{match.group(2)}"
2424
2425
2426 def normalize_string_quotes(leaf: Leaf) -> None:
2427     """Prefer double quotes but only if it doesn't cause more escaping.
2428
2429     Adds or removes backslashes as appropriate. Doesn't parse and fix
2430     strings nested in f-strings (yet).
2431
2432     Note: Mutates its argument.
2433     """
2434     value = leaf.value.lstrip("furbFURB")
2435     if value[:3] == '"""':
2436         return
2437
2438     elif value[:3] == "'''":
2439         orig_quote = "'''"
2440         new_quote = '"""'
2441     elif value[0] == '"':
2442         orig_quote = '"'
2443         new_quote = "'"
2444     else:
2445         orig_quote = "'"
2446         new_quote = '"'
2447     first_quote_pos = leaf.value.find(orig_quote)
2448     if first_quote_pos == -1:
2449         return  # There's an internal error
2450
2451     prefix = leaf.value[:first_quote_pos]
2452     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2453     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2454     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2455     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2456     if "r" in prefix.casefold():
2457         if unescaped_new_quote.search(body):
2458             # There's at least one unescaped new_quote in this raw string
2459             # so converting is impossible
2460             return
2461
2462         # Do not introduce or remove backslashes in raw strings
2463         new_body = body
2464     else:
2465         # remove unnecessary quotes
2466         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2467         if body != new_body:
2468             # Consider the string without unnecessary quotes as the original
2469             body = new_body
2470             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2471         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2472         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2473     if new_quote == '"""' and new_body[-1] == '"':
2474         # edge case:
2475         new_body = new_body[:-1] + '\\"'
2476     orig_escape_count = body.count("\\")
2477     new_escape_count = new_body.count("\\")
2478     if new_escape_count > orig_escape_count:
2479         return  # Do not introduce more escaping
2480
2481     if new_escape_count == orig_escape_count and orig_quote == '"':
2482         return  # Prefer double quotes
2483
2484     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2485
2486
2487 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2488     """Make existing optional parentheses invisible or create new ones.
2489
2490     `parens_after` is a set of string leaf values immeditely after which parens
2491     should be put.
2492
2493     Standardizes on visible parentheses for single-element tuples, and keeps
2494     existing visible parentheses for other tuples and generator expressions.
2495     """
2496     try:
2497         list(generate_comments(node))
2498     except FormatOff:
2499         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2500
2501     check_lpar = False
2502     for index, child in enumerate(list(node.children)):
2503         if check_lpar:
2504             if child.type == syms.atom:
2505                 maybe_make_parens_invisible_in_atom(child)
2506             elif is_one_tuple(child):
2507                 # wrap child in visible parentheses
2508                 lpar = Leaf(token.LPAR, "(")
2509                 rpar = Leaf(token.RPAR, ")")
2510                 child.remove()
2511                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2512             elif node.type == syms.import_from:
2513                 # "import from" nodes store parentheses directly as part of
2514                 # the statement
2515                 if child.type == token.LPAR:
2516                     # make parentheses invisible
2517                     child.value = ""  # type: ignore
2518                     node.children[-1].value = ""  # type: ignore
2519                 elif child.type != token.STAR:
2520                     # insert invisible parentheses
2521                     node.insert_child(index, Leaf(token.LPAR, ""))
2522                     node.append_child(Leaf(token.RPAR, ""))
2523                 break
2524
2525             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2526                 # wrap child in invisible parentheses
2527                 lpar = Leaf(token.LPAR, "")
2528                 rpar = Leaf(token.RPAR, "")
2529                 index = child.remove() or 0
2530                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2531
2532         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2533
2534
2535 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2536     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2537     if (
2538         node.type != syms.atom
2539         or is_empty_tuple(node)
2540         or is_one_tuple(node)
2541         or is_yield(node)
2542         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2543     ):
2544         return False
2545
2546     first = node.children[0]
2547     last = node.children[-1]
2548     if first.type == token.LPAR and last.type == token.RPAR:
2549         # make parentheses invisible
2550         first.value = ""  # type: ignore
2551         last.value = ""  # type: ignore
2552         if len(node.children) > 1:
2553             maybe_make_parens_invisible_in_atom(node.children[1])
2554         return True
2555
2556     return False
2557
2558
2559 def is_empty_tuple(node: LN) -> bool:
2560     """Return True if `node` holds an empty tuple."""
2561     return (
2562         node.type == syms.atom
2563         and len(node.children) == 2
2564         and node.children[0].type == token.LPAR
2565         and node.children[1].type == token.RPAR
2566     )
2567
2568
2569 def is_one_tuple(node: LN) -> bool:
2570     """Return True if `node` holds a tuple with one element, with or without parens."""
2571     if node.type == syms.atom:
2572         if len(node.children) != 3:
2573             return False
2574
2575         lpar, gexp, rpar = node.children
2576         if not (
2577             lpar.type == token.LPAR
2578             and gexp.type == syms.testlist_gexp
2579             and rpar.type == token.RPAR
2580         ):
2581             return False
2582
2583         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2584
2585     return (
2586         node.type in IMPLICIT_TUPLE
2587         and len(node.children) == 2
2588         and node.children[1].type == token.COMMA
2589     )
2590
2591
2592 def is_yield(node: LN) -> bool:
2593     """Return True if `node` holds a `yield` or `yield from` expression."""
2594     if node.type == syms.yield_expr:
2595         return True
2596
2597     if node.type == token.NAME and node.value == "yield":  # type: ignore
2598         return True
2599
2600     if node.type != syms.atom:
2601         return False
2602
2603     if len(node.children) != 3:
2604         return False
2605
2606     lpar, expr, rpar = node.children
2607     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2608         return is_yield(expr)
2609
2610     return False
2611
2612
2613 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2614     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2615
2616     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2617     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2618     extended iterable unpacking (PEP 3132) and additional unpacking
2619     generalizations (PEP 448).
2620     """
2621     if leaf.type not in STARS or not leaf.parent:
2622         return False
2623
2624     p = leaf.parent
2625     if p.type == syms.star_expr:
2626         # Star expressions are also used as assignment targets in extended
2627         # iterable unpacking (PEP 3132).  See what its parent is instead.
2628         if not p.parent:
2629             return False
2630
2631         p = p.parent
2632
2633     return p.type in within
2634
2635
2636 def is_multiline_string(leaf: Leaf) -> bool:
2637     """Return True if `leaf` is a multiline string that actually spans many lines."""
2638     value = leaf.value.lstrip("furbFURB")
2639     return value[:3] in {'"""', "'''"} and "\n" in value
2640
2641
2642 def is_stub_suite(node: Node) -> bool:
2643     """Return True if `node` is a suite with a stub body."""
2644     if (
2645         len(node.children) != 4
2646         or node.children[0].type != token.NEWLINE
2647         or node.children[1].type != token.INDENT
2648         or node.children[3].type != token.DEDENT
2649     ):
2650         return False
2651
2652     return is_stub_body(node.children[2])
2653
2654
2655 def is_stub_body(node: LN) -> bool:
2656     """Return True if `node` is a simple statement containing an ellipsis."""
2657     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2658         return False
2659
2660     if len(node.children) != 2:
2661         return False
2662
2663     child = node.children[0]
2664     return (
2665         child.type == syms.atom
2666         and len(child.children) == 3
2667         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2668     )
2669
2670
2671 def max_delimiter_priority_in_atom(node: LN) -> int:
2672     """Return maximum delimiter priority inside `node`.
2673
2674     This is specific to atoms with contents contained in a pair of parentheses.
2675     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2676     """
2677     if node.type != syms.atom:
2678         return 0
2679
2680     first = node.children[0]
2681     last = node.children[-1]
2682     if not (first.type == token.LPAR and last.type == token.RPAR):
2683         return 0
2684
2685     bt = BracketTracker()
2686     for c in node.children[1:-1]:
2687         if isinstance(c, Leaf):
2688             bt.mark(c)
2689         else:
2690             for leaf in c.leaves():
2691                 bt.mark(leaf)
2692     try:
2693         return bt.max_delimiter_priority()
2694
2695     except ValueError:
2696         return 0
2697
2698
2699 def ensure_visible(leaf: Leaf) -> None:
2700     """Make sure parentheses are visible.
2701
2702     They could be invisible as part of some statements (see
2703     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2704     """
2705     if leaf.type == token.LPAR:
2706         leaf.value = "("
2707     elif leaf.type == token.RPAR:
2708         leaf.value = ")"
2709
2710
2711 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2712     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2713     if not (
2714         opening_bracket.parent
2715         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2716         and opening_bracket.value in "[{("
2717     ):
2718         return False
2719
2720     try:
2721         last_leaf = line.leaves[-1]
2722         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2723         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2724     except (IndexError, ValueError):
2725         return False
2726
2727     return max_priority == COMMA_PRIORITY
2728
2729
2730 def is_python36(node: Node) -> bool:
2731     """Return True if the current file is using Python 3.6+ features.
2732
2733     Currently looking for:
2734     - f-strings; and
2735     - trailing commas after * or ** in function signatures and calls.
2736     """
2737     for n in node.pre_order():
2738         if n.type == token.STRING:
2739             value_head = n.value[:2]  # type: ignore
2740             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2741                 return True
2742
2743         elif (
2744             n.type in {syms.typedargslist, syms.arglist}
2745             and n.children
2746             and n.children[-1].type == token.COMMA
2747         ):
2748             for ch in n.children:
2749                 if ch.type in STARS:
2750                     return True
2751
2752                 if ch.type == syms.argument:
2753                     for argch in ch.children:
2754                         if argch.type in STARS:
2755                             return True
2756
2757     return False
2758
2759
2760 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2761     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2762
2763     Brackets can be omitted if the entire trailer up to and including
2764     a preceding closing bracket fits in one line.
2765
2766     Yielded sets are cumulative (contain results of previous yields, too).  First
2767     set is empty.
2768     """
2769
2770     omit: Set[LeafID] = set()
2771     yield omit
2772
2773     length = 4 * line.depth
2774     opening_bracket = None
2775     closing_bracket = None
2776     optional_brackets: Set[LeafID] = set()
2777     inner_brackets: Set[LeafID] = set()
2778     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2779         length += leaf_length
2780         if length > line_length:
2781             break
2782
2783         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2784         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2785             break
2786
2787         optional_brackets.discard(id(leaf))
2788         if opening_bracket:
2789             if leaf is opening_bracket:
2790                 opening_bracket = None
2791             elif leaf.type in CLOSING_BRACKETS:
2792                 inner_brackets.add(id(leaf))
2793         elif leaf.type in CLOSING_BRACKETS:
2794             if not leaf.value:
2795                 optional_brackets.add(id(opening_bracket))
2796                 continue
2797
2798             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2799                 # Empty brackets would fail a split so treat them as "inner"
2800                 # brackets (e.g. only add them to the `omit` set if another
2801                 # pair of brackets was good enough.
2802                 inner_brackets.add(id(leaf))
2803                 continue
2804
2805             opening_bracket = leaf.opening_bracket
2806             if closing_bracket:
2807                 omit.add(id(closing_bracket))
2808                 omit.update(inner_brackets)
2809                 inner_brackets.clear()
2810                 yield omit
2811             closing_bracket = leaf
2812
2813
2814 def get_future_imports(node: Node) -> Set[str]:
2815     """Return a set of __future__ imports in the file."""
2816     imports = set()
2817     for child in node.children:
2818         if child.type != syms.simple_stmt:
2819             break
2820         first_child = child.children[0]
2821         if isinstance(first_child, Leaf):
2822             # Continue looking if we see a docstring; otherwise stop.
2823             if (
2824                 len(child.children) == 2
2825                 and first_child.type == token.STRING
2826                 and child.children[1].type == token.NEWLINE
2827             ):
2828                 continue
2829             else:
2830                 break
2831         elif first_child.type == syms.import_from:
2832             module_name = first_child.children[1]
2833             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2834                 break
2835             for import_from_child in first_child.children[3:]:
2836                 if isinstance(import_from_child, Leaf):
2837                     if import_from_child.type == token.NAME:
2838                         imports.add(import_from_child.value)
2839                 else:
2840                     assert import_from_child.type == syms.import_as_names
2841                     for leaf in import_from_child.children:
2842                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2843                             imports.add(leaf.value)
2844         else:
2845             break
2846     return imports
2847
2848
2849 def gen_python_files_in_dir(
2850     path: Path,
2851     root: Path,
2852     include: Pattern[str],
2853     exclude: Pattern[str],
2854     report: "Report",
2855 ) -> Iterator[Path]:
2856     """Generate all files under `path` whose paths are not excluded by the
2857     `exclude` regex, but are included by the `include` regex.
2858
2859     `report` is where output about exclusions goes.
2860     """
2861     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2862     for child in path.iterdir():
2863         normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2864         if child.is_dir():
2865             normalized_path += "/"
2866         exclude_match = exclude.search(normalized_path)
2867         if exclude_match and exclude_match.group(0):
2868             report.path_ignored(child, f"matches --exclude={exclude.pattern}")
2869             continue
2870
2871         if child.is_dir():
2872             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2873
2874         elif child.is_file():
2875             include_match = include.search(normalized_path)
2876             if include_match:
2877                 yield child
2878
2879
2880 def find_project_root(srcs: List[str]) -> Path:
2881     """Return a directory containing .git, .hg, or pyproject.toml.
2882
2883     That directory can be one of the directories passed in `srcs` or their
2884     common parent.
2885
2886     If no directory in the tree contains a marker that would specify it's the
2887     project root, the root of the file system is returned.
2888     """
2889     if not srcs:
2890         return Path("/").resolve()
2891
2892     common_base = min(Path(src).resolve() for src in srcs)
2893     if common_base.is_dir():
2894         # Append a fake file so `parents` below returns `common_base_dir`, too.
2895         common_base /= "fake-file"
2896     for directory in common_base.parents:
2897         if (directory / ".git").is_dir():
2898             return directory
2899
2900         if (directory / ".hg").is_dir():
2901             return directory
2902
2903         if (directory / "pyproject.toml").is_file():
2904             return directory
2905
2906     return directory
2907
2908
2909 @dataclass
2910 class Report:
2911     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2912
2913     check: bool = False
2914     quiet: bool = False
2915     verbose: bool = False
2916     change_count: int = 0
2917     same_count: int = 0
2918     failure_count: int = 0
2919
2920     def done(self, src: Path, changed: Changed) -> None:
2921         """Increment the counter for successful reformatting. Write out a message."""
2922         if changed is Changed.YES:
2923             reformatted = "would reformat" if self.check else "reformatted"
2924             if self.verbose or not self.quiet:
2925                 out(f"{reformatted} {src}")
2926             self.change_count += 1
2927         else:
2928             if self.verbose:
2929                 if changed is Changed.NO:
2930                     msg = f"{src} already well formatted, good job."
2931                 else:
2932                     msg = f"{src} wasn't modified on disk since last run."
2933                 out(msg, bold=False)
2934             self.same_count += 1
2935
2936     def failed(self, src: Path, message: str) -> None:
2937         """Increment the counter for failed reformatting. Write out a message."""
2938         err(f"error: cannot format {src}: {message}")
2939         self.failure_count += 1
2940
2941     def path_ignored(self, path: Path, message: str) -> None:
2942         if self.verbose:
2943             out(f"{path} ignored: {message}", bold=False)
2944
2945     @property
2946     def return_code(self) -> int:
2947         """Return the exit code that the app should use.
2948
2949         This considers the current state of changed files and failures:
2950         - if there were any failures, return 123;
2951         - if any files were changed and --check is being used, return 1;
2952         - otherwise return 0.
2953         """
2954         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2955         # 126 we have special returncodes reserved by the shell.
2956         if self.failure_count:
2957             return 123
2958
2959         elif self.change_count and self.check:
2960             return 1
2961
2962         return 0
2963
2964     def __str__(self) -> str:
2965         """Render a color report of the current state.
2966
2967         Use `click.unstyle` to remove colors.
2968         """
2969         if self.check:
2970             reformatted = "would be reformatted"
2971             unchanged = "would be left unchanged"
2972             failed = "would fail to reformat"
2973         else:
2974             reformatted = "reformatted"
2975             unchanged = "left unchanged"
2976             failed = "failed to reformat"
2977         report = []
2978         if self.change_count:
2979             s = "s" if self.change_count > 1 else ""
2980             report.append(
2981                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2982             )
2983         if self.same_count:
2984             s = "s" if self.same_count > 1 else ""
2985             report.append(f"{self.same_count} file{s} {unchanged}")
2986         if self.failure_count:
2987             s = "s" if self.failure_count > 1 else ""
2988             report.append(
2989                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2990             )
2991         return ", ".join(report) + "."
2992
2993
2994 def assert_equivalent(src: str, dst: str) -> None:
2995     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2996
2997     import ast
2998     import traceback
2999
3000     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3001         """Simple visitor generating strings to compare ASTs by content."""
3002         yield f"{'  ' * depth}{node.__class__.__name__}("
3003
3004         for field in sorted(node._fields):
3005             try:
3006                 value = getattr(node, field)
3007             except AttributeError:
3008                 continue
3009
3010             yield f"{'  ' * (depth+1)}{field}="
3011
3012             if isinstance(value, list):
3013                 for item in value:
3014                     if isinstance(item, ast.AST):
3015                         yield from _v(item, depth + 2)
3016
3017             elif isinstance(value, ast.AST):
3018                 yield from _v(value, depth + 2)
3019
3020             else:
3021                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3022
3023         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3024
3025     try:
3026         src_ast = ast.parse(src)
3027     except Exception as exc:
3028         major, minor = sys.version_info[:2]
3029         raise AssertionError(
3030             f"cannot use --safe with this file; failed to parse source file "
3031             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3032             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3033         )
3034
3035     try:
3036         dst_ast = ast.parse(dst)
3037     except Exception as exc:
3038         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3039         raise AssertionError(
3040             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3041             f"Please report a bug on https://github.com/ambv/black/issues.  "
3042             f"This invalid output might be helpful: {log}"
3043         ) from None
3044
3045     src_ast_str = "\n".join(_v(src_ast))
3046     dst_ast_str = "\n".join(_v(dst_ast))
3047     if src_ast_str != dst_ast_str:
3048         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3049         raise AssertionError(
3050             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3051             f"the source.  "
3052             f"Please report a bug on https://github.com/ambv/black/issues.  "
3053             f"This diff might be helpful: {log}"
3054         ) from None
3055
3056
3057 def assert_stable(
3058     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3059 ) -> None:
3060     """Raise AssertionError if `dst` reformats differently the second time."""
3061     newdst = format_str(dst, line_length=line_length, mode=mode)
3062     if dst != newdst:
3063         log = dump_to_file(
3064             diff(src, dst, "source", "first pass"),
3065             diff(dst, newdst, "first pass", "second pass"),
3066         )
3067         raise AssertionError(
3068             f"INTERNAL ERROR: Black produced different code on the second pass "
3069             f"of the formatter.  "
3070             f"Please report a bug on https://github.com/ambv/black/issues.  "
3071             f"This diff might be helpful: {log}"
3072         ) from None
3073
3074
3075 def dump_to_file(*output: str) -> str:
3076     """Dump `output` to a temporary file. Return path to the file."""
3077     import tempfile
3078
3079     with tempfile.NamedTemporaryFile(
3080         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3081     ) as f:
3082         for lines in output:
3083             f.write(lines)
3084             if lines and lines[-1] != "\n":
3085                 f.write("\n")
3086     return f.name
3087
3088
3089 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3090     """Return a unified diff string between strings `a` and `b`."""
3091     import difflib
3092
3093     a_lines = [line + "\n" for line in a.split("\n")]
3094     b_lines = [line + "\n" for line in b.split("\n")]
3095     return "".join(
3096         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3097     )
3098
3099
3100 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3101     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3102     err("Aborted!")
3103     for task in tasks:
3104         task.cancel()
3105
3106
3107 def shutdown(loop: BaseEventLoop) -> None:
3108     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3109     try:
3110         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3111         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3112         if not to_cancel:
3113             return
3114
3115         for task in to_cancel:
3116             task.cancel()
3117         loop.run_until_complete(
3118             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3119         )
3120     finally:
3121         # `concurrent.futures.Future` objects cannot be cancelled once they
3122         # are already running. There might be some when the `shutdown()` happened.
3123         # Silence their logger's spew about the event loop being closed.
3124         cf_logger = logging.getLogger("concurrent.futures")
3125         cf_logger.setLevel(logging.CRITICAL)
3126         loop.close()
3127
3128
3129 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3130     """Replace `regex` with `replacement` twice on `original`.
3131
3132     This is used by string normalization to perform replaces on
3133     overlapping matches.
3134     """
3135     return regex.sub(replacement, regex.sub(replacement, original))
3136
3137
3138 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3139     """Like `reversed(enumerate(sequence))` if that were possible."""
3140     index = len(sequence) - 1
3141     for element in reversed(sequence):
3142         yield (index, element)
3143         index -= 1
3144
3145
3146 def enumerate_with_length(
3147     line: Line, reversed: bool = False
3148 ) -> Iterator[Tuple[Index, Leaf, int]]:
3149     """Return an enumeration of leaves with their length.
3150
3151     Stops prematurely on multiline strings and standalone comments.
3152     """
3153     op = cast(
3154         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3155         enumerate_reversed if reversed else enumerate,
3156     )
3157     for index, leaf in op(line.leaves):
3158         length = len(leaf.prefix) + len(leaf.value)
3159         if "\n" in leaf.value:
3160             return  # Multiline strings, we can't continue.
3161
3162         comment: Optional[Leaf]
3163         for comment in line.comments_after(leaf, index):
3164             length += len(comment.value)
3165
3166         yield index, leaf, length
3167
3168
3169 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3170     """Return True if `line` is no longer than `line_length`.
3171
3172     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3173     """
3174     if not line_str:
3175         line_str = str(line).strip("\n")
3176     return (
3177         len(line_str) <= line_length
3178         and "\n" not in line_str  # multiline strings
3179         and not line.contains_standalone_comments()
3180     )
3181
3182
3183 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3184     """Does `line` have a shape safe to reformat without optional parens around it?
3185
3186     Returns True for only a subset of potentially nice looking formattings but
3187     the point is to not return false positives that end up producing lines that
3188     are too long.
3189     """
3190     bt = line.bracket_tracker
3191     if not bt.delimiters:
3192         # Without delimiters the optional parentheses are useless.
3193         return True
3194
3195     max_priority = bt.max_delimiter_priority()
3196     if bt.delimiter_count_with_priority(max_priority) > 1:
3197         # With more than one delimiter of a kind the optional parentheses read better.
3198         return False
3199
3200     if max_priority == DOT_PRIORITY:
3201         # A single stranded method call doesn't require optional parentheses.
3202         return True
3203
3204     assert len(line.leaves) >= 2, "Stranded delimiter"
3205
3206     first = line.leaves[0]
3207     second = line.leaves[1]
3208     penultimate = line.leaves[-2]
3209     last = line.leaves[-1]
3210
3211     # With a single delimiter, omit if the expression starts or ends with
3212     # a bracket.
3213     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3214         remainder = False
3215         length = 4 * line.depth
3216         for _index, leaf, leaf_length in enumerate_with_length(line):
3217             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3218                 remainder = True
3219             if remainder:
3220                 length += leaf_length
3221                 if length > line_length:
3222                     break
3223
3224                 if leaf.type in OPENING_BRACKETS:
3225                     # There are brackets we can further split on.
3226                     remainder = False
3227
3228         else:
3229             # checked the entire string and line length wasn't exceeded
3230             if len(line.leaves) == _index + 1:
3231                 return True
3232
3233         # Note: we are not returning False here because a line might have *both*
3234         # a leading opening bracket and a trailing closing bracket.  If the
3235         # opening bracket doesn't match our rule, maybe the closing will.
3236
3237     if (
3238         last.type == token.RPAR
3239         or last.type == token.RBRACE
3240         or (
3241             # don't use indexing for omitting optional parentheses;
3242             # it looks weird
3243             last.type == token.RSQB
3244             and last.parent
3245             and last.parent.type != syms.trailer
3246         )
3247     ):
3248         if penultimate.type in OPENING_BRACKETS:
3249             # Empty brackets don't help.
3250             return False
3251
3252         if is_multiline_string(first):
3253             # Additional wrapping of a multiline string in this situation is
3254             # unnecessary.
3255             return True
3256
3257         length = 4 * line.depth
3258         seen_other_brackets = False
3259         for _index, leaf, leaf_length in enumerate_with_length(line):
3260             length += leaf_length
3261             if leaf is last.opening_bracket:
3262                 if seen_other_brackets or length <= line_length:
3263                     return True
3264
3265             elif leaf.type in OPENING_BRACKETS:
3266                 # There are brackets we can further split on.
3267                 seen_other_brackets = True
3268
3269     return False
3270
3271
3272 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3273     pyi = bool(mode & FileMode.PYI)
3274     py36 = bool(mode & FileMode.PYTHON36)
3275     return (
3276         CACHE_DIR
3277         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3278     )
3279
3280
3281 def read_cache(line_length: int, mode: FileMode) -> Cache:
3282     """Read the cache if it exists and is well formed.
3283
3284     If it is not well formed, the call to write_cache later should resolve the issue.
3285     """
3286     cache_file = get_cache_file(line_length, mode)
3287     if not cache_file.exists():
3288         return {}
3289
3290     with cache_file.open("rb") as fobj:
3291         try:
3292             cache: Cache = pickle.load(fobj)
3293         except pickle.UnpicklingError:
3294             return {}
3295
3296     return cache
3297
3298
3299 def get_cache_info(path: Path) -> CacheInfo:
3300     """Return the information used to check if a file is already formatted or not."""
3301     stat = path.stat()
3302     return stat.st_mtime, stat.st_size
3303
3304
3305 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3306     """Split an iterable of paths in `sources` into two sets.
3307
3308     The first contains paths of files that modified on disk or are not in the
3309     cache. The other contains paths to non-modified files.
3310     """
3311     todo, done = set(), set()
3312     for src in sources:
3313         src = src.resolve()
3314         if cache.get(src) != get_cache_info(src):
3315             todo.add(src)
3316         else:
3317             done.add(src)
3318     return todo, done
3319
3320
3321 def write_cache(
3322     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3323 ) -> None:
3324     """Update the cache file."""
3325     cache_file = get_cache_file(line_length, mode)
3326     try:
3327         if not CACHE_DIR.exists():
3328             CACHE_DIR.mkdir(parents=True)
3329         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3330         with cache_file.open("wb") as fobj:
3331             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3332     except OSError:
3333         pass
3334
3335
3336 if __name__ == "__main__":
3337     main()