]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add `--verbose` and report excluded paths in it, too
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum, Flag
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b1"
48 DEFAULT_LINE_LENGTH = 88
49 DEFAULT_EXCLUDES = (
50     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
51 )
52 DEFAULT_INCLUDES = r"\.pyi?$"
53 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
54
55
56 # types
57 FileContent = str
58 Encoding = str
59 Depth = int
60 NodeType = int
61 LeafID = int
62 Priority = int
63 Index = int
64 LN = Union[Leaf, Node]
65 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
66 Timestamp = float
67 FileSize = int
68 CacheInfo = Tuple[Timestamp, FileSize]
69 Cache = Dict[Path, CacheInfo]
70 out = partial(click.secho, bold=True, err=True)
71 err = partial(click.secho, fg="red", err=True)
72
73 pygram.initialize(CACHE_DIR)
74 syms = pygram.python_symbols
75
76
77 class NothingChanged(UserWarning):
78     """Raised by :func:`format_file` when reformatted code is the same as source."""
79
80
81 class CannotSplit(Exception):
82     """A readable split that fits the allotted line length is impossible.
83
84     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
85     :func:`delimiter_split`.
86     """
87
88
89 class FormatError(Exception):
90     """Base exception for `# fmt: on` and `# fmt: off` handling.
91
92     It holds the number of bytes of the prefix consumed before the format
93     control comment appeared.
94     """
95
96     def __init__(self, consumed: int) -> None:
97         super().__init__(consumed)
98         self.consumed = consumed
99
100     def trim_prefix(self, leaf: Leaf) -> None:
101         leaf.prefix = leaf.prefix[self.consumed :]
102
103     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
104         """Returns a new Leaf from the consumed part of the prefix."""
105         unformatted_prefix = leaf.prefix[: self.consumed]
106         return Leaf(token.NEWLINE, unformatted_prefix)
107
108
109 class FormatOn(FormatError):
110     """Found a comment like `# fmt: on` in the file."""
111
112
113 class FormatOff(FormatError):
114     """Found a comment like `# fmt: off` in the file."""
115
116
117 class WriteBack(Enum):
118     NO = 0
119     YES = 1
120     DIFF = 2
121
122     @classmethod
123     def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
124         if check and not diff:
125             return cls.NO
126
127         return cls.DIFF if diff else cls.YES
128
129
130 class Changed(Enum):
131     NO = 0
132     CACHED = 1
133     YES = 2
134
135
136 class FileMode(Flag):
137     AUTO_DETECT = 0
138     PYTHON36 = 1
139     PYI = 2
140     NO_STRING_NORMALIZATION = 4
141
142     @classmethod
143     def from_configuration(
144         cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
145     ) -> "FileMode":
146         mode = cls.AUTO_DETECT
147         if py36:
148             mode |= cls.PYTHON36
149         if pyi:
150             mode |= cls.PYI
151         if skip_string_normalization:
152             mode |= cls.NO_STRING_NORMALIZATION
153         return mode
154
155
156 @click.command()
157 @click.option(
158     "-l",
159     "--line-length",
160     type=int,
161     default=DEFAULT_LINE_LENGTH,
162     help="How many character per line to allow.",
163     show_default=True,
164 )
165 @click.option(
166     "--py36",
167     is_flag=True,
168     help=(
169         "Allow using Python 3.6-only syntax on all input files.  This will put "
170         "trailing commas in function signatures and calls also after *args and "
171         "**kwargs.  [default: per-file auto-detection]"
172     ),
173 )
174 @click.option(
175     "--pyi",
176     is_flag=True,
177     help=(
178         "Format all input files like typing stubs regardless of file extension "
179         "(useful when piping source on standard input)."
180     ),
181 )
182 @click.option(
183     "-S",
184     "--skip-string-normalization",
185     is_flag=True,
186     help="Don't normalize string quotes or prefixes.",
187 )
188 @click.option(
189     "--check",
190     is_flag=True,
191     help=(
192         "Don't write the files back, just return the status.  Return code 0 "
193         "means nothing would change.  Return code 1 means some files would be "
194         "reformatted.  Return code 123 means there was an internal error."
195     ),
196 )
197 @click.option(
198     "--diff",
199     is_flag=True,
200     help="Don't write the files back, just output a diff for each file on stdout.",
201 )
202 @click.option(
203     "--fast/--safe",
204     is_flag=True,
205     help="If --fast given, skip temporary sanity checks. [default: --safe]",
206 )
207 @click.option(
208     "--include",
209     type=str,
210     default=DEFAULT_INCLUDES,
211     help=(
212         "A regular expression that matches files and directories that should be "
213         "included on recursive searches.  An empty value means all files are "
214         "included regardless of the name.  Use forward slashes for directories on "
215         "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
216         "later."
217     ),
218     show_default=True,
219 )
220 @click.option(
221     "--exclude",
222     type=str,
223     default=DEFAULT_EXCLUDES,
224     help=(
225         "A regular expression that matches files and directories that should be "
226         "excluded on recursive searches.  An empty value means no paths are excluded. "
227         "Use forward slashes for directories on all platforms (Windows, too).  "
228         "Exclusions are calculated first, inclusions later."
229     ),
230     show_default=True,
231 )
232 @click.option(
233     "-q",
234     "--quiet",
235     is_flag=True,
236     help=(
237         "Don't emit non-error messages to stderr. Errors are still emitted, "
238         "silence those with 2>/dev/null."
239     ),
240 )
241 @click.option(
242     "-v",
243     "--verbose",
244     is_flag=True,
245     help=(
246         "Also emit messages to stderr about files that were not changed or were "
247         "ignored due to --exclude=."
248     ),
249 )
250 @click.version_option(version=__version__)
251 @click.argument(
252     "src",
253     nargs=-1,
254     type=click.Path(
255         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
256     ),
257 )
258 @click.pass_context
259 def main(
260     ctx: click.Context,
261     line_length: int,
262     check: bool,
263     diff: bool,
264     fast: bool,
265     pyi: bool,
266     py36: bool,
267     skip_string_normalization: bool,
268     quiet: bool,
269     verbose: bool,
270     include: str,
271     exclude: str,
272     src: List[str],
273 ) -> None:
274     """The uncompromising code formatter."""
275     write_back = WriteBack.from_configuration(check=check, diff=diff)
276     mode = FileMode.from_configuration(
277         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
278     )
279     report = Report(check=check, quiet=quiet, verbose=verbose)
280     sources: List[Path] = []
281     try:
282         include_regex = re.compile(include)
283     except re.error:
284         err(f"Invalid regular expression for include given: {include!r}")
285         ctx.exit(2)
286     try:
287         exclude_regex = re.compile(exclude)
288     except re.error:
289         err(f"Invalid regular expression for exclude given: {exclude!r}")
290         ctx.exit(2)
291     root = find_project_root(src)
292     for s in src:
293         p = Path(s)
294         if p.is_dir():
295             sources.extend(
296                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
297             )
298         elif p.is_file() or s == "-":
299             # if a file was explicitly given, we don't care about its extension
300             sources.append(p)
301         else:
302             err(f"invalid path: {s}")
303     if len(sources) == 0:
304         out("No paths given. Nothing to do 😴")
305         ctx.exit(0)
306         return
307
308     elif len(sources) == 1:
309         reformat_one(
310             src=sources[0],
311             line_length=line_length,
312             fast=fast,
313             write_back=write_back,
314             mode=mode,
315             report=report,
316         )
317     else:
318         loop = asyncio.get_event_loop()
319         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
320         try:
321             loop.run_until_complete(
322                 schedule_formatting(
323                     sources=sources,
324                     line_length=line_length,
325                     fast=fast,
326                     write_back=write_back,
327                     mode=mode,
328                     report=report,
329                     loop=loop,
330                     executor=executor,
331                 )
332             )
333         finally:
334             shutdown(loop)
335         if not quiet:
336             out("All done! ✨ 🍰 ✨")
337             click.echo(str(report))
338     ctx.exit(report.return_code)
339
340
341 def reformat_one(
342     src: Path,
343     line_length: int,
344     fast: bool,
345     write_back: WriteBack,
346     mode: FileMode,
347     report: "Report",
348 ) -> None:
349     """Reformat a single file under `src` without spawning child processes.
350
351     If `quiet` is True, non-error messages are not output. `line_length`,
352     `write_back`, `fast` and `pyi` options are passed to
353     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
354     """
355     try:
356         changed = Changed.NO
357         if not src.is_file() and str(src) == "-":
358             if format_stdin_to_stdout(
359                 line_length=line_length, fast=fast, write_back=write_back, mode=mode
360             ):
361                 changed = Changed.YES
362         else:
363             cache: Cache = {}
364             if write_back != WriteBack.DIFF:
365                 cache = read_cache(line_length, mode)
366                 res_src = src.resolve()
367                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
368                     changed = Changed.CACHED
369             if changed is not Changed.CACHED and format_file_in_place(
370                 src,
371                 line_length=line_length,
372                 fast=fast,
373                 write_back=write_back,
374                 mode=mode,
375             ):
376                 changed = Changed.YES
377             if write_back == WriteBack.YES and changed is not Changed.NO:
378                 write_cache(cache, [src], line_length, mode)
379         report.done(src, changed)
380     except Exception as exc:
381         report.failed(src, str(exc))
382
383
384 async def schedule_formatting(
385     sources: List[Path],
386     line_length: int,
387     fast: bool,
388     write_back: WriteBack,
389     mode: FileMode,
390     report: "Report",
391     loop: BaseEventLoop,
392     executor: Executor,
393 ) -> None:
394     """Run formatting of `sources` in parallel using the provided `executor`.
395
396     (Use ProcessPoolExecutors for actual parallelism.)
397
398     `line_length`, `write_back`, `fast`, and `pyi` options are passed to
399     :func:`format_file_in_place`.
400     """
401     cache: Cache = {}
402     if write_back != WriteBack.DIFF:
403         cache = read_cache(line_length, mode)
404         sources, cached = filter_cached(cache, sources)
405         for src in cached:
406             report.done(src, Changed.CACHED)
407     cancelled = []
408     formatted = []
409     if sources:
410         lock = None
411         if write_back == WriteBack.DIFF:
412             # For diff output, we need locks to ensure we don't interleave output
413             # from different processes.
414             manager = Manager()
415             lock = manager.Lock()
416         tasks = {
417             loop.run_in_executor(
418                 executor,
419                 format_file_in_place,
420                 src,
421                 line_length,
422                 fast,
423                 write_back,
424                 mode,
425                 lock,
426             ): src
427             for src in sorted(sources)
428         }
429         pending: Iterable[asyncio.Task] = tasks.keys()
430         try:
431             loop.add_signal_handler(signal.SIGINT, cancel, pending)
432             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
433         except NotImplementedError:
434             # There are no good alternatives for these on Windows
435             pass
436         while pending:
437             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
438             for task in done:
439                 src = tasks.pop(task)
440                 if task.cancelled():
441                     cancelled.append(task)
442                 elif task.exception():
443                     report.failed(src, str(task.exception()))
444                 else:
445                     formatted.append(src)
446                     report.done(src, Changed.YES if task.result() else Changed.NO)
447     if cancelled:
448         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
449     if write_back == WriteBack.YES and formatted:
450         write_cache(cache, formatted, line_length, mode)
451
452
453 def format_file_in_place(
454     src: Path,
455     line_length: int,
456     fast: bool,
457     write_back: WriteBack = WriteBack.NO,
458     mode: FileMode = FileMode.AUTO_DETECT,
459     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
460 ) -> bool:
461     """Format file under `src` path. Return True if changed.
462
463     If `write_back` is True, write reformatted code back to stdout.
464     `line_length` and `fast` options are passed to :func:`format_file_contents`.
465     """
466     if src.suffix == ".pyi":
467         mode |= FileMode.PYI
468     with tokenize.open(src) as src_buffer:
469         src_contents = src_buffer.read()
470     try:
471         dst_contents = format_file_contents(
472             src_contents, line_length=line_length, fast=fast, mode=mode
473         )
474     except NothingChanged:
475         return False
476
477     if write_back == write_back.YES:
478         with open(src, "w", encoding=src_buffer.encoding) as f:
479             f.write(dst_contents)
480     elif write_back == write_back.DIFF:
481         src_name = f"{src}  (original)"
482         dst_name = f"{src}  (formatted)"
483         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
484         if lock:
485             lock.acquire()
486         try:
487             sys.stdout.write(diff_contents)
488         finally:
489             if lock:
490                 lock.release()
491     return True
492
493
494 def format_stdin_to_stdout(
495     line_length: int,
496     fast: bool,
497     write_back: WriteBack = WriteBack.NO,
498     mode: FileMode = FileMode.AUTO_DETECT,
499 ) -> bool:
500     """Format file on stdin. Return True if changed.
501
502     If `write_back` is True, write reformatted code back to stdout.
503     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
504     :func:`format_file_contents`.
505     """
506     src = sys.stdin.read()
507     dst = src
508     try:
509         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
510         return True
511
512     except NothingChanged:
513         return False
514
515     finally:
516         if write_back == WriteBack.YES:
517             sys.stdout.write(dst)
518         elif write_back == WriteBack.DIFF:
519             src_name = "<stdin>  (original)"
520             dst_name = "<stdin>  (formatted)"
521             sys.stdout.write(diff(src, dst, src_name, dst_name))
522
523
524 def format_file_contents(
525     src_contents: str,
526     *,
527     line_length: int,
528     fast: bool,
529     mode: FileMode = FileMode.AUTO_DETECT,
530 ) -> FileContent:
531     """Reformat contents a file and return new contents.
532
533     If `fast` is False, additionally confirm that the reformatted code is
534     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
535     `line_length` is passed to :func:`format_str`.
536     """
537     if src_contents.strip() == "":
538         raise NothingChanged
539
540     dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
541     if src_contents == dst_contents:
542         raise NothingChanged
543
544     if not fast:
545         assert_equivalent(src_contents, dst_contents)
546         assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
547     return dst_contents
548
549
550 def format_str(
551     src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
552 ) -> FileContent:
553     """Reformat a string and return new contents.
554
555     `line_length` determines how many characters per line are allowed.
556     """
557     src_node = lib2to3_parse(src_contents)
558     dst_contents = ""
559     future_imports = get_future_imports(src_node)
560     is_pyi = bool(mode & FileMode.PYI)
561     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
562     normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
563     lines = LineGenerator(
564         remove_u_prefix=py36 or "unicode_literals" in future_imports,
565         is_pyi=is_pyi,
566         normalize_strings=normalize_strings,
567     )
568     elt = EmptyLineTracker(is_pyi=is_pyi)
569     empty_line = Line()
570     after = 0
571     for current_line in lines.visit(src_node):
572         for _ in range(after):
573             dst_contents += str(empty_line)
574         before, after = elt.maybe_empty_lines(current_line)
575         for _ in range(before):
576             dst_contents += str(empty_line)
577         for line in split_line(current_line, line_length=line_length, py36=py36):
578             dst_contents += str(line)
579     return dst_contents
580
581
582 GRAMMARS = [
583     pygram.python_grammar_no_print_statement_no_exec_statement,
584     pygram.python_grammar_no_print_statement,
585     pygram.python_grammar,
586 ]
587
588
589 def lib2to3_parse(src_txt: str) -> Node:
590     """Given a string with source, return the lib2to3 Node."""
591     grammar = pygram.python_grammar_no_print_statement
592     if src_txt[-1] != "\n":
593         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
594         src_txt += nl
595     for grammar in GRAMMARS:
596         drv = driver.Driver(grammar, pytree.convert)
597         try:
598             result = drv.parse_string(src_txt, True)
599             break
600
601         except ParseError as pe:
602             lineno, column = pe.context[1]
603             lines = src_txt.splitlines()
604             try:
605                 faulty_line = lines[lineno - 1]
606             except IndexError:
607                 faulty_line = "<line number missing in source>"
608             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
609     else:
610         raise exc from None
611
612     if isinstance(result, Leaf):
613         result = Node(syms.file_input, [result])
614     return result
615
616
617 def lib2to3_unparse(node: Node) -> str:
618     """Given a lib2to3 node, return its string representation."""
619     code = str(node)
620     return code
621
622
623 T = TypeVar("T")
624
625
626 class Visitor(Generic[T]):
627     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
628
629     def visit(self, node: LN) -> Iterator[T]:
630         """Main method to visit `node` and its children.
631
632         It tries to find a `visit_*()` method for the given `node.type`, like
633         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
634         If no dedicated `visit_*()` method is found, chooses `visit_default()`
635         instead.
636
637         Then yields objects of type `T` from the selected visitor.
638         """
639         if node.type < 256:
640             name = token.tok_name[node.type]
641         else:
642             name = type_repr(node.type)
643         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
644
645     def visit_default(self, node: LN) -> Iterator[T]:
646         """Default `visit_*()` implementation. Recurses to children of `node`."""
647         if isinstance(node, Node):
648             for child in node.children:
649                 yield from self.visit(child)
650
651
652 @dataclass
653 class DebugVisitor(Visitor[T]):
654     tree_depth: int = 0
655
656     def visit_default(self, node: LN) -> Iterator[T]:
657         indent = " " * (2 * self.tree_depth)
658         if isinstance(node, Node):
659             _type = type_repr(node.type)
660             out(f"{indent}{_type}", fg="yellow")
661             self.tree_depth += 1
662             for child in node.children:
663                 yield from self.visit(child)
664
665             self.tree_depth -= 1
666             out(f"{indent}/{_type}", fg="yellow", bold=False)
667         else:
668             _type = token.tok_name.get(node.type, str(node.type))
669             out(f"{indent}{_type}", fg="blue", nl=False)
670             if node.prefix:
671                 # We don't have to handle prefixes for `Node` objects since
672                 # that delegates to the first child anyway.
673                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
674             out(f" {node.value!r}", fg="blue", bold=False)
675
676     @classmethod
677     def show(cls, code: str) -> None:
678         """Pretty-print the lib2to3 AST of a given string of `code`.
679
680         Convenience method for debugging.
681         """
682         v: DebugVisitor[None] = DebugVisitor()
683         list(v.visit(lib2to3_parse(code)))
684
685
686 KEYWORDS = set(keyword.kwlist)
687 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
688 FLOW_CONTROL = {"return", "raise", "break", "continue"}
689 STATEMENT = {
690     syms.if_stmt,
691     syms.while_stmt,
692     syms.for_stmt,
693     syms.try_stmt,
694     syms.except_clause,
695     syms.with_stmt,
696     syms.funcdef,
697     syms.classdef,
698 }
699 STANDALONE_COMMENT = 153
700 LOGIC_OPERATORS = {"and", "or"}
701 COMPARATORS = {
702     token.LESS,
703     token.GREATER,
704     token.EQEQUAL,
705     token.NOTEQUAL,
706     token.LESSEQUAL,
707     token.GREATEREQUAL,
708 }
709 MATH_OPERATORS = {
710     token.VBAR,
711     token.CIRCUMFLEX,
712     token.AMPER,
713     token.LEFTSHIFT,
714     token.RIGHTSHIFT,
715     token.PLUS,
716     token.MINUS,
717     token.STAR,
718     token.SLASH,
719     token.DOUBLESLASH,
720     token.PERCENT,
721     token.AT,
722     token.TILDE,
723     token.DOUBLESTAR,
724 }
725 STARS = {token.STAR, token.DOUBLESTAR}
726 VARARGS_PARENTS = {
727     syms.arglist,
728     syms.argument,  # double star in arglist
729     syms.trailer,  # single argument to call
730     syms.typedargslist,
731     syms.varargslist,  # lambdas
732 }
733 UNPACKING_PARENTS = {
734     syms.atom,  # single element of a list or set literal
735     syms.dictsetmaker,
736     syms.listmaker,
737     syms.testlist_gexp,
738 }
739 TEST_DESCENDANTS = {
740     syms.test,
741     syms.lambdef,
742     syms.or_test,
743     syms.and_test,
744     syms.not_test,
745     syms.comparison,
746     syms.star_expr,
747     syms.expr,
748     syms.xor_expr,
749     syms.and_expr,
750     syms.shift_expr,
751     syms.arith_expr,
752     syms.trailer,
753     syms.term,
754     syms.power,
755 }
756 ASSIGNMENTS = {
757     "=",
758     "+=",
759     "-=",
760     "*=",
761     "@=",
762     "/=",
763     "%=",
764     "&=",
765     "|=",
766     "^=",
767     "<<=",
768     ">>=",
769     "**=",
770     "//=",
771 }
772 COMPREHENSION_PRIORITY = 20
773 COMMA_PRIORITY = 18
774 TERNARY_PRIORITY = 16
775 LOGIC_PRIORITY = 14
776 STRING_PRIORITY = 12
777 COMPARATOR_PRIORITY = 10
778 MATH_PRIORITIES = {
779     token.VBAR: 9,
780     token.CIRCUMFLEX: 8,
781     token.AMPER: 7,
782     token.LEFTSHIFT: 6,
783     token.RIGHTSHIFT: 6,
784     token.PLUS: 5,
785     token.MINUS: 5,
786     token.STAR: 4,
787     token.SLASH: 4,
788     token.DOUBLESLASH: 4,
789     token.PERCENT: 4,
790     token.AT: 4,
791     token.TILDE: 3,
792     token.DOUBLESTAR: 2,
793 }
794 DOT_PRIORITY = 1
795
796
797 @dataclass
798 class BracketTracker:
799     """Keeps track of brackets on a line."""
800
801     depth: int = 0
802     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
803     delimiters: Dict[LeafID, Priority] = Factory(dict)
804     previous: Optional[Leaf] = None
805     _for_loop_variable: int = 0
806     _lambda_arguments: int = 0
807
808     def mark(self, leaf: Leaf) -> None:
809         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
810
811         All leaves receive an int `bracket_depth` field that stores how deep
812         within brackets a given leaf is. 0 means there are no enclosing brackets
813         that started on this line.
814
815         If a leaf is itself a closing bracket, it receives an `opening_bracket`
816         field that it forms a pair with. This is a one-directional link to
817         avoid reference cycles.
818
819         If a leaf is a delimiter (a token on which Black can split the line if
820         needed) and it's on depth 0, its `id()` is stored in the tracker's
821         `delimiters` field.
822         """
823         if leaf.type == token.COMMENT:
824             return
825
826         self.maybe_decrement_after_for_loop_variable(leaf)
827         self.maybe_decrement_after_lambda_arguments(leaf)
828         if leaf.type in CLOSING_BRACKETS:
829             self.depth -= 1
830             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
831             leaf.opening_bracket = opening_bracket
832         leaf.bracket_depth = self.depth
833         if self.depth == 0:
834             delim = is_split_before_delimiter(leaf, self.previous)
835             if delim and self.previous is not None:
836                 self.delimiters[id(self.previous)] = delim
837             else:
838                 delim = is_split_after_delimiter(leaf, self.previous)
839                 if delim:
840                     self.delimiters[id(leaf)] = delim
841         if leaf.type in OPENING_BRACKETS:
842             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
843             self.depth += 1
844         self.previous = leaf
845         self.maybe_increment_lambda_arguments(leaf)
846         self.maybe_increment_for_loop_variable(leaf)
847
848     def any_open_brackets(self) -> bool:
849         """Return True if there is an yet unmatched open bracket on the line."""
850         return bool(self.bracket_match)
851
852     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
853         """Return the highest priority of a delimiter found on the line.
854
855         Values are consistent with what `is_split_*_delimiter()` return.
856         Raises ValueError on no delimiters.
857         """
858         return max(v for k, v in self.delimiters.items() if k not in exclude)
859
860     def delimiter_count_with_priority(self, priority: int = 0) -> int:
861         """Return the number of delimiters with the given `priority`.
862
863         If no `priority` is passed, defaults to max priority on the line.
864         """
865         if not self.delimiters:
866             return 0
867
868         priority = priority or self.max_delimiter_priority()
869         return sum(1 for p in self.delimiters.values() if p == priority)
870
871     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
872         """In a for loop, or comprehension, the variables are often unpacks.
873
874         To avoid splitting on the comma in this situation, increase the depth of
875         tokens between `for` and `in`.
876         """
877         if leaf.type == token.NAME and leaf.value == "for":
878             self.depth += 1
879             self._for_loop_variable += 1
880             return True
881
882         return False
883
884     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
885         """See `maybe_increment_for_loop_variable` above for explanation."""
886         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
887             self.depth -= 1
888             self._for_loop_variable -= 1
889             return True
890
891         return False
892
893     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
894         """In a lambda expression, there might be more than one argument.
895
896         To avoid splitting on the comma in this situation, increase the depth of
897         tokens between `lambda` and `:`.
898         """
899         if leaf.type == token.NAME and leaf.value == "lambda":
900             self.depth += 1
901             self._lambda_arguments += 1
902             return True
903
904         return False
905
906     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
907         """See `maybe_increment_lambda_arguments` above for explanation."""
908         if self._lambda_arguments and leaf.type == token.COLON:
909             self.depth -= 1
910             self._lambda_arguments -= 1
911             return True
912
913         return False
914
915     def get_open_lsqb(self) -> Optional[Leaf]:
916         """Return the most recent opening square bracket (if any)."""
917         return self.bracket_match.get((self.depth - 1, token.RSQB))
918
919
920 @dataclass
921 class Line:
922     """Holds leaves and comments. Can be printed with `str(line)`."""
923
924     depth: int = 0
925     leaves: List[Leaf] = Factory(list)
926     comments: List[Tuple[Index, Leaf]] = Factory(list)
927     bracket_tracker: BracketTracker = Factory(BracketTracker)
928     inside_brackets: bool = False
929     should_explode: bool = False
930
931     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
932         """Add a new `leaf` to the end of the line.
933
934         Unless `preformatted` is True, the `leaf` will receive a new consistent
935         whitespace prefix and metadata applied by :class:`BracketTracker`.
936         Trailing commas are maybe removed, unpacked for loop variables are
937         demoted from being delimiters.
938
939         Inline comments are put aside.
940         """
941         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
942         if not has_value:
943             return
944
945         if token.COLON == leaf.type and self.is_class_paren_empty:
946             del self.leaves[-2:]
947         if self.leaves and not preformatted:
948             # Note: at this point leaf.prefix should be empty except for
949             # imports, for which we only preserve newlines.
950             leaf.prefix += whitespace(
951                 leaf, complex_subscript=self.is_complex_subscript(leaf)
952             )
953         if self.inside_brackets or not preformatted:
954             self.bracket_tracker.mark(leaf)
955             self.maybe_remove_trailing_comma(leaf)
956         if not self.append_comment(leaf):
957             self.leaves.append(leaf)
958
959     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
960         """Like :func:`append()` but disallow invalid standalone comment structure.
961
962         Raises ValueError when any `leaf` is appended after a standalone comment
963         or when a standalone comment is not the first leaf on the line.
964         """
965         if self.bracket_tracker.depth == 0:
966             if self.is_comment:
967                 raise ValueError("cannot append to standalone comments")
968
969             if self.leaves and leaf.type == STANDALONE_COMMENT:
970                 raise ValueError(
971                     "cannot append standalone comments to a populated line"
972                 )
973
974         self.append(leaf, preformatted=preformatted)
975
976     @property
977     def is_comment(self) -> bool:
978         """Is this line a standalone comment?"""
979         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
980
981     @property
982     def is_decorator(self) -> bool:
983         """Is this line a decorator?"""
984         return bool(self) and self.leaves[0].type == token.AT
985
986     @property
987     def is_import(self) -> bool:
988         """Is this an import line?"""
989         return bool(self) and is_import(self.leaves[0])
990
991     @property
992     def is_class(self) -> bool:
993         """Is this line a class definition?"""
994         return (
995             bool(self)
996             and self.leaves[0].type == token.NAME
997             and self.leaves[0].value == "class"
998         )
999
1000     @property
1001     def is_stub_class(self) -> bool:
1002         """Is this line a class definition with a body consisting only of "..."?"""
1003         return self.is_class and self.leaves[-3:] == [
1004             Leaf(token.DOT, ".") for _ in range(3)
1005         ]
1006
1007     @property
1008     def is_def(self) -> bool:
1009         """Is this a function definition? (Also returns True for async defs.)"""
1010         try:
1011             first_leaf = self.leaves[0]
1012         except IndexError:
1013             return False
1014
1015         try:
1016             second_leaf: Optional[Leaf] = self.leaves[1]
1017         except IndexError:
1018             second_leaf = None
1019         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
1020             first_leaf.type == token.ASYNC
1021             and second_leaf is not None
1022             and second_leaf.type == token.NAME
1023             and second_leaf.value == "def"
1024         )
1025
1026     @property
1027     def is_class_paren_empty(self) -> bool:
1028         """Is this a class with no base classes but using parentheses?
1029
1030         Those are unnecessary and should be removed.
1031         """
1032         return (
1033             bool(self)
1034             and len(self.leaves) == 4
1035             and self.is_class
1036             and self.leaves[2].type == token.LPAR
1037             and self.leaves[2].value == "("
1038             and self.leaves[3].type == token.RPAR
1039             and self.leaves[3].value == ")"
1040         )
1041
1042     @property
1043     def is_triple_quoted_string(self) -> bool:
1044         """Is the line a triple quoted string?"""
1045         return (
1046             bool(self)
1047             and self.leaves[0].type == token.STRING
1048             and self.leaves[0].value.startswith(('"""', "'''"))
1049         )
1050
1051     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
1052         """If so, needs to be split before emitting."""
1053         for leaf in self.leaves:
1054             if leaf.type == STANDALONE_COMMENT:
1055                 if leaf.bracket_depth <= depth_limit:
1056                     return True
1057
1058         return False
1059
1060     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1061         """Remove trailing comma if there is one and it's safe."""
1062         if not (
1063             self.leaves
1064             and self.leaves[-1].type == token.COMMA
1065             and closing.type in CLOSING_BRACKETS
1066         ):
1067             return False
1068
1069         if closing.type == token.RBRACE:
1070             self.remove_trailing_comma()
1071             return True
1072
1073         if closing.type == token.RSQB:
1074             comma = self.leaves[-1]
1075             if comma.parent and comma.parent.type == syms.listmaker:
1076                 self.remove_trailing_comma()
1077                 return True
1078
1079         # For parens let's check if it's safe to remove the comma.
1080         # Imports are always safe.
1081         if self.is_import:
1082             self.remove_trailing_comma()
1083             return True
1084
1085         # Otheriwsse, if the trailing one is the only one, we might mistakenly
1086         # change a tuple into a different type by removing the comma.
1087         depth = closing.bracket_depth + 1
1088         commas = 0
1089         opening = closing.opening_bracket
1090         for _opening_index, leaf in enumerate(self.leaves):
1091             if leaf is opening:
1092                 break
1093
1094         else:
1095             return False
1096
1097         for leaf in self.leaves[_opening_index + 1 :]:
1098             if leaf is closing:
1099                 break
1100
1101             bracket_depth = leaf.bracket_depth
1102             if bracket_depth == depth and leaf.type == token.COMMA:
1103                 commas += 1
1104                 if leaf.parent and leaf.parent.type == syms.arglist:
1105                     commas += 1
1106                     break
1107
1108         if commas > 1:
1109             self.remove_trailing_comma()
1110             return True
1111
1112         return False
1113
1114     def append_comment(self, comment: Leaf) -> bool:
1115         """Add an inline or standalone comment to the line."""
1116         if (
1117             comment.type == STANDALONE_COMMENT
1118             and self.bracket_tracker.any_open_brackets()
1119         ):
1120             comment.prefix = ""
1121             return False
1122
1123         if comment.type != token.COMMENT:
1124             return False
1125
1126         after = len(self.leaves) - 1
1127         if after == -1:
1128             comment.type = STANDALONE_COMMENT
1129             comment.prefix = ""
1130             return False
1131
1132         else:
1133             self.comments.append((after, comment))
1134             return True
1135
1136     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1137         """Generate comments that should appear directly after `leaf`.
1138
1139         Provide a non-negative leaf `_index` to speed up the function.
1140         """
1141         if _index == -1:
1142             for _index, _leaf in enumerate(self.leaves):
1143                 if leaf is _leaf:
1144                     break
1145
1146             else:
1147                 return
1148
1149         for index, comment_after in self.comments:
1150             if _index == index:
1151                 yield comment_after
1152
1153     def remove_trailing_comma(self) -> None:
1154         """Remove the trailing comma and moves the comments attached to it."""
1155         comma_index = len(self.leaves) - 1
1156         for i in range(len(self.comments)):
1157             comment_index, comment = self.comments[i]
1158             if comment_index == comma_index:
1159                 self.comments[i] = (comma_index - 1, comment)
1160         self.leaves.pop()
1161
1162     def is_complex_subscript(self, leaf: Leaf) -> bool:
1163         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1164         open_lsqb = (
1165             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1166         )
1167         if open_lsqb is None:
1168             return False
1169
1170         subscript_start = open_lsqb.next_sibling
1171         if (
1172             isinstance(subscript_start, Node)
1173             and subscript_start.type == syms.subscriptlist
1174         ):
1175             subscript_start = child_towards(subscript_start, leaf)
1176         return subscript_start is not None and any(
1177             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1178         )
1179
1180     def __str__(self) -> str:
1181         """Render the line."""
1182         if not self:
1183             return "\n"
1184
1185         indent = "    " * self.depth
1186         leaves = iter(self.leaves)
1187         first = next(leaves)
1188         res = f"{first.prefix}{indent}{first.value}"
1189         for leaf in leaves:
1190             res += str(leaf)
1191         for _, comment in self.comments:
1192             res += str(comment)
1193         return res + "\n"
1194
1195     def __bool__(self) -> bool:
1196         """Return True if the line has leaves or comments."""
1197         return bool(self.leaves or self.comments)
1198
1199
1200 class UnformattedLines(Line):
1201     """Just like :class:`Line` but stores lines which aren't reformatted."""
1202
1203     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1204         """Just add a new `leaf` to the end of the lines.
1205
1206         The `preformatted` argument is ignored.
1207
1208         Keeps track of indentation `depth`, which is useful when the user
1209         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1210         """
1211         try:
1212             list(generate_comments(leaf))
1213         except FormatOn as f_on:
1214             self.leaves.append(f_on.leaf_from_consumed(leaf))
1215             raise
1216
1217         self.leaves.append(leaf)
1218         if leaf.type == token.INDENT:
1219             self.depth += 1
1220         elif leaf.type == token.DEDENT:
1221             self.depth -= 1
1222
1223     def __str__(self) -> str:
1224         """Render unformatted lines from leaves which were added with `append()`.
1225
1226         `depth` is not used for indentation in this case.
1227         """
1228         if not self:
1229             return "\n"
1230
1231         res = ""
1232         for leaf in self.leaves:
1233             res += str(leaf)
1234         return res
1235
1236     def append_comment(self, comment: Leaf) -> bool:
1237         """Not implemented in this class. Raises `NotImplementedError`."""
1238         raise NotImplementedError("Unformatted lines don't store comments separately.")
1239
1240     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1241         """Does nothing and returns False."""
1242         return False
1243
1244     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1245         """Does nothing and returns False."""
1246         return False
1247
1248
1249 @dataclass
1250 class EmptyLineTracker:
1251     """Provides a stateful method that returns the number of potential extra
1252     empty lines needed before and after the currently processed line.
1253
1254     Note: this tracker works on lines that haven't been split yet.  It assumes
1255     the prefix of the first leaf consists of optional newlines.  Those newlines
1256     are consumed by `maybe_empty_lines()` and included in the computation.
1257     """
1258
1259     is_pyi: bool = False
1260     previous_line: Optional[Line] = None
1261     previous_after: int = 0
1262     previous_defs: List[int] = Factory(list)
1263
1264     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1265         """Return the number of extra empty lines before and after the `current_line`.
1266
1267         This is for separating `def`, `async def` and `class` with extra empty
1268         lines (two on module-level).
1269         """
1270         if isinstance(current_line, UnformattedLines):
1271             return 0, 0
1272
1273         before, after = self._maybe_empty_lines(current_line)
1274         before -= self.previous_after
1275         self.previous_after = after
1276         self.previous_line = current_line
1277         return before, after
1278
1279     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1280         max_allowed = 1
1281         if current_line.depth == 0:
1282             max_allowed = 1 if self.is_pyi else 2
1283         if current_line.leaves:
1284             # Consume the first leaf's extra newlines.
1285             first_leaf = current_line.leaves[0]
1286             before = first_leaf.prefix.count("\n")
1287             before = min(before, max_allowed)
1288             first_leaf.prefix = ""
1289         else:
1290             before = 0
1291         depth = current_line.depth
1292         while self.previous_defs and self.previous_defs[-1] >= depth:
1293             self.previous_defs.pop()
1294             if self.is_pyi:
1295                 before = 0 if depth else 1
1296             else:
1297                 before = 1 if depth else 2
1298         is_decorator = current_line.is_decorator
1299         if is_decorator or current_line.is_def or current_line.is_class:
1300             if not is_decorator:
1301                 self.previous_defs.append(depth)
1302             if self.previous_line is None:
1303                 # Don't insert empty lines before the first line in the file.
1304                 return 0, 0
1305
1306             if self.previous_line.is_decorator:
1307                 return 0, 0
1308
1309             if self.previous_line.depth < current_line.depth and (
1310                 self.previous_line.is_class or self.previous_line.is_def
1311             ):
1312                 return 0, 0
1313
1314             if (
1315                 self.previous_line.is_comment
1316                 and self.previous_line.depth == current_line.depth
1317                 and before == 0
1318             ):
1319                 return 0, 0
1320
1321             if self.is_pyi:
1322                 if self.previous_line.depth > current_line.depth:
1323                     newlines = 1
1324                 elif current_line.is_class or self.previous_line.is_class:
1325                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1326                         newlines = 0
1327                     else:
1328                         newlines = 1
1329                 else:
1330                     newlines = 0
1331             else:
1332                 newlines = 2
1333             if current_line.depth and newlines:
1334                 newlines -= 1
1335             return newlines, 0
1336
1337         if (
1338             self.previous_line
1339             and self.previous_line.is_import
1340             and not current_line.is_import
1341             and depth == self.previous_line.depth
1342         ):
1343             return (before or 1), 0
1344
1345         if (
1346             self.previous_line
1347             and self.previous_line.is_class
1348             and current_line.is_triple_quoted_string
1349         ):
1350             return before, 1
1351
1352         return before, 0
1353
1354
1355 @dataclass
1356 class LineGenerator(Visitor[Line]):
1357     """Generates reformatted Line objects.  Empty lines are not emitted.
1358
1359     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1360     in ways that will no longer stringify to valid Python code on the tree.
1361     """
1362
1363     is_pyi: bool = False
1364     normalize_strings: bool = True
1365     current_line: Line = Factory(Line)
1366     remove_u_prefix: bool = False
1367
1368     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1369         """Generate a line.
1370
1371         If the line is empty, only emit if it makes sense.
1372         If the line is too long, split it first and then generate.
1373
1374         If any lines were generated, set up a new current_line.
1375         """
1376         if not self.current_line:
1377             if self.current_line.__class__ == type:
1378                 self.current_line.depth += indent
1379             else:
1380                 self.current_line = type(depth=self.current_line.depth + indent)
1381             return  # Line is empty, don't emit. Creating a new one unnecessary.
1382
1383         complete_line = self.current_line
1384         self.current_line = type(depth=complete_line.depth + indent)
1385         yield complete_line
1386
1387     def visit(self, node: LN) -> Iterator[Line]:
1388         """Main method to visit `node` and its children.
1389
1390         Yields :class:`Line` objects.
1391         """
1392         if isinstance(self.current_line, UnformattedLines):
1393             # File contained `# fmt: off`
1394             yield from self.visit_unformatted(node)
1395
1396         else:
1397             yield from super().visit(node)
1398
1399     def visit_default(self, node: LN) -> Iterator[Line]:
1400         """Default `visit_*()` implementation. Recurses to children of `node`."""
1401         if isinstance(node, Leaf):
1402             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1403             try:
1404                 for comment in generate_comments(node):
1405                     if any_open_brackets:
1406                         # any comment within brackets is subject to splitting
1407                         self.current_line.append(comment)
1408                     elif comment.type == token.COMMENT:
1409                         # regular trailing comment
1410                         self.current_line.append(comment)
1411                         yield from self.line()
1412
1413                     else:
1414                         # regular standalone comment
1415                         yield from self.line()
1416
1417                         self.current_line.append(comment)
1418                         yield from self.line()
1419
1420             except FormatOff as f_off:
1421                 f_off.trim_prefix(node)
1422                 yield from self.line(type=UnformattedLines)
1423                 yield from self.visit(node)
1424
1425             except FormatOn as f_on:
1426                 # This only happens here if somebody says "fmt: on" multiple
1427                 # times in a row.
1428                 f_on.trim_prefix(node)
1429                 yield from self.visit_default(node)
1430
1431             else:
1432                 normalize_prefix(node, inside_brackets=any_open_brackets)
1433                 if self.normalize_strings and node.type == token.STRING:
1434                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1435                     normalize_string_quotes(node)
1436                 if node.type not in WHITESPACE:
1437                     self.current_line.append(node)
1438         yield from super().visit_default(node)
1439
1440     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1441         """Increase indentation level, maybe yield a line."""
1442         # In blib2to3 INDENT never holds comments.
1443         yield from self.line(+1)
1444         yield from self.visit_default(node)
1445
1446     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1447         """Decrease indentation level, maybe yield a line."""
1448         # The current line might still wait for trailing comments.  At DEDENT time
1449         # there won't be any (they would be prefixes on the preceding NEWLINE).
1450         # Emit the line then.
1451         yield from self.line()
1452
1453         # While DEDENT has no value, its prefix may contain standalone comments
1454         # that belong to the current indentation level.  Get 'em.
1455         yield from self.visit_default(node)
1456
1457         # Finally, emit the dedent.
1458         yield from self.line(-1)
1459
1460     def visit_stmt(
1461         self, node: Node, keywords: Set[str], parens: Set[str]
1462     ) -> Iterator[Line]:
1463         """Visit a statement.
1464
1465         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1466         `def`, `with`, `class`, `assert` and assignments.
1467
1468         The relevant Python language `keywords` for a given statement will be
1469         NAME leaves within it. This methods puts those on a separate line.
1470
1471         `parens` holds a set of string leaf values immediately after which
1472         invisible parens should be put.
1473         """
1474         normalize_invisible_parens(node, parens_after=parens)
1475         for child in node.children:
1476             if child.type == token.NAME and child.value in keywords:  # type: ignore
1477                 yield from self.line()
1478
1479             yield from self.visit(child)
1480
1481     def visit_suite(self, node: Node) -> Iterator[Line]:
1482         """Visit a suite."""
1483         if self.is_pyi and is_stub_suite(node):
1484             yield from self.visit(node.children[2])
1485         else:
1486             yield from self.visit_default(node)
1487
1488     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1489         """Visit a statement without nested statements."""
1490         is_suite_like = node.parent and node.parent.type in STATEMENT
1491         if is_suite_like:
1492             if self.is_pyi and is_stub_body(node):
1493                 yield from self.visit_default(node)
1494             else:
1495                 yield from self.line(+1)
1496                 yield from self.visit_default(node)
1497                 yield from self.line(-1)
1498
1499         else:
1500             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1501                 yield from self.line()
1502             yield from self.visit_default(node)
1503
1504     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1505         """Visit `async def`, `async for`, `async with`."""
1506         yield from self.line()
1507
1508         children = iter(node.children)
1509         for child in children:
1510             yield from self.visit(child)
1511
1512             if child.type == token.ASYNC:
1513                 break
1514
1515         internal_stmt = next(children)
1516         for child in internal_stmt.children:
1517             yield from self.visit(child)
1518
1519     def visit_decorators(self, node: Node) -> Iterator[Line]:
1520         """Visit decorators."""
1521         for child in node.children:
1522             yield from self.line()
1523             yield from self.visit(child)
1524
1525     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1526         """Remove a semicolon and put the other statement on a separate line."""
1527         yield from self.line()
1528
1529     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1530         """End of file. Process outstanding comments and end with a newline."""
1531         yield from self.visit_default(leaf)
1532         yield from self.line()
1533
1534     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1535         """Used when file contained a `# fmt: off`."""
1536         if isinstance(node, Node):
1537             for child in node.children:
1538                 yield from self.visit(child)
1539
1540         else:
1541             try:
1542                 self.current_line.append(node)
1543             except FormatOn as f_on:
1544                 f_on.trim_prefix(node)
1545                 yield from self.line()
1546                 yield from self.visit(node)
1547
1548             if node.type == token.ENDMARKER:
1549                 # somebody decided not to put a final `# fmt: on`
1550                 yield from self.line()
1551
1552     def __attrs_post_init__(self) -> None:
1553         """You are in a twisty little maze of passages."""
1554         v = self.visit_stmt
1555         Ø: Set[str] = set()
1556         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1557         self.visit_if_stmt = partial(
1558             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1559         )
1560         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1561         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1562         self.visit_try_stmt = partial(
1563             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1564         )
1565         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1566         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1567         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1568         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1569         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1570         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1571         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
1572         self.visit_async_funcdef = self.visit_async_stmt
1573         self.visit_decorated = self.visit_decorators
1574
1575
1576 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1577 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1578 OPENING_BRACKETS = set(BRACKET.keys())
1579 CLOSING_BRACKETS = set(BRACKET.values())
1580 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1581 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1582
1583
1584 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1585     """Return whitespace prefix if needed for the given `leaf`.
1586
1587     `complex_subscript` signals whether the given leaf is part of a subscription
1588     which has non-trivial arguments, like arithmetic expressions or function calls.
1589     """
1590     NO = ""
1591     SPACE = " "
1592     DOUBLESPACE = "  "
1593     t = leaf.type
1594     p = leaf.parent
1595     v = leaf.value
1596     if t in ALWAYS_NO_SPACE:
1597         return NO
1598
1599     if t == token.COMMENT:
1600         return DOUBLESPACE
1601
1602     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1603     if t == token.COLON and p.type not in {
1604         syms.subscript,
1605         syms.subscriptlist,
1606         syms.sliceop,
1607     }:
1608         return NO
1609
1610     prev = leaf.prev_sibling
1611     if not prev:
1612         prevp = preceding_leaf(p)
1613         if not prevp or prevp.type in OPENING_BRACKETS:
1614             return NO
1615
1616         if t == token.COLON:
1617             if prevp.type == token.COLON:
1618                 return NO
1619
1620             elif prevp.type != token.COMMA and not complex_subscript:
1621                 return NO
1622
1623             return SPACE
1624
1625         if prevp.type == token.EQUAL:
1626             if prevp.parent:
1627                 if prevp.parent.type in {
1628                     syms.arglist,
1629                     syms.argument,
1630                     syms.parameters,
1631                     syms.varargslist,
1632                 }:
1633                     return NO
1634
1635                 elif prevp.parent.type == syms.typedargslist:
1636                     # A bit hacky: if the equal sign has whitespace, it means we
1637                     # previously found it's a typed argument.  So, we're using
1638                     # that, too.
1639                     return prevp.prefix
1640
1641         elif prevp.type in STARS:
1642             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1643                 return NO
1644
1645         elif prevp.type == token.COLON:
1646             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1647                 return SPACE if complex_subscript else NO
1648
1649         elif (
1650             prevp.parent
1651             and prevp.parent.type == syms.factor
1652             and prevp.type in MATH_OPERATORS
1653         ):
1654             return NO
1655
1656         elif (
1657             prevp.type == token.RIGHTSHIFT
1658             and prevp.parent
1659             and prevp.parent.type == syms.shift_expr
1660             and prevp.prev_sibling
1661             and prevp.prev_sibling.type == token.NAME
1662             and prevp.prev_sibling.value == "print"  # type: ignore
1663         ):
1664             # Python 2 print chevron
1665             return NO
1666
1667     elif prev.type in OPENING_BRACKETS:
1668         return NO
1669
1670     if p.type in {syms.parameters, syms.arglist}:
1671         # untyped function signatures or calls
1672         if not prev or prev.type != token.COMMA:
1673             return NO
1674
1675     elif p.type == syms.varargslist:
1676         # lambdas
1677         if prev and prev.type != token.COMMA:
1678             return NO
1679
1680     elif p.type == syms.typedargslist:
1681         # typed function signatures
1682         if not prev:
1683             return NO
1684
1685         if t == token.EQUAL:
1686             if prev.type != syms.tname:
1687                 return NO
1688
1689         elif prev.type == token.EQUAL:
1690             # A bit hacky: if the equal sign has whitespace, it means we
1691             # previously found it's a typed argument.  So, we're using that, too.
1692             return prev.prefix
1693
1694         elif prev.type != token.COMMA:
1695             return NO
1696
1697     elif p.type == syms.tname:
1698         # type names
1699         if not prev:
1700             prevp = preceding_leaf(p)
1701             if not prevp or prevp.type != token.COMMA:
1702                 return NO
1703
1704     elif p.type == syms.trailer:
1705         # attributes and calls
1706         if t == token.LPAR or t == token.RPAR:
1707             return NO
1708
1709         if not prev:
1710             if t == token.DOT:
1711                 prevp = preceding_leaf(p)
1712                 if not prevp or prevp.type != token.NUMBER:
1713                     return NO
1714
1715             elif t == token.LSQB:
1716                 return NO
1717
1718         elif prev.type != token.COMMA:
1719             return NO
1720
1721     elif p.type == syms.argument:
1722         # single argument
1723         if t == token.EQUAL:
1724             return NO
1725
1726         if not prev:
1727             prevp = preceding_leaf(p)
1728             if not prevp or prevp.type == token.LPAR:
1729                 return NO
1730
1731         elif prev.type in {token.EQUAL} | STARS:
1732             return NO
1733
1734     elif p.type == syms.decorator:
1735         # decorators
1736         return NO
1737
1738     elif p.type == syms.dotted_name:
1739         if prev:
1740             return NO
1741
1742         prevp = preceding_leaf(p)
1743         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1744             return NO
1745
1746     elif p.type == syms.classdef:
1747         if t == token.LPAR:
1748             return NO
1749
1750         if prev and prev.type == token.LPAR:
1751             return NO
1752
1753     elif p.type in {syms.subscript, syms.sliceop}:
1754         # indexing
1755         if not prev:
1756             assert p.parent is not None, "subscripts are always parented"
1757             if p.parent.type == syms.subscriptlist:
1758                 return SPACE
1759
1760             return NO
1761
1762         elif not complex_subscript:
1763             return NO
1764
1765     elif p.type == syms.atom:
1766         if prev and t == token.DOT:
1767             # dots, but not the first one.
1768             return NO
1769
1770     elif p.type == syms.dictsetmaker:
1771         # dict unpacking
1772         if prev and prev.type == token.DOUBLESTAR:
1773             return NO
1774
1775     elif p.type in {syms.factor, syms.star_expr}:
1776         # unary ops
1777         if not prev:
1778             prevp = preceding_leaf(p)
1779             if not prevp or prevp.type in OPENING_BRACKETS:
1780                 return NO
1781
1782             prevp_parent = prevp.parent
1783             assert prevp_parent is not None
1784             if prevp.type == token.COLON and prevp_parent.type in {
1785                 syms.subscript,
1786                 syms.sliceop,
1787             }:
1788                 return NO
1789
1790             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1791                 return NO
1792
1793         elif t == token.NAME or t == token.NUMBER:
1794             return NO
1795
1796     elif p.type == syms.import_from:
1797         if t == token.DOT:
1798             if prev and prev.type == token.DOT:
1799                 return NO
1800
1801         elif t == token.NAME:
1802             if v == "import":
1803                 return SPACE
1804
1805             if prev and prev.type == token.DOT:
1806                 return NO
1807
1808     elif p.type == syms.sliceop:
1809         return NO
1810
1811     return SPACE
1812
1813
1814 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1815     """Return the first leaf that precedes `node`, if any."""
1816     while node:
1817         res = node.prev_sibling
1818         if res:
1819             if isinstance(res, Leaf):
1820                 return res
1821
1822             try:
1823                 return list(res.leaves())[-1]
1824
1825             except IndexError:
1826                 return None
1827
1828         node = node.parent
1829     return None
1830
1831
1832 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1833     """Return the child of `ancestor` that contains `descendant`."""
1834     node: Optional[LN] = descendant
1835     while node and node.parent != ancestor:
1836         node = node.parent
1837     return node
1838
1839
1840 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1841     """Return the priority of the `leaf` delimiter, given a line break after it.
1842
1843     The delimiter priorities returned here are from those delimiters that would
1844     cause a line break after themselves.
1845
1846     Higher numbers are higher priority.
1847     """
1848     if leaf.type == token.COMMA:
1849         return COMMA_PRIORITY
1850
1851     return 0
1852
1853
1854 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1855     """Return the priority of the `leaf` delimiter, given a line before after it.
1856
1857     The delimiter priorities returned here are from those delimiters that would
1858     cause a line break before themselves.
1859
1860     Higher numbers are higher priority.
1861     """
1862     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1863         # * and ** might also be MATH_OPERATORS but in this case they are not.
1864         # Don't treat them as a delimiter.
1865         return 0
1866
1867     if (
1868         leaf.type == token.DOT
1869         and leaf.parent
1870         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1871         and (previous is None or previous.type in CLOSING_BRACKETS)
1872     ):
1873         return DOT_PRIORITY
1874
1875     if (
1876         leaf.type in MATH_OPERATORS
1877         and leaf.parent
1878         and leaf.parent.type not in {syms.factor, syms.star_expr}
1879     ):
1880         return MATH_PRIORITIES[leaf.type]
1881
1882     if leaf.type in COMPARATORS:
1883         return COMPARATOR_PRIORITY
1884
1885     if (
1886         leaf.type == token.STRING
1887         and previous is not None
1888         and previous.type == token.STRING
1889     ):
1890         return STRING_PRIORITY
1891
1892     if leaf.type != token.NAME:
1893         return 0
1894
1895     if (
1896         leaf.value == "for"
1897         and leaf.parent
1898         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1899     ):
1900         return COMPREHENSION_PRIORITY
1901
1902     if (
1903         leaf.value == "if"
1904         and leaf.parent
1905         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1906     ):
1907         return COMPREHENSION_PRIORITY
1908
1909     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1910         return TERNARY_PRIORITY
1911
1912     if leaf.value == "is":
1913         return COMPARATOR_PRIORITY
1914
1915     if (
1916         leaf.value == "in"
1917         and leaf.parent
1918         and leaf.parent.type in {syms.comp_op, syms.comparison}
1919         and not (
1920             previous is not None
1921             and previous.type == token.NAME
1922             and previous.value == "not"
1923         )
1924     ):
1925         return COMPARATOR_PRIORITY
1926
1927     if (
1928         leaf.value == "not"
1929         and leaf.parent
1930         and leaf.parent.type == syms.comp_op
1931         and not (
1932             previous is not None
1933             and previous.type == token.NAME
1934             and previous.value == "is"
1935         )
1936     ):
1937         return COMPARATOR_PRIORITY
1938
1939     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1940         return LOGIC_PRIORITY
1941
1942     return 0
1943
1944
1945 def generate_comments(leaf: LN) -> Iterator[Leaf]:
1946     """Clean the prefix of the `leaf` and generate comments from it, if any.
1947
1948     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1949     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1950     move because it does away with modifying the grammar to include all the
1951     possible places in which comments can be placed.
1952
1953     The sad consequence for us though is that comments don't "belong" anywhere.
1954     This is why this function generates simple parentless Leaf objects for
1955     comments.  We simply don't know what the correct parent should be.
1956
1957     No matter though, we can live without this.  We really only need to
1958     differentiate between inline and standalone comments.  The latter don't
1959     share the line with any code.
1960
1961     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1962     are emitted with a fake STANDALONE_COMMENT token identifier.
1963     """
1964     p = leaf.prefix
1965     if not p:
1966         return
1967
1968     if "#" not in p:
1969         return
1970
1971     consumed = 0
1972     nlines = 0
1973     for index, line in enumerate(p.split("\n")):
1974         consumed += len(line) + 1  # adding the length of the split '\n'
1975         line = line.lstrip()
1976         if not line:
1977             nlines += 1
1978         if not line.startswith("#"):
1979             continue
1980
1981         if index == 0 and leaf.type != token.ENDMARKER:
1982             comment_type = token.COMMENT  # simple trailing comment
1983         else:
1984             comment_type = STANDALONE_COMMENT
1985         comment = make_comment(line)
1986         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1987
1988         if comment in {"# fmt: on", "# yapf: enable"}:
1989             raise FormatOn(consumed)
1990
1991         if comment in {"# fmt: off", "# yapf: disable"}:
1992             if comment_type == STANDALONE_COMMENT:
1993                 raise FormatOff(consumed)
1994
1995             prev = preceding_leaf(leaf)
1996             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1997                 raise FormatOff(consumed)
1998
1999         nlines = 0
2000
2001
2002 def make_comment(content: str) -> str:
2003     """Return a consistently formatted comment from the given `content` string.
2004
2005     All comments (except for "##", "#!", "#:") should have a single space between
2006     the hash sign and the content.
2007
2008     If `content` didn't start with a hash sign, one is provided.
2009     """
2010     content = content.rstrip()
2011     if not content:
2012         return "#"
2013
2014     if content[0] == "#":
2015         content = content[1:]
2016     if content and content[0] not in " !:#":
2017         content = " " + content
2018     return "#" + content
2019
2020
2021 def split_line(
2022     line: Line, line_length: int, inner: bool = False, py36: bool = False
2023 ) -> Iterator[Line]:
2024     """Split a `line` into potentially many lines.
2025
2026     They should fit in the allotted `line_length` but might not be able to.
2027     `inner` signifies that there were a pair of brackets somewhere around the
2028     current `line`, possibly transitively. This means we can fallback to splitting
2029     by delimiters if the LHS/RHS don't yield any results.
2030
2031     If `py36` is True, splitting may generate syntax that is only compatible
2032     with Python 3.6 and later.
2033     """
2034     if isinstance(line, UnformattedLines) or line.is_comment:
2035         yield line
2036         return
2037
2038     line_str = str(line).strip("\n")
2039     if not line.should_explode and is_line_short_enough(
2040         line, line_length=line_length, line_str=line_str
2041     ):
2042         yield line
2043         return
2044
2045     split_funcs: List[SplitFunc]
2046     if line.is_def:
2047         split_funcs = [left_hand_split]
2048     else:
2049
2050         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
2051             for omit in generate_trailers_to_omit(line, line_length):
2052                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
2053                 if is_line_short_enough(lines[0], line_length=line_length):
2054                     yield from lines
2055                     return
2056
2057             # All splits failed, best effort split with no omits.
2058             # This mostly happens to multiline strings that are by definition
2059             # reported as not fitting a single line.
2060             yield from right_hand_split(line, py36)
2061
2062         if line.inside_brackets:
2063             split_funcs = [delimiter_split, standalone_comment_split, rhs]
2064         else:
2065             split_funcs = [rhs]
2066     for split_func in split_funcs:
2067         # We are accumulating lines in `result` because we might want to abort
2068         # mission and return the original line in the end, or attempt a different
2069         # split altogether.
2070         result: List[Line] = []
2071         try:
2072             for l in split_func(line, py36):
2073                 if str(l).strip("\n") == line_str:
2074                     raise CannotSplit("Split function returned an unchanged result")
2075
2076                 result.extend(
2077                     split_line(l, line_length=line_length, inner=True, py36=py36)
2078                 )
2079         except CannotSplit as cs:
2080             continue
2081
2082         else:
2083             yield from result
2084             break
2085
2086     else:
2087         yield line
2088
2089
2090 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
2091     """Split line into many lines, starting with the first matching bracket pair.
2092
2093     Note: this usually looks weird, only use this for function definitions.
2094     Prefer RHS otherwise.  This is why this function is not symmetrical with
2095     :func:`right_hand_split` which also handles optional parentheses.
2096     """
2097     head = Line(depth=line.depth)
2098     body = Line(depth=line.depth + 1, inside_brackets=True)
2099     tail = Line(depth=line.depth)
2100     tail_leaves: List[Leaf] = []
2101     body_leaves: List[Leaf] = []
2102     head_leaves: List[Leaf] = []
2103     current_leaves = head_leaves
2104     matching_bracket = None
2105     for leaf in line.leaves:
2106         if (
2107             current_leaves is body_leaves
2108             and leaf.type in CLOSING_BRACKETS
2109             and leaf.opening_bracket is matching_bracket
2110         ):
2111             current_leaves = tail_leaves if body_leaves else head_leaves
2112         current_leaves.append(leaf)
2113         if current_leaves is head_leaves:
2114             if leaf.type in OPENING_BRACKETS:
2115                 matching_bracket = leaf
2116                 current_leaves = body_leaves
2117     # Since body is a new indent level, remove spurious leading whitespace.
2118     if body_leaves:
2119         normalize_prefix(body_leaves[0], inside_brackets=True)
2120     # Build the new lines.
2121     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2122         for leaf in leaves:
2123             result.append(leaf, preformatted=True)
2124             for comment_after in line.comments_after(leaf):
2125                 result.append(comment_after, preformatted=True)
2126     bracket_split_succeeded_or_raise(head, body, tail)
2127     for result in (head, body, tail):
2128         if result:
2129             yield result
2130
2131
2132 def right_hand_split(
2133     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2134 ) -> Iterator[Line]:
2135     """Split line into many lines, starting with the last matching bracket pair.
2136
2137     If the split was by optional parentheses, attempt splitting without them, too.
2138     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2139     this split.
2140
2141     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2142     """
2143     head = Line(depth=line.depth)
2144     body = Line(depth=line.depth + 1, inside_brackets=True)
2145     tail = Line(depth=line.depth)
2146     tail_leaves: List[Leaf] = []
2147     body_leaves: List[Leaf] = []
2148     head_leaves: List[Leaf] = []
2149     current_leaves = tail_leaves
2150     opening_bracket = None
2151     closing_bracket = None
2152     for leaf in reversed(line.leaves):
2153         if current_leaves is body_leaves:
2154             if leaf is opening_bracket:
2155                 current_leaves = head_leaves if body_leaves else tail_leaves
2156         current_leaves.append(leaf)
2157         if current_leaves is tail_leaves:
2158             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2159                 opening_bracket = leaf.opening_bracket
2160                 closing_bracket = leaf
2161                 current_leaves = body_leaves
2162     tail_leaves.reverse()
2163     body_leaves.reverse()
2164     head_leaves.reverse()
2165     # Since body is a new indent level, remove spurious leading whitespace.
2166     if body_leaves:
2167         normalize_prefix(body_leaves[0], inside_brackets=True)
2168     if not head_leaves:
2169         # No `head` means the split failed. Either `tail` has all content or
2170         # the matching `opening_bracket` wasn't available on `line` anymore.
2171         raise CannotSplit("No brackets found")
2172
2173     # Build the new lines.
2174     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2175         for leaf in leaves:
2176             result.append(leaf, preformatted=True)
2177             for comment_after in line.comments_after(leaf):
2178                 result.append(comment_after, preformatted=True)
2179     bracket_split_succeeded_or_raise(head, body, tail)
2180     assert opening_bracket and closing_bracket
2181     if (
2182         # the opening bracket is an optional paren
2183         opening_bracket.type == token.LPAR
2184         and not opening_bracket.value
2185         # the closing bracket is an optional paren
2186         and closing_bracket.type == token.RPAR
2187         and not closing_bracket.value
2188         # there are no standalone comments in the body
2189         and not line.contains_standalone_comments(0)
2190         # and it's not an import (optional parens are the only thing we can split
2191         # on in this case; attempting a split without them is a waste of time)
2192         and not line.is_import
2193     ):
2194         omit = {id(closing_bracket), *omit}
2195         if can_omit_invisible_parens(body, line_length):
2196             try:
2197                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2198                 return
2199             except CannotSplit:
2200                 pass
2201
2202     ensure_visible(opening_bracket)
2203     ensure_visible(closing_bracket)
2204     body.should_explode = should_explode(body, opening_bracket)
2205     for result in (head, body, tail):
2206         if result:
2207             yield result
2208
2209
2210 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2211     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2212
2213     Do nothing otherwise.
2214
2215     A left- or right-hand split is based on a pair of brackets. Content before
2216     (and including) the opening bracket is left on one line, content inside the
2217     brackets is put on a separate line, and finally content starting with and
2218     following the closing bracket is put on a separate line.
2219
2220     Those are called `head`, `body`, and `tail`, respectively. If the split
2221     produced the same line (all content in `head`) or ended up with an empty `body`
2222     and the `tail` is just the closing bracket, then it's considered failed.
2223     """
2224     tail_len = len(str(tail).strip())
2225     if not body:
2226         if tail_len == 0:
2227             raise CannotSplit("Splitting brackets produced the same line")
2228
2229         elif tail_len < 3:
2230             raise CannotSplit(
2231                 f"Splitting brackets on an empty body to save "
2232                 f"{tail_len} characters is not worth it"
2233             )
2234
2235
2236 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2237     """Normalize prefix of the first leaf in every line returned by `split_func`.
2238
2239     This is a decorator over relevant split functions.
2240     """
2241
2242     @wraps(split_func)
2243     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2244         for l in split_func(line, py36):
2245             normalize_prefix(l.leaves[0], inside_brackets=True)
2246             yield l
2247
2248     return split_wrapper
2249
2250
2251 @dont_increase_indentation
2252 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2253     """Split according to delimiters of the highest priority.
2254
2255     If `py36` is True, the split will add trailing commas also in function
2256     signatures that contain `*` and `**`.
2257     """
2258     try:
2259         last_leaf = line.leaves[-1]
2260     except IndexError:
2261         raise CannotSplit("Line empty")
2262
2263     bt = line.bracket_tracker
2264     try:
2265         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2266     except ValueError:
2267         raise CannotSplit("No delimiters found")
2268
2269     if delimiter_priority == DOT_PRIORITY:
2270         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2271             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2272
2273     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2274     lowest_depth = sys.maxsize
2275     trailing_comma_safe = True
2276
2277     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2278         """Append `leaf` to current line or to new line if appending impossible."""
2279         nonlocal current_line
2280         try:
2281             current_line.append_safe(leaf, preformatted=True)
2282         except ValueError as ve:
2283             yield current_line
2284
2285             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2286             current_line.append(leaf)
2287
2288     for index, leaf in enumerate(line.leaves):
2289         yield from append_to_line(leaf)
2290
2291         for comment_after in line.comments_after(leaf, index):
2292             yield from append_to_line(comment_after)
2293
2294         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2295         if leaf.bracket_depth == lowest_depth and is_vararg(
2296             leaf, within=VARARGS_PARENTS
2297         ):
2298             trailing_comma_safe = trailing_comma_safe and py36
2299         leaf_priority = bt.delimiters.get(id(leaf))
2300         if leaf_priority == delimiter_priority:
2301             yield current_line
2302
2303             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2304     if current_line:
2305         if (
2306             trailing_comma_safe
2307             and delimiter_priority == COMMA_PRIORITY
2308             and current_line.leaves[-1].type != token.COMMA
2309             and current_line.leaves[-1].type != STANDALONE_COMMENT
2310         ):
2311             current_line.append(Leaf(token.COMMA, ","))
2312         yield current_line
2313
2314
2315 @dont_increase_indentation
2316 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2317     """Split standalone comments from the rest of the line."""
2318     if not line.contains_standalone_comments(0):
2319         raise CannotSplit("Line does not have any standalone comments")
2320
2321     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2322
2323     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2324         """Append `leaf` to current line or to new line if appending impossible."""
2325         nonlocal current_line
2326         try:
2327             current_line.append_safe(leaf, preformatted=True)
2328         except ValueError as ve:
2329             yield current_line
2330
2331             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2332             current_line.append(leaf)
2333
2334     for index, leaf in enumerate(line.leaves):
2335         yield from append_to_line(leaf)
2336
2337         for comment_after in line.comments_after(leaf, index):
2338             yield from append_to_line(comment_after)
2339
2340     if current_line:
2341         yield current_line
2342
2343
2344 def is_import(leaf: Leaf) -> bool:
2345     """Return True if the given leaf starts an import statement."""
2346     p = leaf.parent
2347     t = leaf.type
2348     v = leaf.value
2349     return bool(
2350         t == token.NAME
2351         and (
2352             (v == "import" and p and p.type == syms.import_name)
2353             or (v == "from" and p and p.type == syms.import_from)
2354         )
2355     )
2356
2357
2358 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2359     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2360     else.
2361
2362     Note: don't use backslashes for formatting or you'll lose your voting rights.
2363     """
2364     if not inside_brackets:
2365         spl = leaf.prefix.split("#")
2366         if "\\" not in spl[0]:
2367             nl_count = spl[-1].count("\n")
2368             if len(spl) > 1:
2369                 nl_count -= 1
2370             leaf.prefix = "\n" * nl_count
2371             return
2372
2373     leaf.prefix = ""
2374
2375
2376 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2377     """Make all string prefixes lowercase.
2378
2379     If remove_u_prefix is given, also removes any u prefix from the string.
2380
2381     Note: Mutates its argument.
2382     """
2383     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2384     assert match is not None, f"failed to match string {leaf.value!r}"
2385     orig_prefix = match.group(1)
2386     new_prefix = orig_prefix.lower()
2387     if remove_u_prefix:
2388         new_prefix = new_prefix.replace("u", "")
2389     leaf.value = f"{new_prefix}{match.group(2)}"
2390
2391
2392 def normalize_string_quotes(leaf: Leaf) -> None:
2393     """Prefer double quotes but only if it doesn't cause more escaping.
2394
2395     Adds or removes backslashes as appropriate. Doesn't parse and fix
2396     strings nested in f-strings (yet).
2397
2398     Note: Mutates its argument.
2399     """
2400     value = leaf.value.lstrip("furbFURB")
2401     if value[:3] == '"""':
2402         return
2403
2404     elif value[:3] == "'''":
2405         orig_quote = "'''"
2406         new_quote = '"""'
2407     elif value[0] == '"':
2408         orig_quote = '"'
2409         new_quote = "'"
2410     else:
2411         orig_quote = "'"
2412         new_quote = '"'
2413     first_quote_pos = leaf.value.find(orig_quote)
2414     if first_quote_pos == -1:
2415         return  # There's an internal error
2416
2417     prefix = leaf.value[:first_quote_pos]
2418     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2419     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2420     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2421     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2422     if "r" in prefix.casefold():
2423         if unescaped_new_quote.search(body):
2424             # There's at least one unescaped new_quote in this raw string
2425             # so converting is impossible
2426             return
2427
2428         # Do not introduce or remove backslashes in raw strings
2429         new_body = body
2430     else:
2431         # remove unnecessary quotes
2432         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2433         if body != new_body:
2434             # Consider the string without unnecessary quotes as the original
2435             body = new_body
2436             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2437         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2438         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2439     if new_quote == '"""' and new_body[-1] == '"':
2440         # edge case:
2441         new_body = new_body[:-1] + '\\"'
2442     orig_escape_count = body.count("\\")
2443     new_escape_count = new_body.count("\\")
2444     if new_escape_count > orig_escape_count:
2445         return  # Do not introduce more escaping
2446
2447     if new_escape_count == orig_escape_count and orig_quote == '"':
2448         return  # Prefer double quotes
2449
2450     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2451
2452
2453 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2454     """Make existing optional parentheses invisible or create new ones.
2455
2456     `parens_after` is a set of string leaf values immeditely after which parens
2457     should be put.
2458
2459     Standardizes on visible parentheses for single-element tuples, and keeps
2460     existing visible parentheses for other tuples and generator expressions.
2461     """
2462     try:
2463         list(generate_comments(node))
2464     except FormatOff:
2465         return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
2466
2467     check_lpar = False
2468     for index, child in enumerate(list(node.children)):
2469         if check_lpar:
2470             if child.type == syms.atom:
2471                 maybe_make_parens_invisible_in_atom(child)
2472             elif is_one_tuple(child):
2473                 # wrap child in visible parentheses
2474                 lpar = Leaf(token.LPAR, "(")
2475                 rpar = Leaf(token.RPAR, ")")
2476                 child.remove()
2477                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2478             elif node.type == syms.import_from:
2479                 # "import from" nodes store parentheses directly as part of
2480                 # the statement
2481                 if child.type == token.LPAR:
2482                     # make parentheses invisible
2483                     child.value = ""  # type: ignore
2484                     node.children[-1].value = ""  # type: ignore
2485                 elif child.type != token.STAR:
2486                     # insert invisible parentheses
2487                     node.insert_child(index, Leaf(token.LPAR, ""))
2488                     node.append_child(Leaf(token.RPAR, ""))
2489                 break
2490
2491             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2492                 # wrap child in invisible parentheses
2493                 lpar = Leaf(token.LPAR, "")
2494                 rpar = Leaf(token.RPAR, "")
2495                 index = child.remove() or 0
2496                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2497
2498         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2499
2500
2501 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2502     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2503     if (
2504         node.type != syms.atom
2505         or is_empty_tuple(node)
2506         or is_one_tuple(node)
2507         or is_yield(node)
2508         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2509     ):
2510         return False
2511
2512     first = node.children[0]
2513     last = node.children[-1]
2514     if first.type == token.LPAR and last.type == token.RPAR:
2515         # make parentheses invisible
2516         first.value = ""  # type: ignore
2517         last.value = ""  # type: ignore
2518         if len(node.children) > 1:
2519             maybe_make_parens_invisible_in_atom(node.children[1])
2520         return True
2521
2522     return False
2523
2524
2525 def is_empty_tuple(node: LN) -> bool:
2526     """Return True if `node` holds an empty tuple."""
2527     return (
2528         node.type == syms.atom
2529         and len(node.children) == 2
2530         and node.children[0].type == token.LPAR
2531         and node.children[1].type == token.RPAR
2532     )
2533
2534
2535 def is_one_tuple(node: LN) -> bool:
2536     """Return True if `node` holds a tuple with one element, with or without parens."""
2537     if node.type == syms.atom:
2538         if len(node.children) != 3:
2539             return False
2540
2541         lpar, gexp, rpar = node.children
2542         if not (
2543             lpar.type == token.LPAR
2544             and gexp.type == syms.testlist_gexp
2545             and rpar.type == token.RPAR
2546         ):
2547             return False
2548
2549         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2550
2551     return (
2552         node.type in IMPLICIT_TUPLE
2553         and len(node.children) == 2
2554         and node.children[1].type == token.COMMA
2555     )
2556
2557
2558 def is_yield(node: LN) -> bool:
2559     """Return True if `node` holds a `yield` or `yield from` expression."""
2560     if node.type == syms.yield_expr:
2561         return True
2562
2563     if node.type == token.NAME and node.value == "yield":  # type: ignore
2564         return True
2565
2566     if node.type != syms.atom:
2567         return False
2568
2569     if len(node.children) != 3:
2570         return False
2571
2572     lpar, expr, rpar = node.children
2573     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2574         return is_yield(expr)
2575
2576     return False
2577
2578
2579 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2580     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2581
2582     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2583     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2584     extended iterable unpacking (PEP 3132) and additional unpacking
2585     generalizations (PEP 448).
2586     """
2587     if leaf.type not in STARS or not leaf.parent:
2588         return False
2589
2590     p = leaf.parent
2591     if p.type == syms.star_expr:
2592         # Star expressions are also used as assignment targets in extended
2593         # iterable unpacking (PEP 3132).  See what its parent is instead.
2594         if not p.parent:
2595             return False
2596
2597         p = p.parent
2598
2599     return p.type in within
2600
2601
2602 def is_multiline_string(leaf: Leaf) -> bool:
2603     """Return True if `leaf` is a multiline string that actually spans many lines."""
2604     value = leaf.value.lstrip("furbFURB")
2605     return value[:3] in {'"""', "'''"} and "\n" in value
2606
2607
2608 def is_stub_suite(node: Node) -> bool:
2609     """Return True if `node` is a suite with a stub body."""
2610     if (
2611         len(node.children) != 4
2612         or node.children[0].type != token.NEWLINE
2613         or node.children[1].type != token.INDENT
2614         or node.children[3].type != token.DEDENT
2615     ):
2616         return False
2617
2618     return is_stub_body(node.children[2])
2619
2620
2621 def is_stub_body(node: LN) -> bool:
2622     """Return True if `node` is a simple statement containing an ellipsis."""
2623     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2624         return False
2625
2626     if len(node.children) != 2:
2627         return False
2628
2629     child = node.children[0]
2630     return (
2631         child.type == syms.atom
2632         and len(child.children) == 3
2633         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2634     )
2635
2636
2637 def max_delimiter_priority_in_atom(node: LN) -> int:
2638     """Return maximum delimiter priority inside `node`.
2639
2640     This is specific to atoms with contents contained in a pair of parentheses.
2641     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2642     """
2643     if node.type != syms.atom:
2644         return 0
2645
2646     first = node.children[0]
2647     last = node.children[-1]
2648     if not (first.type == token.LPAR and last.type == token.RPAR):
2649         return 0
2650
2651     bt = BracketTracker()
2652     for c in node.children[1:-1]:
2653         if isinstance(c, Leaf):
2654             bt.mark(c)
2655         else:
2656             for leaf in c.leaves():
2657                 bt.mark(leaf)
2658     try:
2659         return bt.max_delimiter_priority()
2660
2661     except ValueError:
2662         return 0
2663
2664
2665 def ensure_visible(leaf: Leaf) -> None:
2666     """Make sure parentheses are visible.
2667
2668     They could be invisible as part of some statements (see
2669     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2670     """
2671     if leaf.type == token.LPAR:
2672         leaf.value = "("
2673     elif leaf.type == token.RPAR:
2674         leaf.value = ")"
2675
2676
2677 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2678     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2679     if not (
2680         opening_bracket.parent
2681         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2682         and opening_bracket.value in "[{("
2683     ):
2684         return False
2685
2686     try:
2687         last_leaf = line.leaves[-1]
2688         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2689         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2690     except (IndexError, ValueError):
2691         return False
2692
2693     return max_priority == COMMA_PRIORITY
2694
2695
2696 def is_python36(node: Node) -> bool:
2697     """Return True if the current file is using Python 3.6+ features.
2698
2699     Currently looking for:
2700     - f-strings; and
2701     - trailing commas after * or ** in function signatures and calls.
2702     """
2703     for n in node.pre_order():
2704         if n.type == token.STRING:
2705             value_head = n.value[:2]  # type: ignore
2706             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2707                 return True
2708
2709         elif (
2710             n.type in {syms.typedargslist, syms.arglist}
2711             and n.children
2712             and n.children[-1].type == token.COMMA
2713         ):
2714             for ch in n.children:
2715                 if ch.type in STARS:
2716                     return True
2717
2718                 if ch.type == syms.argument:
2719                     for argch in ch.children:
2720                         if argch.type in STARS:
2721                             return True
2722
2723     return False
2724
2725
2726 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2727     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2728
2729     Brackets can be omitted if the entire trailer up to and including
2730     a preceding closing bracket fits in one line.
2731
2732     Yielded sets are cumulative (contain results of previous yields, too).  First
2733     set is empty.
2734     """
2735
2736     omit: Set[LeafID] = set()
2737     yield omit
2738
2739     length = 4 * line.depth
2740     opening_bracket = None
2741     closing_bracket = None
2742     optional_brackets: Set[LeafID] = set()
2743     inner_brackets: Set[LeafID] = set()
2744     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2745         length += leaf_length
2746         if length > line_length:
2747             break
2748
2749         has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
2750         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
2751             break
2752
2753         optional_brackets.discard(id(leaf))
2754         if opening_bracket:
2755             if leaf is opening_bracket:
2756                 opening_bracket = None
2757             elif leaf.type in CLOSING_BRACKETS:
2758                 inner_brackets.add(id(leaf))
2759         elif leaf.type in CLOSING_BRACKETS:
2760             if not leaf.value:
2761                 optional_brackets.add(id(opening_bracket))
2762                 continue
2763
2764             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2765                 # Empty brackets would fail a split so treat them as "inner"
2766                 # brackets (e.g. only add them to the `omit` set if another
2767                 # pair of brackets was good enough.
2768                 inner_brackets.add(id(leaf))
2769                 continue
2770
2771             opening_bracket = leaf.opening_bracket
2772             if closing_bracket:
2773                 omit.add(id(closing_bracket))
2774                 omit.update(inner_brackets)
2775                 inner_brackets.clear()
2776                 yield omit
2777             closing_bracket = leaf
2778
2779
2780 def get_future_imports(node: Node) -> Set[str]:
2781     """Return a set of __future__ imports in the file."""
2782     imports = set()
2783     for child in node.children:
2784         if child.type != syms.simple_stmt:
2785             break
2786         first_child = child.children[0]
2787         if isinstance(first_child, Leaf):
2788             # Continue looking if we see a docstring; otherwise stop.
2789             if (
2790                 len(child.children) == 2
2791                 and first_child.type == token.STRING
2792                 and child.children[1].type == token.NEWLINE
2793             ):
2794                 continue
2795             else:
2796                 break
2797         elif first_child.type == syms.import_from:
2798             module_name = first_child.children[1]
2799             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2800                 break
2801             for import_from_child in first_child.children[3:]:
2802                 if isinstance(import_from_child, Leaf):
2803                     if import_from_child.type == token.NAME:
2804                         imports.add(import_from_child.value)
2805                 else:
2806                     assert import_from_child.type == syms.import_as_names
2807                     for leaf in import_from_child.children:
2808                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2809                             imports.add(leaf.value)
2810         else:
2811             break
2812     return imports
2813
2814
2815 def gen_python_files_in_dir(
2816     path: Path,
2817     root: Path,
2818     include: Pattern[str],
2819     exclude: Pattern[str],
2820     report: "Report",
2821 ) -> Iterator[Path]:
2822     """Generate all files under `path` whose paths are not excluded by the
2823     `exclude` regex, but are included by the `include` regex.
2824
2825     `report` is where output about exclusions goes.
2826     """
2827     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
2828     for child in path.iterdir():
2829         normalized_path = child.resolve().relative_to(root).as_posix()
2830         if child.is_dir():
2831             normalized_path += "/"
2832         exclude_match = exclude.search(normalized_path)
2833         if exclude_match and exclude_match.group(0):
2834             report.path_ignored(child, f"matches --exclude={exclude.pattern}")
2835             continue
2836
2837         if child.is_dir():
2838             yield from gen_python_files_in_dir(child, root, include, exclude, report)
2839
2840         elif child.is_file():
2841             include_match = include.search(normalized_path)
2842             if include_match:
2843                 yield child
2844
2845
2846 def find_project_root(srcs: List[str]) -> Path:
2847     """Return a directory containing .git, .hg, or pyproject.toml.
2848
2849     That directory can be one of the directories passed in `srcs` or their
2850     common parent.
2851
2852     If no directory in the tree contains a marker that would specify it's the
2853     project root, the root of the file system is returned.
2854     """
2855     if not srcs:
2856         return Path("/").resolve()
2857
2858     common_base = min(Path(src).resolve() for src in srcs)
2859     if common_base.is_dir():
2860         # Append a fake file so `parents` below returns `common_base_dir`, too.
2861         common_base /= "fake-file"
2862     for directory in common_base.parents:
2863         if (directory / ".git").is_dir():
2864             return directory
2865
2866         if (directory / ".hg").is_dir():
2867             return directory
2868
2869         if (directory / "pyproject.toml").is_file():
2870             return directory
2871
2872     return directory
2873
2874
2875 @dataclass
2876 class Report:
2877     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2878
2879     check: bool = False
2880     quiet: bool = False
2881     verbose: bool = False
2882     change_count: int = 0
2883     same_count: int = 0
2884     failure_count: int = 0
2885
2886     def done(self, src: Path, changed: Changed) -> None:
2887         """Increment the counter for successful reformatting. Write out a message."""
2888         if changed is Changed.YES:
2889             reformatted = "would reformat" if self.check else "reformatted"
2890             if self.verbose or not self.quiet:
2891                 out(f"{reformatted} {src}")
2892             self.change_count += 1
2893         else:
2894             if self.verbose:
2895                 if changed is Changed.NO:
2896                     msg = f"{src} already well formatted, good job."
2897                 else:
2898                     msg = f"{src} wasn't modified on disk since last run."
2899                 out(msg, bold=False)
2900             self.same_count += 1
2901
2902     def failed(self, src: Path, message: str) -> None:
2903         """Increment the counter for failed reformatting. Write out a message."""
2904         err(f"error: cannot format {src}: {message}")
2905         self.failure_count += 1
2906
2907     def path_ignored(self, path: Path, message: str) -> None:
2908         if self.verbose:
2909             out(f"{path} ignored: {message}", bold=False)
2910
2911     @property
2912     def return_code(self) -> int:
2913         """Return the exit code that the app should use.
2914
2915         This considers the current state of changed files and failures:
2916         - if there were any failures, return 123;
2917         - if any files were changed and --check is being used, return 1;
2918         - otherwise return 0.
2919         """
2920         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2921         # 126 we have special returncodes reserved by the shell.
2922         if self.failure_count:
2923             return 123
2924
2925         elif self.change_count and self.check:
2926             return 1
2927
2928         return 0
2929
2930     def __str__(self) -> str:
2931         """Render a color report of the current state.
2932
2933         Use `click.unstyle` to remove colors.
2934         """
2935         if self.check:
2936             reformatted = "would be reformatted"
2937             unchanged = "would be left unchanged"
2938             failed = "would fail to reformat"
2939         else:
2940             reformatted = "reformatted"
2941             unchanged = "left unchanged"
2942             failed = "failed to reformat"
2943         report = []
2944         if self.change_count:
2945             s = "s" if self.change_count > 1 else ""
2946             report.append(
2947                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2948             )
2949         if self.same_count:
2950             s = "s" if self.same_count > 1 else ""
2951             report.append(f"{self.same_count} file{s} {unchanged}")
2952         if self.failure_count:
2953             s = "s" if self.failure_count > 1 else ""
2954             report.append(
2955                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2956             )
2957         return ", ".join(report) + "."
2958
2959
2960 def assert_equivalent(src: str, dst: str) -> None:
2961     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2962
2963     import ast
2964     import traceback
2965
2966     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2967         """Simple visitor generating strings to compare ASTs by content."""
2968         yield f"{'  ' * depth}{node.__class__.__name__}("
2969
2970         for field in sorted(node._fields):
2971             try:
2972                 value = getattr(node, field)
2973             except AttributeError:
2974                 continue
2975
2976             yield f"{'  ' * (depth+1)}{field}="
2977
2978             if isinstance(value, list):
2979                 for item in value:
2980                     if isinstance(item, ast.AST):
2981                         yield from _v(item, depth + 2)
2982
2983             elif isinstance(value, ast.AST):
2984                 yield from _v(value, depth + 2)
2985
2986             else:
2987                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2988
2989         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2990
2991     try:
2992         src_ast = ast.parse(src)
2993     except Exception as exc:
2994         major, minor = sys.version_info[:2]
2995         raise AssertionError(
2996             f"cannot use --safe with this file; failed to parse source file "
2997             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2998             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2999         )
3000
3001     try:
3002         dst_ast = ast.parse(dst)
3003     except Exception as exc:
3004         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
3005         raise AssertionError(
3006             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
3007             f"Please report a bug on https://github.com/ambv/black/issues.  "
3008             f"This invalid output might be helpful: {log}"
3009         ) from None
3010
3011     src_ast_str = "\n".join(_v(src_ast))
3012     dst_ast_str = "\n".join(_v(dst_ast))
3013     if src_ast_str != dst_ast_str:
3014         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
3015         raise AssertionError(
3016             f"INTERNAL ERROR: Black produced code that is not equivalent to "
3017             f"the source.  "
3018             f"Please report a bug on https://github.com/ambv/black/issues.  "
3019             f"This diff might be helpful: {log}"
3020         ) from None
3021
3022
3023 def assert_stable(
3024     src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
3025 ) -> None:
3026     """Raise AssertionError if `dst` reformats differently the second time."""
3027     newdst = format_str(dst, line_length=line_length, mode=mode)
3028     if dst != newdst:
3029         log = dump_to_file(
3030             diff(src, dst, "source", "first pass"),
3031             diff(dst, newdst, "first pass", "second pass"),
3032         )
3033         raise AssertionError(
3034             f"INTERNAL ERROR: Black produced different code on the second pass "
3035             f"of the formatter.  "
3036             f"Please report a bug on https://github.com/ambv/black/issues.  "
3037             f"This diff might be helpful: {log}"
3038         ) from None
3039
3040
3041 def dump_to_file(*output: str) -> str:
3042     """Dump `output` to a temporary file. Return path to the file."""
3043     import tempfile
3044
3045     with tempfile.NamedTemporaryFile(
3046         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
3047     ) as f:
3048         for lines in output:
3049             f.write(lines)
3050             if lines and lines[-1] != "\n":
3051                 f.write("\n")
3052     return f.name
3053
3054
3055 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
3056     """Return a unified diff string between strings `a` and `b`."""
3057     import difflib
3058
3059     a_lines = [line + "\n" for line in a.split("\n")]
3060     b_lines = [line + "\n" for line in b.split("\n")]
3061     return "".join(
3062         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
3063     )
3064
3065
3066 def cancel(tasks: Iterable[asyncio.Task]) -> None:
3067     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
3068     err("Aborted!")
3069     for task in tasks:
3070         task.cancel()
3071
3072
3073 def shutdown(loop: BaseEventLoop) -> None:
3074     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
3075     try:
3076         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
3077         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
3078         if not to_cancel:
3079             return
3080
3081         for task in to_cancel:
3082             task.cancel()
3083         loop.run_until_complete(
3084             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
3085         )
3086     finally:
3087         # `concurrent.futures.Future` objects cannot be cancelled once they
3088         # are already running. There might be some when the `shutdown()` happened.
3089         # Silence their logger's spew about the event loop being closed.
3090         cf_logger = logging.getLogger("concurrent.futures")
3091         cf_logger.setLevel(logging.CRITICAL)
3092         loop.close()
3093
3094
3095 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
3096     """Replace `regex` with `replacement` twice on `original`.
3097
3098     This is used by string normalization to perform replaces on
3099     overlapping matches.
3100     """
3101     return regex.sub(replacement, regex.sub(replacement, original))
3102
3103
3104 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
3105     """Like `reversed(enumerate(sequence))` if that were possible."""
3106     index = len(sequence) - 1
3107     for element in reversed(sequence):
3108         yield (index, element)
3109         index -= 1
3110
3111
3112 def enumerate_with_length(
3113     line: Line, reversed: bool = False
3114 ) -> Iterator[Tuple[Index, Leaf, int]]:
3115     """Return an enumeration of leaves with their length.
3116
3117     Stops prematurely on multiline strings and standalone comments.
3118     """
3119     op = cast(
3120         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
3121         enumerate_reversed if reversed else enumerate,
3122     )
3123     for index, leaf in op(line.leaves):
3124         length = len(leaf.prefix) + len(leaf.value)
3125         if "\n" in leaf.value:
3126             return  # Multiline strings, we can't continue.
3127
3128         comment: Optional[Leaf]
3129         for comment in line.comments_after(leaf, index):
3130             length += len(comment.value)
3131
3132         yield index, leaf, length
3133
3134
3135 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
3136     """Return True if `line` is no longer than `line_length`.
3137
3138     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
3139     """
3140     if not line_str:
3141         line_str = str(line).strip("\n")
3142     return (
3143         len(line_str) <= line_length
3144         and "\n" not in line_str  # multiline strings
3145         and not line.contains_standalone_comments()
3146     )
3147
3148
3149 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
3150     """Does `line` have a shape safe to reformat without optional parens around it?
3151
3152     Returns True for only a subset of potentially nice looking formattings but
3153     the point is to not return false positives that end up producing lines that
3154     are too long.
3155     """
3156     bt = line.bracket_tracker
3157     if not bt.delimiters:
3158         # Without delimiters the optional parentheses are useless.
3159         return True
3160
3161     max_priority = bt.max_delimiter_priority()
3162     if bt.delimiter_count_with_priority(max_priority) > 1:
3163         # With more than one delimiter of a kind the optional parentheses read better.
3164         return False
3165
3166     if max_priority == DOT_PRIORITY:
3167         # A single stranded method call doesn't require optional parentheses.
3168         return True
3169
3170     assert len(line.leaves) >= 2, "Stranded delimiter"
3171
3172     first = line.leaves[0]
3173     second = line.leaves[1]
3174     penultimate = line.leaves[-2]
3175     last = line.leaves[-1]
3176
3177     # With a single delimiter, omit if the expression starts or ends with
3178     # a bracket.
3179     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3180         remainder = False
3181         length = 4 * line.depth
3182         for _index, leaf, leaf_length in enumerate_with_length(line):
3183             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3184                 remainder = True
3185             if remainder:
3186                 length += leaf_length
3187                 if length > line_length:
3188                     break
3189
3190                 if leaf.type in OPENING_BRACKETS:
3191                     # There are brackets we can further split on.
3192                     remainder = False
3193
3194         else:
3195             # checked the entire string and line length wasn't exceeded
3196             if len(line.leaves) == _index + 1:
3197                 return True
3198
3199         # Note: we are not returning False here because a line might have *both*
3200         # a leading opening bracket and a trailing closing bracket.  If the
3201         # opening bracket doesn't match our rule, maybe the closing will.
3202
3203     if (
3204         last.type == token.RPAR
3205         or last.type == token.RBRACE
3206         or (
3207             # don't use indexing for omitting optional parentheses;
3208             # it looks weird
3209             last.type == token.RSQB
3210             and last.parent
3211             and last.parent.type != syms.trailer
3212         )
3213     ):
3214         if penultimate.type in OPENING_BRACKETS:
3215             # Empty brackets don't help.
3216             return False
3217
3218         if is_multiline_string(first):
3219             # Additional wrapping of a multiline string in this situation is
3220             # unnecessary.
3221             return True
3222
3223         length = 4 * line.depth
3224         seen_other_brackets = False
3225         for _index, leaf, leaf_length in enumerate_with_length(line):
3226             length += leaf_length
3227             if leaf is last.opening_bracket:
3228                 if seen_other_brackets or length <= line_length:
3229                     return True
3230
3231             elif leaf.type in OPENING_BRACKETS:
3232                 # There are brackets we can further split on.
3233                 seen_other_brackets = True
3234
3235     return False
3236
3237
3238 def get_cache_file(line_length: int, mode: FileMode) -> Path:
3239     pyi = bool(mode & FileMode.PYI)
3240     py36 = bool(mode & FileMode.PYTHON36)
3241     return (
3242         CACHE_DIR
3243         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
3244     )
3245
3246
3247 def read_cache(line_length: int, mode: FileMode) -> Cache:
3248     """Read the cache if it exists and is well formed.
3249
3250     If it is not well formed, the call to write_cache later should resolve the issue.
3251     """
3252     cache_file = get_cache_file(line_length, mode)
3253     if not cache_file.exists():
3254         return {}
3255
3256     with cache_file.open("rb") as fobj:
3257         try:
3258             cache: Cache = pickle.load(fobj)
3259         except pickle.UnpicklingError:
3260             return {}
3261
3262     return cache
3263
3264
3265 def get_cache_info(path: Path) -> CacheInfo:
3266     """Return the information used to check if a file is already formatted or not."""
3267     stat = path.stat()
3268     return stat.st_mtime, stat.st_size
3269
3270
3271 def filter_cached(
3272     cache: Cache, sources: Iterable[Path]
3273 ) -> Tuple[List[Path], List[Path]]:
3274     """Split a list of paths into two.
3275
3276     The first list contains paths of files that modified on disk or are not in the
3277     cache. The other list contains paths to non-modified files.
3278     """
3279     todo, done = [], []
3280     for src in sources:
3281         src = src.resolve()
3282         if cache.get(src) != get_cache_info(src):
3283             todo.append(src)
3284         else:
3285             done.append(src)
3286     return todo, done
3287
3288
3289 def write_cache(
3290     cache: Cache, sources: List[Path], line_length: int, mode: FileMode
3291 ) -> None:
3292     """Update the cache file."""
3293     cache_file = get_cache_file(line_length, mode)
3294     try:
3295         if not CACHE_DIR.exists():
3296             CACHE_DIR.mkdir(parents=True)
3297         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3298         with cache_file.open("wb") as fobj:
3299             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3300     except OSError:
3301         pass
3302
3303
3304 if __name__ == "__main__":
3305     main()