]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Properly format unified diff
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import io
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import pickle
14 import re
15 import signal
16 import sys
17 import tokenize
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Sequence,
30     Set,
31     Tuple,
32     Type,
33     TypeVar,
34     Union,
35     cast,
36 )
37
38 from appdirs import user_cache_dir
39 from attr import dataclass, Factory
40 import click
41
42 # lib2to3 fork
43 from blib2to3.pytree import Node, Leaf, type_repr
44 from blib2to3 import pygram, pytree
45 from blib2to3.pgen2 import driver, token
46 from blib2to3.pgen2.parse import ParseError
47
48
49 __version__ = "18.5b1"
50 DEFAULT_LINE_LENGTH = 88
51 DEFAULT_EXCLUDES = (
52     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
53 )
54 DEFAULT_INCLUDES = r"\.pyi?$"
55 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
56
57
58 # types
59 FileContent = str
60 Encoding = str
61 NewLine = str
62 Depth = int
63 NodeType = int
64 LeafID = int
65 Priority = int
66 Index = int
67 LN = Union[Leaf, Node]
68 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
69 Timestamp = float
70 FileSize = int
71 CacheInfo = Tuple[Timestamp, FileSize]
72 Cache = Dict[Path, CacheInfo]
73 out = partial(click.secho, bold=True, err=True)
74 err = partial(click.secho, fg="red", err=True)
75
76 pygram.initialize(CACHE_DIR)
77 syms = pygram.python_symbols
78
79
80 class NothingChanged(UserWarning):
81     """Raised by :func:`format_file` when reformatted code is the same as source."""
82
83
84 class CannotSplit(Exception):
85     """A readable split that fits the allotted line length is impossible.
86
87     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
88     :func:`delimiter_split`.
89     """
90
91
92 class FormatError(Exception):
93     """Base exception for `# fmt: on` and `# fmt: off` handling.
94
95     It holds the number of bytes of the prefix consumed before the format
96     control comment appeared.
97     """
98
99     def __init__(self, consumed: int) -> None:
100         super().__init__(consumed)
101         self.consumed = consumed
102
103     def trim_prefix(self, leaf: Leaf) -> None:
104         leaf.prefix = leaf.prefix[self.consumed :]
105
106     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
107         """Returns a new Leaf from the consumed part of the prefix."""
108         unformatted_prefix = leaf.prefix[: self.consumed]
109         return Leaf(token.NEWLINE, unformatted_prefix)
110
111
112 class FormatOn(FormatError):
113     """Found a comment like `# fmt: on` in the file."""
114
115
116 class FormatOff(FormatError):
117     """Found a comment like `# fmt: off` in the file."""
118
119
120 class WriteBack(Enum):
121     NO = 0
122     YES = 1
123     DIFF = 2
124
125     @classmethod
126     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
127         if check and not diff:
128             return cls.NO
129
130         return cls.DIFF if diff else cls.YES
131
132
133 class Changed(Enum):
134     NO = 0
135     CACHED = 1
136     YES = 2
137
138
139 class FileMode(Flag):
140     AUTO_DETECT = 0
141     PYTHON36 = 1
142     PYI = 2
143     NO_STRING_NORMALIZATION = 4
144
145     @classmethod
146     def from_configuration(
147         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
148     ) -> "FileMode":
149         mode = cls.AUTO_DETECT
150         if py36:
151             mode |= cls.PYTHON36
152         if pyi:
153             mode |= cls.PYI
154         if skip_string_normalization:
155             mode |= cls.NO_STRING_NORMALIZATION
156         return mode
157
158
159 @click.command()
160 @click.option(
161     "-l",
162     "--line-length",
163     type=int,
164     default=DEFAULT_LINE_LENGTH,
165     help="How many character per line to allow.",
166     show_default=True,
167 )
168 @click.option(
169     "--py36",
170     is_flag=True,
171     help=(
172         "Allow using Python 3.6-only syntax on all input files.  This will put "
173         "trailing commas in function signatures and calls also after *args and "
174         "**kwargs.  [default: per-file auto-detection]"
175     ),
176 )
177 @click.option(
178     "--pyi",
179     is_flag=True,
180     help=(
181         "Format all input files like typing stubs regardless of file extension "
182         "(useful when piping source on standard input)."
183     ),
184 )
185 @click.option(
186     "-S",
187     "--skip-string-normalization",
188     is_flag=True,
189     help="Don't normalize string quotes or prefixes.",
190 )
191 @click.option(
192     "--check",
193     is_flag=True,
194     help=(
195         "Don't write the files back, just return the status.  Return code 0 "
196         "means nothing would change.  Return code 1 means some files would be "
197         "reformatted.  Return code 123 means there was an internal error."
198     ),
199 )
200 @click.option(
201     "--diff",
202     is_flag=True,
203     help="Don't write the files back, just output a diff for each file on stdout.",
204 )
205 @click.option(
206     "--fast/--safe",
207     is_flag=True,
208     help="If --fast given, skip temporary sanity checks. [default: --safe]",
209 )
210 @click.option(
211     "--include",
212     type=str,
213     default=DEFAULT_INCLUDES,
214     help=(
215         "A regular expression that matches files and directories that should be "
216         "included on recursive searches.  An empty value means all files are "
217         "included regardless of the name.  Use forward slashes for directories on "
218         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
219         "later."
220     ),
221     show_default=True,
222 )
223 @click.option(
224     "--exclude",
225     type=str,
226     default=DEFAULT_EXCLUDES,
227     help=(
228         "A regular expression that matches files and directories that should be "
229         "excluded on recursive searches.  An empty value means no paths are excluded. "
230         "Use forward slashes for directories on all platforms (Windows, too).  "
231         "Exclusions are calculated first, inclusions later."
232     ),
233     show_default=True,
234 )
235 @click.option(
236     "-q",
237     "--quiet",
238     is_flag=True,
239     help=(
240         "Don't emit non-error messages to stderr. Errors are still emitted, "
241         "silence those with 2>/dev/null."
242     ),
243 )
244 @click.option(
245     "-v",
246     "--verbose",
247     is_flag=True,
248     help=(
249         "Also emit messages to stderr about files that were not changed or were "
250         "ignored due to --exclude=."
251     ),
252 )
253 @click.version_option(version=__version__)
254 @click.argument(
255     "src",
256     nargs=-1,
257     type=click.Path(
258         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
259     ),
260 )
261 @click.pass_context
262 def main(
263     ctx: click.Context,
264     line_length: int,
265     check: bool,
266     diff: bool,
267     fast: bool,
268     pyi: bool,
269     py36: bool,
270     skip_string_normalization: bool,
271     quiet: bool,
272     verbose: bool,
273     include: str,
274     exclude: str,
275     src: List[str],
276 ) -> None:
277     """The uncompromising code formatter."""
278     write_back = WriteBack.from_configuration(check=check, diff=diff)
279     mode = FileMode.from_configuration(
280         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
281     )
282     report = Report(check=check, quiet=quiet, verbose=verbose)
283     sources: Set[Path] = set()
284     try:
285         include_regex = re.compile(include)
286     except re.error:
287         err(f"Invalid regular expression for include given: {include!r}")
288         ctx.exit(2)
289     try:
290         exclude_regex = re.compile(exclude)
291     except re.error:
292         err(f"Invalid regular expression for exclude given: {exclude!r}")
293         ctx.exit(2)
294     root = find_project_root(src)
295     for s in src:
296         p = Path(s)
297         if p.is_dir():
298             sources.update(
299                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
300             )
301         elif p.is_file() or s == "-":
302             # if a file was explicitly given, we don't care about its extension
303             sources.add(p)
304         else:
305             err(f"invalid path: {s}")
306     if len(sources) == 0:
307         if verbose or not quiet:
308             out("No paths given. Nothing to do 😴")
309         ctx.exit(0)
310         return
311
312     elif len(sources) == 1:
313         reformat_one(
314             src=sources.pop(),
315             line_length=line_length,
316             fast=fast,
317             write_back=write_back,
318             mode=mode,
319             report=report,
320         )
321     else:
322         loop = asyncio.get_event_loop()
323         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
324         try:
325             loop.run_until_complete(
326                 schedule_formatting(
327                     sources=sources,
328                     line_length=line_length,
329                     fast=fast,
330                     write_back=write_back,
331                     mode=mode,
332                     report=report,
333                     loop=loop,
334                     executor=executor,
335                 )
336             )
337         finally:
338             shutdown(loop)
339     if verbose or not quiet:
340         out("All done! ✨ 🍰 ✨")
341         click.echo(str(report))
342     ctx.exit(report.return_code)
343
344
345 def reformat_one(
346     src: Path,
347     line_length: int,
348     fast: bool,
349     write_back: WriteBack,
350     mode: FileMode,
351     report: "Report",
352 ) -> None:
353     """Reformat a single file under `src` without spawning child processes.
354
355     If `quiet` is True, non-error messages are not output. `line_length`,
356     `write_back`, `fast` and `pyi` options are passed to
357     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
358     """
359     try:
360         changed = Changed.NO
361         if not src.is_file() and str(src) == "-":
362             if format_stdin_to_stdout(
363                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
364             ):
365                 changed = Changed.YES
366         else:
367             cache: Cache = {}
368             if write_back != WriteBack.DIFF:
369                 cache = read_cache(line_length, mode)
370                 res_src = src.resolve()
371                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
372                     changed = Changed.CACHED
373             if changed is not Changed.CACHED and format_file_in_place(
374                 src,
375                 line_length=line_length,
376                 fast=fast,
377                 write_back=write_back,
378                 mode=mode,
379             ):
380                 changed = Changed.YES
381             if write_back == WriteBack.YES and changed is not Changed.NO:
382                 write_cache(cache, [src], line_length, mode)
383         report.done(src, changed)
384     except Exception as exc:
385         report.failed(src, str(exc))
386
387
388 async def schedule_formatting(
389     sources: Set[Path],
390     line_length: int,
391     fast: bool,
392     write_back: WriteBack,
393     mode: FileMode,
394     report: "Report",
395     loop: BaseEventLoop,
396     executor: Executor,
397 ) -> None:
398     """Run formatting of `sources` in parallel using the provided `executor`.
399
400     (Use ProcessPoolExecutors for actual parallelism.)
401
402     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
403     :func:`format_file_in_place`.
404     """
405     cache: Cache = {}
406     if write_back != WriteBack.DIFF:
407         cache = read_cache(line_length, mode)
408         sources, cached = filter_cached(cache, sources)
409         for src in sorted(cached):
410             report.done(src, Changed.CACHED)
411     cancelled = []
412     formatted = []
413     if sources:
414         lock = None
415         if write_back == WriteBack.DIFF:
416             # For diff output, we need locks to ensure we don't interleave output
417             # from different processes.
418             manager = Manager()
419             lock = manager.Lock()
420         tasks = {
421             loop.run_in_executor(
422                 executor,
423                 format_file_in_place,
424                 src,
425                 line_length,
426                 fast,
427                 write_back,
428                 mode,
429                 lock,
430             ): src
431             for src in sorted(sources)
432         }
433         pending: Iterable[asyncio.Task] = tasks.keys()
434         try:
435             loop.add_signal_handler(signal.SIGINT, cancel, pending)
436             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
437         except NotImplementedError:
438             # There are no good alternatives for these on Windows
439             pass
440         while pending:
441             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
442             for task in done:
443                 src = tasks.pop(task)
444                 if task.cancelled():
445                     cancelled.append(task)
446                 elif task.exception():
447                     report.failed(src, str(task.exception()))
448                 else:
449                     formatted.append(src)
450                     report.done(src, Changed.YES if task.result() else Changed.NO)
451     if cancelled:
452         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
453     if write_back == WriteBack.YES and formatted:
454         write_cache(cache, formatted, line_length, mode)
455
456
457 def format_file_in_place(
458     src: Path,
459     line_length: int,
460     fast: bool,
461     write_back: WriteBack = WriteBack.NO,
462     mode: FileMode = FileMode.AUTO_DETECT,
463     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
464 ) -> bool:
465     """Format file under `src` path. Return True if changed.
466
467     If `write_back` is True, write reformatted code back to stdout.
468     `line_length` and `fast` options are passed to :func:`format_file_contents`.
469     """
470     if src.suffix == ".pyi":
471         mode |= FileMode.PYI
472
473     then = datetime.utcfromtimestamp(src.stat().st_mtime)
474     with open(src, "rb") as buf:
475         src_contents, encoding, newline = decode_bytes(buf.read())
476     try:
477         dst_contents = format_file_contents(
478             src_contents, line_length=line_length, fast=fast, mode=mode
479         )
480     except NothingChanged:
481         return False
482
483     if write_back == write_back.YES:
484         with open(src, "w", encoding=encoding, newline=newline) as f:
485             f.write(dst_contents)
486     elif write_back == write_back.DIFF:
487         now = datetime.utcnow()
488         src_name = f"{src}\t{then} +0000"
489         dst_name = f"{src}\t{now} +0000"
490         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
491         if lock:
492             lock.acquire()
493         try:
494             f = io.TextIOWrapper(
495                 sys.stdout.buffer,
496                 encoding=encoding,
497                 newline=newline,
498                 write_through=True,
499             )
500             f.write(diff_contents)
501             f.detach()
502         finally:
503             if lock:
504                 lock.release()
505     return True
506
507
508 def format_stdin_to_stdout(
509     line_length: int,
510     fast: bool,
511     write_back: WriteBack = WriteBack.NO,
512     mode: FileMode = FileMode.AUTO_DETECT,
513 ) -> bool:
514     """Format file on stdin. Return True if changed.
515
516     If `write_back` is True, write reformatted code back to stdout.
517     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
518     :func:`format_file_contents`.
519     """
520     then = datetime.utcnow()
521     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
522     dst = src
523     try:
524         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
525         return True
526
527     except NothingChanged:
528         return False
529
530     finally:
531         f = io.TextIOWrapper(
532             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
533         )
534         if write_back == WriteBack.YES:
535             f.write(dst)
536         elif write_back == WriteBack.DIFF:
537             now = datetime.utcnow()
538             src_name = f"STDIN\t{then} +0000"
539             dst_name = f"STDOUT\t{now} +0000"
540             f.write(diff(src, dst, src_name, dst_name))
541         f.detach()
542
543
544 def format_file_contents(
545     src_contents: str,
546     *,
547     line_length: int,
548     fast: bool,
549     mode: FileMode = FileMode.AUTO_DETECT,
550 ) -> FileContent:
551     """Reformat contents a file and return new contents.
552
553     If `fast` is False, additionally confirm that the reformatted code is
554     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
555     `line_length` is passed to :func:`format_str`.
556     """
557     if src_contents.strip() == "":
558         raise NothingChanged
559
560     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
561     if src_contents == dst_contents:
562         raise NothingChanged
563
564     if not fast:
565         assert_equivalent(src_contents, dst_contents)
566         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
567     return dst_contents
568
569
570 def format_str(
571     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
572 ) -> FileContent:
573     """Reformat a string and return new contents.
574
575     `line_length` determines how many characters per line are allowed.
576     """
577     src_node = lib2to3_parse(src_contents)
578     dst_contents = ""
579     future_imports = get_future_imports(src_node)
580     is_pyi = bool(mode & FileMode.PYI)
581     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
582     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
583     lines = LineGenerator(
584         remove_u_prefix=py36 or "unicode_literals" in future_imports,
585         is_pyi=is_pyi,
586         normalize_strings=normalize_strings,
587     )
588     elt = EmptyLineTracker(is_pyi=is_pyi)
589     empty_line = Line()
590     after = 0
591     for current_line in lines.visit(src_node):
592         for _ in range(after):
593             dst_contents += str(empty_line)
594         before, after = elt.maybe_empty_lines(current_line)
595         for _ in range(before):
596             dst_contents += str(empty_line)
597         for line in split_line(current_line, line_length=line_length, py36=py36):
598             dst_contents += str(line)
599     return dst_contents
600
601
602 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
603     """Return a tuple of (decoded_contents, encoding, newline).
604
605     `newline` is either CRLF or LF but `decoded_contents` is decoded with
606     universal newlines (i.e. only contains LF).
607     """
608     srcbuf = io.BytesIO(src)
609     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
610     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
611     srcbuf.seek(0)
612     with io.TextIOWrapper(srcbuf, encoding) as tiow:
613         return tiow.read(), encoding, newline
614
615
616 GRAMMARS = [
617     pygram.python_grammar_no_print_statement_no_exec_statement,
618     pygram.python_grammar_no_print_statement,
619     pygram.python_grammar,
620 ]
621
622
623 def lib2to3_parse(src_txt: str) -> Node:
624     """Given a string with source, return the lib2to3 Node."""
625     grammar = pygram.python_grammar_no_print_statement
626     if src_txt[-1] != "\n":
627         src_txt += "\n"
628     for grammar in GRAMMARS:
629         drv = driver.Driver(grammar, pytree.convert)
630         try:
631             result = drv.parse_string(src_txt, True)
632             break
633
634         except ParseError as pe:
635             lineno, column = pe.context[1]
636             lines = src_txt.splitlines()
637             try:
638                 faulty_line = lines[lineno - 1]
639             except IndexError:
640                 faulty_line = "<line number missing in source>"
641             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
642     else:
643         raise exc from None
644
645     if isinstance(result, Leaf):
646         result = Node(syms.file_input, [result])
647     return result
648
649
650 def lib2to3_unparse(node: Node) -> str:
651     """Given a lib2to3 node, return its string representation."""
652     code = str(node)
653     return code
654
655
656 T = TypeVar("T")
657
658
659 class Visitor(Generic[T]):
660     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
661
662     def visit(self, node: LN) -> Iterator[T]:
663         """Main method to visit `node` and its children.
664
665         It tries to find a `visit_*()` method for the given `node.type`, like
666         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
667         If no dedicated `visit_*()` method is found, chooses `visit_default()`
668         instead.
669
670         Then yields objects of type `T` from the selected visitor.
671         """
672         if node.type < 256:
673             name = token.tok_name[node.type]
674         else:
675             name = type_repr(node.type)
676         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
677
678     def visit_default(self, node: LN) -> Iterator[T]:
679         """Default `visit_*()` implementation. Recurses to children of `node`."""
680         if isinstance(node, Node):
681             for child in node.children:
682                 yield from self.visit(child)
683
684
685 @dataclass
686 class DebugVisitor(Visitor[T]):
687     tree_depth: int = 0
688
689     def visit_default(self, node: LN) -> Iterator[T]:
690         indent = " " * (2 * self.tree_depth)
691         if isinstance(node, Node):
692             _type = type_repr(node.type)
693             out(f"{indent}{_type}", fg="yellow")
694             self.tree_depth += 1
695             for child in node.children:
696                 yield from self.visit(child)
697
698             self.tree_depth -= 1
699             out(f"{indent}/{_type}", fg="yellow", bold=False)
700         else:
701             _type = token.tok_name.get(node.type, str(node.type))
702             out(f"{indent}{_type}", fg="blue", nl=False)
703             if node.prefix:
704                 # We don't have to handle prefixes for `Node` objects since
705                 # that delegates to the first child anyway.
706                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
707             out(f" {node.value!r}", fg="blue", bold=False)
708
709     @classmethod
710     def show(cls, code: str) -> None:
711         """Pretty-print the lib2to3 AST of a given string of `code`.
712
713         Convenience method for debugging.
714         """
715         v: DebugVisitor[None] = DebugVisitor()
716         list(v.visit(lib2to3_parse(code)))
717
718
719 KEYWORDS = set(keyword.kwlist)
720 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
721 FLOW_CONTROL = {"return", "raise", "break", "continue"}
722 STATEMENT = {
723     syms.if_stmt,
724     syms.while_stmt,
725     syms.for_stmt,
726     syms.try_stmt,
727     syms.except_clause,
728     syms.with_stmt,
729     syms.funcdef,
730     syms.classdef,
731 }
732 STANDALONE_COMMENT = 153
733 LOGIC_OPERATORS = {"and", "or"}
734 COMPARATORS = {
735     token.LESS,
736     token.GREATER,
737     token.EQEQUAL,
738     token.NOTEQUAL,
739     token.LESSEQUAL,
740     token.GREATEREQUAL,
741 }
742 MATH_OPERATORS = {
743     token.VBAR,
744     token.CIRCUMFLEX,
745     token.AMPER,
746     token.LEFTSHIFT,
747     token.RIGHTSHIFT,
748     token.PLUS,
749     token.MINUS,
750     token.STAR,
751     token.SLASH,
752     token.DOUBLESLASH,
753     token.PERCENT,
754     token.AT,
755     token.TILDE,
756     token.DOUBLESTAR,
757 }
758 STARS = {token.STAR, token.DOUBLESTAR}
759 VARARGS_PARENTS = {
760     syms.arglist,
761     syms.argument,  # double star in arglist
762     syms.trailer,  # single argument to call
763     syms.typedargslist,
764     syms.varargslist,  # lambdas
765 }
766 UNPACKING_PARENTS = {
767     syms.atom,  # single element of a list or set literal
768     syms.dictsetmaker,
769     syms.listmaker,
770     syms.testlist_gexp,
771 }
772 TEST_DESCENDANTS = {
773     syms.test,
774     syms.lambdef,
775     syms.or_test,
776     syms.and_test,
777     syms.not_test,
778     syms.comparison,
779     syms.star_expr,
780     syms.expr,
781     syms.xor_expr,
782     syms.and_expr,
783     syms.shift_expr,
784     syms.arith_expr,
785     syms.trailer,
786     syms.term,
787     syms.power,
788 }
789 ASSIGNMENTS = {
790     "=",
791     "+=",
792     "-=",
793     "*=",
794     "@=",
795     "/=",
796     "%=",
797     "&=",
798     "|=",
799     "^=",
800     "<<=",
801     ">>=",
802     "**=",
803     "//=",
804 }
805 COMPREHENSION_PRIORITY = 20
806 COMMA_PRIORITY = 18
807 TERNARY_PRIORITY = 16
808 LOGIC_PRIORITY = 14
809 STRING_PRIORITY = 12
810 COMPARATOR_PRIORITY = 10
811 MATH_PRIORITIES = {
812     token.VBAR: 9,
813     token.CIRCUMFLEX: 8,
814     token.AMPER: 7,
815     token.LEFTSHIFT: 6,
816     token.RIGHTSHIFT: 6,
817     token.PLUS: 5,
818     token.MINUS: 5,
819     token.STAR: 4,
820     token.SLASH: 4,
821     token.DOUBLESLASH: 4,
822     token.PERCENT: 4,
823     token.AT: 4,
824     token.TILDE: 3,
825     token.DOUBLESTAR: 2,
826 }
827 DOT_PRIORITY = 1
828
829
830 @dataclass
831 class BracketTracker:
832     """Keeps track of brackets on a line."""
833
834     depth: int = 0
835     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
836     delimiters: Dict[LeafID, Priority] = Factory(dict)
837     previous: Optional[Leaf] = None
838     _for_loop_variable: int = 0
839     _lambda_arguments: int = 0
840
841     def mark(self, leaf: Leaf) -> None:
842         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
843
844         All leaves receive an int `bracket_depth` field that stores how deep
845         within brackets a given leaf is. 0 means there are no enclosing brackets
846         that started on this line.
847
848         If a leaf is itself a closing bracket, it receives an `opening_bracket`
849         field that it forms a pair with. This is a one-directional link to
850         avoid reference cycles.
851
852         If a leaf is a delimiter (a token on which Black can split the line if
853         needed) and it's on depth 0, its `id()` is stored in the tracker's
854         `delimiters` field.
855         """
856         if leaf.type == token.COMMENT:
857             return
858
859         self.maybe_decrement_after_for_loop_variable(leaf)
860         self.maybe_decrement_after_lambda_arguments(leaf)
861         if leaf.type in CLOSING_BRACKETS:
862             self.depth -= 1
863             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
864             leaf.opening_bracket = opening_bracket
865         leaf.bracket_depth = self.depth
866         if self.depth == 0:
867             delim = is_split_before_delimiter(leaf, self.previous)
868             if delim and self.previous is not None:
869                 self.delimiters[id(self.previous)] = delim
870             else:
871                 delim = is_split_after_delimiter(leaf, self.previous)
872                 if delim:
873                     self.delimiters[id(leaf)] = delim
874         if leaf.type in OPENING_BRACKETS:
875             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
876             self.depth += 1
877         self.previous = leaf
878         self.maybe_increment_lambda_arguments(leaf)
879         self.maybe_increment_for_loop_variable(leaf)
880
881     def any_open_brackets(self) -> bool:
882         """Return True if there is an yet unmatched open bracket on the line."""
883         return bool(self.bracket_match)
884
885     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
886         """Return the highest priority of a delimiter found on the line.
887
888         Values are consistent with what `is_split_*_delimiter()` return.
889         Raises ValueError on no delimiters.
890         """
891         return max(v for k, v in self.delimiters.items() if k not in exclude)
892
893     def delimiter_count_with_priority(self, priority: int = 0) -> int:
894         """Return the number of delimiters with the given `priority`.
895
896         If no `priority` is passed, defaults to max priority on the line.
897         """
898         if not self.delimiters:
899             return 0
900
901         priority = priority or self.max_delimiter_priority()
902         return sum(1 for p in self.delimiters.values() if p == priority)
903
904     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
905         """In a for loop, or comprehension, the variables are often unpacks.
906
907         To avoid splitting on the comma in this situation, increase the depth of
908         tokens between `for` and `in`.
909         """
910         if leaf.type == token.NAME and leaf.value == "for":
911             self.depth += 1
912             self._for_loop_variable += 1
913             return True
914
915         return False
916
917     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
918         """See `maybe_increment_for_loop_variable` above for explanation."""
919         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
920             self.depth -= 1
921             self._for_loop_variable -= 1
922             return True
923
924         return False
925
926     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
927         """In a lambda expression, there might be more than one argument.
928
929         To avoid splitting on the comma in this situation, increase the depth of
930         tokens between `lambda` and `:`.
931         """
932         if leaf.type == token.NAME and leaf.value == "lambda":
933             self.depth += 1
934             self._lambda_arguments += 1
935             return True
936
937         return False
938
939     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
940         """See `maybe_increment_lambda_arguments` above for explanation."""
941         if self._lambda_arguments and leaf.type == token.COLON:
942             self.depth -= 1
943             self._lambda_arguments -= 1
944             return True
945
946         return False
947
948     def get_open_lsqb(self) -> Optional[Leaf]:
949         """Return the most recent opening square bracket (if any)."""
950         return self.bracket_match.get((self.depth - 1, token.RSQB))
951
952
953 @dataclass
954 class Line:
955     """Holds leaves and comments. Can be printed with `str(line)`."""
956
957     depth: int = 0
958     leaves: List[Leaf] = Factory(list)
959     comments: List[Tuple[Index, Leaf]] = Factory(list)
960     bracket_tracker: BracketTracker = Factory(BracketTracker)
961     inside_brackets: bool = False
962     should_explode: bool = False
963
964     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
965         """Add a new `leaf` to the end of the line.
966
967         Unless `preformatted` is True, the `leaf` will receive a new consistent
968         whitespace prefix and metadata applied by :class:`BracketTracker`.
969         Trailing commas are maybe removed, unpacked for loop variables are
970         demoted from being delimiters.
971
972         Inline comments are put aside.
973         """
974         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
975         if not has_value:
976             return
977
978         if token.COLON == leaf.type and self.is_class_paren_empty:
979             del self.leaves[-2:]
980         if self.leaves and not preformatted:
981             # Note: at this point leaf.prefix should be empty except for
982             # imports, for which we only preserve newlines.
983             leaf.prefix += whitespace(
984                 leaf, complex_subscript=self.is_complex_subscript(leaf)
985             )
986         if self.inside_brackets or not preformatted:
987             self.bracket_tracker.mark(leaf)
988             self.maybe_remove_trailing_comma(leaf)
989         if not self.append_comment(leaf):
990             self.leaves.append(leaf)
991
992     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
993         """Like :func:`append()` but disallow invalid standalone comment structure.
994
995         Raises ValueError when any `leaf` is appended after a standalone comment
996         or when a standalone comment is not the first leaf on the line.
997         """
998         if self.bracket_tracker.depth == 0:
999             if self.is_comment:
1000                 raise ValueError("cannot append to standalone comments")
1001
1002             if self.leaves and leaf.type == STANDALONE_COMMENT:
1003                 raise ValueError(
1004                     "cannot append standalone comments to a populated line"
1005                 )
1006
1007         self.append(leaf, preformatted=preformatted)
1008
1009     @property
1010     def is_comment(self) -> bool:
1011         """Is this line a standalone comment?"""
1012         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1013
1014     @property
1015     def is_decorator(self) -> bool:
1016         """Is this line a decorator?"""
1017         return bool(self) and self.leaves[0].type == token.AT
1018
1019     @property
1020     def is_import(self) -> bool:
1021         """Is this an import line?"""
1022         return bool(self) and is_import(self.leaves[0])
1023
1024     @property
1025     def is_class(self) -> bool:
1026         """Is this line a class definition?"""
1027         return (
1028             bool(self)
1029             and self.leaves[0].type == token.NAME
1030             and self.leaves[0].value == "class"
1031         )
1032
1033     @property
1034     def is_stub_class(self) -> bool:
1035         """Is this line a class definition with a body consisting only of "..."?"""
1036         return self.is_class and self.leaves[-3:] == [
1037             Leaf(token.DOT, ".") for _ in range(3)
1038         ]
1039
1040     @property
1041     def is_def(self) -> bool:
1042         """Is this a function definition? (Also returns True for async defs.)"""
1043         try:
1044             first_leaf = self.leaves[0]
1045         except IndexError:
1046             return False
1047
1048         try:
1049             second_leaf: Optional[Leaf] = self.leaves[1]
1050         except IndexError:
1051             second_leaf = None
1052         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1053             first_leaf.type == token.ASYNC
1054             and second_leaf is not None
1055             and second_leaf.type == token.NAME
1056             and second_leaf.value == "def"
1057         )
1058
1059     @property
1060     def is_class_paren_empty(self) -> bool:
1061         """Is this a class with no base classes but using parentheses?
1062
1063         Those are unnecessary and should be removed.
1064         """
1065         return (
1066             bool(self)
1067             and len(self.leaves) == 4
1068             and self.is_class
1069             and self.leaves[2].type == token.LPAR
1070             and self.leaves[2].value == "("
1071             and self.leaves[3].type == token.RPAR
1072             and self.leaves[3].value == ")"
1073         )
1074
1075     @property
1076     def is_triple_quoted_string(self) -> bool:
1077         """Is the line a triple quoted string?"""
1078         return (
1079             bool(self)
1080             and self.leaves[0].type == token.STRING
1081             and self.leaves[0].value.startswith(('"""', "'''"))
1082         )
1083
1084     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1085         """If so, needs to be split before emitting."""
1086         for leaf in self.leaves:
1087             if leaf.type == STANDALONE_COMMENT:
1088                 if leaf.bracket_depth <= depth_limit:
1089                     return True
1090
1091         return False
1092
1093     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1094         """Remove trailing comma if there is one and it's safe."""
1095         if not (
1096             self.leaves
1097             and self.leaves[-1].type == token.COMMA
1098             and closing.type in CLOSING_BRACKETS
1099         ):
1100             return False
1101
1102         if closing.type == token.RBRACE:
1103             self.remove_trailing_comma()
1104             return True
1105
1106         if closing.type == token.RSQB:
1107             comma = self.leaves[-1]
1108             if comma.parent and comma.parent.type == syms.listmaker:
1109                 self.remove_trailing_comma()
1110                 return True
1111
1112         # For parens let's check if it's safe to remove the comma.
1113         # Imports are always safe.
1114         if self.is_import:
1115             self.remove_trailing_comma()
1116             return True
1117
1118         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1119         # change a tuple into a different type by removing the comma.
1120         depth = closing.bracket_depth + 1
1121         commas = 0
1122         opening = closing.opening_bracket
1123         for _opening_index, leaf in enumerate(self.leaves):
1124             if leaf is opening:
1125                 break
1126
1127         else:
1128             return False
1129
1130         for leaf in self.leaves[_opening_index + 1 :]:
1131             if leaf is closing:
1132                 break
1133
1134             bracket_depth = leaf.bracket_depth
1135             if bracket_depth == depth and leaf.type == token.COMMA:
1136                 commas += 1
1137                 if leaf.parent and leaf.parent.type == syms.arglist:
1138                     commas += 1
1139                     break
1140
1141         if commas > 1:
1142             self.remove_trailing_comma()
1143             return True
1144
1145         return False
1146
1147     def append_comment(self, comment: Leaf) -> bool:
1148         """Add an inline or standalone comment to the line."""
1149         if (
1150             comment.type == STANDALONE_COMMENT
1151             and self.bracket_tracker.any_open_brackets()
1152         ):
1153             comment.prefix = ""
1154             return False
1155
1156         if comment.type != token.COMMENT:
1157             return False
1158
1159         after = len(self.leaves) - 1
1160         if after == -1:
1161             comment.type = STANDALONE_COMMENT
1162             comment.prefix = ""
1163             return False
1164
1165         else:
1166             self.comments.append((after, comment))
1167             return True
1168
1169     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1170         """Generate comments that should appear directly after `leaf`.
1171
1172         Provide a non-negative leaf `_index` to speed up the function.
1173         """
1174         if _index == -1:
1175             for _index, _leaf in enumerate(self.leaves):
1176                 if leaf is _leaf:
1177                     break
1178
1179             else:
1180                 return
1181
1182         for index, comment_after in self.comments:
1183             if _index == index:
1184                 yield comment_after
1185
1186     def remove_trailing_comma(self) -> None:
1187         """Remove the trailing comma and moves the comments attached to it."""
1188         comma_index = len(self.leaves) - 1
1189         for i in range(len(self.comments)):
1190             comment_index, comment = self.comments[i]
1191             if comment_index == comma_index:
1192                 self.comments[i] = (comma_index - 1, comment)
1193         self.leaves.pop()
1194
1195     def is_complex_subscript(self, leaf: Leaf) -> bool:
1196         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1197         open_lsqb = (
1198             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1199         )
1200         if open_lsqb is None:
1201             return False
1202
1203         subscript_start = open_lsqb.next_sibling
1204         if (
1205             isinstance(subscript_start, Node)
1206             and subscript_start.type == syms.subscriptlist
1207         ):
1208             subscript_start = child_towards(subscript_start, leaf)
1209         return subscript_start is not None and any(
1210             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1211         )
1212
1213     def __str__(self) -> str:
1214         """Render the line."""
1215         if not self:
1216             return "\n"
1217
1218         indent = "    " * self.depth
1219         leaves = iter(self.leaves)
1220         first = next(leaves)
1221         res = f"{first.prefix}{indent}{first.value}"
1222         for leaf in leaves:
1223             res += str(leaf)
1224         for _, comment in self.comments:
1225             res += str(comment)
1226         return res + "\n"
1227
1228     def __bool__(self) -> bool:
1229         """Return True if the line has leaves or comments."""
1230         return bool(self.leaves or self.comments)
1231
1232
1233 class UnformattedLines(Line):
1234     """Just like :class:`Line` but stores lines which aren't reformatted."""
1235
1236     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1237         """Just add a new `leaf` to the end of the lines.
1238
1239         The `preformatted` argument is ignored.
1240
1241         Keeps track of indentation `depth`, which is useful when the user
1242         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1243         """
1244         try:
1245             list(generate_comments(leaf))
1246         except FormatOn as f_on:
1247             self.leaves.append(f_on.leaf_from_consumed(leaf))
1248             raise
1249
1250         self.leaves.append(leaf)
1251         if leaf.type == token.INDENT:
1252             self.depth += 1
1253         elif leaf.type == token.DEDENT:
1254             self.depth -= 1
1255
1256     def __str__(self) -> str:
1257         """Render unformatted lines from leaves which were added with `append()`.
1258
1259         `depth` is not used for indentation in this case.
1260         """
1261         if not self:
1262             return "\n"
1263
1264         res = ""
1265         for leaf in self.leaves:
1266             res += str(leaf)
1267         return res
1268
1269     def append_comment(self, comment: Leaf) -> bool:
1270         """Not implemented in this class. Raises `NotImplementedError`."""
1271         raise NotImplementedError("Unformatted lines don't store comments separately.")
1272
1273     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1274         """Does nothing and returns False."""
1275         return False
1276
1277     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1278         """Does nothing and returns False."""
1279         return False
1280
1281
1282 @dataclass
1283 class EmptyLineTracker:
1284     """Provides a stateful method that returns the number of potential extra
1285     empty lines needed before and after the currently processed line.
1286
1287     Note: this tracker works on lines that haven't been split yet.  It assumes
1288     the prefix of the first leaf consists of optional newlines.  Those newlines
1289     are consumed by `maybe_empty_lines()` and included in the computation.
1290     """
1291
1292     is_pyi: bool = False
1293     previous_line: Optional[Line] = None
1294     previous_after: int = 0
1295     previous_defs: List[int] = Factory(list)
1296
1297     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1298         """Return the number of extra empty lines before and after the `current_line`.
1299
1300         This is for separating `def`, `async def` and `class` with extra empty
1301         lines (two on module-level).
1302         """
1303         if isinstance(current_line, UnformattedLines):
1304             return 0, 0
1305
1306         before, after = self._maybe_empty_lines(current_line)
1307         before -= self.previous_after
1308         self.previous_after = after
1309         self.previous_line = current_line
1310         return before, after
1311
1312     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1313         max_allowed = 1
1314         if current_line.depth == 0:
1315             max_allowed = 1 if self.is_pyi else 2
1316         if current_line.leaves:
1317             # Consume the first leaf's extra newlines.
1318             first_leaf = current_line.leaves[0]
1319             before = first_leaf.prefix.count("\n")
1320             before = min(before, max_allowed)
1321             first_leaf.prefix = ""
1322         else:
1323             before = 0
1324         depth = current_line.depth
1325         while self.previous_defs and self.previous_defs[-1] >= depth:
1326             self.previous_defs.pop()
1327             if self.is_pyi:
1328                 before = 0 if depth else 1
1329             else:
1330                 before = 1 if depth else 2
1331         is_decorator = current_line.is_decorator
1332         if is_decorator or current_line.is_def or current_line.is_class:
1333             if not is_decorator:
1334                 self.previous_defs.append(depth)
1335             if self.previous_line is None:
1336                 # Don't insert empty lines before the first line in the file.
1337                 return 0, 0
1338
1339             if self.previous_line.is_decorator:
1340                 return 0, 0
1341
1342             if self.previous_line.depth < current_line.depth and (
1343                 self.previous_line.is_class or self.previous_line.is_def
1344             ):
1345                 return 0, 0
1346
1347             if (
1348                 self.previous_line.is_comment
1349                 and self.previous_line.depth == current_line.depth
1350                 and before == 0
1351             ):
1352                 return 0, 0
1353
1354             if self.is_pyi:
1355                 if self.previous_line.depth > current_line.depth:
1356                     newlines = 1
1357                 elif current_line.is_class or self.previous_line.is_class:
1358                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1359                         newlines = 0
1360                     else:
1361                         newlines = 1
1362                 else:
1363                     newlines = 0
1364             else:
1365                 newlines = 2
1366             if current_line.depth and newlines:
1367                 newlines -= 1
1368             return newlines, 0
1369
1370         if (
1371             self.previous_line
1372             and self.previous_line.is_import
1373             and not current_line.is_import
1374             and depth == self.previous_line.depth
1375         ):
1376             return (before or 1), 0
1377
1378         if (
1379             self.previous_line
1380             and self.previous_line.is_class
1381             and current_line.is_triple_quoted_string
1382         ):
1383             return before, 1
1384
1385         return before, 0
1386
1387
1388 @dataclass
1389 class LineGenerator(Visitor[Line]):
1390     """Generates reformatted Line objects.  Empty lines are not emitted.
1391
1392     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1393     in ways that will no longer stringify to valid Python code on the tree.
1394     """
1395
1396     is_pyi: bool = False
1397     normalize_strings: bool = True
1398     current_line: Line = Factory(Line)
1399     remove_u_prefix: bool = False
1400
1401     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1402         """Generate a line.
1403
1404         If the line is empty, only emit if it makes sense.
1405         If the line is too long, split it first and then generate.
1406
1407         If any lines were generated, set up a new current_line.
1408         """
1409         if not self.current_line:
1410             if self.current_line.__class__ == type:
1411                 self.current_line.depth += indent
1412             else:
1413                 self.current_line = type(depth=self.current_line.depth + indent)
1414             return  # Line is empty, don't emit. Creating a new one unnecessary.
1415
1416         complete_line = self.current_line
1417         self.current_line = type(depth=complete_line.depth + indent)
1418         yield complete_line
1419
1420     def visit(self, node: LN) -> Iterator[Line]:
1421         """Main method to visit `node` and its children.
1422
1423         Yields :class:`Line` objects.
1424         """
1425         if isinstance(self.current_line, UnformattedLines):
1426             # File contained `# fmt: off`
1427             yield from self.visit_unformatted(node)
1428
1429         else:
1430             yield from super().visit(node)
1431
1432     def visit_default(self, node: LN) -> Iterator[Line]:
1433         """Default `visit_*()` implementation. Recurses to children of `node`."""
1434         if isinstance(node, Leaf):
1435             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1436             try:
1437                 for comment in generate_comments(node):
1438                     if any_open_brackets:
1439                         # any comment within brackets is subject to splitting
1440                         self.current_line.append(comment)
1441                     elif comment.type == token.COMMENT:
1442                         # regular trailing comment
1443                         self.current_line.append(comment)
1444                         yield from self.line()
1445
1446                     else:
1447                         # regular standalone comment
1448                         yield from self.line()
1449
1450                         self.current_line.append(comment)
1451                         yield from self.line()
1452
1453             except FormatOff as f_off:
1454                 f_off.trim_prefix(node)
1455                 yield from self.line(type=UnformattedLines)
1456                 yield from self.visit(node)
1457
1458             except FormatOn as f_on:
1459                 # This only happens here if somebody says "fmt: on" multiple
1460                 # times in a row.
1461                 f_on.trim_prefix(node)
1462                 yield from self.visit_default(node)
1463
1464             else:
1465                 normalize_prefix(node, inside_brackets=any_open_brackets)
1466                 if self.normalize_strings and node.type == token.STRING:
1467                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1468                     normalize_string_quotes(node)
1469                 if node.type not in WHITESPACE:
1470                     self.current_line.append(node)
1471         yield from super().visit_default(node)
1472
1473     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1474         """Increase indentation level, maybe yield a line."""
1475         # In blib2to3 INDENT never holds comments.
1476         yield from self.line(+1)
1477         yield from self.visit_default(node)
1478
1479     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1480         """Decrease indentation level, maybe yield a line."""
1481         # The current line might still wait for trailing comments.  At DEDENT time
1482         # there won't be any (they would be prefixes on the preceding NEWLINE).
1483         # Emit the line then.
1484         yield from self.line()
1485
1486         # While DEDENT has no value, its prefix may contain standalone comments
1487         # that belong to the current indentation level.  Get 'em.
1488         yield from self.visit_default(node)
1489
1490         # Finally, emit the dedent.
1491         yield from self.line(-1)
1492
1493     def visit_stmt(
1494         self, node: Node, keywords: Set[str], parens: Set[str]
1495     ) -> Iterator[Line]:
1496         """Visit a statement.
1497
1498         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1499         `def`, `with`, `class`, `assert` and assignments.
1500
1501         The relevant Python language `keywords` for a given statement will be
1502         NAME leaves within it. This methods puts those on a separate line.
1503
1504         `parens` holds a set of string leaf values immediately after which
1505         invisible parens should be put.
1506         """
1507         normalize_invisible_parens(node, parens_after=parens)
1508         for child in node.children:
1509             if child.type == token.NAME and child.value in keywords:  # type: ignore
1510                 yield from self.line()
1511
1512             yield from self.visit(child)
1513
1514     def visit_suite(self, node: Node) -> Iterator[Line]:
1515         """Visit a suite."""
1516         if self.is_pyi and is_stub_suite(node):
1517             yield from self.visit(node.children[2])
1518         else:
1519             yield from self.visit_default(node)
1520
1521     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1522         """Visit a statement without nested statements."""
1523         is_suite_like = node.parent and node.parent.type in STATEMENT
1524         if is_suite_like:
1525             if self.is_pyi and is_stub_body(node):
1526                 yield from self.visit_default(node)
1527             else:
1528                 yield from self.line(+1)
1529                 yield from self.visit_default(node)
1530                 yield from self.line(-1)
1531
1532         else:
1533             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1534                 yield from self.line()
1535             yield from self.visit_default(node)
1536
1537     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1538         """Visit `async def`, `async for`, `async with`."""
1539         yield from self.line()
1540
1541         children = iter(node.children)
1542         for child in children:
1543             yield from self.visit(child)
1544
1545             if child.type == token.ASYNC:
1546                 break
1547
1548         internal_stmt = next(children)
1549         for child in internal_stmt.children:
1550             yield from self.visit(child)
1551
1552     def visit_decorators(self, node: Node) -> Iterator[Line]:
1553         """Visit decorators."""
1554         for child in node.children:
1555             yield from self.line()
1556             yield from self.visit(child)
1557
1558     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1559         """Remove a semicolon and put the other statement on a separate line."""
1560         yield from self.line()
1561
1562     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1563         """End of file. Process outstanding comments and end with a newline."""
1564         yield from self.visit_default(leaf)
1565         yield from self.line()
1566
1567     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1568         """Used when file contained a `# fmt: off`."""
1569         if isinstance(node, Node):
1570             for child in node.children:
1571                 yield from self.visit(child)
1572
1573         else:
1574             try:
1575                 self.current_line.append(node)
1576             except FormatOn as f_on:
1577                 f_on.trim_prefix(node)
1578                 yield from self.line()
1579                 yield from self.visit(node)
1580
1581             if node.type == token.ENDMARKER:
1582                 # somebody decided not to put a final `# fmt: on`
1583                 yield from self.line()
1584
1585     def __attrs_post_init__(self) -> None:
1586         """You are in a twisty little maze of passages."""
1587         v = self.visit_stmt
1588         Ø: Set[str] = set()
1589         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1590         self.visit_if_stmt = partial(
1591             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1592         )
1593         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1594         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1595         self.visit_try_stmt = partial(
1596             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1597         )
1598         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1599         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1600         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1601         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1602         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1603         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1604         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1605         self.visit_async_funcdef = self.visit_async_stmt
1606         self.visit_decorated = self.visit_decorators
1607
1608
1609 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1610 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1611 OPENING_BRACKETS = set(BRACKET.keys())
1612 CLOSING_BRACKETS = set(BRACKET.values())
1613 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1614 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1615
1616
1617 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1618     """Return whitespace prefix if needed for the given `leaf`.
1619
1620     `complex_subscript` signals whether the given leaf is part of a subscription
1621     which has non-trivial arguments, like arithmetic expressions or function calls.
1622     """
1623     NO = ""
1624     SPACE = " "
1625     DOUBLESPACE = "  "
1626     t = leaf.type
1627     p = leaf.parent
1628     v = leaf.value
1629     if t in ALWAYS_NO_SPACE:
1630         return NO
1631
1632     if t == token.COMMENT:
1633         return DOUBLESPACE
1634
1635     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1636     if t == token.COLON and p.type not in {
1637         syms.subscript,
1638         syms.subscriptlist,
1639         syms.sliceop,
1640     }:
1641         return NO
1642
1643     prev = leaf.prev_sibling
1644     if not prev:
1645         prevp = preceding_leaf(p)
1646         if not prevp or prevp.type in OPENING_BRACKETS:
1647             return NO
1648
1649         if t == token.COLON:
1650             if prevp.type == token.COLON:
1651                 return NO
1652
1653             elif prevp.type != token.COMMA and not complex_subscript:
1654                 return NO
1655
1656             return SPACE
1657
1658         if prevp.type == token.EQUAL:
1659             if prevp.parent:
1660                 if prevp.parent.type in {
1661                     syms.arglist,
1662                     syms.argument,
1663                     syms.parameters,
1664                     syms.varargslist,
1665                 }:
1666                     return NO
1667
1668                 elif prevp.parent.type == syms.typedargslist:
1669                     # A bit hacky: if the equal sign has whitespace, it means we
1670                     # previously found it's a typed argument.  So, we're using
1671                     # that, too.
1672                     return prevp.prefix
1673
1674         elif prevp.type in STARS:
1675             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1676                 return NO
1677
1678         elif prevp.type == token.COLON:
1679             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1680                 return SPACE if complex_subscript else NO
1681
1682         elif (
1683             prevp.parent
1684             and prevp.parent.type == syms.factor
1685             and prevp.type in MATH_OPERATORS
1686         ):
1687             return NO
1688
1689         elif (
1690             prevp.type == token.RIGHTSHIFT
1691             and prevp.parent
1692             and prevp.parent.type == syms.shift_expr
1693             and prevp.prev_sibling
1694             and prevp.prev_sibling.type == token.NAME
1695             and prevp.prev_sibling.value == "print"  # type: ignore
1696         ):
1697             # Python 2 print chevron
1698             return NO
1699
1700     elif prev.type in OPENING_BRACKETS:
1701         return NO
1702
1703     if p.type in {syms.parameters, syms.arglist}:
1704         # untyped function signatures or calls
1705         if not prev or prev.type != token.COMMA:
1706             return NO
1707
1708     elif p.type == syms.varargslist:
1709         # lambdas
1710         if prev and prev.type != token.COMMA:
1711             return NO
1712
1713     elif p.type == syms.typedargslist:
1714         # typed function signatures
1715         if not prev:
1716             return NO
1717
1718         if t == token.EQUAL:
1719             if prev.type != syms.tname:
1720                 return NO
1721
1722         elif prev.type == token.EQUAL:
1723             # A bit hacky: if the equal sign has whitespace, it means we
1724             # previously found it's a typed argument.  So, we're using that, too.
1725             return prev.prefix
1726
1727         elif prev.type != token.COMMA:
1728             return NO
1729
1730     elif p.type == syms.tname:
1731         # type names
1732         if not prev:
1733             prevp = preceding_leaf(p)
1734             if not prevp or prevp.type != token.COMMA:
1735                 return NO
1736
1737     elif p.type == syms.trailer:
1738         # attributes and calls
1739         if t == token.LPAR or t == token.RPAR:
1740             return NO
1741
1742         if not prev:
1743             if t == token.DOT:
1744                 prevp = preceding_leaf(p)
1745                 if not prevp or prevp.type != token.NUMBER:
1746                     return NO
1747
1748             elif t == token.LSQB:
1749                 return NO
1750
1751         elif prev.type != token.COMMA:
1752             return NO
1753
1754     elif p.type == syms.argument:
1755         # single argument
1756         if t == token.EQUAL:
1757             return NO
1758
1759         if not prev:
1760             prevp = preceding_leaf(p)
1761             if not prevp or prevp.type == token.LPAR:
1762                 return NO
1763
1764         elif prev.type in {token.EQUAL} | STARS:
1765             return NO
1766
1767     elif p.type == syms.decorator:
1768         # decorators
1769         return NO
1770
1771     elif p.type == syms.dotted_name:
1772         if prev:
1773             return NO
1774
1775         prevp = preceding_leaf(p)
1776         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1777             return NO
1778
1779     elif p.type == syms.classdef:
1780         if t == token.LPAR:
1781             return NO
1782
1783         if prev and prev.type == token.LPAR:
1784             return NO
1785
1786     elif p.type in {syms.subscript, syms.sliceop}:
1787         # indexing
1788         if not prev:
1789             assert p.parent is not None, "subscripts are always parented"
1790             if p.parent.type == syms.subscriptlist:
1791                 return SPACE
1792
1793             return NO
1794
1795         elif not complex_subscript:
1796             return NO
1797
1798     elif p.type == syms.atom:
1799         if prev and t == token.DOT:
1800             # dots, but not the first one.
1801             return NO
1802
1803     elif p.type == syms.dictsetmaker:
1804         # dict unpacking
1805         if prev and prev.type == token.DOUBLESTAR:
1806             return NO
1807
1808     elif p.type in {syms.factor, syms.star_expr}:
1809         # unary ops
1810         if not prev:
1811             prevp = preceding_leaf(p)
1812             if not prevp or prevp.type in OPENING_BRACKETS:
1813                 return NO
1814
1815             prevp_parent = prevp.parent
1816             assert prevp_parent is not None
1817             if prevp.type == token.COLON and prevp_parent.type in {
1818                 syms.subscript,
1819                 syms.sliceop,
1820             }:
1821                 return NO
1822
1823             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1824                 return NO
1825
1826         elif t == token.NAME or t == token.NUMBER:
1827             return NO
1828
1829     elif p.type == syms.import_from:
1830         if t == token.DOT:
1831             if prev and prev.type == token.DOT:
1832                 return NO
1833
1834         elif t == token.NAME:
1835             if v == "import":
1836                 return SPACE
1837
1838             if prev and prev.type == token.DOT:
1839                 return NO
1840
1841     elif p.type == syms.sliceop:
1842         return NO
1843
1844     return SPACE
1845
1846
1847 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1848     """Return the first leaf that precedes `node`, if any."""
1849     while node:
1850         res = node.prev_sibling
1851         if res:
1852             if isinstance(res, Leaf):
1853                 return res
1854
1855             try:
1856                 return list(res.leaves())[-1]
1857
1858             except IndexError:
1859                 return None
1860
1861         node = node.parent
1862     return None
1863
1864
1865 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1866     """Return the child of `ancestor` that contains `descendant`."""
1867     node: Optional[LN] = descendant
1868     while node and node.parent != ancestor:
1869         node = node.parent
1870     return node
1871
1872
1873 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1874     """Return the priority of the `leaf` delimiter, given a line break after it.
1875
1876     The delimiter priorities returned here are from those delimiters that would
1877     cause a line break after themselves.
1878
1879     Higher numbers are higher priority.
1880     """
1881     if leaf.type == token.COMMA:
1882         return COMMA_PRIORITY
1883
1884     return 0
1885
1886
1887 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1888     """Return the priority of the `leaf` delimiter, given a line before after it.
1889
1890     The delimiter priorities returned here are from those delimiters that would
1891     cause a line break before themselves.
1892
1893     Higher numbers are higher priority.
1894     """
1895     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1896         # * and ** might also be MATH_OPERATORS but in this case they are not.
1897         # Don't treat them as a delimiter.
1898         return 0
1899
1900     if (
1901         leaf.type == token.DOT
1902         and leaf.parent
1903         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1904         and (previous is None or previous.type in CLOSING_BRACKETS)
1905     ):
1906         return DOT_PRIORITY
1907
1908     if (
1909         leaf.type in MATH_OPERATORS
1910         and leaf.parent
1911         and leaf.parent.type not in {syms.factor, syms.star_expr}
1912     ):
1913         return MATH_PRIORITIES[leaf.type]
1914
1915     if leaf.type in COMPARATORS:
1916         return COMPARATOR_PRIORITY
1917
1918     if (
1919         leaf.type == token.STRING
1920         and previous is not None
1921         and previous.type == token.STRING
1922     ):
1923         return STRING_PRIORITY
1924
1925     if leaf.type != token.NAME:
1926         return 0
1927
1928     if (
1929         leaf.value == "for"
1930         and leaf.parent
1931         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1932     ):
1933         return COMPREHENSION_PRIORITY
1934
1935     if (
1936         leaf.value == "if"
1937         and leaf.parent
1938         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1939     ):
1940         return COMPREHENSION_PRIORITY
1941
1942     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1943         return TERNARY_PRIORITY
1944
1945     if leaf.value == "is":
1946         return COMPARATOR_PRIORITY
1947
1948     if (
1949         leaf.value == "in"
1950         and leaf.parent
1951         and leaf.parent.type in {syms.comp_op, syms.comparison}
1952         and not (
1953             previous is not None
1954             and previous.type == token.NAME
1955             and previous.value == "not"
1956         )
1957     ):
1958         return COMPARATOR_PRIORITY
1959
1960     if (
1961         leaf.value == "not"
1962         and leaf.parent
1963         and leaf.parent.type == syms.comp_op
1964         and not (
1965             previous is not None
1966             and previous.type == token.NAME
1967             and previous.value == "is"
1968         )
1969     ):
1970         return COMPARATOR_PRIORITY
1971
1972     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1973         return LOGIC_PRIORITY
1974
1975     return 0
1976
1977
1978 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1979     """Clean the prefix of the `leaf` and generate comments from it, if any.
1980
1981     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1982     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1983     move because it does away with modifying the grammar to include all the
1984     possible places in which comments can be placed.
1985
1986     The sad consequence for us though is that comments don't "belong" anywhere.
1987     This is why this function generates simple parentless Leaf objects for
1988     comments.  We simply don't know what the correct parent should be.
1989
1990     No matter though, we can live without this.  We really only need to
1991     differentiate between inline and standalone comments.  The latter don't
1992     share the line with any code.
1993
1994     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1995     are emitted with a fake STANDALONE_COMMENT token identifier.
1996     """
1997     p = leaf.prefix
1998     if not p:
1999         return
2000
2001     if "#" not in p:
2002         return
2003
2004     consumed = 0
2005     nlines = 0
2006     for index, line in enumerate(p.split("\n")):
2007         consumed += len(line) + 1  # adding the length of the split '\n'
2008         line = line.lstrip()
2009         if not line:
2010             nlines += 1
2011         if not line.startswith("#"):
2012             continue
2013
2014         if index == 0 and leaf.type != token.ENDMARKER:
2015             comment_type = token.COMMENT  # simple trailing comment
2016         else:
2017             comment_type = STANDALONE_COMMENT
2018         comment = make_comment(line)
2019         yield Leaf(comment_type, comment, prefix="\n" * nlines)
2020
2021         if comment in {"# fmt: on", "# yapf: enable"}:
2022             raise FormatOn(consumed)
2023
2024         if comment in {"# fmt: off", "# yapf: disable"}:
2025             if comment_type == STANDALONE_COMMENT:
2026                 raise FormatOff(consumed)
2027
2028             prev = preceding_leaf(leaf)
2029             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
2030                 raise FormatOff(consumed)
2031
2032         nlines = 0
2033
2034
2035 def make_comment(content: str) -> str:
2036     """Return a consistently formatted comment from the given `content` string.
2037
2038     All comments (except for "##", "#!", "#:") should have a single space between
2039     the hash sign and the content.
2040
2041     If `content` didn't start with a hash sign, one is provided.
2042     """
2043     content = content.rstrip()
2044     if not content:
2045         return "#"
2046
2047     if content[0] == "#":
2048         content = content[1:]
2049     if content and content[0] not in " !:#":
2050         content = " " + content
2051     return "#" + content
2052
2053
2054 def split_line(
2055     line: Line, line_length: int, inner: bool = False, py36: bool = False
2056 ) -> Iterator[Line]:
2057     """Split a `line` into potentially many lines.
2058
2059     They should fit in the allotted `line_length` but might not be able to.
2060     `inner` signifies that there were a pair of brackets somewhere around the
2061     current `line`, possibly transitively. This means we can fallback to splitting
2062     by delimiters if the LHS/RHS don't yield any results.
2063
2064     If `py36` is True, splitting may generate syntax that is only compatible
2065     with Python 3.6 and later.
2066     """
2067     if isinstance(line, UnformattedLines) or line.is_comment:
2068         yield line
2069         return
2070
2071     line_str = str(line).strip("\n")
2072     if not line.should_explode and is_line_short_enough(
2073         line, line_length=line_length, line_str=line_str
2074     ):
2075         yield line
2076         return
2077
2078     split_funcs: List[SplitFunc]
2079     if line.is_def:
2080         split_funcs = [left_hand_split]
2081     else:
2082
2083         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2084             for omit in generate_trailers_to_omit(line, line_length):
2085                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2086                 if is_line_short_enough(lines[0], line_length=line_length):
2087                     yield from lines
2088                     return
2089
2090             # All splits failed, best effort split with no omits.
2091             # This mostly happens to multiline strings that are by definition
2092             # reported as not fitting a single line.
2093             yield from right_hand_split(line, py36)
2094
2095         if line.inside_brackets:
2096             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2097         else:
2098             split_funcs = [rhs]
2099     for split_func in split_funcs:
2100         # We are accumulating lines in `result` because we might want to abort
2101         # mission and return the original line in the end, or attempt a different
2102         # split altogether.
2103         result: List[Line] = []
2104         try:
2105             for l in split_func(line, py36):
2106                 if str(l).strip("\n") == line_str:
2107                     raise CannotSplit("Split function returned an unchanged result")
2108
2109                 result.extend(
2110                     split_line(l, line_length=line_length, inner=True, py36=py36)
2111                 )
2112         except CannotSplit as cs:
2113             continue
2114
2115         else:
2116             yield from result
2117             break
2118
2119     else:
2120         yield line
2121
2122
2123 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2124     """Split line into many lines, starting with the first matching bracket pair.
2125
2126     Note: this usually looks weird, only use this for function definitions.
2127     Prefer RHS otherwise.  This is why this function is not symmetrical with
2128     :func:`right_hand_split` which also handles optional parentheses.
2129     """
2130     head = Line(depth=line.depth)
2131     body = Line(depth=line.depth + 1, inside_brackets=True)
2132     tail = Line(depth=line.depth)
2133     tail_leaves: List[Leaf] = []
2134     body_leaves: List[Leaf] = []
2135     head_leaves: List[Leaf] = []
2136     current_leaves = head_leaves
2137     matching_bracket = None
2138     for leaf in line.leaves:
2139         if (
2140             current_leaves is body_leaves
2141             and leaf.type in CLOSING_BRACKETS
2142             and leaf.opening_bracket is matching_bracket
2143         ):
2144             current_leaves = tail_leaves if body_leaves else head_leaves
2145         current_leaves.append(leaf)
2146         if current_leaves is head_leaves:
2147             if leaf.type in OPENING_BRACKETS:
2148                 matching_bracket = leaf
2149                 current_leaves = body_leaves
2150     # Since body is a new indent level, remove spurious leading whitespace.
2151     if body_leaves:
2152         normalize_prefix(body_leaves[0], inside_brackets=True)
2153     # Build the new lines.
2154     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2155         for leaf in leaves:
2156             result.append(leaf, preformatted=True)
2157             for comment_after in line.comments_after(leaf):
2158                 result.append(comment_after, preformatted=True)
2159     bracket_split_succeeded_or_raise(head, body, tail)
2160     for result in (head, body, tail):
2161         if result:
2162             yield result
2163
2164
2165 def right_hand_split(
2166     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2167 ) -> Iterator[Line]:
2168     """Split line into many lines, starting with the last matching bracket pair.
2169
2170     If the split was by optional parentheses, attempt splitting without them, too.
2171     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2172     this split.
2173
2174     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2175     """
2176     head = Line(depth=line.depth)
2177     body = Line(depth=line.depth + 1, inside_brackets=True)
2178     tail = Line(depth=line.depth)
2179     tail_leaves: List[Leaf] = []
2180     body_leaves: List[Leaf] = []
2181     head_leaves: List[Leaf] = []
2182     current_leaves = tail_leaves
2183     opening_bracket = None
2184     closing_bracket = None
2185     for leaf in reversed(line.leaves):
2186         if current_leaves is body_leaves:
2187             if leaf is opening_bracket:
2188                 current_leaves = head_leaves if body_leaves else tail_leaves
2189         current_leaves.append(leaf)
2190         if current_leaves is tail_leaves:
2191             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2192                 opening_bracket = leaf.opening_bracket
2193                 closing_bracket = leaf
2194                 current_leaves = body_leaves
2195     tail_leaves.reverse()
2196     body_leaves.reverse()
2197     head_leaves.reverse()
2198     # Since body is a new indent level, remove spurious leading whitespace.
2199     if body_leaves:
2200         normalize_prefix(body_leaves[0], inside_brackets=True)
2201     if not head_leaves:
2202         # No `head` means the split failed. Either `tail` has all content or
2203         # the matching `opening_bracket` wasn't available on `line` anymore.
2204         raise CannotSplit("No brackets found")
2205
2206     # Build the new lines.
2207     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2208         for leaf in leaves:
2209             result.append(leaf, preformatted=True)
2210             for comment_after in line.comments_after(leaf):
2211                 result.append(comment_after, preformatted=True)
2212     bracket_split_succeeded_or_raise(head, body, tail)
2213     assert opening_bracket and closing_bracket
2214     if (
2215         # the opening bracket is an optional paren
2216         opening_bracket.type == token.LPAR
2217         and not opening_bracket.value
2218         # the closing bracket is an optional paren
2219         and closing_bracket.type == token.RPAR
2220         and not closing_bracket.value
2221         # there are no standalone comments in the body
2222         and not line.contains_standalone_comments(0)
2223         # and it's not an import (optional parens are the only thing we can split
2224         # on in this case; attempting a split without them is a waste of time)
2225         and not line.is_import
2226     ):
2227         omit = {id(closing_bracket), *omit}
2228         if can_omit_invisible_parens(body, line_length):
2229             try:
2230                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2231                 return
2232             except CannotSplit:
2233                 pass
2234
2235     ensure_visible(opening_bracket)
2236     ensure_visible(closing_bracket)
2237     body.should_explode = should_explode(body, opening_bracket)
2238     for result in (head, body, tail):
2239         if result:
2240             yield result
2241
2242
2243 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2244     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2245
2246     Do nothing otherwise.
2247
2248     A left- or right-hand split is based on a pair of brackets. Content before
2249     (and including) the opening bracket is left on one line, content inside the
2250     brackets is put on a separate line, and finally content starting with and
2251     following the closing bracket is put on a separate line.
2252
2253     Those are called `head`, `body`, and `tail`, respectively. If the split
2254     produced the same line (all content in `head`) or ended up with an empty `body`
2255     and the `tail` is just the closing bracket, then it's considered failed.
2256     """
2257     tail_len = len(str(tail).strip())
2258     if not body:
2259         if tail_len == 0:
2260             raise CannotSplit("Splitting brackets produced the same line")
2261
2262         elif tail_len < 3:
2263             raise CannotSplit(
2264                 f"Splitting brackets on an empty body to save "
2265                 f"{tail_len} characters is not worth it"
2266             )
2267
2268
2269 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2270     """Normalize prefix of the first leaf in every line returned by `split_func`.
2271
2272     This is a decorator over relevant split functions.
2273     """
2274
2275     @wraps(split_func)
2276     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2277         for l in split_func(line, py36):
2278             normalize_prefix(l.leaves[0], inside_brackets=True)
2279             yield l
2280
2281     return split_wrapper
2282
2283
2284 @dont_increase_indentation
2285 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2286     """Split according to delimiters of the highest priority.
2287
2288     If `py36` is True, the split will add trailing commas also in function
2289     signatures that contain `*` and `**`.
2290     """
2291     try:
2292         last_leaf = line.leaves[-1]
2293     except IndexError:
2294         raise CannotSplit("Line empty")
2295
2296     bt = line.bracket_tracker
2297     try:
2298         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2299     except ValueError:
2300         raise CannotSplit("No delimiters found")
2301
2302     if delimiter_priority == DOT_PRIORITY:
2303         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2304             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2305
2306     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2307     lowest_depth = sys.maxsize
2308     trailing_comma_safe = True
2309
2310     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2311         """Append `leaf` to current line or to new line if appending impossible."""
2312         nonlocal current_line
2313         try:
2314             current_line.append_safe(leaf, preformatted=True)
2315         except ValueError as ve:
2316             yield current_line
2317
2318             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2319             current_line.append(leaf)
2320
2321     for index, leaf in enumerate(line.leaves):
2322         yield from append_to_line(leaf)
2323
2324         for comment_after in line.comments_after(leaf, index):
2325             yield from append_to_line(comment_after)
2326
2327         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2328         if leaf.bracket_depth == lowest_depth and is_vararg(
2329             leaf, within=VARARGS_PARENTS
2330         ):
2331             trailing_comma_safe = trailing_comma_safe and py36
2332         leaf_priority = bt.delimiters.get(id(leaf))
2333         if leaf_priority == delimiter_priority:
2334             yield current_line
2335
2336             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2337     if current_line:
2338         if (
2339             trailing_comma_safe
2340             and delimiter_priority == COMMA_PRIORITY
2341             and current_line.leaves[-1].type != token.COMMA
2342             and current_line.leaves[-1].type != STANDALONE_COMMENT
2343         ):
2344             current_line.append(Leaf(token.COMMA, ","))
2345         yield current_line
2346
2347
2348 @dont_increase_indentation
2349 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2350     """Split standalone comments from the rest of the line."""
2351     if not line.contains_standalone_comments(0):
2352         raise CannotSplit("Line does not have any standalone comments")
2353
2354     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2355
2356     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2357         """Append `leaf` to current line or to new line if appending impossible."""
2358         nonlocal current_line
2359         try:
2360             current_line.append_safe(leaf, preformatted=True)
2361         except ValueError as ve:
2362             yield current_line
2363
2364             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2365             current_line.append(leaf)
2366
2367     for index, leaf in enumerate(line.leaves):
2368         yield from append_to_line(leaf)
2369
2370         for comment_after in line.comments_after(leaf, index):
2371             yield from append_to_line(comment_after)
2372
2373     if current_line:
2374         yield current_line
2375
2376
2377 def is_import(leaf: Leaf) -> bool:
2378     """Return True if the given leaf starts an import statement."""
2379     p = leaf.parent
2380     t = leaf.type
2381     v = leaf.value
2382     return bool(
2383         t == token.NAME
2384         and (
2385             (v == "import" and p and p.type == syms.import_name)
2386             or (v == "from" and p and p.type == syms.import_from)
2387         )
2388     )
2389
2390
2391 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2392     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2393     else.
2394
2395     Note: don't use backslashes for formatting or you'll lose your voting rights.
2396     """
2397     if not inside_brackets:
2398         spl = leaf.prefix.split("#")
2399         if "\\" not in spl[0]:
2400             nl_count = spl[-1].count("\n")
2401             if len(spl) > 1:
2402                 nl_count -= 1
2403             leaf.prefix = "\n" * nl_count
2404             return
2405
2406     leaf.prefix = ""
2407
2408
2409 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2410     """Make all string prefixes lowercase.
2411
2412     If remove_u_prefix is given, also removes any u prefix from the string.
2413
2414     Note: Mutates its argument.
2415     """
2416     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2417     assert match is not None, f"failed to match string {leaf.value!r}"
2418     orig_prefix = match.group(1)
2419     new_prefix = orig_prefix.lower()
2420     if remove_u_prefix:
2421         new_prefix = new_prefix.replace("u", "")
2422     leaf.value = f"{new_prefix}{match.group(2)}"
2423
2424
2425 def normalize_string_quotes(leaf: Leaf) -> None:
2426     """Prefer double quotes but only if it doesn't cause more escaping.
2427
2428     Adds or removes backslashes as appropriate. Doesn't parse and fix
2429     strings nested in f-strings (yet).
2430
2431     Note: Mutates its argument.
2432     """
2433     value = leaf.value.lstrip("furbFURB")
2434     if value[:3] == '"""':
2435         return
2436
2437     elif value[:3] == "'''":
2438         orig_quote = "'''"
2439         new_quote = '"""'
2440     elif value[0] == '"':
2441         orig_quote = '"'
2442         new_quote = "'"
2443     else:
2444         orig_quote = "'"
2445         new_quote = '"'
2446     first_quote_pos = leaf.value.find(orig_quote)
2447     if first_quote_pos == -1:
2448         return  # There's an internal error
2449
2450     prefix = leaf.value[:first_quote_pos]
2451     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2452     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2453     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2454     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2455     if "r" in prefix.casefold():
2456         if unescaped_new_quote.search(body):
2457             # There's at least one unescaped new_quote in this raw string
2458             # so converting is impossible
2459             return
2460
2461         # Do not introduce or remove backslashes in raw strings
2462         new_body = body
2463     else:
2464         # remove unnecessary quotes
2465         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2466         if body != new_body:
2467             # Consider the string without unnecessary quotes as the original
2468             body = new_body
2469             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2470         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2471         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2472     if new_quote == '"""' and new_body[-1] == '"':
2473         # edge case:
2474         new_body = new_body[:-1] + '\\"'
2475     orig_escape_count = body.count("\\")
2476     new_escape_count = new_body.count("\\")
2477     if new_escape_count > orig_escape_count:
2478         return  # Do not introduce more escaping
2479
2480     if new_escape_count == orig_escape_count and orig_quote == '"':
2481         return  # Prefer double quotes
2482
2483     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2484
2485
2486 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2487     """Make existing optional parentheses invisible or create new ones.
2488
2489     `parens_after` is a set of string leaf values immeditely after which parens
2490     should be put.
2491
2492     Standardizes on visible parentheses for single-element tuples, and keeps
2493     existing visible parentheses for other tuples and generator expressions.
2494     """
2495     try:
2496         list(generate_comments(node))
2497     except FormatOff:
2498         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2499
2500     check_lpar = False
2501     for index, child in enumerate(list(node.children)):
2502         if check_lpar:
2503             if child.type == syms.atom:
2504                 maybe_make_parens_invisible_in_atom(child)
2505             elif is_one_tuple(child):
2506                 # wrap child in visible parentheses
2507                 lpar = Leaf(token.LPAR, "(")
2508                 rpar = Leaf(token.RPAR, ")")
2509                 child.remove()
2510                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2511             elif node.type == syms.import_from:
2512                 # "import from" nodes store parentheses directly as part of
2513                 # the statement
2514                 if child.type == token.LPAR:
2515                     # make parentheses invisible
2516                     child.value = ""  # type: ignore
2517                     node.children[-1].value = ""  # type: ignore
2518                 elif child.type != token.STAR:
2519                     # insert invisible parentheses
2520                     node.insert_child(index, Leaf(token.LPAR, ""))
2521                     node.append_child(Leaf(token.RPAR, ""))
2522                 break
2523
2524             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2525                 # wrap child in invisible parentheses
2526                 lpar = Leaf(token.LPAR, "")
2527                 rpar = Leaf(token.RPAR, "")
2528                 index = child.remove() or 0
2529                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2530
2531         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2532
2533
2534 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2535     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2536     if (
2537         node.type != syms.atom
2538         or is_empty_tuple(node)
2539         or is_one_tuple(node)
2540         or is_yield(node)
2541         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2542     ):
2543         return False
2544
2545     first = node.children[0]
2546     last = node.children[-1]
2547     if first.type == token.LPAR and last.type == token.RPAR:
2548         # make parentheses invisible
2549         first.value = ""  # type: ignore
2550         last.value = ""  # type: ignore
2551         if len(node.children) > 1:
2552             maybe_make_parens_invisible_in_atom(node.children[1])
2553         return True
2554
2555     return False
2556
2557
2558 def is_empty_tuple(node: LN) -> bool:
2559     """Return True if `node` holds an empty tuple."""
2560     return (
2561         node.type == syms.atom
2562         and len(node.children) == 2
2563         and node.children[0].type == token.LPAR
2564         and node.children[1].type == token.RPAR
2565     )
2566
2567
2568 def is_one_tuple(node: LN) -> bool:
2569     """Return True if `node` holds a tuple with one element, with or without parens."""
2570     if node.type == syms.atom:
2571         if len(node.children) != 3:
2572             return False
2573
2574         lpar, gexp, rpar = node.children
2575         if not (
2576             lpar.type == token.LPAR
2577             and gexp.type == syms.testlist_gexp
2578             and rpar.type == token.RPAR
2579         ):
2580             return False
2581
2582         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2583
2584     return (
2585         node.type in IMPLICIT_TUPLE
2586         and len(node.children) == 2
2587         and node.children[1].type == token.COMMA
2588     )
2589
2590
2591 def is_yield(node: LN) -> bool:
2592     """Return True if `node` holds a `yield` or `yield from` expression."""
2593     if node.type == syms.yield_expr:
2594         return True
2595
2596     if node.type == token.NAME and node.value == "yield":  # type: ignore
2597         return True
2598
2599     if node.type != syms.atom:
2600         return False
2601
2602     if len(node.children) != 3:
2603         return False
2604
2605     lpar, expr, rpar = node.children
2606     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2607         return is_yield(expr)
2608
2609     return False
2610
2611
2612 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2613     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2614
2615     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2616     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2617     extended iterable unpacking (PEP 3132) and additional unpacking
2618     generalizations (PEP 448).
2619     """
2620     if leaf.type not in STARS or not leaf.parent:
2621         return False
2622
2623     p = leaf.parent
2624     if p.type == syms.star_expr:
2625         # Star expressions are also used as assignment targets in extended
2626         # iterable unpacking (PEP 3132).  See what its parent is instead.
2627         if not p.parent:
2628             return False
2629
2630         p = p.parent
2631
2632     return p.type in within
2633
2634
2635 def is_multiline_string(leaf: Leaf) -> bool:
2636     """Return True if `leaf` is a multiline string that actually spans many lines."""
2637     value = leaf.value.lstrip("furbFURB")
2638     return value[:3] in {'"""', "'''"} and "\n" in value
2639
2640
2641 def is_stub_suite(node: Node) -> bool:
2642     """Return True if `node` is a suite with a stub body."""
2643     if (
2644         len(node.children) != 4
2645         or node.children[0].type != token.NEWLINE
2646         or node.children[1].type != token.INDENT
2647         or node.children[3].type != token.DEDENT
2648     ):
2649         return False
2650
2651     return is_stub_body(node.children[2])
2652
2653
2654 def is_stub_body(node: LN) -> bool:
2655     """Return True if `node` is a simple statement containing an ellipsis."""
2656     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2657         return False
2658
2659     if len(node.children) != 2:
2660         return False
2661
2662     child = node.children[0]
2663     return (
2664         child.type == syms.atom
2665         and len(child.children) == 3
2666         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2667     )
2668
2669
2670 def max_delimiter_priority_in_atom(node: LN) -> int:
2671     """Return maximum delimiter priority inside `node`.
2672
2673     This is specific to atoms with contents contained in a pair of parentheses.
2674     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2675     """
2676     if node.type != syms.atom:
2677         return 0
2678
2679     first = node.children[0]
2680     last = node.children[-1]
2681     if not (first.type == token.LPAR and last.type == token.RPAR):
2682         return 0
2683
2684     bt = BracketTracker()
2685     for c in node.children[1:-1]:
2686         if isinstance(c, Leaf):
2687             bt.mark(c)
2688         else:
2689             for leaf in c.leaves():
2690                 bt.mark(leaf)
2691     try:
2692         return bt.max_delimiter_priority()
2693
2694     except ValueError:
2695         return 0
2696
2697
2698 def ensure_visible(leaf: Leaf) -> None:
2699     """Make sure parentheses are visible.
2700
2701     They could be invisible as part of some statements (see
2702     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2703     """
2704     if leaf.type == token.LPAR:
2705         leaf.value = "("
2706     elif leaf.type == token.RPAR:
2707         leaf.value = ")"
2708
2709
2710 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2711     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2712     if not (
2713         opening_bracket.parent
2714         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2715         and opening_bracket.value in "[{("
2716     ):
2717         return False
2718
2719     try:
2720         last_leaf = line.leaves[-1]
2721         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2722         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2723     except (IndexError, ValueError):
2724         return False
2725
2726     return max_priority == COMMA_PRIORITY
2727
2728
2729 def is_python36(node: Node) -> bool:
2730     """Return True if the current file is using Python 3.6+ features.
2731
2732     Currently looking for:
2733     - f-strings; and
2734     - trailing commas after * or ** in function signatures and calls.
2735     """
2736     for n in node.pre_order():
2737         if n.type == token.STRING:
2738             value_head = n.value[:2]  # type: ignore
2739             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2740                 return True
2741
2742         elif (
2743             n.type in {syms.typedargslist, syms.arglist}
2744             and n.children
2745             and n.children[-1].type == token.COMMA
2746         ):
2747             for ch in n.children:
2748                 if ch.type in STARS:
2749                     return True
2750
2751                 if ch.type == syms.argument:
2752                     for argch in ch.children:
2753                         if argch.type in STARS:
2754                             return True
2755
2756     return False
2757
2758
2759 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2760     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2761
2762     Brackets can be omitted if the entire trailer up to and including
2763     a preceding closing bracket fits in one line.
2764
2765     Yielded sets are cumulative (contain results of previous yields, too).  First
2766     set is empty.
2767     """
2768
2769     omit: Set[LeafID] = set()
2770     yield omit
2771
2772     length = 4 * line.depth
2773     opening_bracket = None
2774     closing_bracket = None
2775     optional_brackets: Set[LeafID] = set()
2776     inner_brackets: Set[LeafID] = set()
2777     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2778         length += leaf_length
2779         if length > line_length:
2780             break
2781
2782         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2783         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2784             break
2785
2786         optional_brackets.discard(id(leaf))
2787         if opening_bracket:
2788             if leaf is opening_bracket:
2789                 opening_bracket = None
2790             elif leaf.type in CLOSING_BRACKETS:
2791                 inner_brackets.add(id(leaf))
2792         elif leaf.type in CLOSING_BRACKETS:
2793             if not leaf.value:
2794                 optional_brackets.add(id(opening_bracket))
2795                 continue
2796
2797             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2798                 # Empty brackets would fail a split so treat them as "inner"
2799                 # brackets (e.g. only add them to the `omit` set if another
2800                 # pair of brackets was good enough.
2801                 inner_brackets.add(id(leaf))
2802                 continue
2803
2804             opening_bracket = leaf.opening_bracket
2805             if closing_bracket:
2806                 omit.add(id(closing_bracket))
2807                 omit.update(inner_brackets)
2808                 inner_brackets.clear()
2809                 yield omit
2810             closing_bracket = leaf
2811
2812
2813 def get_future_imports(node: Node) -> Set[str]:
2814     """Return a set of __future__ imports in the file."""
2815     imports = set()
2816     for child in node.children:
2817         if child.type != syms.simple_stmt:
2818             break
2819         first_child = child.children[0]
2820         if isinstance(first_child, Leaf):
2821             # Continue looking if we see a docstring; otherwise stop.
2822             if (
2823                 len(child.children) == 2
2824                 and first_child.type == token.STRING
2825                 and child.children[1].type == token.NEWLINE
2826             ):
2827                 continue
2828             else:
2829                 break
2830         elif first_child.type == syms.import_from:
2831             module_name = first_child.children[1]
2832             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2833                 break
2834             for import_from_child in first_child.children[3:]:
2835                 if isinstance(import_from_child, Leaf):
2836                     if import_from_child.type == token.NAME:
2837                         imports.add(import_from_child.value)
2838                 else:
2839                     assert import_from_child.type == syms.import_as_names
2840                     for leaf in import_from_child.children:
2841                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2842                             imports.add(leaf.value)
2843         else:
2844             break
2845     return imports
2846
2847
2848 def gen_python_files_in_dir(
2849     path: Path,
2850     root: Path,
2851     include: Pattern[str],
2852     exclude: Pattern[str],
2853     report: "Report",
2854 ) -> Iterator[Path]:
2855     """Generate all files under `path` whose paths are not excluded by the
2856     `exclude` regex, but are included by the `include` regex.
2857
2858     `report` is where output about exclusions goes.
2859     """
2860     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2861     for child in path.iterdir():
2862         normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2863         if child.is_dir():
2864             normalized_path += "/"
2865         exclude_match = exclude.search(normalized_path)
2866         if exclude_match and exclude_match.group(0):
2867             report.path_ignored(child, f"matches --exclude={exclude.pattern}")
2868             continue
2869
2870         if child.is_dir():
2871             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2872
2873         elif child.is_file():
2874             include_match = include.search(normalized_path)
2875             if include_match:
2876                 yield child
2877
2878
2879 def find_project_root(srcs: List[str]) -> Path:
2880     """Return a directory containing .git, .hg, or pyproject.toml.
2881
2882     That directory can be one of the directories passed in `srcs` or their
2883     common parent.
2884
2885     If no directory in the tree contains a marker that would specify it's the
2886     project root, the root of the file system is returned.
2887     """
2888     if not srcs:
2889         return Path("/").resolve()
2890
2891     common_base = min(Path(src).resolve() for src in srcs)
2892     if common_base.is_dir():
2893         # Append a fake file so `parents` below returns `common_base_dir`, too.
2894         common_base /= "fake-file"
2895     for directory in common_base.parents:
2896         if (directory / ".git").is_dir():
2897             return directory
2898
2899         if (directory / ".hg").is_dir():
2900             return directory
2901
2902         if (directory / "pyproject.toml").is_file():
2903             return directory
2904
2905     return directory
2906
2907
2908 @dataclass
2909 class Report:
2910     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2911
2912     check: bool = False
2913     quiet: bool = False
2914     verbose: bool = False
2915     change_count: int = 0
2916     same_count: int = 0
2917     failure_count: int = 0
2918
2919     def done(self, src: Path, changed: Changed) -> None:
2920         """Increment the counter for successful reformatting. Write out a message."""
2921         if changed is Changed.YES:
2922             reformatted = "would reformat" if self.check else "reformatted"
2923             if self.verbose or not self.quiet:
2924                 out(f"{reformatted} {src}")
2925             self.change_count += 1
2926         else:
2927             if self.verbose:
2928                 if changed is Changed.NO:
2929                     msg = f"{src} already well formatted, good job."
2930                 else:
2931                     msg = f"{src} wasn't modified on disk since last run."
2932                 out(msg, bold=False)
2933             self.same_count += 1
2934
2935     def failed(self, src: Path, message: str) -> None:
2936         """Increment the counter for failed reformatting. Write out a message."""
2937         err(f"error: cannot format {src}: {message}")
2938         self.failure_count += 1
2939
2940     def path_ignored(self, path: Path, message: str) -> None:
2941         if self.verbose:
2942             out(f"{path} ignored: {message}", bold=False)
2943
2944     @property
2945     def return_code(self) -> int:
2946         """Return the exit code that the app should use.
2947
2948         This considers the current state of changed files and failures:
2949         - if there were any failures, return 123;
2950         - if any files were changed and --check is being used, return 1;
2951         - otherwise return 0.
2952         """
2953         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2954         # 126 we have special returncodes reserved by the shell.
2955         if self.failure_count:
2956             return 123
2957
2958         elif self.change_count and self.check:
2959             return 1
2960
2961         return 0
2962
2963     def __str__(self) -> str:
2964         """Render a color report of the current state.
2965
2966         Use `click.unstyle` to remove colors.
2967         """
2968         if self.check:
2969             reformatted = "would be reformatted"
2970             unchanged = "would be left unchanged"
2971             failed = "would fail to reformat"
2972         else:
2973             reformatted = "reformatted"
2974             unchanged = "left unchanged"
2975             failed = "failed to reformat"
2976         report = []
2977         if self.change_count:
2978             s = "s" if self.change_count > 1 else ""
2979             report.append(
2980                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2981             )
2982         if self.same_count:
2983             s = "s" if self.same_count > 1 else ""
2984             report.append(f"{self.same_count} file{s} {unchanged}")
2985         if self.failure_count:
2986             s = "s" if self.failure_count > 1 else ""
2987             report.append(
2988                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2989             )
2990         return ", ".join(report) + "."
2991
2992
2993 def assert_equivalent(src: str, dst: str) -> None:
2994     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2995
2996     import ast
2997     import traceback
2998
2999     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
3000         """Simple visitor generating strings to compare ASTs by content."""
3001         yield f"{'  ' * depth}{node.__class__.__name__}("
3002
3003         for field in sorted(node._fields):
3004             try:
3005                 value = getattr(node, field)
3006             except AttributeError:
3007                 continue
3008
3009             yield f"{'  ' * (depth+1)}{field}="
3010
3011             if isinstance(value, list):
3012                 for item in value:
3013                     if isinstance(item, ast.AST):
3014                         yield from _v(item, depth + 2)
3015
3016             elif isinstance(value, ast.AST):
3017                 yield from _v(value, depth + 2)
3018
3019             else:
3020                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3021
3022         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3023
3024     try:
3025         src_ast = ast.parse(src)
3026     except Exception as exc:
3027         major, minor = sys.version_info[:2]
3028         raise AssertionError(
3029             f"cannot use --safe with this file; failed to parse source file "
3030             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3031             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3032         )
3033
3034     try:
3035         dst_ast = ast.parse(dst)
3036     except Exception as exc:
3037         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3038         raise AssertionError(
3039             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3040             f"Please report a bug on https://github.com/ambv/black/issues.  "
3041             f"This invalid output might be helpful: {log}"
3042         ) from None
3043
3044     src_ast_str = "\n".join(_v(src_ast))
3045     dst_ast_str = "\n".join(_v(dst_ast))
3046     if src_ast_str != dst_ast_str:
3047         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3048         raise AssertionError(
3049             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3050             f"the source.  "
3051             f"Please report a bug on https://github.com/ambv/black/issues.  "
3052             f"This diff might be helpful: {log}"
3053         ) from None
3054
3055
3056 def assert_stable(
3057     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3058 ) -> None:
3059     """Raise AssertionError if `dst` reformats differently the second time."""
3060     newdst = format_str(dst, line_length=line_length, mode=mode)
3061     if dst != newdst:
3062         log = dump_to_file(
3063             diff(src, dst, "source", "first pass"),
3064             diff(dst, newdst, "first pass", "second pass"),
3065         )
3066         raise AssertionError(
3067             f"INTERNAL ERROR: Black produced different code on the second pass "
3068             f"of the formatter.  "
3069             f"Please report a bug on https://github.com/ambv/black/issues.  "
3070             f"This diff might be helpful: {log}"
3071         ) from None
3072
3073
3074 def dump_to_file(*output: str) -> str:
3075     """Dump `output` to a temporary file. Return path to the file."""
3076     import tempfile
3077
3078     with tempfile.NamedTemporaryFile(
3079         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3080     ) as f:
3081         for lines in output:
3082             f.write(lines)
3083             if lines and lines[-1] != "\n":
3084                 f.write("\n")
3085     return f.name
3086
3087
3088 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3089     """Return a unified diff string between strings `a` and `b`."""
3090     import difflib
3091
3092     a_lines = [line + "\n" for line in a.split("\n")]
3093     b_lines = [line + "\n" for line in b.split("\n")]
3094     return "".join(
3095         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3096     )
3097
3098
3099 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3100     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3101     err("Aborted!")
3102     for task in tasks:
3103         task.cancel()
3104
3105
3106 def shutdown(loop: BaseEventLoop) -> None:
3107     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3108     try:
3109         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3110         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3111         if not to_cancel:
3112             return
3113
3114         for task in to_cancel:
3115             task.cancel()
3116         loop.run_until_complete(
3117             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3118         )
3119     finally:
3120         # `concurrent.futures.Future` objects cannot be cancelled once they
3121         # are already running. There might be some when the `shutdown()` happened.
3122         # Silence their logger's spew about the event loop being closed.
3123         cf_logger = logging.getLogger("concurrent.futures")
3124         cf_logger.setLevel(logging.CRITICAL)
3125         loop.close()
3126
3127
3128 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3129     """Replace `regex` with `replacement` twice on `original`.
3130
3131     This is used by string normalization to perform replaces on
3132     overlapping matches.
3133     """
3134     return regex.sub(replacement, regex.sub(replacement, original))
3135
3136
3137 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3138     """Like `reversed(enumerate(sequence))` if that were possible."""
3139     index = len(sequence) - 1
3140     for element in reversed(sequence):
3141         yield (index, element)
3142         index -= 1
3143
3144
3145 def enumerate_with_length(
3146     line: Line, reversed: bool = False
3147 ) -> Iterator[Tuple[Index, Leaf, int]]:
3148     """Return an enumeration of leaves with their length.
3149
3150     Stops prematurely on multiline strings and standalone comments.
3151     """
3152     op = cast(
3153         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3154         enumerate_reversed if reversed else enumerate,
3155     )
3156     for index, leaf in op(line.leaves):
3157         length = len(leaf.prefix) + len(leaf.value)
3158         if "\n" in leaf.value:
3159             return  # Multiline strings, we can't continue.
3160
3161         comment: Optional[Leaf]
3162         for comment in line.comments_after(leaf, index):
3163             length += len(comment.value)
3164
3165         yield index, leaf, length
3166
3167
3168 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3169     """Return True if `line` is no longer than `line_length`.
3170
3171     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3172     """
3173     if not line_str:
3174         line_str = str(line).strip("\n")
3175     return (
3176         len(line_str) <= line_length
3177         and "\n" not in line_str  # multiline strings
3178         and not line.contains_standalone_comments()
3179     )
3180
3181
3182 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3183     """Does `line` have a shape safe to reformat without optional parens around it?
3184
3185     Returns True for only a subset of potentially nice looking formattings but
3186     the point is to not return false positives that end up producing lines that
3187     are too long.
3188     """
3189     bt = line.bracket_tracker
3190     if not bt.delimiters:
3191         # Without delimiters the optional parentheses are useless.
3192         return True
3193
3194     max_priority = bt.max_delimiter_priority()
3195     if bt.delimiter_count_with_priority(max_priority) > 1:
3196         # With more than one delimiter of a kind the optional parentheses read better.
3197         return False
3198
3199     if max_priority == DOT_PRIORITY:
3200         # A single stranded method call doesn't require optional parentheses.
3201         return True
3202
3203     assert len(line.leaves) >= 2, "Stranded delimiter"
3204
3205     first = line.leaves[0]
3206     second = line.leaves[1]
3207     penultimate = line.leaves[-2]
3208     last = line.leaves[-1]
3209
3210     # With a single delimiter, omit if the expression starts or ends with
3211     # a bracket.
3212     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3213         remainder = False
3214         length = 4 * line.depth
3215         for _index, leaf, leaf_length in enumerate_with_length(line):
3216             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3217                 remainder = True
3218             if remainder:
3219                 length += leaf_length
3220                 if length > line_length:
3221                     break
3222
3223                 if leaf.type in OPENING_BRACKETS:
3224                     # There are brackets we can further split on.
3225                     remainder = False
3226
3227         else:
3228             # checked the entire string and line length wasn't exceeded
3229             if len(line.leaves) == _index + 1:
3230                 return True
3231
3232         # Note: we are not returning False here because a line might have *both*
3233         # a leading opening bracket and a trailing closing bracket.  If the
3234         # opening bracket doesn't match our rule, maybe the closing will.
3235
3236     if (
3237         last.type == token.RPAR
3238         or last.type == token.RBRACE
3239         or (
3240             # don't use indexing for omitting optional parentheses;
3241             # it looks weird
3242             last.type == token.RSQB
3243             and last.parent
3244             and last.parent.type != syms.trailer
3245         )
3246     ):
3247         if penultimate.type in OPENING_BRACKETS:
3248             # Empty brackets don't help.
3249             return False
3250
3251         if is_multiline_string(first):
3252             # Additional wrapping of a multiline string in this situation is
3253             # unnecessary.
3254             return True
3255
3256         length = 4 * line.depth
3257         seen_other_brackets = False
3258         for _index, leaf, leaf_length in enumerate_with_length(line):
3259             length += leaf_length
3260             if leaf is last.opening_bracket:
3261                 if seen_other_brackets or length <= line_length:
3262                     return True
3263
3264             elif leaf.type in OPENING_BRACKETS:
3265                 # There are brackets we can further split on.
3266                 seen_other_brackets = True
3267
3268     return False
3269
3270
3271 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3272     pyi = bool(mode & FileMode.PYI)
3273     py36 = bool(mode & FileMode.PYTHON36)
3274     return (
3275         CACHE_DIR
3276         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3277     )
3278
3279
3280 def read_cache(line_length: int, mode: FileMode) -> Cache:
3281     """Read the cache if it exists and is well formed.
3282
3283     If it is not well formed, the call to write_cache later should resolve the issue.
3284     """
3285     cache_file = get_cache_file(line_length, mode)
3286     if not cache_file.exists():
3287         return {}
3288
3289     with cache_file.open("rb") as fobj:
3290         try:
3291             cache: Cache = pickle.load(fobj)
3292         except pickle.UnpicklingError:
3293             return {}
3294
3295     return cache
3296
3297
3298 def get_cache_info(path: Path) -> CacheInfo:
3299     """Return the information used to check if a file is already formatted or not."""
3300     stat = path.stat()
3301     return stat.st_mtime, stat.st_size
3302
3303
3304 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3305     """Split an iterable of paths in `sources` into two sets.
3306
3307     The first contains paths of files that modified on disk or are not in the
3308     cache. The other contains paths to non-modified files.
3309     """
3310     todo, done = set(), set()
3311     for src in sources:
3312         src = src.resolve()
3313         if cache.get(src) != get_cache_info(src):
3314             todo.add(src)
3315         else:
3316             done.add(src)
3317     return todo, done
3318
3319
3320 def write_cache(
3321     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3322 ) -> None:
3323     """Update the cache file."""
3324     cache_file = get_cache_file(line_length, mode)
3325     try:
3326         if not CACHE_DIR.exists():
3327             CACHE_DIR.mkdir(parents=True)
3328         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3329         with cache_file.open("wb") as fobj:
3330             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3331     except OSError:
3332         pass
3333
3334
3335 if __name__ == "__main__":
3336     main()