]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Nits
[etc/vim.git] / black.py
1 import asyncio
2 from asyncio.base_events import BaseEventLoop
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from enum import Enum, Flag
5 from functools import partial, wraps
6 import io
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import pickle
13 import re
14 import signal
15 import sys
16 import tokenize
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Pattern,
28     Sequence,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34     cast,
35 )
36
37 from appdirs import user_cache_dir
38 from attr import dataclass, Factory
39 import click
40
41 # lib2to3 fork
42 from blib2to3.pytree import Node, Leaf, type_repr
43 from blib2to3 import pygram, pytree
44 from blib2to3.pgen2 import driver, token
45 from blib2to3.pgen2.parse import ParseError
46
47
48 __version__ = "18.5b1"
49 DEFAULT_LINE_LENGTH = 88
50 DEFAULT_EXCLUDES = (
51     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
52 )
53 DEFAULT_INCLUDES = r"\.pyi?$"
54 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
55
56
57 # types
58 FileContent = str
59 Encoding = str
60 NewLine = str
61 Depth = int
62 NodeType = int
63 LeafID = int
64 Priority = int
65 Index = int
66 LN = Union[Leaf, Node]
67 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
68 Timestamp = float
69 FileSize = int
70 CacheInfo = Tuple[Timestamp, FileSize]
71 Cache = Dict[Path, CacheInfo]
72 out = partial(click.secho, bold=True, err=True)
73 err = partial(click.secho, fg="red", err=True)
74
75 pygram.initialize(CACHE_DIR)
76 syms = pygram.python_symbols
77
78
79 class NothingChanged(UserWarning):
80     """Raised by :func:`format_file` when reformatted code is the same as source."""
81
82
83 class CannotSplit(Exception):
84     """A readable split that fits the allotted line length is impossible.
85
86     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
87     :func:`delimiter_split`.
88     """
89
90
91 class FormatError(Exception):
92     """Base exception for `# fmt: on` and `# fmt: off` handling.
93
94     It holds the number of bytes of the prefix consumed before the format
95     control comment appeared.
96     """
97
98     def __init__(self, consumed: int) -> None:
99         super().__init__(consumed)
100         self.consumed = consumed
101
102     def trim_prefix(self, leaf: Leaf) -> None:
103         leaf.prefix = leaf.prefix[self.consumed :]
104
105     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
106         """Returns a new Leaf from the consumed part of the prefix."""
107         unformatted_prefix = leaf.prefix[: self.consumed]
108         return Leaf(token.NEWLINE, unformatted_prefix)
109
110
111 class FormatOn(FormatError):
112     """Found a comment like `# fmt: on` in the file."""
113
114
115 class FormatOff(FormatError):
116     """Found a comment like `# fmt: off` in the file."""
117
118
119 class WriteBack(Enum):
120     NO = 0
121     YES = 1
122     DIFF = 2
123
124     @classmethod
125     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
126         if check and not diff:
127             return cls.NO
128
129         return cls.DIFF if diff else cls.YES
130
131
132 class Changed(Enum):
133     NO = 0
134     CACHED = 1
135     YES = 2
136
137
138 class FileMode(Flag):
139     AUTO_DETECT = 0
140     PYTHON36 = 1
141     PYI = 2
142     NO_STRING_NORMALIZATION = 4
143
144     @classmethod
145     def from_configuration(
146         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
147     ) -> "FileMode":
148         mode = cls.AUTO_DETECT
149         if py36:
150             mode |= cls.PYTHON36
151         if pyi:
152             mode |= cls.PYI
153         if skip_string_normalization:
154             mode |= cls.NO_STRING_NORMALIZATION
155         return mode
156
157
158 @click.command()
159 @click.option(
160     "-l",
161     "--line-length",
162     type=int,
163     default=DEFAULT_LINE_LENGTH,
164     help="How many character per line to allow.",
165     show_default=True,
166 )
167 @click.option(
168     "--py36",
169     is_flag=True,
170     help=(
171         "Allow using Python 3.6-only syntax on all input files.  This will put "
172         "trailing commas in function signatures and calls also after *args and "
173         "**kwargs.  [default: per-file auto-detection]"
174     ),
175 )
176 @click.option(
177     "--pyi",
178     is_flag=True,
179     help=(
180         "Format all input files like typing stubs regardless of file extension "
181         "(useful when piping source on standard input)."
182     ),
183 )
184 @click.option(
185     "-S",
186     "--skip-string-normalization",
187     is_flag=True,
188     help="Don't normalize string quotes or prefixes.",
189 )
190 @click.option(
191     "--check",
192     is_flag=True,
193     help=(
194         "Don't write the files back, just return the status.  Return code 0 "
195         "means nothing would change.  Return code 1 means some files would be "
196         "reformatted.  Return code 123 means there was an internal error."
197     ),
198 )
199 @click.option(
200     "--diff",
201     is_flag=True,
202     help="Don't write the files back, just output a diff for each file on stdout.",
203 )
204 @click.option(
205     "--fast/--safe",
206     is_flag=True,
207     help="If --fast given, skip temporary sanity checks. [default: --safe]",
208 )
209 @click.option(
210     "--include",
211     type=str,
212     default=DEFAULT_INCLUDES,
213     help=(
214         "A regular expression that matches files and directories that should be "
215         "included on recursive searches.  An empty value means all files are "
216         "included regardless of the name.  Use forward slashes for directories on "
217         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
218         "later."
219     ),
220     show_default=True,
221 )
222 @click.option(
223     "--exclude",
224     type=str,
225     default=DEFAULT_EXCLUDES,
226     help=(
227         "A regular expression that matches files and directories that should be "
228         "excluded on recursive searches.  An empty value means no paths are excluded. "
229         "Use forward slashes for directories on all platforms (Windows, too).  "
230         "Exclusions are calculated first, inclusions later."
231     ),
232     show_default=True,
233 )
234 @click.option(
235     "-q",
236     "--quiet",
237     is_flag=True,
238     help=(
239         "Don't emit non-error messages to stderr. Errors are still emitted, "
240         "silence those with 2>/dev/null."
241     ),
242 )
243 @click.option(
244     "-v",
245     "--verbose",
246     is_flag=True,
247     help=(
248         "Also emit messages to stderr about files that were not changed or were "
249         "ignored due to --exclude=."
250     ),
251 )
252 @click.version_option(version=__version__)
253 @click.argument(
254     "src",
255     nargs=-1,
256     type=click.Path(
257         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
258     ),
259 )
260 @click.pass_context
261 def main(
262     ctx: click.Context,
263     line_length: int,
264     check: bool,
265     diff: bool,
266     fast: bool,
267     pyi: bool,
268     py36: bool,
269     skip_string_normalization: bool,
270     quiet: bool,
271     verbose: bool,
272     include: str,
273     exclude: str,
274     src: List[str],
275 ) -> None:
276     """The uncompromising code formatter."""
277     write_back = WriteBack.from_configuration(check=check, diff=diff)
278     mode = FileMode.from_configuration(
279         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
280     )
281     report = Report(check=check, quiet=quiet, verbose=verbose)
282     sources: Set[Path] = set()
283     try:
284         include_regex = re.compile(include)
285     except re.error:
286         err(f"Invalid regular expression for include given: {include!r}")
287         ctx.exit(2)
288     try:
289         exclude_regex = re.compile(exclude)
290     except re.error:
291         err(f"Invalid regular expression for exclude given: {exclude!r}")
292         ctx.exit(2)
293     root = find_project_root(src)
294     for s in src:
295         p = Path(s)
296         if p.is_dir():
297             sources.update(
298                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
299             )
300         elif p.is_file() or s == "-":
301             # if a file was explicitly given, we don't care about its extension
302             sources.add(p)
303         else:
304             err(f"invalid path: {s}")
305     if len(sources) == 0:
306         if verbose or not quiet:
307             out("No paths given. Nothing to do 😴")
308         ctx.exit(0)
309         return
310
311     elif len(sources) == 1:
312         reformat_one(
313             src=sources.pop(),
314             line_length=line_length,
315             fast=fast,
316             write_back=write_back,
317             mode=mode,
318             report=report,
319         )
320     else:
321         loop = asyncio.get_event_loop()
322         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
323         try:
324             loop.run_until_complete(
325                 schedule_formatting(
326                     sources=sources,
327                     line_length=line_length,
328                     fast=fast,
329                     write_back=write_back,
330                     mode=mode,
331                     report=report,
332                     loop=loop,
333                     executor=executor,
334                 )
335             )
336         finally:
337             shutdown(loop)
338     if verbose or not quiet:
339         out("All done! ✨ 🍰 ✨")
340         click.echo(str(report))
341     ctx.exit(report.return_code)
342
343
344 def reformat_one(
345     src: Path,
346     line_length: int,
347     fast: bool,
348     write_back: WriteBack,
349     mode: FileMode,
350     report: "Report",
351 ) -> None:
352     """Reformat a single file under `src` without spawning child processes.
353
354     If `quiet` is True, non-error messages are not output. `line_length`,
355     `write_back`, `fast` and `pyi` options are passed to
356     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
357     """
358     try:
359         changed = Changed.NO
360         if not src.is_file() and str(src) == "-":
361             if format_stdin_to_stdout(
362                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
363             ):
364                 changed = Changed.YES
365         else:
366             cache: Cache = {}
367             if write_back != WriteBack.DIFF:
368                 cache = read_cache(line_length, mode)
369                 res_src = src.resolve()
370                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
371                     changed = Changed.CACHED
372             if changed is not Changed.CACHED and format_file_in_place(
373                 src,
374                 line_length=line_length,
375                 fast=fast,
376                 write_back=write_back,
377                 mode=mode,
378             ):
379                 changed = Changed.YES
380             if write_back == WriteBack.YES and changed is not Changed.NO:
381                 write_cache(cache, [src], line_length, mode)
382         report.done(src, changed)
383     except Exception as exc:
384         report.failed(src, str(exc))
385
386
387 async def schedule_formatting(
388     sources: Set[Path],
389     line_length: int,
390     fast: bool,
391     write_back: WriteBack,
392     mode: FileMode,
393     report: "Report",
394     loop: BaseEventLoop,
395     executor: Executor,
396 ) -> None:
397     """Run formatting of `sources` in parallel using the provided `executor`.
398
399     (Use ProcessPoolExecutors for actual parallelism.)
400
401     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
402     :func:`format_file_in_place`.
403     """
404     cache: Cache = {}
405     if write_back != WriteBack.DIFF:
406         cache = read_cache(line_length, mode)
407         sources, cached = filter_cached(cache, sources)
408         for src in sorted(cached):
409             report.done(src, Changed.CACHED)
410     cancelled = []
411     formatted = []
412     if sources:
413         lock = None
414         if write_back == WriteBack.DIFF:
415             # For diff output, we need locks to ensure we don't interleave output
416             # from different processes.
417             manager = Manager()
418             lock = manager.Lock()
419         tasks = {
420             loop.run_in_executor(
421                 executor,
422                 format_file_in_place,
423                 src,
424                 line_length,
425                 fast,
426                 write_back,
427                 mode,
428                 lock,
429             ): src
430             for src in sorted(sources)
431         }
432         pending: Iterable[asyncio.Task] = tasks.keys()
433         try:
434             loop.add_signal_handler(signal.SIGINT, cancel, pending)
435             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
436         except NotImplementedError:
437             # There are no good alternatives for these on Windows
438             pass
439         while pending:
440             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
441             for task in done:
442                 src = tasks.pop(task)
443                 if task.cancelled():
444                     cancelled.append(task)
445                 elif task.exception():
446                     report.failed(src, str(task.exception()))
447                 else:
448                     formatted.append(src)
449                     report.done(src, Changed.YES if task.result() else Changed.NO)
450     if cancelled:
451         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
452     if write_back == WriteBack.YES and formatted:
453         write_cache(cache, formatted, line_length, mode)
454
455
456 def format_file_in_place(
457     src: Path,
458     line_length: int,
459     fast: bool,
460     write_back: WriteBack = WriteBack.NO,
461     mode: FileMode = FileMode.AUTO_DETECT,
462     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
463 ) -> bool:
464     """Format file under `src` path. Return True if changed.
465
466     If `write_back` is True, write reformatted code back to stdout.
467     `line_length` and `fast` options are passed to :func:`format_file_contents`.
468     """
469     if src.suffix == ".pyi":
470         mode |= FileMode.PYI
471
472     with open(src, "rb") as buf:
473         src_contents, encoding, newline = decode_bytes(buf.read())
474     try:
475         dst_contents = format_file_contents(
476             src_contents, line_length=line_length, fast=fast, mode=mode
477         )
478     except NothingChanged:
479         return False
480
481     if write_back == write_back.YES:
482         with open(src, "w", encoding=encoding, newline=newline) as f:
483             f.write(dst_contents)
484     elif write_back == write_back.DIFF:
485         src_name = f"{src}  (original)"
486         dst_name = f"{src}  (formatted)"
487         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
488         if lock:
489             lock.acquire()
490         try:
491             f = io.TextIOWrapper(
492                 sys.stdout.buffer,
493                 encoding=encoding,
494                 newline=newline,
495                 write_through=True,
496             )
497             f.write(diff_contents)
498             f.detach()
499         finally:
500             if lock:
501                 lock.release()
502     return True
503
504
505 def format_stdin_to_stdout(
506     line_length: int,
507     fast: bool,
508     write_back: WriteBack = WriteBack.NO,
509     mode: FileMode = FileMode.AUTO_DETECT,
510 ) -> bool:
511     """Format file on stdin. Return True if changed.
512
513     If `write_back` is True, write reformatted code back to stdout.
514     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
515     :func:`format_file_contents`.
516     """
517     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
518     dst = src
519     try:
520         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
521         return True
522
523     except NothingChanged:
524         return False
525
526     finally:
527         f = io.TextIOWrapper(
528             sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
529         )
530         if write_back == WriteBack.YES:
531             f.write(dst)
532         elif write_back == WriteBack.DIFF:
533             src_name = "<stdin>  (original)"
534             dst_name = "<stdin>  (formatted)"
535             f.write(diff(src, dst, src_name, dst_name))
536         f.detach()
537
538
539 def format_file_contents(
540     src_contents: str,
541     *,
542     line_length: int,
543     fast: bool,
544     mode: FileMode = FileMode.AUTO_DETECT,
545 ) -> FileContent:
546     """Reformat contents a file and return new contents.
547
548     If `fast` is False, additionally confirm that the reformatted code is
549     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
550     `line_length` is passed to :func:`format_str`.
551     """
552     if src_contents.strip() == "":
553         raise NothingChanged
554
555     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
556     if src_contents == dst_contents:
557         raise NothingChanged
558
559     if not fast:
560         assert_equivalent(src_contents, dst_contents)
561         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
562     return dst_contents
563
564
565 def format_str(
566     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
567 ) -> FileContent:
568     """Reformat a string and return new contents.
569
570     `line_length` determines how many characters per line are allowed.
571     """
572     src_node = lib2to3_parse(src_contents)
573     dst_contents = ""
574     future_imports = get_future_imports(src_node)
575     is_pyi = bool(mode & FileMode.PYI)
576     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
577     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
578     lines = LineGenerator(
579         remove_u_prefix=py36 or "unicode_literals" in future_imports,
580         is_pyi=is_pyi,
581         normalize_strings=normalize_strings,
582     )
583     elt = EmptyLineTracker(is_pyi=is_pyi)
584     empty_line = Line()
585     after = 0
586     for current_line in lines.visit(src_node):
587         for _ in range(after):
588             dst_contents += str(empty_line)
589         before, after = elt.maybe_empty_lines(current_line)
590         for _ in range(before):
591             dst_contents += str(empty_line)
592         for line in split_line(current_line, line_length=line_length, py36=py36):
593             dst_contents += str(line)
594     return dst_contents
595
596
597 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
598     """Return a tuple of (decoded_contents, encoding, newline).
599
600     `newline` is either CRLF or LF but `decoded_contents` is decoded with
601     universal newlines (i.e. only contains LF).
602     """
603     srcbuf = io.BytesIO(src)
604     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
605     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
606     srcbuf.seek(0)
607     with io.TextIOWrapper(srcbuf, encoding) as tiow:
608         return tiow.read(), encoding, newline
609
610
611 GRAMMARS = [
612     pygram.python_grammar_no_print_statement_no_exec_statement,
613     pygram.python_grammar_no_print_statement,
614     pygram.python_grammar,
615 ]
616
617
618 def lib2to3_parse(src_txt: str) -> Node:
619     """Given a string with source, return the lib2to3 Node."""
620     grammar = pygram.python_grammar_no_print_statement
621     if src_txt[-1] != "\n":
622         src_txt += "\n"
623     for grammar in GRAMMARS:
624         drv = driver.Driver(grammar, pytree.convert)
625         try:
626             result = drv.parse_string(src_txt, True)
627             break
628
629         except ParseError as pe:
630             lineno, column = pe.context[1]
631             lines = src_txt.splitlines()
632             try:
633                 faulty_line = lines[lineno - 1]
634             except IndexError:
635                 faulty_line = "<line number missing in source>"
636             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
637     else:
638         raise exc from None
639
640     if isinstance(result, Leaf):
641         result = Node(syms.file_input, [result])
642     return result
643
644
645 def lib2to3_unparse(node: Node) -> str:
646     """Given a lib2to3 node, return its string representation."""
647     code = str(node)
648     return code
649
650
651 T = TypeVar("T")
652
653
654 class Visitor(Generic[T]):
655     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
656
657     def visit(self, node: LN) -> Iterator[T]:
658         """Main method to visit `node` and its children.
659
660         It tries to find a `visit_*()` method for the given `node.type`, like
661         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
662         If no dedicated `visit_*()` method is found, chooses `visit_default()`
663         instead.
664
665         Then yields objects of type `T` from the selected visitor.
666         """
667         if node.type < 256:
668             name = token.tok_name[node.type]
669         else:
670             name = type_repr(node.type)
671         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
672
673     def visit_default(self, node: LN) -> Iterator[T]:
674         """Default `visit_*()` implementation. Recurses to children of `node`."""
675         if isinstance(node, Node):
676             for child in node.children:
677                 yield from self.visit(child)
678
679
680 @dataclass
681 class DebugVisitor(Visitor[T]):
682     tree_depth: int = 0
683
684     def visit_default(self, node: LN) -> Iterator[T]:
685         indent = " " * (2 * self.tree_depth)
686         if isinstance(node, Node):
687             _type = type_repr(node.type)
688             out(f"{indent}{_type}", fg="yellow")
689             self.tree_depth += 1
690             for child in node.children:
691                 yield from self.visit(child)
692
693             self.tree_depth -= 1
694             out(f"{indent}/{_type}", fg="yellow", bold=False)
695         else:
696             _type = token.tok_name.get(node.type, str(node.type))
697             out(f"{indent}{_type}", fg="blue", nl=False)
698             if node.prefix:
699                 # We don't have to handle prefixes for `Node` objects since
700                 # that delegates to the first child anyway.
701                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
702             out(f" {node.value!r}", fg="blue", bold=False)
703
704     @classmethod
705     def show(cls, code: str) -> None:
706         """Pretty-print the lib2to3 AST of a given string of `code`.
707
708         Convenience method for debugging.
709         """
710         v: DebugVisitor[None] = DebugVisitor()
711         list(v.visit(lib2to3_parse(code)))
712
713
714 KEYWORDS = set(keyword.kwlist)
715 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
716 FLOW_CONTROL = {"return", "raise", "break", "continue"}
717 STATEMENT = {
718     syms.if_stmt,
719     syms.while_stmt,
720     syms.for_stmt,
721     syms.try_stmt,
722     syms.except_clause,
723     syms.with_stmt,
724     syms.funcdef,
725     syms.classdef,
726 }
727 STANDALONE_COMMENT = 153
728 LOGIC_OPERATORS = {"and", "or"}
729 COMPARATORS = {
730     token.LESS,
731     token.GREATER,
732     token.EQEQUAL,
733     token.NOTEQUAL,
734     token.LESSEQUAL,
735     token.GREATEREQUAL,
736 }
737 MATH_OPERATORS = {
738     token.VBAR,
739     token.CIRCUMFLEX,
740     token.AMPER,
741     token.LEFTSHIFT,
742     token.RIGHTSHIFT,
743     token.PLUS,
744     token.MINUS,
745     token.STAR,
746     token.SLASH,
747     token.DOUBLESLASH,
748     token.PERCENT,
749     token.AT,
750     token.TILDE,
751     token.DOUBLESTAR,
752 }
753 STARS = {token.STAR, token.DOUBLESTAR}
754 VARARGS_PARENTS = {
755     syms.arglist,
756     syms.argument,  # double star in arglist
757     syms.trailer,  # single argument to call
758     syms.typedargslist,
759     syms.varargslist,  # lambdas
760 }
761 UNPACKING_PARENTS = {
762     syms.atom,  # single element of a list or set literal
763     syms.dictsetmaker,
764     syms.listmaker,
765     syms.testlist_gexp,
766 }
767 TEST_DESCENDANTS = {
768     syms.test,
769     syms.lambdef,
770     syms.or_test,
771     syms.and_test,
772     syms.not_test,
773     syms.comparison,
774     syms.star_expr,
775     syms.expr,
776     syms.xor_expr,
777     syms.and_expr,
778     syms.shift_expr,
779     syms.arith_expr,
780     syms.trailer,
781     syms.term,
782     syms.power,
783 }
784 ASSIGNMENTS = {
785     "=",
786     "+=",
787     "-=",
788     "*=",
789     "@=",
790     "/=",
791     "%=",
792     "&=",
793     "|=",
794     "^=",
795     "<<=",
796     ">>=",
797     "**=",
798     "//=",
799 }
800 COMPREHENSION_PRIORITY = 20
801 COMMA_PRIORITY = 18
802 TERNARY_PRIORITY = 16
803 LOGIC_PRIORITY = 14
804 STRING_PRIORITY = 12
805 COMPARATOR_PRIORITY = 10
806 MATH_PRIORITIES = {
807     token.VBAR: 9,
808     token.CIRCUMFLEX: 8,
809     token.AMPER: 7,
810     token.LEFTSHIFT: 6,
811     token.RIGHTSHIFT: 6,
812     token.PLUS: 5,
813     token.MINUS: 5,
814     token.STAR: 4,
815     token.SLASH: 4,
816     token.DOUBLESLASH: 4,
817     token.PERCENT: 4,
818     token.AT: 4,
819     token.TILDE: 3,
820     token.DOUBLESTAR: 2,
821 }
822 DOT_PRIORITY = 1
823
824
825 @dataclass
826 class BracketTracker:
827     """Keeps track of brackets on a line."""
828
829     depth: int = 0
830     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
831     delimiters: Dict[LeafID, Priority] = Factory(dict)
832     previous: Optional[Leaf] = None
833     _for_loop_variable: int = 0
834     _lambda_arguments: int = 0
835
836     def mark(self, leaf: Leaf) -> None:
837         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
838
839         All leaves receive an int `bracket_depth` field that stores how deep
840         within brackets a given leaf is. 0 means there are no enclosing brackets
841         that started on this line.
842
843         If a leaf is itself a closing bracket, it receives an `opening_bracket`
844         field that it forms a pair with. This is a one-directional link to
845         avoid reference cycles.
846
847         If a leaf is a delimiter (a token on which Black can split the line if
848         needed) and it's on depth 0, its `id()` is stored in the tracker's
849         `delimiters` field.
850         """
851         if leaf.type == token.COMMENT:
852             return
853
854         self.maybe_decrement_after_for_loop_variable(leaf)
855         self.maybe_decrement_after_lambda_arguments(leaf)
856         if leaf.type in CLOSING_BRACKETS:
857             self.depth -= 1
858             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
859             leaf.opening_bracket = opening_bracket
860         leaf.bracket_depth = self.depth
861         if self.depth == 0:
862             delim = is_split_before_delimiter(leaf, self.previous)
863             if delim and self.previous is not None:
864                 self.delimiters[id(self.previous)] = delim
865             else:
866                 delim = is_split_after_delimiter(leaf, self.previous)
867                 if delim:
868                     self.delimiters[id(leaf)] = delim
869         if leaf.type in OPENING_BRACKETS:
870             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
871             self.depth += 1
872         self.previous = leaf
873         self.maybe_increment_lambda_arguments(leaf)
874         self.maybe_increment_for_loop_variable(leaf)
875
876     def any_open_brackets(self) -> bool:
877         """Return True if there is an yet unmatched open bracket on the line."""
878         return bool(self.bracket_match)
879
880     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
881         """Return the highest priority of a delimiter found on the line.
882
883         Values are consistent with what `is_split_*_delimiter()` return.
884         Raises ValueError on no delimiters.
885         """
886         return max(v for k, v in self.delimiters.items() if k not in exclude)
887
888     def delimiter_count_with_priority(self, priority: int = 0) -> int:
889         """Return the number of delimiters with the given `priority`.
890
891         If no `priority` is passed, defaults to max priority on the line.
892         """
893         if not self.delimiters:
894             return 0
895
896         priority = priority or self.max_delimiter_priority()
897         return sum(1 for p in self.delimiters.values() if p == priority)
898
899     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
900         """In a for loop, or comprehension, the variables are often unpacks.
901
902         To avoid splitting on the comma in this situation, increase the depth of
903         tokens between `for` and `in`.
904         """
905         if leaf.type == token.NAME and leaf.value == "for":
906             self.depth += 1
907             self._for_loop_variable += 1
908             return True
909
910         return False
911
912     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
913         """See `maybe_increment_for_loop_variable` above for explanation."""
914         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
915             self.depth -= 1
916             self._for_loop_variable -= 1
917             return True
918
919         return False
920
921     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
922         """In a lambda expression, there might be more than one argument.
923
924         To avoid splitting on the comma in this situation, increase the depth of
925         tokens between `lambda` and `:`.
926         """
927         if leaf.type == token.NAME and leaf.value == "lambda":
928             self.depth += 1
929             self._lambda_arguments += 1
930             return True
931
932         return False
933
934     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
935         """See `maybe_increment_lambda_arguments` above for explanation."""
936         if self._lambda_arguments and leaf.type == token.COLON:
937             self.depth -= 1
938             self._lambda_arguments -= 1
939             return True
940
941         return False
942
943     def get_open_lsqb(self) -> Optional[Leaf]:
944         """Return the most recent opening square bracket (if any)."""
945         return self.bracket_match.get((self.depth - 1, token.RSQB))
946
947
948 @dataclass
949 class Line:
950     """Holds leaves and comments. Can be printed with `str(line)`."""
951
952     depth: int = 0
953     leaves: List[Leaf] = Factory(list)
954     comments: List[Tuple[Index, Leaf]] = Factory(list)
955     bracket_tracker: BracketTracker = Factory(BracketTracker)
956     inside_brackets: bool = False
957     should_explode: bool = False
958
959     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
960         """Add a new `leaf` to the end of the line.
961
962         Unless `preformatted` is True, the `leaf` will receive a new consistent
963         whitespace prefix and metadata applied by :class:`BracketTracker`.
964         Trailing commas are maybe removed, unpacked for loop variables are
965         demoted from being delimiters.
966
967         Inline comments are put aside.
968         """
969         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
970         if not has_value:
971             return
972
973         if token.COLON == leaf.type and self.is_class_paren_empty:
974             del self.leaves[-2:]
975         if self.leaves and not preformatted:
976             # Note: at this point leaf.prefix should be empty except for
977             # imports, for which we only preserve newlines.
978             leaf.prefix += whitespace(
979                 leaf, complex_subscript=self.is_complex_subscript(leaf)
980             )
981         if self.inside_brackets or not preformatted:
982             self.bracket_tracker.mark(leaf)
983             self.maybe_remove_trailing_comma(leaf)
984         if not self.append_comment(leaf):
985             self.leaves.append(leaf)
986
987     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
988         """Like :func:`append()` but disallow invalid standalone comment structure.
989
990         Raises ValueError when any `leaf` is appended after a standalone comment
991         or when a standalone comment is not the first leaf on the line.
992         """
993         if self.bracket_tracker.depth == 0:
994             if self.is_comment:
995                 raise ValueError("cannot append to standalone comments")
996
997             if self.leaves and leaf.type == STANDALONE_COMMENT:
998                 raise ValueError(
999                     "cannot append standalone comments to a populated line"
1000                 )
1001
1002         self.append(leaf, preformatted=preformatted)
1003
1004     @property
1005     def is_comment(self) -> bool:
1006         """Is this line a standalone comment?"""
1007         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
1008
1009     @property
1010     def is_decorator(self) -> bool:
1011         """Is this line a decorator?"""
1012         return bool(self) and self.leaves[0].type == token.AT
1013
1014     @property
1015     def is_import(self) -> bool:
1016         """Is this an import line?"""
1017         return bool(self) and is_import(self.leaves[0])
1018
1019     @property
1020     def is_class(self) -> bool:
1021         """Is this line a class definition?"""
1022         return (
1023             bool(self)
1024             and self.leaves[0].type == token.NAME
1025             and self.leaves[0].value == "class"
1026         )
1027
1028     @property
1029     def is_stub_class(self) -> bool:
1030         """Is this line a class definition with a body consisting only of "..."?"""
1031         return self.is_class and self.leaves[-3:] == [
1032             Leaf(token.DOT, ".") for _ in range(3)
1033         ]
1034
1035     @property
1036     def is_def(self) -> bool:
1037         """Is this a function definition? (Also returns True for async defs.)"""
1038         try:
1039             first_leaf = self.leaves[0]
1040         except IndexError:
1041             return False
1042
1043         try:
1044             second_leaf: Optional[Leaf] = self.leaves[1]
1045         except IndexError:
1046             second_leaf = None
1047         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1048             first_leaf.type == token.ASYNC
1049             and second_leaf is not None
1050             and second_leaf.type == token.NAME
1051             and second_leaf.value == "def"
1052         )
1053
1054     @property
1055     def is_class_paren_empty(self) -> bool:
1056         """Is this a class with no base classes but using parentheses?
1057
1058         Those are unnecessary and should be removed.
1059         """
1060         return (
1061             bool(self)
1062             and len(self.leaves) == 4
1063             and self.is_class
1064             and self.leaves[2].type == token.LPAR
1065             and self.leaves[2].value == "("
1066             and self.leaves[3].type == token.RPAR
1067             and self.leaves[3].value == ")"
1068         )
1069
1070     @property
1071     def is_triple_quoted_string(self) -> bool:
1072         """Is the line a triple quoted string?"""
1073         return (
1074             bool(self)
1075             and self.leaves[0].type == token.STRING
1076             and self.leaves[0].value.startswith(('"""', "'''"))
1077         )
1078
1079     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1080         """If so, needs to be split before emitting."""
1081         for leaf in self.leaves:
1082             if leaf.type == STANDALONE_COMMENT:
1083                 if leaf.bracket_depth <= depth_limit:
1084                     return True
1085
1086         return False
1087
1088     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1089         """Remove trailing comma if there is one and it's safe."""
1090         if not (
1091             self.leaves
1092             and self.leaves[-1].type == token.COMMA
1093             and closing.type in CLOSING_BRACKETS
1094         ):
1095             return False
1096
1097         if closing.type == token.RBRACE:
1098             self.remove_trailing_comma()
1099             return True
1100
1101         if closing.type == token.RSQB:
1102             comma = self.leaves[-1]
1103             if comma.parent and comma.parent.type == syms.listmaker:
1104                 self.remove_trailing_comma()
1105                 return True
1106
1107         # For parens let's check if it's safe to remove the comma.
1108         # Imports are always safe.
1109         if self.is_import:
1110             self.remove_trailing_comma()
1111             return True
1112
1113         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1114         # change a tuple into a different type by removing the comma.
1115         depth = closing.bracket_depth + 1
1116         commas = 0
1117         opening = closing.opening_bracket
1118         for _opening_index, leaf in enumerate(self.leaves):
1119             if leaf is opening:
1120                 break
1121
1122         else:
1123             return False
1124
1125         for leaf in self.leaves[_opening_index + 1 :]:
1126             if leaf is closing:
1127                 break
1128
1129             bracket_depth = leaf.bracket_depth
1130             if bracket_depth == depth and leaf.type == token.COMMA:
1131                 commas += 1
1132                 if leaf.parent and leaf.parent.type == syms.arglist:
1133                     commas += 1
1134                     break
1135
1136         if commas > 1:
1137             self.remove_trailing_comma()
1138             return True
1139
1140         return False
1141
1142     def append_comment(self, comment: Leaf) -> bool:
1143         """Add an inline or standalone comment to the line."""
1144         if (
1145             comment.type == STANDALONE_COMMENT
1146             and self.bracket_tracker.any_open_brackets()
1147         ):
1148             comment.prefix = ""
1149             return False
1150
1151         if comment.type != token.COMMENT:
1152             return False
1153
1154         after = len(self.leaves) - 1
1155         if after == -1:
1156             comment.type = STANDALONE_COMMENT
1157             comment.prefix = ""
1158             return False
1159
1160         else:
1161             self.comments.append((after, comment))
1162             return True
1163
1164     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1165         """Generate comments that should appear directly after `leaf`.
1166
1167         Provide a non-negative leaf `_index` to speed up the function.
1168         """
1169         if _index == -1:
1170             for _index, _leaf in enumerate(self.leaves):
1171                 if leaf is _leaf:
1172                     break
1173
1174             else:
1175                 return
1176
1177         for index, comment_after in self.comments:
1178             if _index == index:
1179                 yield comment_after
1180
1181     def remove_trailing_comma(self) -> None:
1182         """Remove the trailing comma and moves the comments attached to it."""
1183         comma_index = len(self.leaves) - 1
1184         for i in range(len(self.comments)):
1185             comment_index, comment = self.comments[i]
1186             if comment_index == comma_index:
1187                 self.comments[i] = (comma_index - 1, comment)
1188         self.leaves.pop()
1189
1190     def is_complex_subscript(self, leaf: Leaf) -> bool:
1191         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1192         open_lsqb = (
1193             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1194         )
1195         if open_lsqb is None:
1196             return False
1197
1198         subscript_start = open_lsqb.next_sibling
1199         if (
1200             isinstance(subscript_start, Node)
1201             and subscript_start.type == syms.subscriptlist
1202         ):
1203             subscript_start = child_towards(subscript_start, leaf)
1204         return subscript_start is not None and any(
1205             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1206         )
1207
1208     def __str__(self) -> str:
1209         """Render the line."""
1210         if not self:
1211             return "\n"
1212
1213         indent = "    " * self.depth
1214         leaves = iter(self.leaves)
1215         first = next(leaves)
1216         res = f"{first.prefix}{indent}{first.value}"
1217         for leaf in leaves:
1218             res += str(leaf)
1219         for _, comment in self.comments:
1220             res += str(comment)
1221         return res + "\n"
1222
1223     def __bool__(self) -> bool:
1224         """Return True if the line has leaves or comments."""
1225         return bool(self.leaves or self.comments)
1226
1227
1228 class UnformattedLines(Line):
1229     """Just like :class:`Line` but stores lines which aren't reformatted."""
1230
1231     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1232         """Just add a new `leaf` to the end of the lines.
1233
1234         The `preformatted` argument is ignored.
1235
1236         Keeps track of indentation `depth`, which is useful when the user
1237         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1238         """
1239         try:
1240             list(generate_comments(leaf))
1241         except FormatOn as f_on:
1242             self.leaves.append(f_on.leaf_from_consumed(leaf))
1243             raise
1244
1245         self.leaves.append(leaf)
1246         if leaf.type == token.INDENT:
1247             self.depth += 1
1248         elif leaf.type == token.DEDENT:
1249             self.depth -= 1
1250
1251     def __str__(self) -> str:
1252         """Render unformatted lines from leaves which were added with `append()`.
1253
1254         `depth` is not used for indentation in this case.
1255         """
1256         if not self:
1257             return "\n"
1258
1259         res = ""
1260         for leaf in self.leaves:
1261             res += str(leaf)
1262         return res
1263
1264     def append_comment(self, comment: Leaf) -> bool:
1265         """Not implemented in this class. Raises `NotImplementedError`."""
1266         raise NotImplementedError("Unformatted lines don't store comments separately.")
1267
1268     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1269         """Does nothing and returns False."""
1270         return False
1271
1272     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1273         """Does nothing and returns False."""
1274         return False
1275
1276
1277 @dataclass
1278 class EmptyLineTracker:
1279     """Provides a stateful method that returns the number of potential extra
1280     empty lines needed before and after the currently processed line.
1281
1282     Note: this tracker works on lines that haven't been split yet.  It assumes
1283     the prefix of the first leaf consists of optional newlines.  Those newlines
1284     are consumed by `maybe_empty_lines()` and included in the computation.
1285     """
1286
1287     is_pyi: bool = False
1288     previous_line: Optional[Line] = None
1289     previous_after: int = 0
1290     previous_defs: List[int] = Factory(list)
1291
1292     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1293         """Return the number of extra empty lines before and after the `current_line`.
1294
1295         This is for separating `def`, `async def` and `class` with extra empty
1296         lines (two on module-level).
1297         """
1298         if isinstance(current_line, UnformattedLines):
1299             return 0, 0
1300
1301         before, after = self._maybe_empty_lines(current_line)
1302         before -= self.previous_after
1303         self.previous_after = after
1304         self.previous_line = current_line
1305         return before, after
1306
1307     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1308         max_allowed = 1
1309         if current_line.depth == 0:
1310             max_allowed = 1 if self.is_pyi else 2
1311         if current_line.leaves:
1312             # Consume the first leaf's extra newlines.
1313             first_leaf = current_line.leaves[0]
1314             before = first_leaf.prefix.count("\n")
1315             before = min(before, max_allowed)
1316             first_leaf.prefix = ""
1317         else:
1318             before = 0
1319         depth = current_line.depth
1320         while self.previous_defs and self.previous_defs[-1] >= depth:
1321             self.previous_defs.pop()
1322             if self.is_pyi:
1323                 before = 0 if depth else 1
1324             else:
1325                 before = 1 if depth else 2
1326         is_decorator = current_line.is_decorator
1327         if is_decorator or current_line.is_def or current_line.is_class:
1328             if not is_decorator:
1329                 self.previous_defs.append(depth)
1330             if self.previous_line is None:
1331                 # Don't insert empty lines before the first line in the file.
1332                 return 0, 0
1333
1334             if self.previous_line.is_decorator:
1335                 return 0, 0
1336
1337             if self.previous_line.depth < current_line.depth and (
1338                 self.previous_line.is_class or self.previous_line.is_def
1339             ):
1340                 return 0, 0
1341
1342             if (
1343                 self.previous_line.is_comment
1344                 and self.previous_line.depth == current_line.depth
1345                 and before == 0
1346             ):
1347                 return 0, 0
1348
1349             if self.is_pyi:
1350                 if self.previous_line.depth > current_line.depth:
1351                     newlines = 1
1352                 elif current_line.is_class or self.previous_line.is_class:
1353                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1354                         newlines = 0
1355                     else:
1356                         newlines = 1
1357                 else:
1358                     newlines = 0
1359             else:
1360                 newlines = 2
1361             if current_line.depth and newlines:
1362                 newlines -= 1
1363             return newlines, 0
1364
1365         if (
1366             self.previous_line
1367             and self.previous_line.is_import
1368             and not current_line.is_import
1369             and depth == self.previous_line.depth
1370         ):
1371             return (before or 1), 0
1372
1373         if (
1374             self.previous_line
1375             and self.previous_line.is_class
1376             and current_line.is_triple_quoted_string
1377         ):
1378             return before, 1
1379
1380         return before, 0
1381
1382
1383 @dataclass
1384 class LineGenerator(Visitor[Line]):
1385     """Generates reformatted Line objects.  Empty lines are not emitted.
1386
1387     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1388     in ways that will no longer stringify to valid Python code on the tree.
1389     """
1390
1391     is_pyi: bool = False
1392     normalize_strings: bool = True
1393     current_line: Line = Factory(Line)
1394     remove_u_prefix: bool = False
1395
1396     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1397         """Generate a line.
1398
1399         If the line is empty, only emit if it makes sense.
1400         If the line is too long, split it first and then generate.
1401
1402         If any lines were generated, set up a new current_line.
1403         """
1404         if not self.current_line:
1405             if self.current_line.__class__ == type:
1406                 self.current_line.depth += indent
1407             else:
1408                 self.current_line = type(depth=self.current_line.depth + indent)
1409             return  # Line is empty, don't emit. Creating a new one unnecessary.
1410
1411         complete_line = self.current_line
1412         self.current_line = type(depth=complete_line.depth + indent)
1413         yield complete_line
1414
1415     def visit(self, node: LN) -> Iterator[Line]:
1416         """Main method to visit `node` and its children.
1417
1418         Yields :class:`Line` objects.
1419         """
1420         if isinstance(self.current_line, UnformattedLines):
1421             # File contained `# fmt: off`
1422             yield from self.visit_unformatted(node)
1423
1424         else:
1425             yield from super().visit(node)
1426
1427     def visit_default(self, node: LN) -> Iterator[Line]:
1428         """Default `visit_*()` implementation. Recurses to children of `node`."""
1429         if isinstance(node, Leaf):
1430             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1431             try:
1432                 for comment in generate_comments(node):
1433                     if any_open_brackets:
1434                         # any comment within brackets is subject to splitting
1435                         self.current_line.append(comment)
1436                     elif comment.type == token.COMMENT:
1437                         # regular trailing comment
1438                         self.current_line.append(comment)
1439                         yield from self.line()
1440
1441                     else:
1442                         # regular standalone comment
1443                         yield from self.line()
1444
1445                         self.current_line.append(comment)
1446                         yield from self.line()
1447
1448             except FormatOff as f_off:
1449                 f_off.trim_prefix(node)
1450                 yield from self.line(type=UnformattedLines)
1451                 yield from self.visit(node)
1452
1453             except FormatOn as f_on:
1454                 # This only happens here if somebody says "fmt: on" multiple
1455                 # times in a row.
1456                 f_on.trim_prefix(node)
1457                 yield from self.visit_default(node)
1458
1459             else:
1460                 normalize_prefix(node, inside_brackets=any_open_brackets)
1461                 if self.normalize_strings and node.type == token.STRING:
1462                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1463                     normalize_string_quotes(node)
1464                 if node.type not in WHITESPACE:
1465                     self.current_line.append(node)
1466         yield from super().visit_default(node)
1467
1468     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1469         """Increase indentation level, maybe yield a line."""
1470         # In blib2to3 INDENT never holds comments.
1471         yield from self.line(+1)
1472         yield from self.visit_default(node)
1473
1474     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1475         """Decrease indentation level, maybe yield a line."""
1476         # The current line might still wait for trailing comments.  At DEDENT time
1477         # there won't be any (they would be prefixes on the preceding NEWLINE).
1478         # Emit the line then.
1479         yield from self.line()
1480
1481         # While DEDENT has no value, its prefix may contain standalone comments
1482         # that belong to the current indentation level.  Get 'em.
1483         yield from self.visit_default(node)
1484
1485         # Finally, emit the dedent.
1486         yield from self.line(-1)
1487
1488     def visit_stmt(
1489         self, node: Node, keywords: Set[str], parens: Set[str]
1490     ) -> Iterator[Line]:
1491         """Visit a statement.
1492
1493         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1494         `def`, `with`, `class`, `assert` and assignments.
1495
1496         The relevant Python language `keywords` for a given statement will be
1497         NAME leaves within it. This methods puts those on a separate line.
1498
1499         `parens` holds a set of string leaf values immediately after which
1500         invisible parens should be put.
1501         """
1502         normalize_invisible_parens(node, parens_after=parens)
1503         for child in node.children:
1504             if child.type == token.NAME and child.value in keywords:  # type: ignore
1505                 yield from self.line()
1506
1507             yield from self.visit(child)
1508
1509     def visit_suite(self, node: Node) -> Iterator[Line]:
1510         """Visit a suite."""
1511         if self.is_pyi and is_stub_suite(node):
1512             yield from self.visit(node.children[2])
1513         else:
1514             yield from self.visit_default(node)
1515
1516     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1517         """Visit a statement without nested statements."""
1518         is_suite_like = node.parent and node.parent.type in STATEMENT
1519         if is_suite_like:
1520             if self.is_pyi and is_stub_body(node):
1521                 yield from self.visit_default(node)
1522             else:
1523                 yield from self.line(+1)
1524                 yield from self.visit_default(node)
1525                 yield from self.line(-1)
1526
1527         else:
1528             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1529                 yield from self.line()
1530             yield from self.visit_default(node)
1531
1532     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1533         """Visit `async def`, `async for`, `async with`."""
1534         yield from self.line()
1535
1536         children = iter(node.children)
1537         for child in children:
1538             yield from self.visit(child)
1539
1540             if child.type == token.ASYNC:
1541                 break
1542
1543         internal_stmt = next(children)
1544         for child in internal_stmt.children:
1545             yield from self.visit(child)
1546
1547     def visit_decorators(self, node: Node) -> Iterator[Line]:
1548         """Visit decorators."""
1549         for child in node.children:
1550             yield from self.line()
1551             yield from self.visit(child)
1552
1553     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1554         """Remove a semicolon and put the other statement on a separate line."""
1555         yield from self.line()
1556
1557     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1558         """End of file. Process outstanding comments and end with a newline."""
1559         yield from self.visit_default(leaf)
1560         yield from self.line()
1561
1562     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1563         """Used when file contained a `# fmt: off`."""
1564         if isinstance(node, Node):
1565             for child in node.children:
1566                 yield from self.visit(child)
1567
1568         else:
1569             try:
1570                 self.current_line.append(node)
1571             except FormatOn as f_on:
1572                 f_on.trim_prefix(node)
1573                 yield from self.line()
1574                 yield from self.visit(node)
1575
1576             if node.type == token.ENDMARKER:
1577                 # somebody decided not to put a final `# fmt: on`
1578                 yield from self.line()
1579
1580     def __attrs_post_init__(self) -> None:
1581         """You are in a twisty little maze of passages."""
1582         v = self.visit_stmt
1583         Ø: Set[str] = set()
1584         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1585         self.visit_if_stmt = partial(
1586             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1587         )
1588         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1589         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1590         self.visit_try_stmt = partial(
1591             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1592         )
1593         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1594         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1595         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1596         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1597         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1598         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1599         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1600         self.visit_async_funcdef = self.visit_async_stmt
1601         self.visit_decorated = self.visit_decorators
1602
1603
1604 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1605 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1606 OPENING_BRACKETS = set(BRACKET.keys())
1607 CLOSING_BRACKETS = set(BRACKET.values())
1608 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1609 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1610
1611
1612 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1613     """Return whitespace prefix if needed for the given `leaf`.
1614
1615     `complex_subscript` signals whether the given leaf is part of a subscription
1616     which has non-trivial arguments, like arithmetic expressions or function calls.
1617     """
1618     NO = ""
1619     SPACE = " "
1620     DOUBLESPACE = "  "
1621     t = leaf.type
1622     p = leaf.parent
1623     v = leaf.value
1624     if t in ALWAYS_NO_SPACE:
1625         return NO
1626
1627     if t == token.COMMENT:
1628         return DOUBLESPACE
1629
1630     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1631     if t == token.COLON and p.type not in {
1632         syms.subscript,
1633         syms.subscriptlist,
1634         syms.sliceop,
1635     }:
1636         return NO
1637
1638     prev = leaf.prev_sibling
1639     if not prev:
1640         prevp = preceding_leaf(p)
1641         if not prevp or prevp.type in OPENING_BRACKETS:
1642             return NO
1643
1644         if t == token.COLON:
1645             if prevp.type == token.COLON:
1646                 return NO
1647
1648             elif prevp.type != token.COMMA and not complex_subscript:
1649                 return NO
1650
1651             return SPACE
1652
1653         if prevp.type == token.EQUAL:
1654             if prevp.parent:
1655                 if prevp.parent.type in {
1656                     syms.arglist,
1657                     syms.argument,
1658                     syms.parameters,
1659                     syms.varargslist,
1660                 }:
1661                     return NO
1662
1663                 elif prevp.parent.type == syms.typedargslist:
1664                     # A bit hacky: if the equal sign has whitespace, it means we
1665                     # previously found it's a typed argument.  So, we're using
1666                     # that, too.
1667                     return prevp.prefix
1668
1669         elif prevp.type in STARS:
1670             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1671                 return NO
1672
1673         elif prevp.type == token.COLON:
1674             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1675                 return SPACE if complex_subscript else NO
1676
1677         elif (
1678             prevp.parent
1679             and prevp.parent.type == syms.factor
1680             and prevp.type in MATH_OPERATORS
1681         ):
1682             return NO
1683
1684         elif (
1685             prevp.type == token.RIGHTSHIFT
1686             and prevp.parent
1687             and prevp.parent.type == syms.shift_expr
1688             and prevp.prev_sibling
1689             and prevp.prev_sibling.type == token.NAME
1690             and prevp.prev_sibling.value == "print"  # type: ignore
1691         ):
1692             # Python 2 print chevron
1693             return NO
1694
1695     elif prev.type in OPENING_BRACKETS:
1696         return NO
1697
1698     if p.type in {syms.parameters, syms.arglist}:
1699         # untyped function signatures or calls
1700         if not prev or prev.type != token.COMMA:
1701             return NO
1702
1703     elif p.type == syms.varargslist:
1704         # lambdas
1705         if prev and prev.type != token.COMMA:
1706             return NO
1707
1708     elif p.type == syms.typedargslist:
1709         # typed function signatures
1710         if not prev:
1711             return NO
1712
1713         if t == token.EQUAL:
1714             if prev.type != syms.tname:
1715                 return NO
1716
1717         elif prev.type == token.EQUAL:
1718             # A bit hacky: if the equal sign has whitespace, it means we
1719             # previously found it's a typed argument.  So, we're using that, too.
1720             return prev.prefix
1721
1722         elif prev.type != token.COMMA:
1723             return NO
1724
1725     elif p.type == syms.tname:
1726         # type names
1727         if not prev:
1728             prevp = preceding_leaf(p)
1729             if not prevp or prevp.type != token.COMMA:
1730                 return NO
1731
1732     elif p.type == syms.trailer:
1733         # attributes and calls
1734         if t == token.LPAR or t == token.RPAR:
1735             return NO
1736
1737         if not prev:
1738             if t == token.DOT:
1739                 prevp = preceding_leaf(p)
1740                 if not prevp or prevp.type != token.NUMBER:
1741                     return NO
1742
1743             elif t == token.LSQB:
1744                 return NO
1745
1746         elif prev.type != token.COMMA:
1747             return NO
1748
1749     elif p.type == syms.argument:
1750         # single argument
1751         if t == token.EQUAL:
1752             return NO
1753
1754         if not prev:
1755             prevp = preceding_leaf(p)
1756             if not prevp or prevp.type == token.LPAR:
1757                 return NO
1758
1759         elif prev.type in {token.EQUAL} | STARS:
1760             return NO
1761
1762     elif p.type == syms.decorator:
1763         # decorators
1764         return NO
1765
1766     elif p.type == syms.dotted_name:
1767         if prev:
1768             return NO
1769
1770         prevp = preceding_leaf(p)
1771         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1772             return NO
1773
1774     elif p.type == syms.classdef:
1775         if t == token.LPAR:
1776             return NO
1777
1778         if prev and prev.type == token.LPAR:
1779             return NO
1780
1781     elif p.type in {syms.subscript, syms.sliceop}:
1782         # indexing
1783         if not prev:
1784             assert p.parent is not None, "subscripts are always parented"
1785             if p.parent.type == syms.subscriptlist:
1786                 return SPACE
1787
1788             return NO
1789
1790         elif not complex_subscript:
1791             return NO
1792
1793     elif p.type == syms.atom:
1794         if prev and t == token.DOT:
1795             # dots, but not the first one.
1796             return NO
1797
1798     elif p.type == syms.dictsetmaker:
1799         # dict unpacking
1800         if prev and prev.type == token.DOUBLESTAR:
1801             return NO
1802
1803     elif p.type in {syms.factor, syms.star_expr}:
1804         # unary ops
1805         if not prev:
1806             prevp = preceding_leaf(p)
1807             if not prevp or prevp.type in OPENING_BRACKETS:
1808                 return NO
1809
1810             prevp_parent = prevp.parent
1811             assert prevp_parent is not None
1812             if prevp.type == token.COLON and prevp_parent.type in {
1813                 syms.subscript,
1814                 syms.sliceop,
1815             }:
1816                 return NO
1817
1818             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1819                 return NO
1820
1821         elif t == token.NAME or t == token.NUMBER:
1822             return NO
1823
1824     elif p.type == syms.import_from:
1825         if t == token.DOT:
1826             if prev and prev.type == token.DOT:
1827                 return NO
1828
1829         elif t == token.NAME:
1830             if v == "import":
1831                 return SPACE
1832
1833             if prev and prev.type == token.DOT:
1834                 return NO
1835
1836     elif p.type == syms.sliceop:
1837         return NO
1838
1839     return SPACE
1840
1841
1842 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1843     """Return the first leaf that precedes `node`, if any."""
1844     while node:
1845         res = node.prev_sibling
1846         if res:
1847             if isinstance(res, Leaf):
1848                 return res
1849
1850             try:
1851                 return list(res.leaves())[-1]
1852
1853             except IndexError:
1854                 return None
1855
1856         node = node.parent
1857     return None
1858
1859
1860 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1861     """Return the child of `ancestor` that contains `descendant`."""
1862     node: Optional[LN] = descendant
1863     while node and node.parent != ancestor:
1864         node = node.parent
1865     return node
1866
1867
1868 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1869     """Return the priority of the `leaf` delimiter, given a line break after it.
1870
1871     The delimiter priorities returned here are from those delimiters that would
1872     cause a line break after themselves.
1873
1874     Higher numbers are higher priority.
1875     """
1876     if leaf.type == token.COMMA:
1877         return COMMA_PRIORITY
1878
1879     return 0
1880
1881
1882 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1883     """Return the priority of the `leaf` delimiter, given a line before after it.
1884
1885     The delimiter priorities returned here are from those delimiters that would
1886     cause a line break before themselves.
1887
1888     Higher numbers are higher priority.
1889     """
1890     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1891         # * and ** might also be MATH_OPERATORS but in this case they are not.
1892         # Don't treat them as a delimiter.
1893         return 0
1894
1895     if (
1896         leaf.type == token.DOT
1897         and leaf.parent
1898         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1899         and (previous is None or previous.type in CLOSING_BRACKETS)
1900     ):
1901         return DOT_PRIORITY
1902
1903     if (
1904         leaf.type in MATH_OPERATORS
1905         and leaf.parent
1906         and leaf.parent.type not in {syms.factor, syms.star_expr}
1907     ):
1908         return MATH_PRIORITIES[leaf.type]
1909
1910     if leaf.type in COMPARATORS:
1911         return COMPARATOR_PRIORITY
1912
1913     if (
1914         leaf.type == token.STRING
1915         and previous is not None
1916         and previous.type == token.STRING
1917     ):
1918         return STRING_PRIORITY
1919
1920     if leaf.type != token.NAME:
1921         return 0
1922
1923     if (
1924         leaf.value == "for"
1925         and leaf.parent
1926         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1927     ):
1928         return COMPREHENSION_PRIORITY
1929
1930     if (
1931         leaf.value == "if"
1932         and leaf.parent
1933         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1934     ):
1935         return COMPREHENSION_PRIORITY
1936
1937     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1938         return TERNARY_PRIORITY
1939
1940     if leaf.value == "is":
1941         return COMPARATOR_PRIORITY
1942
1943     if (
1944         leaf.value == "in"
1945         and leaf.parent
1946         and leaf.parent.type in {syms.comp_op, syms.comparison}
1947         and not (
1948             previous is not None
1949             and previous.type == token.NAME
1950             and previous.value == "not"
1951         )
1952     ):
1953         return COMPARATOR_PRIORITY
1954
1955     if (
1956         leaf.value == "not"
1957         and leaf.parent
1958         and leaf.parent.type == syms.comp_op
1959         and not (
1960             previous is not None
1961             and previous.type == token.NAME
1962             and previous.value == "is"
1963         )
1964     ):
1965         return COMPARATOR_PRIORITY
1966
1967     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1968         return LOGIC_PRIORITY
1969
1970     return 0
1971
1972
1973 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1974     """Clean the prefix of the `leaf` and generate comments from it, if any.
1975
1976     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1977     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1978     move because it does away with modifying the grammar to include all the
1979     possible places in which comments can be placed.
1980
1981     The sad consequence for us though is that comments don't "belong" anywhere.
1982     This is why this function generates simple parentless Leaf objects for
1983     comments.  We simply don't know what the correct parent should be.
1984
1985     No matter though, we can live without this.  We really only need to
1986     differentiate between inline and standalone comments.  The latter don't
1987     share the line with any code.
1988
1989     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1990     are emitted with a fake STANDALONE_COMMENT token identifier.
1991     """
1992     p = leaf.prefix
1993     if not p:
1994         return
1995
1996     if "#" not in p:
1997         return
1998
1999     consumed = 0
2000     nlines = 0
2001     for index, line in enumerate(p.split("\n")):
2002         consumed += len(line) + 1  # adding the length of the split '\n'
2003         line = line.lstrip()
2004         if not line:
2005             nlines += 1
2006         if not line.startswith("#"):
2007             continue
2008
2009         if index == 0 and leaf.type != token.ENDMARKER:
2010             comment_type = token.COMMENT  # simple trailing comment
2011         else:
2012             comment_type = STANDALONE_COMMENT
2013         comment = make_comment(line)
2014         yield Leaf(comment_type, comment, prefix="\n" * nlines)
2015
2016         if comment in {"# fmt: on", "# yapf: enable"}:
2017             raise FormatOn(consumed)
2018
2019         if comment in {"# fmt: off", "# yapf: disable"}:
2020             if comment_type == STANDALONE_COMMENT:
2021                 raise FormatOff(consumed)
2022
2023             prev = preceding_leaf(leaf)
2024             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
2025                 raise FormatOff(consumed)
2026
2027         nlines = 0
2028
2029
2030 def make_comment(content: str) -> str:
2031     """Return a consistently formatted comment from the given `content` string.
2032
2033     All comments (except for "##", "#!", "#:") should have a single space between
2034     the hash sign and the content.
2035
2036     If `content` didn't start with a hash sign, one is provided.
2037     """
2038     content = content.rstrip()
2039     if not content:
2040         return "#"
2041
2042     if content[0] == "#":
2043         content = content[1:]
2044     if content and content[0] not in " !:#":
2045         content = " " + content
2046     return "#" + content
2047
2048
2049 def split_line(
2050     line: Line, line_length: int, inner: bool = False, py36: bool = False
2051 ) -> Iterator[Line]:
2052     """Split a `line` into potentially many lines.
2053
2054     They should fit in the allotted `line_length` but might not be able to.
2055     `inner` signifies that there were a pair of brackets somewhere around the
2056     current `line`, possibly transitively. This means we can fallback to splitting
2057     by delimiters if the LHS/RHS don't yield any results.
2058
2059     If `py36` is True, splitting may generate syntax that is only compatible
2060     with Python 3.6 and later.
2061     """
2062     if isinstance(line, UnformattedLines) or line.is_comment:
2063         yield line
2064         return
2065
2066     line_str = str(line).strip("\n")
2067     if not line.should_explode and is_line_short_enough(
2068         line, line_length=line_length, line_str=line_str
2069     ):
2070         yield line
2071         return
2072
2073     split_funcs: List[SplitFunc]
2074     if line.is_def:
2075         split_funcs = [left_hand_split]
2076     else:
2077
2078         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2079             for omit in generate_trailers_to_omit(line, line_length):
2080                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2081                 if is_line_short_enough(lines[0], line_length=line_length):
2082                     yield from lines
2083                     return
2084
2085             # All splits failed, best effort split with no omits.
2086             # This mostly happens to multiline strings that are by definition
2087             # reported as not fitting a single line.
2088             yield from right_hand_split(line, py36)
2089
2090         if line.inside_brackets:
2091             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2092         else:
2093             split_funcs = [rhs]
2094     for split_func in split_funcs:
2095         # We are accumulating lines in `result` because we might want to abort
2096         # mission and return the original line in the end, or attempt a different
2097         # split altogether.
2098         result: List[Line] = []
2099         try:
2100             for l in split_func(line, py36):
2101                 if str(l).strip("\n") == line_str:
2102                     raise CannotSplit("Split function returned an unchanged result")
2103
2104                 result.extend(
2105                     split_line(l, line_length=line_length, inner=True, py36=py36)
2106                 )
2107         except CannotSplit as cs:
2108             continue
2109
2110         else:
2111             yield from result
2112             break
2113
2114     else:
2115         yield line
2116
2117
2118 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2119     """Split line into many lines, starting with the first matching bracket pair.
2120
2121     Note: this usually looks weird, only use this for function definitions.
2122     Prefer RHS otherwise.  This is why this function is not symmetrical with
2123     :func:`right_hand_split` which also handles optional parentheses.
2124     """
2125     head = Line(depth=line.depth)
2126     body = Line(depth=line.depth + 1, inside_brackets=True)
2127     tail = Line(depth=line.depth)
2128     tail_leaves: List[Leaf] = []
2129     body_leaves: List[Leaf] = []
2130     head_leaves: List[Leaf] = []
2131     current_leaves = head_leaves
2132     matching_bracket = None
2133     for leaf in line.leaves:
2134         if (
2135             current_leaves is body_leaves
2136             and leaf.type in CLOSING_BRACKETS
2137             and leaf.opening_bracket is matching_bracket
2138         ):
2139             current_leaves = tail_leaves if body_leaves else head_leaves
2140         current_leaves.append(leaf)
2141         if current_leaves is head_leaves:
2142             if leaf.type in OPENING_BRACKETS:
2143                 matching_bracket = leaf
2144                 current_leaves = body_leaves
2145     # Since body is a new indent level, remove spurious leading whitespace.
2146     if body_leaves:
2147         normalize_prefix(body_leaves[0], inside_brackets=True)
2148     # Build the new lines.
2149     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2150         for leaf in leaves:
2151             result.append(leaf, preformatted=True)
2152             for comment_after in line.comments_after(leaf):
2153                 result.append(comment_after, preformatted=True)
2154     bracket_split_succeeded_or_raise(head, body, tail)
2155     for result in (head, body, tail):
2156         if result:
2157             yield result
2158
2159
2160 def right_hand_split(
2161     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2162 ) -> Iterator[Line]:
2163     """Split line into many lines, starting with the last matching bracket pair.
2164
2165     If the split was by optional parentheses, attempt splitting without them, too.
2166     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2167     this split.
2168
2169     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2170     """
2171     head = Line(depth=line.depth)
2172     body = Line(depth=line.depth + 1, inside_brackets=True)
2173     tail = Line(depth=line.depth)
2174     tail_leaves: List[Leaf] = []
2175     body_leaves: List[Leaf] = []
2176     head_leaves: List[Leaf] = []
2177     current_leaves = tail_leaves
2178     opening_bracket = None
2179     closing_bracket = None
2180     for leaf in reversed(line.leaves):
2181         if current_leaves is body_leaves:
2182             if leaf is opening_bracket:
2183                 current_leaves = head_leaves if body_leaves else tail_leaves
2184         current_leaves.append(leaf)
2185         if current_leaves is tail_leaves:
2186             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2187                 opening_bracket = leaf.opening_bracket
2188                 closing_bracket = leaf
2189                 current_leaves = body_leaves
2190     tail_leaves.reverse()
2191     body_leaves.reverse()
2192     head_leaves.reverse()
2193     # Since body is a new indent level, remove spurious leading whitespace.
2194     if body_leaves:
2195         normalize_prefix(body_leaves[0], inside_brackets=True)
2196     if not head_leaves:
2197         # No `head` means the split failed. Either `tail` has all content or
2198         # the matching `opening_bracket` wasn't available on `line` anymore.
2199         raise CannotSplit("No brackets found")
2200
2201     # Build the new lines.
2202     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2203         for leaf in leaves:
2204             result.append(leaf, preformatted=True)
2205             for comment_after in line.comments_after(leaf):
2206                 result.append(comment_after, preformatted=True)
2207     bracket_split_succeeded_or_raise(head, body, tail)
2208     assert opening_bracket and closing_bracket
2209     if (
2210         # the opening bracket is an optional paren
2211         opening_bracket.type == token.LPAR
2212         and not opening_bracket.value
2213         # the closing bracket is an optional paren
2214         and closing_bracket.type == token.RPAR
2215         and not closing_bracket.value
2216         # there are no standalone comments in the body
2217         and not line.contains_standalone_comments(0)
2218         # and it's not an import (optional parens are the only thing we can split
2219         # on in this case; attempting a split without them is a waste of time)
2220         and not line.is_import
2221     ):
2222         omit = {id(closing_bracket), *omit}
2223         if can_omit_invisible_parens(body, line_length):
2224             try:
2225                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2226                 return
2227             except CannotSplit:
2228                 pass
2229
2230     ensure_visible(opening_bracket)
2231     ensure_visible(closing_bracket)
2232     body.should_explode = should_explode(body, opening_bracket)
2233     for result in (head, body, tail):
2234         if result:
2235             yield result
2236
2237
2238 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2239     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2240
2241     Do nothing otherwise.
2242
2243     A left- or right-hand split is based on a pair of brackets. Content before
2244     (and including) the opening bracket is left on one line, content inside the
2245     brackets is put on a separate line, and finally content starting with and
2246     following the closing bracket is put on a separate line.
2247
2248     Those are called `head`, `body`, and `tail`, respectively. If the split
2249     produced the same line (all content in `head`) or ended up with an empty `body`
2250     and the `tail` is just the closing bracket, then it's considered failed.
2251     """
2252     tail_len = len(str(tail).strip())
2253     if not body:
2254         if tail_len == 0:
2255             raise CannotSplit("Splitting brackets produced the same line")
2256
2257         elif tail_len < 3:
2258             raise CannotSplit(
2259                 f"Splitting brackets on an empty body to save "
2260                 f"{tail_len} characters is not worth it"
2261             )
2262
2263
2264 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2265     """Normalize prefix of the first leaf in every line returned by `split_func`.
2266
2267     This is a decorator over relevant split functions.
2268     """
2269
2270     @wraps(split_func)
2271     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2272         for l in split_func(line, py36):
2273             normalize_prefix(l.leaves[0], inside_brackets=True)
2274             yield l
2275
2276     return split_wrapper
2277
2278
2279 @dont_increase_indentation
2280 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2281     """Split according to delimiters of the highest priority.
2282
2283     If `py36` is True, the split will add trailing commas also in function
2284     signatures that contain `*` and `**`.
2285     """
2286     try:
2287         last_leaf = line.leaves[-1]
2288     except IndexError:
2289         raise CannotSplit("Line empty")
2290
2291     bt = line.bracket_tracker
2292     try:
2293         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2294     except ValueError:
2295         raise CannotSplit("No delimiters found")
2296
2297     if delimiter_priority == DOT_PRIORITY:
2298         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2299             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2300
2301     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2302     lowest_depth = sys.maxsize
2303     trailing_comma_safe = True
2304
2305     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2306         """Append `leaf` to current line or to new line if appending impossible."""
2307         nonlocal current_line
2308         try:
2309             current_line.append_safe(leaf, preformatted=True)
2310         except ValueError as ve:
2311             yield current_line
2312
2313             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2314             current_line.append(leaf)
2315
2316     for index, leaf in enumerate(line.leaves):
2317         yield from append_to_line(leaf)
2318
2319         for comment_after in line.comments_after(leaf, index):
2320             yield from append_to_line(comment_after)
2321
2322         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2323         if leaf.bracket_depth == lowest_depth and is_vararg(
2324             leaf, within=VARARGS_PARENTS
2325         ):
2326             trailing_comma_safe = trailing_comma_safe and py36
2327         leaf_priority = bt.delimiters.get(id(leaf))
2328         if leaf_priority == delimiter_priority:
2329             yield current_line
2330
2331             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2332     if current_line:
2333         if (
2334             trailing_comma_safe
2335             and delimiter_priority == COMMA_PRIORITY
2336             and current_line.leaves[-1].type != token.COMMA
2337             and current_line.leaves[-1].type != STANDALONE_COMMENT
2338         ):
2339             current_line.append(Leaf(token.COMMA, ","))
2340         yield current_line
2341
2342
2343 @dont_increase_indentation
2344 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2345     """Split standalone comments from the rest of the line."""
2346     if not line.contains_standalone_comments(0):
2347         raise CannotSplit("Line does not have any standalone comments")
2348
2349     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2350
2351     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2352         """Append `leaf` to current line or to new line if appending impossible."""
2353         nonlocal current_line
2354         try:
2355             current_line.append_safe(leaf, preformatted=True)
2356         except ValueError as ve:
2357             yield current_line
2358
2359             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2360             current_line.append(leaf)
2361
2362     for index, leaf in enumerate(line.leaves):
2363         yield from append_to_line(leaf)
2364
2365         for comment_after in line.comments_after(leaf, index):
2366             yield from append_to_line(comment_after)
2367
2368     if current_line:
2369         yield current_line
2370
2371
2372 def is_import(leaf: Leaf) -> bool:
2373     """Return True if the given leaf starts an import statement."""
2374     p = leaf.parent
2375     t = leaf.type
2376     v = leaf.value
2377     return bool(
2378         t == token.NAME
2379         and (
2380             (v == "import" and p and p.type == syms.import_name)
2381             or (v == "from" and p and p.type == syms.import_from)
2382         )
2383     )
2384
2385
2386 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2387     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2388     else.
2389
2390     Note: don't use backslashes for formatting or you'll lose your voting rights.
2391     """
2392     if not inside_brackets:
2393         spl = leaf.prefix.split("#")
2394         if "\\" not in spl[0]:
2395             nl_count = spl[-1].count("\n")
2396             if len(spl) > 1:
2397                 nl_count -= 1
2398             leaf.prefix = "\n" * nl_count
2399             return
2400
2401     leaf.prefix = ""
2402
2403
2404 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2405     """Make all string prefixes lowercase.
2406
2407     If remove_u_prefix is given, also removes any u prefix from the string.
2408
2409     Note: Mutates its argument.
2410     """
2411     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2412     assert match is not None, f"failed to match string {leaf.value!r}"
2413     orig_prefix = match.group(1)
2414     new_prefix = orig_prefix.lower()
2415     if remove_u_prefix:
2416         new_prefix = new_prefix.replace("u", "")
2417     leaf.value = f"{new_prefix}{match.group(2)}"
2418
2419
2420 def normalize_string_quotes(leaf: Leaf) -> None:
2421     """Prefer double quotes but only if it doesn't cause more escaping.
2422
2423     Adds or removes backslashes as appropriate. Doesn't parse and fix
2424     strings nested in f-strings (yet).
2425
2426     Note: Mutates its argument.
2427     """
2428     value = leaf.value.lstrip("furbFURB")
2429     if value[:3] == '"""':
2430         return
2431
2432     elif value[:3] == "'''":
2433         orig_quote = "'''"
2434         new_quote = '"""'
2435     elif value[0] == '"':
2436         orig_quote = '"'
2437         new_quote = "'"
2438     else:
2439         orig_quote = "'"
2440         new_quote = '"'
2441     first_quote_pos = leaf.value.find(orig_quote)
2442     if first_quote_pos == -1:
2443         return  # There's an internal error
2444
2445     prefix = leaf.value[:first_quote_pos]
2446     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2447     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2448     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2449     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2450     if "r" in prefix.casefold():
2451         if unescaped_new_quote.search(body):
2452             # There's at least one unescaped new_quote in this raw string
2453             # so converting is impossible
2454             return
2455
2456         # Do not introduce or remove backslashes in raw strings
2457         new_body = body
2458     else:
2459         # remove unnecessary quotes
2460         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2461         if body != new_body:
2462             # Consider the string without unnecessary quotes as the original
2463             body = new_body
2464             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2465         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2466         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2467     if new_quote == '"""' and new_body[-1] == '"':
2468         # edge case:
2469         new_body = new_body[:-1] + '\\"'
2470     orig_escape_count = body.count("\\")
2471     new_escape_count = new_body.count("\\")
2472     if new_escape_count > orig_escape_count:
2473         return  # Do not introduce more escaping
2474
2475     if new_escape_count == orig_escape_count and orig_quote == '"':
2476         return  # Prefer double quotes
2477
2478     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2479
2480
2481 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2482     """Make existing optional parentheses invisible or create new ones.
2483
2484     `parens_after` is a set of string leaf values immeditely after which parens
2485     should be put.
2486
2487     Standardizes on visible parentheses for single-element tuples, and keeps
2488     existing visible parentheses for other tuples and generator expressions.
2489     """
2490     try:
2491         list(generate_comments(node))
2492     except FormatOff:
2493         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2494
2495     check_lpar = False
2496     for index, child in enumerate(list(node.children)):
2497         if check_lpar:
2498             if child.type == syms.atom:
2499                 maybe_make_parens_invisible_in_atom(child)
2500             elif is_one_tuple(child):
2501                 # wrap child in visible parentheses
2502                 lpar = Leaf(token.LPAR, "(")
2503                 rpar = Leaf(token.RPAR, ")")
2504                 child.remove()
2505                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2506             elif node.type == syms.import_from:
2507                 # "import from" nodes store parentheses directly as part of
2508                 # the statement
2509                 if child.type == token.LPAR:
2510                     # make parentheses invisible
2511                     child.value = ""  # type: ignore
2512                     node.children[-1].value = ""  # type: ignore
2513                 elif child.type != token.STAR:
2514                     # insert invisible parentheses
2515                     node.insert_child(index, Leaf(token.LPAR, ""))
2516                     node.append_child(Leaf(token.RPAR, ""))
2517                 break
2518
2519             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2520                 # wrap child in invisible parentheses
2521                 lpar = Leaf(token.LPAR, "")
2522                 rpar = Leaf(token.RPAR, "")
2523                 index = child.remove() or 0
2524                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2525
2526         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2527
2528
2529 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2530     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2531     if (
2532         node.type != syms.atom
2533         or is_empty_tuple(node)
2534         or is_one_tuple(node)
2535         or is_yield(node)
2536         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2537     ):
2538         return False
2539
2540     first = node.children[0]
2541     last = node.children[-1]
2542     if first.type == token.LPAR and last.type == token.RPAR:
2543         # make parentheses invisible
2544         first.value = ""  # type: ignore
2545         last.value = ""  # type: ignore
2546         if len(node.children) > 1:
2547             maybe_make_parens_invisible_in_atom(node.children[1])
2548         return True
2549
2550     return False
2551
2552
2553 def is_empty_tuple(node: LN) -> bool:
2554     """Return True if `node` holds an empty tuple."""
2555     return (
2556         node.type == syms.atom
2557         and len(node.children) == 2
2558         and node.children[0].type == token.LPAR
2559         and node.children[1].type == token.RPAR
2560     )
2561
2562
2563 def is_one_tuple(node: LN) -> bool:
2564     """Return True if `node` holds a tuple with one element, with or without parens."""
2565     if node.type == syms.atom:
2566         if len(node.children) != 3:
2567             return False
2568
2569         lpar, gexp, rpar = node.children
2570         if not (
2571             lpar.type == token.LPAR
2572             and gexp.type == syms.testlist_gexp
2573             and rpar.type == token.RPAR
2574         ):
2575             return False
2576
2577         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2578
2579     return (
2580         node.type in IMPLICIT_TUPLE
2581         and len(node.children) == 2
2582         and node.children[1].type == token.COMMA
2583     )
2584
2585
2586 def is_yield(node: LN) -> bool:
2587     """Return True if `node` holds a `yield` or `yield from` expression."""
2588     if node.type == syms.yield_expr:
2589         return True
2590
2591     if node.type == token.NAME and node.value == "yield":  # type: ignore
2592         return True
2593
2594     if node.type != syms.atom:
2595         return False
2596
2597     if len(node.children) != 3:
2598         return False
2599
2600     lpar, expr, rpar = node.children
2601     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2602         return is_yield(expr)
2603
2604     return False
2605
2606
2607 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2608     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2609
2610     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2611     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2612     extended iterable unpacking (PEP 3132) and additional unpacking
2613     generalizations (PEP 448).
2614     """
2615     if leaf.type not in STARS or not leaf.parent:
2616         return False
2617
2618     p = leaf.parent
2619     if p.type == syms.star_expr:
2620         # Star expressions are also used as assignment targets in extended
2621         # iterable unpacking (PEP 3132).  See what its parent is instead.
2622         if not p.parent:
2623             return False
2624
2625         p = p.parent
2626
2627     return p.type in within
2628
2629
2630 def is_multiline_string(leaf: Leaf) -> bool:
2631     """Return True if `leaf` is a multiline string that actually spans many lines."""
2632     value = leaf.value.lstrip("furbFURB")
2633     return value[:3] in {'"""', "'''"} and "\n" in value
2634
2635
2636 def is_stub_suite(node: Node) -> bool:
2637     """Return True if `node` is a suite with a stub body."""
2638     if (
2639         len(node.children) != 4
2640         or node.children[0].type != token.NEWLINE
2641         or node.children[1].type != token.INDENT
2642         or node.children[3].type != token.DEDENT
2643     ):
2644         return False
2645
2646     return is_stub_body(node.children[2])
2647
2648
2649 def is_stub_body(node: LN) -> bool:
2650     """Return True if `node` is a simple statement containing an ellipsis."""
2651     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2652         return False
2653
2654     if len(node.children) != 2:
2655         return False
2656
2657     child = node.children[0]
2658     return (
2659         child.type == syms.atom
2660         and len(child.children) == 3
2661         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2662     )
2663
2664
2665 def max_delimiter_priority_in_atom(node: LN) -> int:
2666     """Return maximum delimiter priority inside `node`.
2667
2668     This is specific to atoms with contents contained in a pair of parentheses.
2669     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2670     """
2671     if node.type != syms.atom:
2672         return 0
2673
2674     first = node.children[0]
2675     last = node.children[-1]
2676     if not (first.type == token.LPAR and last.type == token.RPAR):
2677         return 0
2678
2679     bt = BracketTracker()
2680     for c in node.children[1:-1]:
2681         if isinstance(c, Leaf):
2682             bt.mark(c)
2683         else:
2684             for leaf in c.leaves():
2685                 bt.mark(leaf)
2686     try:
2687         return bt.max_delimiter_priority()
2688
2689     except ValueError:
2690         return 0
2691
2692
2693 def ensure_visible(leaf: Leaf) -> None:
2694     """Make sure parentheses are visible.
2695
2696     They could be invisible as part of some statements (see
2697     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2698     """
2699     if leaf.type == token.LPAR:
2700         leaf.value = "("
2701     elif leaf.type == token.RPAR:
2702         leaf.value = ")"
2703
2704
2705 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2706     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2707     if not (
2708         opening_bracket.parent
2709         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2710         and opening_bracket.value in "[{("
2711     ):
2712         return False
2713
2714     try:
2715         last_leaf = line.leaves[-1]
2716         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2717         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2718     except (IndexError, ValueError):
2719         return False
2720
2721     return max_priority == COMMA_PRIORITY
2722
2723
2724 def is_python36(node: Node) -> bool:
2725     """Return True if the current file is using Python 3.6+ features.
2726
2727     Currently looking for:
2728     - f-strings; and
2729     - trailing commas after * or ** in function signatures and calls.
2730     """
2731     for n in node.pre_order():
2732         if n.type == token.STRING:
2733             value_head = n.value[:2]  # type: ignore
2734             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2735                 return True
2736
2737         elif (
2738             n.type in {syms.typedargslist, syms.arglist}
2739             and n.children
2740             and n.children[-1].type == token.COMMA
2741         ):
2742             for ch in n.children:
2743                 if ch.type in STARS:
2744                     return True
2745
2746                 if ch.type == syms.argument:
2747                     for argch in ch.children:
2748                         if argch.type in STARS:
2749                             return True
2750
2751     return False
2752
2753
2754 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2755     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2756
2757     Brackets can be omitted if the entire trailer up to and including
2758     a preceding closing bracket fits in one line.
2759
2760     Yielded sets are cumulative (contain results of previous yields, too).  First
2761     set is empty.
2762     """
2763
2764     omit: Set[LeafID] = set()
2765     yield omit
2766
2767     length = 4 * line.depth
2768     opening_bracket = None
2769     closing_bracket = None
2770     optional_brackets: Set[LeafID] = set()
2771     inner_brackets: Set[LeafID] = set()
2772     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2773         length += leaf_length
2774         if length > line_length:
2775             break
2776
2777         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2778         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2779             break
2780
2781         optional_brackets.discard(id(leaf))
2782         if opening_bracket:
2783             if leaf is opening_bracket:
2784                 opening_bracket = None
2785             elif leaf.type in CLOSING_BRACKETS:
2786                 inner_brackets.add(id(leaf))
2787         elif leaf.type in CLOSING_BRACKETS:
2788             if not leaf.value:
2789                 optional_brackets.add(id(opening_bracket))
2790                 continue
2791
2792             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2793                 # Empty brackets would fail a split so treat them as "inner"
2794                 # brackets (e.g. only add them to the `omit` set if another
2795                 # pair of brackets was good enough.
2796                 inner_brackets.add(id(leaf))
2797                 continue
2798
2799             opening_bracket = leaf.opening_bracket
2800             if closing_bracket:
2801                 omit.add(id(closing_bracket))
2802                 omit.update(inner_brackets)
2803                 inner_brackets.clear()
2804                 yield omit
2805             closing_bracket = leaf
2806
2807
2808 def get_future_imports(node: Node) -> Set[str]:
2809     """Return a set of __future__ imports in the file."""
2810     imports = set()
2811     for child in node.children:
2812         if child.type != syms.simple_stmt:
2813             break
2814         first_child = child.children[0]
2815         if isinstance(first_child, Leaf):
2816             # Continue looking if we see a docstring; otherwise stop.
2817             if (
2818                 len(child.children) == 2
2819                 and first_child.type == token.STRING
2820                 and child.children[1].type == token.NEWLINE
2821             ):
2822                 continue
2823             else:
2824                 break
2825         elif first_child.type == syms.import_from:
2826             module_name = first_child.children[1]
2827             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2828                 break
2829             for import_from_child in first_child.children[3:]:
2830                 if isinstance(import_from_child, Leaf):
2831                     if import_from_child.type == token.NAME:
2832                         imports.add(import_from_child.value)
2833                 else:
2834                     assert import_from_child.type == syms.import_as_names
2835                     for leaf in import_from_child.children:
2836                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2837                             imports.add(leaf.value)
2838         else:
2839             break
2840     return imports
2841
2842
2843 def gen_python_files_in_dir(
2844     path: Path,
2845     root: Path,
2846     include: Pattern[str],
2847     exclude: Pattern[str],
2848     report: "Report",
2849 ) -> Iterator[Path]:
2850     """Generate all files under `path` whose paths are not excluded by the
2851     `exclude` regex, but are included by the `include` regex.
2852
2853     `report` is where output about exclusions goes.
2854     """
2855     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2856     for child in path.iterdir():
2857         normalized_path = "/" + child.resolve().relative_to(root).as_posix()
2858         if child.is_dir():
2859             normalized_path += "/"
2860         exclude_match = exclude.search(normalized_path)
2861         if exclude_match and exclude_match.group(0):
2862             report.path_ignored(child, f"matches --exclude={exclude.pattern}")
2863             continue
2864
2865         if child.is_dir():
2866             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2867
2868         elif child.is_file():
2869             include_match = include.search(normalized_path)
2870             if include_match:
2871                 yield child
2872
2873
2874 def find_project_root(srcs: List[str]) -> Path:
2875     """Return a directory containing .git, .hg, or pyproject.toml.
2876
2877     That directory can be one of the directories passed in `srcs` or their
2878     common parent.
2879
2880     If no directory in the tree contains a marker that would specify it's the
2881     project root, the root of the file system is returned.
2882     """
2883     if not srcs:
2884         return Path("/").resolve()
2885
2886     common_base = min(Path(src).resolve() for src in srcs)
2887     if common_base.is_dir():
2888         # Append a fake file so `parents` below returns `common_base_dir`, too.
2889         common_base /= "fake-file"
2890     for directory in common_base.parents:
2891         if (directory / ".git").is_dir():
2892             return directory
2893
2894         if (directory / ".hg").is_dir():
2895             return directory
2896
2897         if (directory / "pyproject.toml").is_file():
2898             return directory
2899
2900     return directory
2901
2902
2903 @dataclass
2904 class Report:
2905     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2906
2907     check: bool = False
2908     quiet: bool = False
2909     verbose: bool = False
2910     change_count: int = 0
2911     same_count: int = 0
2912     failure_count: int = 0
2913
2914     def done(self, src: Path, changed: Changed) -> None:
2915         """Increment the counter for successful reformatting. Write out a message."""
2916         if changed is Changed.YES:
2917             reformatted = "would reformat" if self.check else "reformatted"
2918             if self.verbose or not self.quiet:
2919                 out(f"{reformatted} {src}")
2920             self.change_count += 1
2921         else:
2922             if self.verbose:
2923                 if changed is Changed.NO:
2924                     msg = f"{src} already well formatted, good job."
2925                 else:
2926                     msg = f"{src} wasn't modified on disk since last run."
2927                 out(msg, bold=False)
2928             self.same_count += 1
2929
2930     def failed(self, src: Path, message: str) -> None:
2931         """Increment the counter for failed reformatting. Write out a message."""
2932         err(f"error: cannot format {src}: {message}")
2933         self.failure_count += 1
2934
2935     def path_ignored(self, path: Path, message: str) -> None:
2936         if self.verbose:
2937             out(f"{path} ignored: {message}", bold=False)
2938
2939     @property
2940     def return_code(self) -> int:
2941         """Return the exit code that the app should use.
2942
2943         This considers the current state of changed files and failures:
2944         - if there were any failures, return 123;
2945         - if any files were changed and --check is being used, return 1;
2946         - otherwise return 0.
2947         """
2948         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2949         # 126 we have special returncodes reserved by the shell.
2950         if self.failure_count:
2951             return 123
2952
2953         elif self.change_count and self.check:
2954             return 1
2955
2956         return 0
2957
2958     def __str__(self) -> str:
2959         """Render a color report of the current state.
2960
2961         Use `click.unstyle` to remove colors.
2962         """
2963         if self.check:
2964             reformatted = "would be reformatted"
2965             unchanged = "would be left unchanged"
2966             failed = "would fail to reformat"
2967         else:
2968             reformatted = "reformatted"
2969             unchanged = "left unchanged"
2970             failed = "failed to reformat"
2971         report = []
2972         if self.change_count:
2973             s = "s" if self.change_count > 1 else ""
2974             report.append(
2975                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2976             )
2977         if self.same_count:
2978             s = "s" if self.same_count > 1 else ""
2979             report.append(f"{self.same_count} file{s} {unchanged}")
2980         if self.failure_count:
2981             s = "s" if self.failure_count > 1 else ""
2982             report.append(
2983                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2984             )
2985         return ", ".join(report) + "."
2986
2987
2988 def assert_equivalent(src: str, dst: str) -> None:
2989     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2990
2991     import ast
2992     import traceback
2993
2994     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2995         """Simple visitor generating strings to compare ASTs by content."""
2996         yield f"{'  ' * depth}{node.__class__.__name__}("
2997
2998         for field in sorted(node._fields):
2999             try:
3000                 value = getattr(node, field)
3001             except AttributeError:
3002                 continue
3003
3004             yield f"{'  ' * (depth+1)}{field}="
3005
3006             if isinstance(value, list):
3007                 for item in value:
3008                     if isinstance(item, ast.AST):
3009                         yield from _v(item, depth + 2)
3010
3011             elif isinstance(value, ast.AST):
3012                 yield from _v(value, depth + 2)
3013
3014             else:
3015                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
3016
3017         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
3018
3019     try:
3020         src_ast = ast.parse(src)
3021     except Exception as exc:
3022         major, minor = sys.version_info[:2]
3023         raise AssertionError(
3024             f"cannot use --safe with this file; failed to parse source file "
3025             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
3026             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
3027         )
3028
3029     try:
3030         dst_ast = ast.parse(dst)
3031     except Exception as exc:
3032         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3033         raise AssertionError(
3034             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3035             f"Please report a bug on https://github.com/ambv/black/issues.  "
3036             f"This invalid output might be helpful: {log}"
3037         ) from None
3038
3039     src_ast_str = "\n".join(_v(src_ast))
3040     dst_ast_str = "\n".join(_v(dst_ast))
3041     if src_ast_str != dst_ast_str:
3042         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3043         raise AssertionError(
3044             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3045             f"the source.  "
3046             f"Please report a bug on https://github.com/ambv/black/issues.  "
3047             f"This diff might be helpful: {log}"
3048         ) from None
3049
3050
3051 def assert_stable(
3052     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3053 ) -> None:
3054     """Raise AssertionError if `dst` reformats differently the second time."""
3055     newdst = format_str(dst, line_length=line_length, mode=mode)
3056     if dst != newdst:
3057         log = dump_to_file(
3058             diff(src, dst, "source", "first pass"),
3059             diff(dst, newdst, "first pass", "second pass"),
3060         )
3061         raise AssertionError(
3062             f"INTERNAL ERROR: Black produced different code on the second pass "
3063             f"of the formatter.  "
3064             f"Please report a bug on https://github.com/ambv/black/issues.  "
3065             f"This diff might be helpful: {log}"
3066         ) from None
3067
3068
3069 def dump_to_file(*output: str) -> str:
3070     """Dump `output` to a temporary file. Return path to the file."""
3071     import tempfile
3072
3073     with tempfile.NamedTemporaryFile(
3074         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3075     ) as f:
3076         for lines in output:
3077             f.write(lines)
3078             if lines and lines[-1] != "\n":
3079                 f.write("\n")
3080     return f.name
3081
3082
3083 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3084     """Return a unified diff string between strings `a` and `b`."""
3085     import difflib
3086
3087     a_lines = [line + "\n" for line in a.split("\n")]
3088     b_lines = [line + "\n" for line in b.split("\n")]
3089     return "".join(
3090         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3091     )
3092
3093
3094 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3095     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3096     err("Aborted!")
3097     for task in tasks:
3098         task.cancel()
3099
3100
3101 def shutdown(loop: BaseEventLoop) -> None:
3102     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3103     try:
3104         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3105         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3106         if not to_cancel:
3107             return
3108
3109         for task in to_cancel:
3110             task.cancel()
3111         loop.run_until_complete(
3112             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3113         )
3114     finally:
3115         # `concurrent.futures.Future` objects cannot be cancelled once they
3116         # are already running. There might be some when the `shutdown()` happened.
3117         # Silence their logger's spew about the event loop being closed.
3118         cf_logger = logging.getLogger("concurrent.futures")
3119         cf_logger.setLevel(logging.CRITICAL)
3120         loop.close()
3121
3122
3123 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3124     """Replace `regex` with `replacement` twice on `original`.
3125
3126     This is used by string normalization to perform replaces on
3127     overlapping matches.
3128     """
3129     return regex.sub(replacement, regex.sub(replacement, original))
3130
3131
3132 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3133     """Like `reversed(enumerate(sequence))` if that were possible."""
3134     index = len(sequence) - 1
3135     for element in reversed(sequence):
3136         yield (index, element)
3137         index -= 1
3138
3139
3140 def enumerate_with_length(
3141     line: Line, reversed: bool = False
3142 ) -> Iterator[Tuple[Index, Leaf, int]]:
3143     """Return an enumeration of leaves with their length.
3144
3145     Stops prematurely on multiline strings and standalone comments.
3146     """
3147     op = cast(
3148         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3149         enumerate_reversed if reversed else enumerate,
3150     )
3151     for index, leaf in op(line.leaves):
3152         length = len(leaf.prefix) + len(leaf.value)
3153         if "\n" in leaf.value:
3154             return  # Multiline strings, we can't continue.
3155
3156         comment: Optional[Leaf]
3157         for comment in line.comments_after(leaf, index):
3158             length += len(comment.value)
3159
3160         yield index, leaf, length
3161
3162
3163 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3164     """Return True if `line` is no longer than `line_length`.
3165
3166     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3167     """
3168     if not line_str:
3169         line_str = str(line).strip("\n")
3170     return (
3171         len(line_str) <= line_length
3172         and "\n" not in line_str  # multiline strings
3173         and not line.contains_standalone_comments()
3174     )
3175
3176
3177 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3178     """Does `line` have a shape safe to reformat without optional parens around it?
3179
3180     Returns True for only a subset of potentially nice looking formattings but
3181     the point is to not return false positives that end up producing lines that
3182     are too long.
3183     """
3184     bt = line.bracket_tracker
3185     if not bt.delimiters:
3186         # Without delimiters the optional parentheses are useless.
3187         return True
3188
3189     max_priority = bt.max_delimiter_priority()
3190     if bt.delimiter_count_with_priority(max_priority) > 1:
3191         # With more than one delimiter of a kind the optional parentheses read better.
3192         return False
3193
3194     if max_priority == DOT_PRIORITY:
3195         # A single stranded method call doesn't require optional parentheses.
3196         return True
3197
3198     assert len(line.leaves) >= 2, "Stranded delimiter"
3199
3200     first = line.leaves[0]
3201     second = line.leaves[1]
3202     penultimate = line.leaves[-2]
3203     last = line.leaves[-1]
3204
3205     # With a single delimiter, omit if the expression starts or ends with
3206     # a bracket.
3207     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3208         remainder = False
3209         length = 4 * line.depth
3210         for _index, leaf, leaf_length in enumerate_with_length(line):
3211             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3212                 remainder = True
3213             if remainder:
3214                 length += leaf_length
3215                 if length > line_length:
3216                     break
3217
3218                 if leaf.type in OPENING_BRACKETS:
3219                     # There are brackets we can further split on.
3220                     remainder = False
3221
3222         else:
3223             # checked the entire string and line length wasn't exceeded
3224             if len(line.leaves) == _index + 1:
3225                 return True
3226
3227         # Note: we are not returning False here because a line might have *both*
3228         # a leading opening bracket and a trailing closing bracket.  If the
3229         # opening bracket doesn't match our rule, maybe the closing will.
3230
3231     if (
3232         last.type == token.RPAR
3233         or last.type == token.RBRACE
3234         or (
3235             # don't use indexing for omitting optional parentheses;
3236             # it looks weird
3237             last.type == token.RSQB
3238             and last.parent
3239             and last.parent.type != syms.trailer
3240         )
3241     ):
3242         if penultimate.type in OPENING_BRACKETS:
3243             # Empty brackets don't help.
3244             return False
3245
3246         if is_multiline_string(first):
3247             # Additional wrapping of a multiline string in this situation is
3248             # unnecessary.
3249             return True
3250
3251         length = 4 * line.depth
3252         seen_other_brackets = False
3253         for _index, leaf, leaf_length in enumerate_with_length(line):
3254             length += leaf_length
3255             if leaf is last.opening_bracket:
3256                 if seen_other_brackets or length <= line_length:
3257                     return True
3258
3259             elif leaf.type in OPENING_BRACKETS:
3260                 # There are brackets we can further split on.
3261                 seen_other_brackets = True
3262
3263     return False
3264
3265
3266 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3267     pyi = bool(mode & FileMode.PYI)
3268     py36 = bool(mode & FileMode.PYTHON36)
3269     return (
3270         CACHE_DIR
3271         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3272     )
3273
3274
3275 def read_cache(line_length: int, mode: FileMode) -> Cache:
3276     """Read the cache if it exists and is well formed.
3277
3278     If it is not well formed, the call to write_cache later should resolve the issue.
3279     """
3280     cache_file = get_cache_file(line_length, mode)
3281     if not cache_file.exists():
3282         return {}
3283
3284     with cache_file.open("rb") as fobj:
3285         try:
3286             cache: Cache = pickle.load(fobj)
3287         except pickle.UnpicklingError:
3288             return {}
3289
3290     return cache
3291
3292
3293 def get_cache_info(path: Path) -> CacheInfo:
3294     """Return the information used to check if a file is already formatted or not."""
3295     stat = path.stat()
3296     return stat.st_mtime, stat.st_size
3297
3298
3299 def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
3300     """Split an iterable of paths in `sources` into two sets.
3301
3302     The first contains paths of files that modified on disk or are not in the
3303     cache. The other contains paths to non-modified files.
3304     """
3305     todo, done = set(), set()
3306     for src in sources:
3307         src = src.resolve()
3308         if cache.get(src) != get_cache_info(src):
3309             todo.add(src)
3310         else:
3311             done.add(src)
3312     return todo, done
3313
3314
3315 def write_cache(
3316     cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
3317 ) -> None:
3318     """Update the cache file."""
3319     cache_file = get_cache_file(line_length, mode)
3320     try:
3321         if not CACHE_DIR.exists():
3322             CACHE_DIR.mkdir(parents=True)
3323         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3324         with cache_file.open("wb") as fobj:
3325             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3326     except OSError:
3327         pass
3328
3329
3330 if __name__ == "__main__":
3331     main()