]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

README updates (#235)
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b0"
48 DEFAULT_LINE_LENGTH = 88
49 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
50
51
52 # types
53 FileContent = str
54 Encoding = str
55 Depth = int
56 NodeType = int
57 LeafID = int
58 Priority = int
59 Index = int
60 LN = Union[Leaf, Node]
61 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 Timestamp = float
63 FileSize = int
64 CacheInfo = Tuple[Timestamp, FileSize]
65 Cache = Dict[Path, CacheInfo]
66 out = partial(click.secho, bold=True, err=True)
67 err = partial(click.secho, fg="red", err=True)
68
69 pygram.initialize(CACHE_DIR)
70 syms = pygram.python_symbols
71
72
73 class NothingChanged(UserWarning):
74     """Raised by :func:`format_file` when reformatted code is the same as source."""
75
76
77 class CannotSplit(Exception):
78     """A readable split that fits the allotted line length is impossible.
79
80     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
81     :func:`delimiter_split`.
82     """
83
84
85 class FormatError(Exception):
86     """Base exception for `# fmt: on` and `# fmt: off` handling.
87
88     It holds the number of bytes of the prefix consumed before the format
89     control comment appeared.
90     """
91
92     def __init__(self, consumed: int) -> None:
93         super().__init__(consumed)
94         self.consumed = consumed
95
96     def trim_prefix(self, leaf: Leaf) -> None:
97         leaf.prefix = leaf.prefix[self.consumed :]
98
99     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
100         """Returns a new Leaf from the consumed part of the prefix."""
101         unformatted_prefix = leaf.prefix[: self.consumed]
102         return Leaf(token.NEWLINE, unformatted_prefix)
103
104
105 class FormatOn(FormatError):
106     """Found a comment like `# fmt: on` in the file."""
107
108
109 class FormatOff(FormatError):
110     """Found a comment like `# fmt: off` in the file."""
111
112
113 class WriteBack(Enum):
114     NO = 0
115     YES = 1
116     DIFF = 2
117
118
119 class Changed(Enum):
120     NO = 0
121     CACHED = 1
122     YES = 2
123
124
125 @click.command()
126 @click.option(
127     "-l",
128     "--line-length",
129     type=int,
130     default=DEFAULT_LINE_LENGTH,
131     help="How many character per line to allow.",
132     show_default=True,
133 )
134 @click.option(
135     "--check",
136     is_flag=True,
137     help=(
138         "Don't write the files back, just return the status.  Return code 0 "
139         "means nothing would change.  Return code 1 means some files would be "
140         "reformatted.  Return code 123 means there was an internal error."
141     ),
142 )
143 @click.option(
144     "--diff",
145     is_flag=True,
146     help="Don't write the files back, just output a diff for each file on stdout.",
147 )
148 @click.option(
149     "--fast/--safe",
150     is_flag=True,
151     help="If --fast given, skip temporary sanity checks. [default: --safe]",
152 )
153 @click.option(
154     "-q",
155     "--quiet",
156     is_flag=True,
157     help=(
158         "Don't emit non-error messages to stderr. Errors are still emitted, "
159         "silence those with 2>/dev/null."
160     ),
161 )
162 @click.version_option(version=__version__)
163 @click.argument(
164     "src",
165     nargs=-1,
166     type=click.Path(
167         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
168     ),
169 )
170 @click.pass_context
171 def main(
172     ctx: click.Context,
173     line_length: int,
174     check: bool,
175     diff: bool,
176     fast: bool,
177     quiet: bool,
178     src: List[str],
179 ) -> None:
180     """The uncompromising code formatter."""
181     sources: List[Path] = []
182     for s in src:
183         p = Path(s)
184         if p.is_dir():
185             sources.extend(gen_python_files_in_dir(p))
186         elif p.is_file():
187             # if a file was explicitly given, we don't care about its extension
188             sources.append(p)
189         elif s == "-":
190             sources.append(Path("-"))
191         else:
192             err(f"invalid path: {s}")
193
194     if check and not diff:
195         write_back = WriteBack.NO
196     elif diff:
197         write_back = WriteBack.DIFF
198     else:
199         write_back = WriteBack.YES
200     report = Report(check=check, quiet=quiet)
201     if len(sources) == 0:
202         out("No paths given. Nothing to do 😴")
203         ctx.exit(0)
204         return
205
206     elif len(sources) == 1:
207         reformat_one(sources[0], line_length, fast, write_back, report)
208     else:
209         loop = asyncio.get_event_loop()
210         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
211         try:
212             loop.run_until_complete(
213                 schedule_formatting(
214                     sources, line_length, fast, write_back, report, loop, executor
215                 )
216             )
217         finally:
218             shutdown(loop)
219         if not quiet:
220             out("All done! ✨ 🍰 ✨")
221             click.echo(str(report))
222     ctx.exit(report.return_code)
223
224
225 def reformat_one(
226     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
227 ) -> None:
228     """Reformat a single file under `src` without spawning child processes.
229
230     If `quiet` is True, non-error messages are not output. `line_length`,
231     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
232     """
233     try:
234         changed = Changed.NO
235         if not src.is_file() and str(src) == "-":
236             if format_stdin_to_stdout(
237                 line_length=line_length, fast=fast, write_back=write_back
238             ):
239                 changed = Changed.YES
240         else:
241             cache: Cache = {}
242             if write_back != WriteBack.DIFF:
243                 cache = read_cache(line_length)
244                 src = src.resolve()
245                 if src in cache and cache[src] == get_cache_info(src):
246                     changed = Changed.CACHED
247             if changed is not Changed.CACHED and format_file_in_place(
248                 src, line_length=line_length, fast=fast, write_back=write_back
249             ):
250                 changed = Changed.YES
251             if write_back == WriteBack.YES and changed is not Changed.NO:
252                 write_cache(cache, [src], line_length)
253         report.done(src, changed)
254     except Exception as exc:
255         report.failed(src, str(exc))
256
257
258 async def schedule_formatting(
259     sources: List[Path],
260     line_length: int,
261     fast: bool,
262     write_back: WriteBack,
263     report: "Report",
264     loop: BaseEventLoop,
265     executor: Executor,
266 ) -> None:
267     """Run formatting of `sources` in parallel using the provided `executor`.
268
269     (Use ProcessPoolExecutors for actual parallelism.)
270
271     `line_length`, `write_back`, and `fast` options are passed to
272     :func:`format_file_in_place`.
273     """
274     cache: Cache = {}
275     if write_back != WriteBack.DIFF:
276         cache = read_cache(line_length)
277         sources, cached = filter_cached(cache, sources)
278         for src in cached:
279             report.done(src, Changed.CACHED)
280     cancelled = []
281     formatted = []
282     if sources:
283         lock = None
284         if write_back == WriteBack.DIFF:
285             # For diff output, we need locks to ensure we don't interleave output
286             # from different processes.
287             manager = Manager()
288             lock = manager.Lock()
289         tasks = {
290             loop.run_in_executor(
291                 executor, format_file_in_place, src, line_length, fast, write_back, lock
292             ): src
293             for src in sorted(sources)
294         }
295         pending: Iterable[asyncio.Task] = tasks.keys()
296         try:
297             loop.add_signal_handler(signal.SIGINT, cancel, pending)
298             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
299         except NotImplementedError:
300             # There are no good alternatives for these on Windows
301             pass
302         while pending:
303             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
304             for task in done:
305                 src = tasks.pop(task)
306                 if task.cancelled():
307                     cancelled.append(task)
308                 elif task.exception():
309                     report.failed(src, str(task.exception()))
310                 else:
311                     formatted.append(src)
312                     report.done(src, Changed.YES if task.result() else Changed.NO)
313     if cancelled:
314         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
315     if write_back == WriteBack.YES and formatted:
316         write_cache(cache, formatted, line_length)
317
318
319 def format_file_in_place(
320     src: Path,
321     line_length: int,
322     fast: bool,
323     write_back: WriteBack = WriteBack.NO,
324     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
325 ) -> bool:
326     """Format file under `src` path. Return True if changed.
327
328     If `write_back` is True, write reformatted code back to stdout.
329     `line_length` and `fast` options are passed to :func:`format_file_contents`.
330     """
331     is_pyi = src.suffix == ".pyi"
332
333     with tokenize.open(src) as src_buffer:
334         src_contents = src_buffer.read()
335     try:
336         dst_contents = format_file_contents(
337             src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
338         )
339     except NothingChanged:
340         return False
341
342     if write_back == write_back.YES:
343         with open(src, "w", encoding=src_buffer.encoding) as f:
344             f.write(dst_contents)
345     elif write_back == write_back.DIFF:
346         src_name = f"{src}  (original)"
347         dst_name = f"{src}  (formatted)"
348         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
349         if lock:
350             lock.acquire()
351         try:
352             sys.stdout.write(diff_contents)
353         finally:
354             if lock:
355                 lock.release()
356     return True
357
358
359 def format_stdin_to_stdout(
360     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
361 ) -> bool:
362     """Format file on stdin. Return True if changed.
363
364     If `write_back` is True, write reformatted code back to stdout.
365     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
366     """
367     src = sys.stdin.read()
368     dst = src
369     try:
370         dst = format_file_contents(src, line_length=line_length, fast=fast)
371         return True
372
373     except NothingChanged:
374         return False
375
376     finally:
377         if write_back == WriteBack.YES:
378             sys.stdout.write(dst)
379         elif write_back == WriteBack.DIFF:
380             src_name = "<stdin>  (original)"
381             dst_name = "<stdin>  (formatted)"
382             sys.stdout.write(diff(src, dst, src_name, dst_name))
383
384
385 def format_file_contents(
386     src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
387 ) -> FileContent:
388     """Reformat contents a file and return new contents.
389
390     If `fast` is False, additionally confirm that the reformatted code is
391     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
392     `line_length` is passed to :func:`format_str`.
393     """
394     if src_contents.strip() == "":
395         raise NothingChanged
396
397     dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
398     if src_contents == dst_contents:
399         raise NothingChanged
400
401     if not fast:
402         assert_equivalent(src_contents, dst_contents)
403         assert_stable(
404             src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
405         )
406     return dst_contents
407
408
409 def format_str(
410     src_contents: str, line_length: int, *, is_pyi: bool = False
411 ) -> FileContent:
412     """Reformat a string and return new contents.
413
414     `line_length` determines how many characters per line are allowed.
415     """
416     src_node = lib2to3_parse(src_contents)
417     dst_contents = ""
418     future_imports = get_future_imports(src_node)
419     elt = EmptyLineTracker(is_pyi=is_pyi)
420     py36 = is_python36(src_node)
421     lines = LineGenerator(
422         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
423     )
424     empty_line = Line()
425     after = 0
426     for current_line in lines.visit(src_node):
427         for _ in range(after):
428             dst_contents += str(empty_line)
429         before, after = elt.maybe_empty_lines(current_line)
430         for _ in range(before):
431             dst_contents += str(empty_line)
432         for line in split_line(current_line, line_length=line_length, py36=py36):
433             dst_contents += str(line)
434     return dst_contents
435
436
437 GRAMMARS = [
438     pygram.python_grammar_no_print_statement_no_exec_statement,
439     pygram.python_grammar_no_print_statement,
440     pygram.python_grammar,
441 ]
442
443
444 def lib2to3_parse(src_txt: str) -> Node:
445     """Given a string with source, return the lib2to3 Node."""
446     grammar = pygram.python_grammar_no_print_statement
447     if src_txt[-1] != "\n":
448         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
449         src_txt += nl
450     for grammar in GRAMMARS:
451         drv = driver.Driver(grammar, pytree.convert)
452         try:
453             result = drv.parse_string(src_txt, True)
454             break
455
456         except ParseError as pe:
457             lineno, column = pe.context[1]
458             lines = src_txt.splitlines()
459             try:
460                 faulty_line = lines[lineno - 1]
461             except IndexError:
462                 faulty_line = "<line number missing in source>"
463             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
464     else:
465         raise exc from None
466
467     if isinstance(result, Leaf):
468         result = Node(syms.file_input, [result])
469     return result
470
471
472 def lib2to3_unparse(node: Node) -> str:
473     """Given a lib2to3 node, return its string representation."""
474     code = str(node)
475     return code
476
477
478 T = TypeVar("T")
479
480
481 class Visitor(Generic[T]):
482     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
483
484     def visit(self, node: LN) -> Iterator[T]:
485         """Main method to visit `node` and its children.
486
487         It tries to find a `visit_*()` method for the given `node.type`, like
488         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
489         If no dedicated `visit_*()` method is found, chooses `visit_default()`
490         instead.
491
492         Then yields objects of type `T` from the selected visitor.
493         """
494         if node.type < 256:
495             name = token.tok_name[node.type]
496         else:
497             name = type_repr(node.type)
498         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
499
500     def visit_default(self, node: LN) -> Iterator[T]:
501         """Default `visit_*()` implementation. Recurses to children of `node`."""
502         if isinstance(node, Node):
503             for child in node.children:
504                 yield from self.visit(child)
505
506
507 @dataclass
508 class DebugVisitor(Visitor[T]):
509     tree_depth: int = 0
510
511     def visit_default(self, node: LN) -> Iterator[T]:
512         indent = " " * (2 * self.tree_depth)
513         if isinstance(node, Node):
514             _type = type_repr(node.type)
515             out(f"{indent}{_type}", fg="yellow")
516             self.tree_depth += 1
517             for child in node.children:
518                 yield from self.visit(child)
519
520             self.tree_depth -= 1
521             out(f"{indent}/{_type}", fg="yellow", bold=False)
522         else:
523             _type = token.tok_name.get(node.type, str(node.type))
524             out(f"{indent}{_type}", fg="blue", nl=False)
525             if node.prefix:
526                 # We don't have to handle prefixes for `Node` objects since
527                 # that delegates to the first child anyway.
528                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
529             out(f" {node.value!r}", fg="blue", bold=False)
530
531     @classmethod
532     def show(cls, code: str) -> None:
533         """Pretty-print the lib2to3 AST of a given string of `code`.
534
535         Convenience method for debugging.
536         """
537         v: DebugVisitor[None] = DebugVisitor()
538         list(v.visit(lib2to3_parse(code)))
539
540
541 KEYWORDS = set(keyword.kwlist)
542 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
543 FLOW_CONTROL = {"return", "raise", "break", "continue"}
544 STATEMENT = {
545     syms.if_stmt,
546     syms.while_stmt,
547     syms.for_stmt,
548     syms.try_stmt,
549     syms.except_clause,
550     syms.with_stmt,
551     syms.funcdef,
552     syms.classdef,
553 }
554 STANDALONE_COMMENT = 153
555 LOGIC_OPERATORS = {"and", "or"}
556 COMPARATORS = {
557     token.LESS,
558     token.GREATER,
559     token.EQEQUAL,
560     token.NOTEQUAL,
561     token.LESSEQUAL,
562     token.GREATEREQUAL,
563 }
564 MATH_OPERATORS = {
565     token.VBAR,
566     token.CIRCUMFLEX,
567     token.AMPER,
568     token.LEFTSHIFT,
569     token.RIGHTSHIFT,
570     token.PLUS,
571     token.MINUS,
572     token.STAR,
573     token.SLASH,
574     token.DOUBLESLASH,
575     token.PERCENT,
576     token.AT,
577     token.TILDE,
578     token.DOUBLESTAR,
579 }
580 STARS = {token.STAR, token.DOUBLESTAR}
581 VARARGS_PARENTS = {
582     syms.arglist,
583     syms.argument,  # double star in arglist
584     syms.trailer,  # single argument to call
585     syms.typedargslist,
586     syms.varargslist,  # lambdas
587 }
588 UNPACKING_PARENTS = {
589     syms.atom,  # single element of a list or set literal
590     syms.dictsetmaker,
591     syms.listmaker,
592     syms.testlist_gexp,
593 }
594 TEST_DESCENDANTS = {
595     syms.test,
596     syms.lambdef,
597     syms.or_test,
598     syms.and_test,
599     syms.not_test,
600     syms.comparison,
601     syms.star_expr,
602     syms.expr,
603     syms.xor_expr,
604     syms.and_expr,
605     syms.shift_expr,
606     syms.arith_expr,
607     syms.trailer,
608     syms.term,
609     syms.power,
610 }
611 ASSIGNMENTS = {
612     "=",
613     "+=",
614     "-=",
615     "*=",
616     "@=",
617     "/=",
618     "%=",
619     "&=",
620     "|=",
621     "^=",
622     "<<=",
623     ">>=",
624     "**=",
625     "//=",
626 }
627 COMPREHENSION_PRIORITY = 20
628 COMMA_PRIORITY = 18
629 TERNARY_PRIORITY = 16
630 LOGIC_PRIORITY = 14
631 STRING_PRIORITY = 12
632 COMPARATOR_PRIORITY = 10
633 MATH_PRIORITIES = {
634     token.VBAR: 9,
635     token.CIRCUMFLEX: 8,
636     token.AMPER: 7,
637     token.LEFTSHIFT: 6,
638     token.RIGHTSHIFT: 6,
639     token.PLUS: 5,
640     token.MINUS: 5,
641     token.STAR: 4,
642     token.SLASH: 4,
643     token.DOUBLESLASH: 4,
644     token.PERCENT: 4,
645     token.AT: 4,
646     token.TILDE: 3,
647     token.DOUBLESTAR: 2,
648 }
649 DOT_PRIORITY = 1
650
651
652 @dataclass
653 class BracketTracker:
654     """Keeps track of brackets on a line."""
655
656     depth: int = 0
657     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
658     delimiters: Dict[LeafID, Priority] = Factory(dict)
659     previous: Optional[Leaf] = None
660     _for_loop_variable: int = 0
661     _lambda_arguments: int = 0
662
663     def mark(self, leaf: Leaf) -> None:
664         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
665
666         All leaves receive an int `bracket_depth` field that stores how deep
667         within brackets a given leaf is. 0 means there are no enclosing brackets
668         that started on this line.
669
670         If a leaf is itself a closing bracket, it receives an `opening_bracket`
671         field that it forms a pair with. This is a one-directional link to
672         avoid reference cycles.
673
674         If a leaf is a delimiter (a token on which Black can split the line if
675         needed) and it's on depth 0, its `id()` is stored in the tracker's
676         `delimiters` field.
677         """
678         if leaf.type == token.COMMENT:
679             return
680
681         self.maybe_decrement_after_for_loop_variable(leaf)
682         self.maybe_decrement_after_lambda_arguments(leaf)
683         if leaf.type in CLOSING_BRACKETS:
684             self.depth -= 1
685             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
686             leaf.opening_bracket = opening_bracket
687         leaf.bracket_depth = self.depth
688         if self.depth == 0:
689             delim = is_split_before_delimiter(leaf, self.previous)
690             if delim and self.previous is not None:
691                 self.delimiters[id(self.previous)] = delim
692             else:
693                 delim = is_split_after_delimiter(leaf, self.previous)
694                 if delim:
695                     self.delimiters[id(leaf)] = delim
696         if leaf.type in OPENING_BRACKETS:
697             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
698             self.depth += 1
699         self.previous = leaf
700         self.maybe_increment_lambda_arguments(leaf)
701         self.maybe_increment_for_loop_variable(leaf)
702
703     def any_open_brackets(self) -> bool:
704         """Return True if there is an yet unmatched open bracket on the line."""
705         return bool(self.bracket_match)
706
707     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
708         """Return the highest priority of a delimiter found on the line.
709
710         Values are consistent with what `is_split_*_delimiter()` return.
711         Raises ValueError on no delimiters.
712         """
713         return max(v for k, v in self.delimiters.items() if k not in exclude)
714
715     def delimiter_count_with_priority(self, priority: int = 0) -> int:
716         """Return the number of delimiters with the given `priority`.
717
718         If no `priority` is passed, defaults to max priority on the line.
719         """
720         if not self.delimiters:
721             return 0
722
723         priority = priority or self.max_delimiter_priority()
724         return sum(1 for p in self.delimiters.values() if p == priority)
725
726     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
727         """In a for loop, or comprehension, the variables are often unpacks.
728
729         To avoid splitting on the comma in this situation, increase the depth of
730         tokens between `for` and `in`.
731         """
732         if leaf.type == token.NAME and leaf.value == "for":
733             self.depth += 1
734             self._for_loop_variable += 1
735             return True
736
737         return False
738
739     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
740         """See `maybe_increment_for_loop_variable` above for explanation."""
741         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
742             self.depth -= 1
743             self._for_loop_variable -= 1
744             return True
745
746         return False
747
748     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
749         """In a lambda expression, there might be more than one argument.
750
751         To avoid splitting on the comma in this situation, increase the depth of
752         tokens between `lambda` and `:`.
753         """
754         if leaf.type == token.NAME and leaf.value == "lambda":
755             self.depth += 1
756             self._lambda_arguments += 1
757             return True
758
759         return False
760
761     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
762         """See `maybe_increment_lambda_arguments` above for explanation."""
763         if self._lambda_arguments and leaf.type == token.COLON:
764             self.depth -= 1
765             self._lambda_arguments -= 1
766             return True
767
768         return False
769
770     def get_open_lsqb(self) -> Optional[Leaf]:
771         """Return the most recent opening square bracket (if any)."""
772         return self.bracket_match.get((self.depth - 1, token.RSQB))
773
774
775 @dataclass
776 class Line:
777     """Holds leaves and comments. Can be printed with `str(line)`."""
778
779     depth: int = 0
780     leaves: List[Leaf] = Factory(list)
781     comments: List[Tuple[Index, Leaf]] = Factory(list)
782     bracket_tracker: BracketTracker = Factory(BracketTracker)
783     inside_brackets: bool = False
784     should_explode: bool = False
785
786     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
787         """Add a new `leaf` to the end of the line.
788
789         Unless `preformatted` is True, the `leaf` will receive a new consistent
790         whitespace prefix and metadata applied by :class:`BracketTracker`.
791         Trailing commas are maybe removed, unpacked for loop variables are
792         demoted from being delimiters.
793
794         Inline comments are put aside.
795         """
796         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
797         if not has_value:
798             return
799
800         if token.COLON == leaf.type and self.is_class_paren_empty:
801             del self.leaves[-2:]
802         if self.leaves and not preformatted:
803             # Note: at this point leaf.prefix should be empty except for
804             # imports, for which we only preserve newlines.
805             leaf.prefix += whitespace(
806                 leaf, complex_subscript=self.is_complex_subscript(leaf)
807             )
808         if self.inside_brackets or not preformatted:
809             self.bracket_tracker.mark(leaf)
810             self.maybe_remove_trailing_comma(leaf)
811         if not self.append_comment(leaf):
812             self.leaves.append(leaf)
813
814     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
815         """Like :func:`append()` but disallow invalid standalone comment structure.
816
817         Raises ValueError when any `leaf` is appended after a standalone comment
818         or when a standalone comment is not the first leaf on the line.
819         """
820         if self.bracket_tracker.depth == 0:
821             if self.is_comment:
822                 raise ValueError("cannot append to standalone comments")
823
824             if self.leaves and leaf.type == STANDALONE_COMMENT:
825                 raise ValueError(
826                     "cannot append standalone comments to a populated line"
827                 )
828
829         self.append(leaf, preformatted=preformatted)
830
831     @property
832     def is_comment(self) -> bool:
833         """Is this line a standalone comment?"""
834         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
835
836     @property
837     def is_decorator(self) -> bool:
838         """Is this line a decorator?"""
839         return bool(self) and self.leaves[0].type == token.AT
840
841     @property
842     def is_import(self) -> bool:
843         """Is this an import line?"""
844         return bool(self) and is_import(self.leaves[0])
845
846     @property
847     def is_class(self) -> bool:
848         """Is this line a class definition?"""
849         return (
850             bool(self)
851             and self.leaves[0].type == token.NAME
852             and self.leaves[0].value == "class"
853         )
854
855     @property
856     def is_stub_class(self) -> bool:
857         """Is this line a class definition with a body consisting only of "..."?"""
858         return self.is_class and self.leaves[-3:] == [
859             Leaf(token.DOT, ".") for _ in range(3)
860         ]
861
862     @property
863     def is_def(self) -> bool:
864         """Is this a function definition? (Also returns True for async defs.)"""
865         try:
866             first_leaf = self.leaves[0]
867         except IndexError:
868             return False
869
870         try:
871             second_leaf: Optional[Leaf] = self.leaves[1]
872         except IndexError:
873             second_leaf = None
874         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
875             first_leaf.type == token.ASYNC
876             and second_leaf is not None
877             and second_leaf.type == token.NAME
878             and second_leaf.value == "def"
879         )
880
881     @property
882     def is_flow_control(self) -> bool:
883         """Is this line a flow control statement?
884
885         Those are `return`, `raise`, `break`, and `continue`.
886         """
887         return (
888             bool(self)
889             and self.leaves[0].type == token.NAME
890             and self.leaves[0].value in FLOW_CONTROL
891         )
892
893     @property
894     def is_yield(self) -> bool:
895         """Is this line a yield statement?"""
896         return (
897             bool(self)
898             and self.leaves[0].type == token.NAME
899             and self.leaves[0].value == "yield"
900         )
901
902     @property
903     def is_class_paren_empty(self) -> bool:
904         """Is this a class with no base classes but using parentheses?
905
906         Those are unnecessary and should be removed.
907         """
908         return (
909             bool(self)
910             and len(self.leaves) == 4
911             and self.is_class
912             and self.leaves[2].type == token.LPAR
913             and self.leaves[2].value == "("
914             and self.leaves[3].type == token.RPAR
915             and self.leaves[3].value == ")"
916         )
917
918     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
919         """If so, needs to be split before emitting."""
920         for leaf in self.leaves:
921             if leaf.type == STANDALONE_COMMENT:
922                 if leaf.bracket_depth <= depth_limit:
923                     return True
924
925         return False
926
927     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
928         """Remove trailing comma if there is one and it's safe."""
929         if not (
930             self.leaves
931             and self.leaves[-1].type == token.COMMA
932             and closing.type in CLOSING_BRACKETS
933         ):
934             return False
935
936         if closing.type == token.RBRACE:
937             self.remove_trailing_comma()
938             return True
939
940         if closing.type == token.RSQB:
941             comma = self.leaves[-1]
942             if comma.parent and comma.parent.type == syms.listmaker:
943                 self.remove_trailing_comma()
944                 return True
945
946         # For parens let's check if it's safe to remove the comma.
947         # Imports are always safe.
948         if self.is_import:
949             self.remove_trailing_comma()
950             return True
951
952         # Otheriwsse, if the trailing one is the only one, we might mistakenly
953         # change a tuple into a different type by removing the comma.
954         depth = closing.bracket_depth + 1
955         commas = 0
956         opening = closing.opening_bracket
957         for _opening_index, leaf in enumerate(self.leaves):
958             if leaf is opening:
959                 break
960
961         else:
962             return False
963
964         for leaf in self.leaves[_opening_index + 1 :]:
965             if leaf is closing:
966                 break
967
968             bracket_depth = leaf.bracket_depth
969             if bracket_depth == depth and leaf.type == token.COMMA:
970                 commas += 1
971                 if leaf.parent and leaf.parent.type == syms.arglist:
972                     commas += 1
973                     break
974
975         if commas > 1:
976             self.remove_trailing_comma()
977             return True
978
979         return False
980
981     def append_comment(self, comment: Leaf) -> bool:
982         """Add an inline or standalone comment to the line."""
983         if (
984             comment.type == STANDALONE_COMMENT
985             and self.bracket_tracker.any_open_brackets()
986         ):
987             comment.prefix = ""
988             return False
989
990         if comment.type != token.COMMENT:
991             return False
992
993         after = len(self.leaves) - 1
994         if after == -1:
995             comment.type = STANDALONE_COMMENT
996             comment.prefix = ""
997             return False
998
999         else:
1000             self.comments.append((after, comment))
1001             return True
1002
1003     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1004         """Generate comments that should appear directly after `leaf`.
1005
1006         Provide a non-negative leaf `_index` to speed up the function.
1007         """
1008         if _index == -1:
1009             for _index, _leaf in enumerate(self.leaves):
1010                 if leaf is _leaf:
1011                     break
1012
1013             else:
1014                 return
1015
1016         for index, comment_after in self.comments:
1017             if _index == index:
1018                 yield comment_after
1019
1020     def remove_trailing_comma(self) -> None:
1021         """Remove the trailing comma and moves the comments attached to it."""
1022         comma_index = len(self.leaves) - 1
1023         for i in range(len(self.comments)):
1024             comment_index, comment = self.comments[i]
1025             if comment_index == comma_index:
1026                 self.comments[i] = (comma_index - 1, comment)
1027         self.leaves.pop()
1028
1029     def is_complex_subscript(self, leaf: Leaf) -> bool:
1030         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1031         open_lsqb = (
1032             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1033         )
1034         if open_lsqb is None:
1035             return False
1036
1037         subscript_start = open_lsqb.next_sibling
1038         if (
1039             isinstance(subscript_start, Node)
1040             and subscript_start.type == syms.subscriptlist
1041         ):
1042             subscript_start = child_towards(subscript_start, leaf)
1043         return subscript_start is not None and any(
1044             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1045         )
1046
1047     def __str__(self) -> str:
1048         """Render the line."""
1049         if not self:
1050             return "\n"
1051
1052         indent = "    " * self.depth
1053         leaves = iter(self.leaves)
1054         first = next(leaves)
1055         res = f"{first.prefix}{indent}{first.value}"
1056         for leaf in leaves:
1057             res += str(leaf)
1058         for _, comment in self.comments:
1059             res += str(comment)
1060         return res + "\n"
1061
1062     def __bool__(self) -> bool:
1063         """Return True if the line has leaves or comments."""
1064         return bool(self.leaves or self.comments)
1065
1066
1067 class UnformattedLines(Line):
1068     """Just like :class:`Line` but stores lines which aren't reformatted."""
1069
1070     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1071         """Just add a new `leaf` to the end of the lines.
1072
1073         The `preformatted` argument is ignored.
1074
1075         Keeps track of indentation `depth`, which is useful when the user
1076         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1077         """
1078         try:
1079             list(generate_comments(leaf))
1080         except FormatOn as f_on:
1081             self.leaves.append(f_on.leaf_from_consumed(leaf))
1082             raise
1083
1084         self.leaves.append(leaf)
1085         if leaf.type == token.INDENT:
1086             self.depth += 1
1087         elif leaf.type == token.DEDENT:
1088             self.depth -= 1
1089
1090     def __str__(self) -> str:
1091         """Render unformatted lines from leaves which were added with `append()`.
1092
1093         `depth` is not used for indentation in this case.
1094         """
1095         if not self:
1096             return "\n"
1097
1098         res = ""
1099         for leaf in self.leaves:
1100             res += str(leaf)
1101         return res
1102
1103     def append_comment(self, comment: Leaf) -> bool:
1104         """Not implemented in this class. Raises `NotImplementedError`."""
1105         raise NotImplementedError("Unformatted lines don't store comments separately.")
1106
1107     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1108         """Does nothing and returns False."""
1109         return False
1110
1111     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1112         """Does nothing and returns False."""
1113         return False
1114
1115
1116 @dataclass
1117 class EmptyLineTracker:
1118     """Provides a stateful method that returns the number of potential extra
1119     empty lines needed before and after the currently processed line.
1120
1121     Note: this tracker works on lines that haven't been split yet.  It assumes
1122     the prefix of the first leaf consists of optional newlines.  Those newlines
1123     are consumed by `maybe_empty_lines()` and included in the computation.
1124     """
1125     is_pyi: bool = False
1126     previous_line: Optional[Line] = None
1127     previous_after: int = 0
1128     previous_defs: List[int] = Factory(list)
1129
1130     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1131         """Return the number of extra empty lines before and after the `current_line`.
1132
1133         This is for separating `def`, `async def` and `class` with extra empty
1134         lines (two on module-level), as well as providing an extra empty line
1135         after flow control keywords to make them more prominent.
1136         """
1137         if isinstance(current_line, UnformattedLines):
1138             return 0, 0
1139
1140         before, after = self._maybe_empty_lines(current_line)
1141         before -= self.previous_after
1142         self.previous_after = after
1143         self.previous_line = current_line
1144         return before, after
1145
1146     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1147         max_allowed = 1
1148         if current_line.depth == 0:
1149             max_allowed = 1 if self.is_pyi else 2
1150         if current_line.leaves:
1151             # Consume the first leaf's extra newlines.
1152             first_leaf = current_line.leaves[0]
1153             before = first_leaf.prefix.count("\n")
1154             before = min(before, max_allowed)
1155             first_leaf.prefix = ""
1156         else:
1157             before = 0
1158         depth = current_line.depth
1159         while self.previous_defs and self.previous_defs[-1] >= depth:
1160             self.previous_defs.pop()
1161             if self.is_pyi:
1162                 before = 0 if depth else 1
1163             else:
1164                 before = 1 if depth else 2
1165         is_decorator = current_line.is_decorator
1166         if is_decorator or current_line.is_def or current_line.is_class:
1167             if not is_decorator:
1168                 self.previous_defs.append(depth)
1169             if self.previous_line is None:
1170                 # Don't insert empty lines before the first line in the file.
1171                 return 0, 0
1172
1173             if self.previous_line.is_decorator:
1174                 return 0, 0
1175
1176             if (
1177                 self.previous_line.is_comment
1178                 and self.previous_line.depth == current_line.depth
1179                 and before == 0
1180             ):
1181                 return 0, 0
1182
1183             if self.is_pyi:
1184                 if self.previous_line.depth > current_line.depth:
1185                     newlines = 1
1186                 elif current_line.is_class or self.previous_line.is_class:
1187                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1188                         newlines = 0
1189                     else:
1190                         newlines = 1
1191                 else:
1192                     newlines = 0
1193             else:
1194                 newlines = 2
1195             if current_line.depth and newlines:
1196                 newlines -= 1
1197             return newlines, 0
1198
1199         if (
1200             self.previous_line
1201             and self.previous_line.is_import
1202             and not current_line.is_import
1203             and depth == self.previous_line.depth
1204         ):
1205             return (before or 1), 0
1206
1207         return before, 0
1208
1209
1210 @dataclass
1211 class LineGenerator(Visitor[Line]):
1212     """Generates reformatted Line objects.  Empty lines are not emitted.
1213
1214     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1215     in ways that will no longer stringify to valid Python code on the tree.
1216     """
1217     is_pyi: bool = False
1218     current_line: Line = Factory(Line)
1219     remove_u_prefix: bool = False
1220
1221     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1222         """Generate a line.
1223
1224         If the line is empty, only emit if it makes sense.
1225         If the line is too long, split it first and then generate.
1226
1227         If any lines were generated, set up a new current_line.
1228         """
1229         if not self.current_line:
1230             if self.current_line.__class__ == type:
1231                 self.current_line.depth += indent
1232             else:
1233                 self.current_line = type(depth=self.current_line.depth + indent)
1234             return  # Line is empty, don't emit. Creating a new one unnecessary.
1235
1236         complete_line = self.current_line
1237         self.current_line = type(depth=complete_line.depth + indent)
1238         yield complete_line
1239
1240     def visit(self, node: LN) -> Iterator[Line]:
1241         """Main method to visit `node` and its children.
1242
1243         Yields :class:`Line` objects.
1244         """
1245         if isinstance(self.current_line, UnformattedLines):
1246             # File contained `# fmt: off`
1247             yield from self.visit_unformatted(node)
1248
1249         else:
1250             yield from super().visit(node)
1251
1252     def visit_default(self, node: LN) -> Iterator[Line]:
1253         """Default `visit_*()` implementation. Recurses to children of `node`."""
1254         if isinstance(node, Leaf):
1255             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1256             try:
1257                 for comment in generate_comments(node):
1258                     if any_open_brackets:
1259                         # any comment within brackets is subject to splitting
1260                         self.current_line.append(comment)
1261                     elif comment.type == token.COMMENT:
1262                         # regular trailing comment
1263                         self.current_line.append(comment)
1264                         yield from self.line()
1265
1266                     else:
1267                         # regular standalone comment
1268                         yield from self.line()
1269
1270                         self.current_line.append(comment)
1271                         yield from self.line()
1272
1273             except FormatOff as f_off:
1274                 f_off.trim_prefix(node)
1275                 yield from self.line(type=UnformattedLines)
1276                 yield from self.visit(node)
1277
1278             except FormatOn as f_on:
1279                 # This only happens here if somebody says "fmt: on" multiple
1280                 # times in a row.
1281                 f_on.trim_prefix(node)
1282                 yield from self.visit_default(node)
1283
1284             else:
1285                 normalize_prefix(node, inside_brackets=any_open_brackets)
1286                 if node.type == token.STRING:
1287                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1288                     normalize_string_quotes(node)
1289                 if node.type not in WHITESPACE:
1290                     self.current_line.append(node)
1291         yield from super().visit_default(node)
1292
1293     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1294         """Increase indentation level, maybe yield a line."""
1295         # In blib2to3 INDENT never holds comments.
1296         yield from self.line(+1)
1297         yield from self.visit_default(node)
1298
1299     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1300         """Decrease indentation level, maybe yield a line."""
1301         # The current line might still wait for trailing comments.  At DEDENT time
1302         # there won't be any (they would be prefixes on the preceding NEWLINE).
1303         # Emit the line then.
1304         yield from self.line()
1305
1306         # While DEDENT has no value, its prefix may contain standalone comments
1307         # that belong to the current indentation level.  Get 'em.
1308         yield from self.visit_default(node)
1309
1310         # Finally, emit the dedent.
1311         yield from self.line(-1)
1312
1313     def visit_stmt(
1314         self, node: Node, keywords: Set[str], parens: Set[str]
1315     ) -> Iterator[Line]:
1316         """Visit a statement.
1317
1318         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1319         `def`, `with`, `class`, `assert` and assignments.
1320
1321         The relevant Python language `keywords` for a given statement will be
1322         NAME leaves within it. This methods puts those on a separate line.
1323
1324         `parens` holds a set of string leaf values immediately after which
1325         invisible parens should be put.
1326         """
1327         normalize_invisible_parens(node, parens_after=parens)
1328         for child in node.children:
1329             if child.type == token.NAME and child.value in keywords:  # type: ignore
1330                 yield from self.line()
1331
1332             yield from self.visit(child)
1333
1334     def visit_suite(self, node: Node) -> Iterator[Line]:
1335         """Visit a suite."""
1336         if self.is_pyi and is_stub_suite(node):
1337             yield from self.visit(node.children[2])
1338         else:
1339             yield from self.visit_default(node)
1340
1341     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1342         """Visit a statement without nested statements."""
1343         is_suite_like = node.parent and node.parent.type in STATEMENT
1344         if is_suite_like:
1345             if self.is_pyi and is_stub_body(node):
1346                 yield from self.visit_default(node)
1347             else:
1348                 yield from self.line(+1)
1349                 yield from self.visit_default(node)
1350                 yield from self.line(-1)
1351
1352         else:
1353             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1354                 yield from self.line()
1355             yield from self.visit_default(node)
1356
1357     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1358         """Visit `async def`, `async for`, `async with`."""
1359         yield from self.line()
1360
1361         children = iter(node.children)
1362         for child in children:
1363             yield from self.visit(child)
1364
1365             if child.type == token.ASYNC:
1366                 break
1367
1368         internal_stmt = next(children)
1369         for child in internal_stmt.children:
1370             yield from self.visit(child)
1371
1372     def visit_decorators(self, node: Node) -> Iterator[Line]:
1373         """Visit decorators."""
1374         for child in node.children:
1375             yield from self.line()
1376             yield from self.visit(child)
1377
1378     def visit_import_from(self, node: Node) -> Iterator[Line]:
1379         """Visit import_from and maybe put invisible parentheses.
1380
1381         This is separate from `visit_stmt` because import statements don't
1382         support arbitrary atoms and thus handling of parentheses is custom.
1383         """
1384         check_lpar = False
1385         for index, child in enumerate(node.children):
1386             if check_lpar:
1387                 if child.type == token.LPAR:
1388                     # make parentheses invisible
1389                     child.value = ""  # type: ignore
1390                     node.children[-1].value = ""  # type: ignore
1391                 else:
1392                     # insert invisible parentheses
1393                     node.insert_child(index, Leaf(token.LPAR, ""))
1394                     node.append_child(Leaf(token.RPAR, ""))
1395                 break
1396
1397             check_lpar = (
1398                 child.type == token.NAME and child.value == "import"  # type: ignore
1399             )
1400
1401         for child in node.children:
1402             yield from self.visit(child)
1403
1404     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1405         """Remove a semicolon and put the other statement on a separate line."""
1406         yield from self.line()
1407
1408     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1409         """End of file. Process outstanding comments and end with a newline."""
1410         yield from self.visit_default(leaf)
1411         yield from self.line()
1412
1413     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1414         """Used when file contained a `# fmt: off`."""
1415         if isinstance(node, Node):
1416             for child in node.children:
1417                 yield from self.visit(child)
1418
1419         else:
1420             try:
1421                 self.current_line.append(node)
1422             except FormatOn as f_on:
1423                 f_on.trim_prefix(node)
1424                 yield from self.line()
1425                 yield from self.visit(node)
1426
1427             if node.type == token.ENDMARKER:
1428                 # somebody decided not to put a final `# fmt: on`
1429                 yield from self.line()
1430
1431     def __attrs_post_init__(self) -> None:
1432         """You are in a twisty little maze of passages."""
1433         v = self.visit_stmt
1434         Ø: Set[str] = set()
1435         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1436         self.visit_if_stmt = partial(
1437             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1438         )
1439         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1440         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1441         self.visit_try_stmt = partial(
1442             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1443         )
1444         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1445         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1446         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1447         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1448         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1449         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1450         self.visit_async_funcdef = self.visit_async_stmt
1451         self.visit_decorated = self.visit_decorators
1452
1453
1454 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1455 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1456 OPENING_BRACKETS = set(BRACKET.keys())
1457 CLOSING_BRACKETS = set(BRACKET.values())
1458 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1459 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1460
1461
1462 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1463     """Return whitespace prefix if needed for the given `leaf`.
1464
1465     `complex_subscript` signals whether the given leaf is part of a subscription
1466     which has non-trivial arguments, like arithmetic expressions or function calls.
1467     """
1468     NO = ""
1469     SPACE = " "
1470     DOUBLESPACE = "  "
1471     t = leaf.type
1472     p = leaf.parent
1473     v = leaf.value
1474     if t in ALWAYS_NO_SPACE:
1475         return NO
1476
1477     if t == token.COMMENT:
1478         return DOUBLESPACE
1479
1480     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1481     if t == token.COLON and p.type not in {
1482         syms.subscript,
1483         syms.subscriptlist,
1484         syms.sliceop,
1485     }:
1486         return NO
1487
1488     prev = leaf.prev_sibling
1489     if not prev:
1490         prevp = preceding_leaf(p)
1491         if not prevp or prevp.type in OPENING_BRACKETS:
1492             return NO
1493
1494         if t == token.COLON:
1495             if prevp.type == token.COLON:
1496                 return NO
1497
1498             elif prevp.type != token.COMMA and not complex_subscript:
1499                 return NO
1500
1501             return SPACE
1502
1503         if prevp.type == token.EQUAL:
1504             if prevp.parent:
1505                 if prevp.parent.type in {
1506                     syms.arglist,
1507                     syms.argument,
1508                     syms.parameters,
1509                     syms.varargslist,
1510                 }:
1511                     return NO
1512
1513                 elif prevp.parent.type == syms.typedargslist:
1514                     # A bit hacky: if the equal sign has whitespace, it means we
1515                     # previously found it's a typed argument.  So, we're using
1516                     # that, too.
1517                     return prevp.prefix
1518
1519         elif prevp.type in STARS:
1520             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1521                 return NO
1522
1523         elif prevp.type == token.COLON:
1524             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1525                 return SPACE if complex_subscript else NO
1526
1527         elif (
1528             prevp.parent
1529             and prevp.parent.type == syms.factor
1530             and prevp.type in MATH_OPERATORS
1531         ):
1532             return NO
1533
1534         elif (
1535             prevp.type == token.RIGHTSHIFT
1536             and prevp.parent
1537             and prevp.parent.type == syms.shift_expr
1538             and prevp.prev_sibling
1539             and prevp.prev_sibling.type == token.NAME
1540             and prevp.prev_sibling.value == "print"  # type: ignore
1541         ):
1542             # Python 2 print chevron
1543             return NO
1544
1545     elif prev.type in OPENING_BRACKETS:
1546         return NO
1547
1548     if p.type in {syms.parameters, syms.arglist}:
1549         # untyped function signatures or calls
1550         if not prev or prev.type != token.COMMA:
1551             return NO
1552
1553     elif p.type == syms.varargslist:
1554         # lambdas
1555         if prev and prev.type != token.COMMA:
1556             return NO
1557
1558     elif p.type == syms.typedargslist:
1559         # typed function signatures
1560         if not prev:
1561             return NO
1562
1563         if t == token.EQUAL:
1564             if prev.type != syms.tname:
1565                 return NO
1566
1567         elif prev.type == token.EQUAL:
1568             # A bit hacky: if the equal sign has whitespace, it means we
1569             # previously found it's a typed argument.  So, we're using that, too.
1570             return prev.prefix
1571
1572         elif prev.type != token.COMMA:
1573             return NO
1574
1575     elif p.type == syms.tname:
1576         # type names
1577         if not prev:
1578             prevp = preceding_leaf(p)
1579             if not prevp or prevp.type != token.COMMA:
1580                 return NO
1581
1582     elif p.type == syms.trailer:
1583         # attributes and calls
1584         if t == token.LPAR or t == token.RPAR:
1585             return NO
1586
1587         if not prev:
1588             if t == token.DOT:
1589                 prevp = preceding_leaf(p)
1590                 if not prevp or prevp.type != token.NUMBER:
1591                     return NO
1592
1593             elif t == token.LSQB:
1594                 return NO
1595
1596         elif prev.type != token.COMMA:
1597             return NO
1598
1599     elif p.type == syms.argument:
1600         # single argument
1601         if t == token.EQUAL:
1602             return NO
1603
1604         if not prev:
1605             prevp = preceding_leaf(p)
1606             if not prevp or prevp.type == token.LPAR:
1607                 return NO
1608
1609         elif prev.type in {token.EQUAL} | STARS:
1610             return NO
1611
1612     elif p.type == syms.decorator:
1613         # decorators
1614         return NO
1615
1616     elif p.type == syms.dotted_name:
1617         if prev:
1618             return NO
1619
1620         prevp = preceding_leaf(p)
1621         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1622             return NO
1623
1624     elif p.type == syms.classdef:
1625         if t == token.LPAR:
1626             return NO
1627
1628         if prev and prev.type == token.LPAR:
1629             return NO
1630
1631     elif p.type in {syms.subscript, syms.sliceop}:
1632         # indexing
1633         if not prev:
1634             assert p.parent is not None, "subscripts are always parented"
1635             if p.parent.type == syms.subscriptlist:
1636                 return SPACE
1637
1638             return NO
1639
1640         elif not complex_subscript:
1641             return NO
1642
1643     elif p.type == syms.atom:
1644         if prev and t == token.DOT:
1645             # dots, but not the first one.
1646             return NO
1647
1648     elif p.type == syms.dictsetmaker:
1649         # dict unpacking
1650         if prev and prev.type == token.DOUBLESTAR:
1651             return NO
1652
1653     elif p.type in {syms.factor, syms.star_expr}:
1654         # unary ops
1655         if not prev:
1656             prevp = preceding_leaf(p)
1657             if not prevp or prevp.type in OPENING_BRACKETS:
1658                 return NO
1659
1660             prevp_parent = prevp.parent
1661             assert prevp_parent is not None
1662             if prevp.type == token.COLON and prevp_parent.type in {
1663                 syms.subscript,
1664                 syms.sliceop,
1665             }:
1666                 return NO
1667
1668             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1669                 return NO
1670
1671         elif t == token.NAME or t == token.NUMBER:
1672             return NO
1673
1674     elif p.type == syms.import_from:
1675         if t == token.DOT:
1676             if prev and prev.type == token.DOT:
1677                 return NO
1678
1679         elif t == token.NAME:
1680             if v == "import":
1681                 return SPACE
1682
1683             if prev and prev.type == token.DOT:
1684                 return NO
1685
1686     elif p.type == syms.sliceop:
1687         return NO
1688
1689     return SPACE
1690
1691
1692 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1693     """Return the first leaf that precedes `node`, if any."""
1694     while node:
1695         res = node.prev_sibling
1696         if res:
1697             if isinstance(res, Leaf):
1698                 return res
1699
1700             try:
1701                 return list(res.leaves())[-1]
1702
1703             except IndexError:
1704                 return None
1705
1706         node = node.parent
1707     return None
1708
1709
1710 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1711     """Return the child of `ancestor` that contains `descendant`."""
1712     node: Optional[LN] = descendant
1713     while node and node.parent != ancestor:
1714         node = node.parent
1715     return node
1716
1717
1718 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1719     """Return the priority of the `leaf` delimiter, given a line break after it.
1720
1721     The delimiter priorities returned here are from those delimiters that would
1722     cause a line break after themselves.
1723
1724     Higher numbers are higher priority.
1725     """
1726     if leaf.type == token.COMMA:
1727         return COMMA_PRIORITY
1728
1729     return 0
1730
1731
1732 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1733     """Return the priority of the `leaf` delimiter, given a line before after it.
1734
1735     The delimiter priorities returned here are from those delimiters that would
1736     cause a line break before themselves.
1737
1738     Higher numbers are higher priority.
1739     """
1740     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1741         # * and ** might also be MATH_OPERATORS but in this case they are not.
1742         # Don't treat them as a delimiter.
1743         return 0
1744
1745     if (
1746         leaf.type == token.DOT
1747         and leaf.parent
1748         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1749         and (previous is None or previous.type in CLOSING_BRACKETS)
1750     ):
1751         return DOT_PRIORITY
1752
1753     if (
1754         leaf.type in MATH_OPERATORS
1755         and leaf.parent
1756         and leaf.parent.type not in {syms.factor, syms.star_expr}
1757     ):
1758         return MATH_PRIORITIES[leaf.type]
1759
1760     if leaf.type in COMPARATORS:
1761         return COMPARATOR_PRIORITY
1762
1763     if (
1764         leaf.type == token.STRING
1765         and previous is not None
1766         and previous.type == token.STRING
1767     ):
1768         return STRING_PRIORITY
1769
1770     if leaf.type != token.NAME:
1771         return 0
1772
1773     if (
1774         leaf.value == "for"
1775         and leaf.parent
1776         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1777     ):
1778         return COMPREHENSION_PRIORITY
1779
1780     if (
1781         leaf.value == "if"
1782         and leaf.parent
1783         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1784     ):
1785         return COMPREHENSION_PRIORITY
1786
1787     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1788         return TERNARY_PRIORITY
1789
1790     if leaf.value == "is":
1791         return COMPARATOR_PRIORITY
1792
1793     if (
1794         leaf.value == "in"
1795         and leaf.parent
1796         and leaf.parent.type in {syms.comp_op, syms.comparison}
1797         and not (
1798             previous is not None
1799             and previous.type == token.NAME
1800             and previous.value == "not"
1801         )
1802     ):
1803         return COMPARATOR_PRIORITY
1804
1805     if (
1806         leaf.value == "not"
1807         and leaf.parent
1808         and leaf.parent.type == syms.comp_op
1809         and not (
1810             previous is not None
1811             and previous.type == token.NAME
1812             and previous.value == "is"
1813         )
1814     ):
1815         return COMPARATOR_PRIORITY
1816
1817     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1818         return LOGIC_PRIORITY
1819
1820     return 0
1821
1822
1823 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1824     """Clean the prefix of the `leaf` and generate comments from it, if any.
1825
1826     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1827     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1828     move because it does away with modifying the grammar to include all the
1829     possible places in which comments can be placed.
1830
1831     The sad consequence for us though is that comments don't "belong" anywhere.
1832     This is why this function generates simple parentless Leaf objects for
1833     comments.  We simply don't know what the correct parent should be.
1834
1835     No matter though, we can live without this.  We really only need to
1836     differentiate between inline and standalone comments.  The latter don't
1837     share the line with any code.
1838
1839     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1840     are emitted with a fake STANDALONE_COMMENT token identifier.
1841     """
1842     p = leaf.prefix
1843     if not p:
1844         return
1845
1846     if "#" not in p:
1847         return
1848
1849     consumed = 0
1850     nlines = 0
1851     for index, line in enumerate(p.split("\n")):
1852         consumed += len(line) + 1  # adding the length of the split '\n'
1853         line = line.lstrip()
1854         if not line:
1855             nlines += 1
1856         if not line.startswith("#"):
1857             continue
1858
1859         if index == 0 and leaf.type != token.ENDMARKER:
1860             comment_type = token.COMMENT  # simple trailing comment
1861         else:
1862             comment_type = STANDALONE_COMMENT
1863         comment = make_comment(line)
1864         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1865
1866         if comment in {"# fmt: on", "# yapf: enable"}:
1867             raise FormatOn(consumed)
1868
1869         if comment in {"# fmt: off", "# yapf: disable"}:
1870             if comment_type == STANDALONE_COMMENT:
1871                 raise FormatOff(consumed)
1872
1873             prev = preceding_leaf(leaf)
1874             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1875                 raise FormatOff(consumed)
1876
1877         nlines = 0
1878
1879
1880 def make_comment(content: str) -> str:
1881     """Return a consistently formatted comment from the given `content` string.
1882
1883     All comments (except for "##", "#!", "#:") should have a single space between
1884     the hash sign and the content.
1885
1886     If `content` didn't start with a hash sign, one is provided.
1887     """
1888     content = content.rstrip()
1889     if not content:
1890         return "#"
1891
1892     if content[0] == "#":
1893         content = content[1:]
1894     if content and content[0] not in " !:#":
1895         content = " " + content
1896     return "#" + content
1897
1898
1899 def split_line(
1900     line: Line, line_length: int, inner: bool = False, py36: bool = False
1901 ) -> Iterator[Line]:
1902     """Split a `line` into potentially many lines.
1903
1904     They should fit in the allotted `line_length` but might not be able to.
1905     `inner` signifies that there were a pair of brackets somewhere around the
1906     current `line`, possibly transitively. This means we can fallback to splitting
1907     by delimiters if the LHS/RHS don't yield any results.
1908
1909     If `py36` is True, splitting may generate syntax that is only compatible
1910     with Python 3.6 and later.
1911     """
1912     if isinstance(line, UnformattedLines) or line.is_comment:
1913         yield line
1914         return
1915
1916     line_str = str(line).strip("\n")
1917     if not line.should_explode and is_line_short_enough(
1918         line, line_length=line_length, line_str=line_str
1919     ):
1920         yield line
1921         return
1922
1923     split_funcs: List[SplitFunc]
1924     if line.is_def:
1925         split_funcs = [left_hand_split]
1926     else:
1927
1928         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1929             for omit in generate_trailers_to_omit(line, line_length):
1930                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
1931                 if is_line_short_enough(lines[0], line_length=line_length):
1932                     yield from lines
1933                     return
1934
1935             # All splits failed, best effort split with no omits.
1936             # This mostly happens to multiline strings that are by definition
1937             # reported as not fitting a single line.
1938             yield from right_hand_split(line, py36)
1939
1940         if line.inside_brackets:
1941             split_funcs = [delimiter_split, standalone_comment_split, rhs]
1942         else:
1943             split_funcs = [rhs]
1944     for split_func in split_funcs:
1945         # We are accumulating lines in `result` because we might want to abort
1946         # mission and return the original line in the end, or attempt a different
1947         # split altogether.
1948         result: List[Line] = []
1949         try:
1950             for l in split_func(line, py36):
1951                 if str(l).strip("\n") == line_str:
1952                     raise CannotSplit("Split function returned an unchanged result")
1953
1954                 result.extend(
1955                     split_line(l, line_length=line_length, inner=True, py36=py36)
1956                 )
1957         except CannotSplit as cs:
1958             continue
1959
1960         else:
1961             yield from result
1962             break
1963
1964     else:
1965         yield line
1966
1967
1968 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1969     """Split line into many lines, starting with the first matching bracket pair.
1970
1971     Note: this usually looks weird, only use this for function definitions.
1972     Prefer RHS otherwise.  This is why this function is not symmetrical with
1973     :func:`right_hand_split` which also handles optional parentheses.
1974     """
1975     head = Line(depth=line.depth)
1976     body = Line(depth=line.depth + 1, inside_brackets=True)
1977     tail = Line(depth=line.depth)
1978     tail_leaves: List[Leaf] = []
1979     body_leaves: List[Leaf] = []
1980     head_leaves: List[Leaf] = []
1981     current_leaves = head_leaves
1982     matching_bracket = None
1983     for leaf in line.leaves:
1984         if (
1985             current_leaves is body_leaves
1986             and leaf.type in CLOSING_BRACKETS
1987             and leaf.opening_bracket is matching_bracket
1988         ):
1989             current_leaves = tail_leaves if body_leaves else head_leaves
1990         current_leaves.append(leaf)
1991         if current_leaves is head_leaves:
1992             if leaf.type in OPENING_BRACKETS:
1993                 matching_bracket = leaf
1994                 current_leaves = body_leaves
1995     # Since body is a new indent level, remove spurious leading whitespace.
1996     if body_leaves:
1997         normalize_prefix(body_leaves[0], inside_brackets=True)
1998     # Build the new lines.
1999     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2000         for leaf in leaves:
2001             result.append(leaf, preformatted=True)
2002             for comment_after in line.comments_after(leaf):
2003                 result.append(comment_after, preformatted=True)
2004     bracket_split_succeeded_or_raise(head, body, tail)
2005     for result in (head, body, tail):
2006         if result:
2007             yield result
2008
2009
2010 def right_hand_split(
2011     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2012 ) -> Iterator[Line]:
2013     """Split line into many lines, starting with the last matching bracket pair.
2014
2015     If the split was by optional parentheses, attempt splitting without them, too.
2016     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2017     this split.
2018
2019     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2020     """
2021     head = Line(depth=line.depth)
2022     body = Line(depth=line.depth + 1, inside_brackets=True)
2023     tail = Line(depth=line.depth)
2024     tail_leaves: List[Leaf] = []
2025     body_leaves: List[Leaf] = []
2026     head_leaves: List[Leaf] = []
2027     current_leaves = tail_leaves
2028     opening_bracket = None
2029     closing_bracket = None
2030     for leaf in reversed(line.leaves):
2031         if current_leaves is body_leaves:
2032             if leaf is opening_bracket:
2033                 current_leaves = head_leaves if body_leaves else tail_leaves
2034         current_leaves.append(leaf)
2035         if current_leaves is tail_leaves:
2036             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2037                 opening_bracket = leaf.opening_bracket
2038                 closing_bracket = leaf
2039                 current_leaves = body_leaves
2040     tail_leaves.reverse()
2041     body_leaves.reverse()
2042     head_leaves.reverse()
2043     # Since body is a new indent level, remove spurious leading whitespace.
2044     if body_leaves:
2045         normalize_prefix(body_leaves[0], inside_brackets=True)
2046     if not head_leaves:
2047         # No `head` means the split failed. Either `tail` has all content or
2048         # the matching `opening_bracket` wasn't available on `line` anymore.
2049         raise CannotSplit("No brackets found")
2050
2051     # Build the new lines.
2052     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2053         for leaf in leaves:
2054             result.append(leaf, preformatted=True)
2055             for comment_after in line.comments_after(leaf):
2056                 result.append(comment_after, preformatted=True)
2057     bracket_split_succeeded_or_raise(head, body, tail)
2058     assert opening_bracket and closing_bracket
2059     if (
2060         # the opening bracket is an optional paren
2061         opening_bracket.type == token.LPAR
2062         and not opening_bracket.value
2063         # the closing bracket is an optional paren
2064         and closing_bracket.type == token.RPAR
2065         and not closing_bracket.value
2066         # there are no standalone comments in the body
2067         and not line.contains_standalone_comments(0)
2068         # and it's not an import (optional parens are the only thing we can split
2069         # on in this case; attempting a split without them is a waste of time)
2070         and not line.is_import
2071     ):
2072         omit = {id(closing_bracket), *omit}
2073         if can_omit_invisible_parens(body, line_length):
2074             try:
2075                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2076                 return
2077             except CannotSplit:
2078                 pass
2079
2080     ensure_visible(opening_bracket)
2081     ensure_visible(closing_bracket)
2082     body.should_explode = should_explode(body, opening_bracket)
2083     for result in (head, body, tail):
2084         if result:
2085             yield result
2086
2087
2088 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2089     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2090
2091     Do nothing otherwise.
2092
2093     A left- or right-hand split is based on a pair of brackets. Content before
2094     (and including) the opening bracket is left on one line, content inside the
2095     brackets is put on a separate line, and finally content starting with and
2096     following the closing bracket is put on a separate line.
2097
2098     Those are called `head`, `body`, and `tail`, respectively. If the split
2099     produced the same line (all content in `head`) or ended up with an empty `body`
2100     and the `tail` is just the closing bracket, then it's considered failed.
2101     """
2102     tail_len = len(str(tail).strip())
2103     if not body:
2104         if tail_len == 0:
2105             raise CannotSplit("Splitting brackets produced the same line")
2106
2107         elif tail_len < 3:
2108             raise CannotSplit(
2109                 f"Splitting brackets on an empty body to save "
2110                 f"{tail_len} characters is not worth it"
2111             )
2112
2113
2114 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2115     """Normalize prefix of the first leaf in every line returned by `split_func`.
2116
2117     This is a decorator over relevant split functions.
2118     """
2119
2120     @wraps(split_func)
2121     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2122         for l in split_func(line, py36):
2123             normalize_prefix(l.leaves[0], inside_brackets=True)
2124             yield l
2125
2126     return split_wrapper
2127
2128
2129 @dont_increase_indentation
2130 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2131     """Split according to delimiters of the highest priority.
2132
2133     If `py36` is True, the split will add trailing commas also in function
2134     signatures that contain `*` and `**`.
2135     """
2136     try:
2137         last_leaf = line.leaves[-1]
2138     except IndexError:
2139         raise CannotSplit("Line empty")
2140
2141     bt = line.bracket_tracker
2142     try:
2143         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2144     except ValueError:
2145         raise CannotSplit("No delimiters found")
2146
2147     if delimiter_priority == DOT_PRIORITY:
2148         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2149             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2150
2151     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2152     lowest_depth = sys.maxsize
2153     trailing_comma_safe = True
2154
2155     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2156         """Append `leaf` to current line or to new line if appending impossible."""
2157         nonlocal current_line
2158         try:
2159             current_line.append_safe(leaf, preformatted=True)
2160         except ValueError as ve:
2161             yield current_line
2162
2163             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2164             current_line.append(leaf)
2165
2166     for index, leaf in enumerate(line.leaves):
2167         yield from append_to_line(leaf)
2168
2169         for comment_after in line.comments_after(leaf, index):
2170             yield from append_to_line(comment_after)
2171
2172         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2173         if leaf.bracket_depth == lowest_depth and is_vararg(
2174             leaf, within=VARARGS_PARENTS
2175         ):
2176             trailing_comma_safe = trailing_comma_safe and py36
2177         leaf_priority = bt.delimiters.get(id(leaf))
2178         if leaf_priority == delimiter_priority:
2179             yield current_line
2180
2181             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2182     if current_line:
2183         if (
2184             trailing_comma_safe
2185             and delimiter_priority == COMMA_PRIORITY
2186             and current_line.leaves[-1].type != token.COMMA
2187             and current_line.leaves[-1].type != STANDALONE_COMMENT
2188         ):
2189             current_line.append(Leaf(token.COMMA, ","))
2190         yield current_line
2191
2192
2193 @dont_increase_indentation
2194 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2195     """Split standalone comments from the rest of the line."""
2196     if not line.contains_standalone_comments(0):
2197         raise CannotSplit("Line does not have any standalone comments")
2198
2199     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2200
2201     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2202         """Append `leaf` to current line or to new line if appending impossible."""
2203         nonlocal current_line
2204         try:
2205             current_line.append_safe(leaf, preformatted=True)
2206         except ValueError as ve:
2207             yield current_line
2208
2209             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2210             current_line.append(leaf)
2211
2212     for index, leaf in enumerate(line.leaves):
2213         yield from append_to_line(leaf)
2214
2215         for comment_after in line.comments_after(leaf, index):
2216             yield from append_to_line(comment_after)
2217
2218     if current_line:
2219         yield current_line
2220
2221
2222 def is_import(leaf: Leaf) -> bool:
2223     """Return True if the given leaf starts an import statement."""
2224     p = leaf.parent
2225     t = leaf.type
2226     v = leaf.value
2227     return bool(
2228         t == token.NAME
2229         and (
2230             (v == "import" and p and p.type == syms.import_name)
2231             or (v == "from" and p and p.type == syms.import_from)
2232         )
2233     )
2234
2235
2236 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2237     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2238     else.
2239
2240     Note: don't use backslashes for formatting or you'll lose your voting rights.
2241     """
2242     if not inside_brackets:
2243         spl = leaf.prefix.split("#")
2244         if "\\" not in spl[0]:
2245             nl_count = spl[-1].count("\n")
2246             if len(spl) > 1:
2247                 nl_count -= 1
2248             leaf.prefix = "\n" * nl_count
2249             return
2250
2251     leaf.prefix = ""
2252
2253
2254 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2255     """Make all string prefixes lowercase.
2256
2257     If remove_u_prefix is given, also removes any u prefix from the string.
2258
2259     Note: Mutates its argument.
2260     """
2261     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2262     assert match is not None, f"failed to match string {leaf.value!r}"
2263     orig_prefix = match.group(1)
2264     new_prefix = orig_prefix.lower()
2265     if remove_u_prefix:
2266         new_prefix = new_prefix.replace("u", "")
2267     leaf.value = f"{new_prefix}{match.group(2)}"
2268
2269
2270 def normalize_string_quotes(leaf: Leaf) -> None:
2271     """Prefer double quotes but only if it doesn't cause more escaping.
2272
2273     Adds or removes backslashes as appropriate. Doesn't parse and fix
2274     strings nested in f-strings (yet).
2275
2276     Note: Mutates its argument.
2277     """
2278     value = leaf.value.lstrip("furbFURB")
2279     if value[:3] == '"""':
2280         return
2281
2282     elif value[:3] == "'''":
2283         orig_quote = "'''"
2284         new_quote = '"""'
2285     elif value[0] == '"':
2286         orig_quote = '"'
2287         new_quote = "'"
2288     else:
2289         orig_quote = "'"
2290         new_quote = '"'
2291     first_quote_pos = leaf.value.find(orig_quote)
2292     if first_quote_pos == -1:
2293         return  # There's an internal error
2294
2295     prefix = leaf.value[:first_quote_pos]
2296     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2297     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2298     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2299     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2300     if "r" in prefix.casefold():
2301         if unescaped_new_quote.search(body):
2302             # There's at least one unescaped new_quote in this raw string
2303             # so converting is impossible
2304             return
2305
2306         # Do not introduce or remove backslashes in raw strings
2307         new_body = body
2308     else:
2309         # remove unnecessary quotes
2310         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2311         if body != new_body:
2312             # Consider the string without unnecessary quotes as the original
2313             body = new_body
2314             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2315         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2316         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2317     if new_quote == '"""' and new_body[-1] == '"':
2318         # edge case:
2319         new_body = new_body[:-1] + '\\"'
2320     orig_escape_count = body.count("\\")
2321     new_escape_count = new_body.count("\\")
2322     if new_escape_count > orig_escape_count:
2323         return  # Do not introduce more escaping
2324
2325     if new_escape_count == orig_escape_count and orig_quote == '"':
2326         return  # Prefer double quotes
2327
2328     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2329
2330
2331 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2332     """Make existing optional parentheses invisible or create new ones.
2333
2334     `parens_after` is a set of string leaf values immeditely after which parens
2335     should be put.
2336
2337     Standardizes on visible parentheses for single-element tuples, and keeps
2338     existing visible parentheses for other tuples and generator expressions.
2339     """
2340     check_lpar = False
2341     for child in list(node.children):
2342         if check_lpar:
2343             if child.type == syms.atom:
2344                 maybe_make_parens_invisible_in_atom(child)
2345             elif is_one_tuple(child):
2346                 # wrap child in visible parentheses
2347                 lpar = Leaf(token.LPAR, "(")
2348                 rpar = Leaf(token.RPAR, ")")
2349                 index = child.remove() or 0
2350                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2351             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2352                 # wrap child in invisible parentheses
2353                 lpar = Leaf(token.LPAR, "")
2354                 rpar = Leaf(token.RPAR, "")
2355                 index = child.remove() or 0
2356                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2357
2358         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2359
2360
2361 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2362     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2363     if (
2364         node.type != syms.atom
2365         or is_empty_tuple(node)
2366         or is_one_tuple(node)
2367         or is_yield(node)
2368         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2369     ):
2370         return False
2371
2372     first = node.children[0]
2373     last = node.children[-1]
2374     if first.type == token.LPAR and last.type == token.RPAR:
2375         # make parentheses invisible
2376         first.value = ""  # type: ignore
2377         last.value = ""  # type: ignore
2378         if len(node.children) > 1:
2379             maybe_make_parens_invisible_in_atom(node.children[1])
2380         return True
2381
2382     return False
2383
2384
2385 def is_empty_tuple(node: LN) -> bool:
2386     """Return True if `node` holds an empty tuple."""
2387     return (
2388         node.type == syms.atom
2389         and len(node.children) == 2
2390         and node.children[0].type == token.LPAR
2391         and node.children[1].type == token.RPAR
2392     )
2393
2394
2395 def is_one_tuple(node: LN) -> bool:
2396     """Return True if `node` holds a tuple with one element, with or without parens."""
2397     if node.type == syms.atom:
2398         if len(node.children) != 3:
2399             return False
2400
2401         lpar, gexp, rpar = node.children
2402         if not (
2403             lpar.type == token.LPAR
2404             and gexp.type == syms.testlist_gexp
2405             and rpar.type == token.RPAR
2406         ):
2407             return False
2408
2409         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2410
2411     return (
2412         node.type in IMPLICIT_TUPLE
2413         and len(node.children) == 2
2414         and node.children[1].type == token.COMMA
2415     )
2416
2417
2418 def is_yield(node: LN) -> bool:
2419     """Return True if `node` holds a `yield` or `yield from` expression."""
2420     if node.type == syms.yield_expr:
2421         return True
2422
2423     if node.type == token.NAME and node.value == "yield":  # type: ignore
2424         return True
2425
2426     if node.type != syms.atom:
2427         return False
2428
2429     if len(node.children) != 3:
2430         return False
2431
2432     lpar, expr, rpar = node.children
2433     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2434         return is_yield(expr)
2435
2436     return False
2437
2438
2439 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2440     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2441
2442     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2443     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2444     extended iterable unpacking (PEP 3132) and additional unpacking
2445     generalizations (PEP 448).
2446     """
2447     if leaf.type not in STARS or not leaf.parent:
2448         return False
2449
2450     p = leaf.parent
2451     if p.type == syms.star_expr:
2452         # Star expressions are also used as assignment targets in extended
2453         # iterable unpacking (PEP 3132).  See what its parent is instead.
2454         if not p.parent:
2455             return False
2456
2457         p = p.parent
2458
2459     return p.type in within
2460
2461
2462 def is_multiline_string(leaf: Leaf) -> bool:
2463     """Return True if `leaf` is a multiline string that actually spans many lines."""
2464     value = leaf.value.lstrip("furbFURB")
2465     return value[:3] in {'"""', "'''"} and "\n" in value
2466
2467
2468 def is_stub_suite(node: Node) -> bool:
2469     """Return True if `node` is a suite with a stub body."""
2470     if (
2471         len(node.children) != 4
2472         or node.children[0].type != token.NEWLINE
2473         or node.children[1].type != token.INDENT
2474         or node.children[3].type != token.DEDENT
2475     ):
2476         return False
2477
2478     return is_stub_body(node.children[2])
2479
2480
2481 def is_stub_body(node: LN) -> bool:
2482     """Return True if `node` is a simple statement containing an ellipsis."""
2483     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2484         return False
2485
2486     if len(node.children) != 2:
2487         return False
2488
2489     child = node.children[0]
2490     return (
2491         child.type == syms.atom
2492         and len(child.children) == 3
2493         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2494     )
2495
2496
2497 def max_delimiter_priority_in_atom(node: LN) -> int:
2498     """Return maximum delimiter priority inside `node`.
2499
2500     This is specific to atoms with contents contained in a pair of parentheses.
2501     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2502     """
2503     if node.type != syms.atom:
2504         return 0
2505
2506     first = node.children[0]
2507     last = node.children[-1]
2508     if not (first.type == token.LPAR and last.type == token.RPAR):
2509         return 0
2510
2511     bt = BracketTracker()
2512     for c in node.children[1:-1]:
2513         if isinstance(c, Leaf):
2514             bt.mark(c)
2515         else:
2516             for leaf in c.leaves():
2517                 bt.mark(leaf)
2518     try:
2519         return bt.max_delimiter_priority()
2520
2521     except ValueError:
2522         return 0
2523
2524
2525 def ensure_visible(leaf: Leaf) -> None:
2526     """Make sure parentheses are visible.
2527
2528     They could be invisible as part of some statements (see
2529     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2530     """
2531     if leaf.type == token.LPAR:
2532         leaf.value = "("
2533     elif leaf.type == token.RPAR:
2534         leaf.value = ")"
2535
2536
2537 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2538     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2539     if not (
2540         opening_bracket.parent
2541         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2542         and opening_bracket.value in "[{("
2543     ):
2544         return False
2545
2546     try:
2547         last_leaf = line.leaves[-1]
2548         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2549         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2550     except (IndexError, ValueError):
2551         return False
2552
2553     return max_priority == COMMA_PRIORITY
2554
2555
2556 def is_python36(node: Node) -> bool:
2557     """Return True if the current file is using Python 3.6+ features.
2558
2559     Currently looking for:
2560     - f-strings; and
2561     - trailing commas after * or ** in function signatures and calls.
2562     """
2563     for n in node.pre_order():
2564         if n.type == token.STRING:
2565             value_head = n.value[:2]  # type: ignore
2566             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2567                 return True
2568
2569         elif (
2570             n.type in {syms.typedargslist, syms.arglist}
2571             and n.children
2572             and n.children[-1].type == token.COMMA
2573         ):
2574             for ch in n.children:
2575                 if ch.type in STARS:
2576                     return True
2577
2578                 if ch.type == syms.argument:
2579                     for argch in ch.children:
2580                         if argch.type in STARS:
2581                             return True
2582
2583     return False
2584
2585
2586 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2587     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2588
2589     Brackets can be omitted if the entire trailer up to and including
2590     a preceding closing bracket fits in one line.
2591
2592     Yielded sets are cumulative (contain results of previous yields, too).  First
2593     set is empty.
2594     """
2595
2596     omit: Set[LeafID] = set()
2597     yield omit
2598
2599     length = 4 * line.depth
2600     opening_bracket = None
2601     closing_bracket = None
2602     optional_brackets: Set[LeafID] = set()
2603     inner_brackets: Set[LeafID] = set()
2604     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2605         length += leaf_length
2606         if length > line_length:
2607             break
2608
2609         optional_brackets.discard(id(leaf))
2610         if opening_bracket:
2611             if leaf is opening_bracket:
2612                 opening_bracket = None
2613             elif leaf.type in CLOSING_BRACKETS:
2614                 inner_brackets.add(id(leaf))
2615         elif leaf.type in CLOSING_BRACKETS:
2616             if not leaf.value:
2617                 optional_brackets.add(id(opening_bracket))
2618                 continue
2619
2620             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2621                 # Empty brackets would fail a split so treat them as "inner"
2622                 # brackets (e.g. only add them to the `omit` set if another
2623                 # pair of brackets was good enough.
2624                 inner_brackets.add(id(leaf))
2625                 continue
2626
2627             opening_bracket = leaf.opening_bracket
2628             if closing_bracket:
2629                 omit.add(id(closing_bracket))
2630                 omit.update(inner_brackets)
2631                 inner_brackets.clear()
2632                 yield omit
2633             closing_bracket = leaf
2634
2635
2636 def get_future_imports(node: Node) -> Set[str]:
2637     """Return a set of __future__ imports in the file."""
2638     imports = set()
2639     for child in node.children:
2640         if child.type != syms.simple_stmt:
2641             break
2642         first_child = child.children[0]
2643         if isinstance(first_child, Leaf):
2644             # Continue looking if we see a docstring; otherwise stop.
2645             if (
2646                 len(child.children) == 2
2647                 and first_child.type == token.STRING
2648                 and child.children[1].type == token.NEWLINE
2649             ):
2650                 continue
2651             else:
2652                 break
2653         elif first_child.type == syms.import_from:
2654             module_name = first_child.children[1]
2655             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2656                 break
2657             for import_from_child in first_child.children[3:]:
2658                 if isinstance(import_from_child, Leaf):
2659                     if import_from_child.type == token.NAME:
2660                         imports.add(import_from_child.value)
2661                 else:
2662                     assert import_from_child.type == syms.import_as_names
2663                     for leaf in import_from_child.children:
2664                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2665                             imports.add(leaf.value)
2666         else:
2667             break
2668     return imports
2669
2670
2671 PYTHON_EXTENSIONS = {".py", ".pyi"}
2672 BLACKLISTED_DIRECTORIES = {
2673     "build",
2674     "buck-out",
2675     "dist",
2676     "_build",
2677     ".git",
2678     ".hg",
2679     ".mypy_cache",
2680     ".tox",
2681     ".venv",
2682 }
2683
2684
2685 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2686     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2687     and have one of the PYTHON_EXTENSIONS.
2688     """
2689     for child in path.iterdir():
2690         if child.is_dir():
2691             if child.name in BLACKLISTED_DIRECTORIES:
2692                 continue
2693
2694             yield from gen_python_files_in_dir(child)
2695
2696         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2697             yield child
2698
2699
2700 @dataclass
2701 class Report:
2702     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2703     check: bool = False
2704     quiet: bool = False
2705     change_count: int = 0
2706     same_count: int = 0
2707     failure_count: int = 0
2708
2709     def done(self, src: Path, changed: Changed) -> None:
2710         """Increment the counter for successful reformatting. Write out a message."""
2711         if changed is Changed.YES:
2712             reformatted = "would reformat" if self.check else "reformatted"
2713             if not self.quiet:
2714                 out(f"{reformatted} {src}")
2715             self.change_count += 1
2716         else:
2717             if not self.quiet:
2718                 if changed is Changed.NO:
2719                     msg = f"{src} already well formatted, good job."
2720                 else:
2721                     msg = f"{src} wasn't modified on disk since last run."
2722                 out(msg, bold=False)
2723             self.same_count += 1
2724
2725     def failed(self, src: Path, message: str) -> None:
2726         """Increment the counter for failed reformatting. Write out a message."""
2727         err(f"error: cannot format {src}: {message}")
2728         self.failure_count += 1
2729
2730     @property
2731     def return_code(self) -> int:
2732         """Return the exit code that the app should use.
2733
2734         This considers the current state of changed files and failures:
2735         - if there were any failures, return 123;
2736         - if any files were changed and --check is being used, return 1;
2737         - otherwise return 0.
2738         """
2739         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2740         # 126 we have special returncodes reserved by the shell.
2741         if self.failure_count:
2742             return 123
2743
2744         elif self.change_count and self.check:
2745             return 1
2746
2747         return 0
2748
2749     def __str__(self) -> str:
2750         """Render a color report of the current state.
2751
2752         Use `click.unstyle` to remove colors.
2753         """
2754         if self.check:
2755             reformatted = "would be reformatted"
2756             unchanged = "would be left unchanged"
2757             failed = "would fail to reformat"
2758         else:
2759             reformatted = "reformatted"
2760             unchanged = "left unchanged"
2761             failed = "failed to reformat"
2762         report = []
2763         if self.change_count:
2764             s = "s" if self.change_count > 1 else ""
2765             report.append(
2766                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2767             )
2768         if self.same_count:
2769             s = "s" if self.same_count > 1 else ""
2770             report.append(f"{self.same_count} file{s} {unchanged}")
2771         if self.failure_count:
2772             s = "s" if self.failure_count > 1 else ""
2773             report.append(
2774                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2775             )
2776         return ", ".join(report) + "."
2777
2778
2779 def assert_equivalent(src: str, dst: str) -> None:
2780     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2781
2782     import ast
2783     import traceback
2784
2785     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2786         """Simple visitor generating strings to compare ASTs by content."""
2787         yield f"{'  ' * depth}{node.__class__.__name__}("
2788
2789         for field in sorted(node._fields):
2790             try:
2791                 value = getattr(node, field)
2792             except AttributeError:
2793                 continue
2794
2795             yield f"{'  ' * (depth+1)}{field}="
2796
2797             if isinstance(value, list):
2798                 for item in value:
2799                     if isinstance(item, ast.AST):
2800                         yield from _v(item, depth + 2)
2801
2802             elif isinstance(value, ast.AST):
2803                 yield from _v(value, depth + 2)
2804
2805             else:
2806                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2807
2808         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2809
2810     try:
2811         src_ast = ast.parse(src)
2812     except Exception as exc:
2813         major, minor = sys.version_info[:2]
2814         raise AssertionError(
2815             f"cannot use --safe with this file; failed to parse source file "
2816             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2817             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2818         )
2819
2820     try:
2821         dst_ast = ast.parse(dst)
2822     except Exception as exc:
2823         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2824         raise AssertionError(
2825             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2826             f"Please report a bug on https://github.com/ambv/black/issues.  "
2827             f"This invalid output might be helpful: {log}"
2828         ) from None
2829
2830     src_ast_str = "\n".join(_v(src_ast))
2831     dst_ast_str = "\n".join(_v(dst_ast))
2832     if src_ast_str != dst_ast_str:
2833         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2834         raise AssertionError(
2835             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2836             f"the source.  "
2837             f"Please report a bug on https://github.com/ambv/black/issues.  "
2838             f"This diff might be helpful: {log}"
2839         ) from None
2840
2841
2842 def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
2843     """Raise AssertionError if `dst` reformats differently the second time."""
2844     newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
2845     if dst != newdst:
2846         log = dump_to_file(
2847             diff(src, dst, "source", "first pass"),
2848             diff(dst, newdst, "first pass", "second pass"),
2849         )
2850         raise AssertionError(
2851             f"INTERNAL ERROR: Black produced different code on the second pass "
2852             f"of the formatter.  "
2853             f"Please report a bug on https://github.com/ambv/black/issues.  "
2854             f"This diff might be helpful: {log}"
2855         ) from None
2856
2857
2858 def dump_to_file(*output: str) -> str:
2859     """Dump `output` to a temporary file. Return path to the file."""
2860     import tempfile
2861
2862     with tempfile.NamedTemporaryFile(
2863         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2864     ) as f:
2865         for lines in output:
2866             f.write(lines)
2867             if lines and lines[-1] != "\n":
2868                 f.write("\n")
2869     return f.name
2870
2871
2872 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2873     """Return a unified diff string between strings `a` and `b`."""
2874     import difflib
2875
2876     a_lines = [line + "\n" for line in a.split("\n")]
2877     b_lines = [line + "\n" for line in b.split("\n")]
2878     return "".join(
2879         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2880     )
2881
2882
2883 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2884     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2885     err("Aborted!")
2886     for task in tasks:
2887         task.cancel()
2888
2889
2890 def shutdown(loop: BaseEventLoop) -> None:
2891     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2892     try:
2893         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2894         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2895         if not to_cancel:
2896             return
2897
2898         for task in to_cancel:
2899             task.cancel()
2900         loop.run_until_complete(
2901             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2902         )
2903     finally:
2904         # `concurrent.futures.Future` objects cannot be cancelled once they
2905         # are already running. There might be some when the `shutdown()` happened.
2906         # Silence their logger's spew about the event loop being closed.
2907         cf_logger = logging.getLogger("concurrent.futures")
2908         cf_logger.setLevel(logging.CRITICAL)
2909         loop.close()
2910
2911
2912 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2913     """Replace `regex` with `replacement` twice on `original`.
2914
2915     This is used by string normalization to perform replaces on
2916     overlapping matches.
2917     """
2918     return regex.sub(replacement, regex.sub(replacement, original))
2919
2920
2921 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2922     """Like `reversed(enumerate(sequence))` if that were possible."""
2923     index = len(sequence) - 1
2924     for element in reversed(sequence):
2925         yield (index, element)
2926         index -= 1
2927
2928
2929 def enumerate_with_length(
2930     line: Line, reversed: bool = False
2931 ) -> Iterator[Tuple[Index, Leaf, int]]:
2932     """Return an enumeration of leaves with their length.
2933
2934     Stops prematurely on multiline strings and standalone comments.
2935     """
2936     op = cast(
2937         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
2938         enumerate_reversed if reversed else enumerate,
2939     )
2940     for index, leaf in op(line.leaves):
2941         length = len(leaf.prefix) + len(leaf.value)
2942         if "\n" in leaf.value:
2943             return  # Multiline strings, we can't continue.
2944
2945         comment: Optional[Leaf]
2946         for comment in line.comments_after(leaf, index):
2947             if "\n" in comment.prefix:
2948                 return  # Oops, standalone comment!
2949
2950             length += len(comment.value)
2951
2952         yield index, leaf, length
2953
2954
2955 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
2956     """Return True if `line` is no longer than `line_length`.
2957
2958     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
2959     """
2960     if not line_str:
2961         line_str = str(line).strip("\n")
2962     return (
2963         len(line_str) <= line_length
2964         and "\n" not in line_str  # multiline strings
2965         and not line.contains_standalone_comments()
2966     )
2967
2968
2969 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
2970     """Does `line` have a shape safe to reformat without optional parens around it?
2971
2972     Returns True for only a subset of potentially nice looking formattings but
2973     the point is to not return false positives that end up producing lines that
2974     are too long.
2975     """
2976     bt = line.bracket_tracker
2977     if not bt.delimiters:
2978         # Without delimiters the optional parentheses are useless.
2979         return True
2980
2981     max_priority = bt.max_delimiter_priority()
2982     if bt.delimiter_count_with_priority(max_priority) > 1:
2983         # With more than one delimiter of a kind the optional parentheses read better.
2984         return False
2985
2986     if max_priority == DOT_PRIORITY:
2987         # A single stranded method call doesn't require optional parentheses.
2988         return True
2989
2990     assert len(line.leaves) >= 2, "Stranded delimiter"
2991
2992     first = line.leaves[0]
2993     second = line.leaves[1]
2994     penultimate = line.leaves[-2]
2995     last = line.leaves[-1]
2996
2997     # With a single delimiter, omit if the expression starts or ends with
2998     # a bracket.
2999     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
3000         remainder = False
3001         length = 4 * line.depth
3002         for _index, leaf, leaf_length in enumerate_with_length(line):
3003             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3004                 remainder = True
3005             if remainder:
3006                 length += leaf_length
3007                 if length > line_length:
3008                     break
3009
3010                 if leaf.type in OPENING_BRACKETS:
3011                     # There are brackets we can further split on.
3012                     remainder = False
3013
3014         else:
3015             # checked the entire string and line length wasn't exceeded
3016             if len(line.leaves) == _index + 1:
3017                 return True
3018
3019         # Note: we are not returning False here because a line might have *both*
3020         # a leading opening bracket and a trailing closing bracket.  If the
3021         # opening bracket doesn't match our rule, maybe the closing will.
3022
3023     if (
3024         last.type == token.RPAR
3025         or last.type == token.RBRACE
3026         or (
3027             # don't use indexing for omitting optional parentheses;
3028             # it looks weird
3029             last.type == token.RSQB
3030             and last.parent
3031             and last.parent.type != syms.trailer
3032         )
3033     ):
3034         if penultimate.type in OPENING_BRACKETS:
3035             # Empty brackets don't help.
3036             return False
3037
3038         if is_multiline_string(first):
3039             # Additional wrapping of a multiline string in this situation is
3040             # unnecessary.
3041             return True
3042
3043         length = 4 * line.depth
3044         seen_other_brackets = False
3045         for _index, leaf, leaf_length in enumerate_with_length(line):
3046             length += leaf_length
3047             if leaf is last.opening_bracket:
3048                 if seen_other_brackets or length <= line_length:
3049                     return True
3050
3051             elif leaf.type in OPENING_BRACKETS:
3052                 # There are brackets we can further split on.
3053                 seen_other_brackets = True
3054
3055     return False
3056
3057
3058 def get_cache_file(line_length: int) -> Path:
3059     return CACHE_DIR / f"cache.{line_length}.pickle"
3060
3061
3062 def read_cache(line_length: int) -> Cache:
3063     """Read the cache if it exists and is well formed.
3064
3065     If it is not well formed, the call to write_cache later should resolve the issue.
3066     """
3067     cache_file = get_cache_file(line_length)
3068     if not cache_file.exists():
3069         return {}
3070
3071     with cache_file.open("rb") as fobj:
3072         try:
3073             cache: Cache = pickle.load(fobj)
3074         except pickle.UnpicklingError:
3075             return {}
3076
3077     return cache
3078
3079
3080 def get_cache_info(path: Path) -> CacheInfo:
3081     """Return the information used to check if a file is already formatted or not."""
3082     stat = path.stat()
3083     return stat.st_mtime, stat.st_size
3084
3085
3086 def filter_cached(
3087     cache: Cache, sources: Iterable[Path]
3088 ) -> Tuple[List[Path], List[Path]]:
3089     """Split a list of paths into two.
3090
3091     The first list contains paths of files that modified on disk or are not in the
3092     cache. The other list contains paths to non-modified files.
3093     """
3094     todo, done = [], []
3095     for src in sources:
3096         src = src.resolve()
3097         if cache.get(src) != get_cache_info(src):
3098             todo.append(src)
3099         else:
3100             done.append(src)
3101     return todo, done
3102
3103
3104 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
3105     """Update the cache file."""
3106     cache_file = get_cache_file(line_length)
3107     try:
3108         if not CACHE_DIR.exists():
3109             CACHE_DIR.mkdir(parents=True)
3110         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3111         with cache_file.open("wb") as fobj:
3112             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3113     except OSError:
3114         pass
3115
3116
3117 if __name__ == "__main__":
3118     main()