]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

21e37435105164883391ecabe83e1f3510e24434
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     report = Report(check=check, quiet=quiet)
195     if len(sources) == 0:
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back != WriteBack.DIFF and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back != WriteBack.DIFF and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 COMPREHENSION_PRIORITY = 20
585 COMMA_PRIORITY = 10
586 TERNARY_PRIORITY = 7
587 LOGIC_PRIORITY = 5
588 STRING_PRIORITY = 4
589 COMPARATOR_PRIORITY = 3
590 MATH_PRIORITY = 1
591
592
593 @dataclass
594 class BracketTracker:
595     """Keeps track of brackets on a line."""
596
597     depth: int = 0
598     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
599     delimiters: Dict[LeafID, Priority] = Factory(dict)
600     previous: Optional[Leaf] = None
601     _for_loop_variable: bool = False
602     _lambda_arguments: bool = False
603
604     def mark(self, leaf: Leaf) -> None:
605         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
606
607         All leaves receive an int `bracket_depth` field that stores how deep
608         within brackets a given leaf is. 0 means there are no enclosing brackets
609         that started on this line.
610
611         If a leaf is itself a closing bracket, it receives an `opening_bracket`
612         field that it forms a pair with. This is a one-directional link to
613         avoid reference cycles.
614
615         If a leaf is a delimiter (a token on which Black can split the line if
616         needed) and it's on depth 0, its `id()` is stored in the tracker's
617         `delimiters` field.
618         """
619         if leaf.type == token.COMMENT:
620             return
621
622         self.maybe_decrement_after_for_loop_variable(leaf)
623         self.maybe_decrement_after_lambda_arguments(leaf)
624         if leaf.type in CLOSING_BRACKETS:
625             self.depth -= 1
626             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
627             leaf.opening_bracket = opening_bracket
628         leaf.bracket_depth = self.depth
629         if self.depth == 0:
630             delim = is_split_before_delimiter(leaf, self.previous)
631             if delim and self.previous is not None:
632                 self.delimiters[id(self.previous)] = delim
633             else:
634                 delim = is_split_after_delimiter(leaf, self.previous)
635                 if delim:
636                     self.delimiters[id(leaf)] = delim
637         if leaf.type in OPENING_BRACKETS:
638             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
639             self.depth += 1
640         self.previous = leaf
641         self.maybe_increment_lambda_arguments(leaf)
642         self.maybe_increment_for_loop_variable(leaf)
643
644     def any_open_brackets(self) -> bool:
645         """Return True if there is an yet unmatched open bracket on the line."""
646         return bool(self.bracket_match)
647
648     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
649         """Return the highest priority of a delimiter found on the line.
650
651         Values are consistent with what `is_split_*_delimiter()` return.
652         Raises ValueError on no delimiters.
653         """
654         return max(v for k, v in self.delimiters.items() if k not in exclude)
655
656     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
657         """In a for loop, or comprehension, the variables are often unpacks.
658
659         To avoid splitting on the comma in this situation, increase the depth of
660         tokens between `for` and `in`.
661         """
662         if leaf.type == token.NAME and leaf.value == "for":
663             self.depth += 1
664             self._for_loop_variable = True
665             return True
666
667         return False
668
669     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
670         """See `maybe_increment_for_loop_variable` above for explanation."""
671         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
672             self.depth -= 1
673             self._for_loop_variable = False
674             return True
675
676         return False
677
678     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
679         """In a lambda expression, there might be more than one argument.
680
681         To avoid splitting on the comma in this situation, increase the depth of
682         tokens between `lambda` and `:`.
683         """
684         if leaf.type == token.NAME and leaf.value == "lambda":
685             self.depth += 1
686             self._lambda_arguments = True
687             return True
688
689         return False
690
691     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
692         """See `maybe_increment_lambda_arguments` above for explanation."""
693         if self._lambda_arguments and leaf.type == token.COLON:
694             self.depth -= 1
695             self._lambda_arguments = False
696             return True
697
698         return False
699
700
701 @dataclass
702 class Line:
703     """Holds leaves and comments. Can be printed with `str(line)`."""
704
705     depth: int = 0
706     leaves: List[Leaf] = Factory(list)
707     comments: List[Tuple[Index, Leaf]] = Factory(list)
708     bracket_tracker: BracketTracker = Factory(BracketTracker)
709     inside_brackets: bool = False
710
711     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
712         """Add a new `leaf` to the end of the line.
713
714         Unless `preformatted` is True, the `leaf` will receive a new consistent
715         whitespace prefix and metadata applied by :class:`BracketTracker`.
716         Trailing commas are maybe removed, unpacked for loop variables are
717         demoted from being delimiters.
718
719         Inline comments are put aside.
720         """
721         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
722         if not has_value:
723             return
724
725         if self.leaves and not preformatted:
726             # Note: at this point leaf.prefix should be empty except for
727             # imports, for which we only preserve newlines.
728             leaf.prefix += whitespace(leaf)
729         if self.inside_brackets or not preformatted:
730             self.bracket_tracker.mark(leaf)
731             self.maybe_remove_trailing_comma(leaf)
732
733         if not self.append_comment(leaf):
734             self.leaves.append(leaf)
735
736     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
737         """Like :func:`append()` but disallow invalid standalone comment structure.
738
739         Raises ValueError when any `leaf` is appended after a standalone comment
740         or when a standalone comment is not the first leaf on the line.
741         """
742         if self.bracket_tracker.depth == 0:
743             if self.is_comment:
744                 raise ValueError("cannot append to standalone comments")
745
746             if self.leaves and leaf.type == STANDALONE_COMMENT:
747                 raise ValueError(
748                     "cannot append standalone comments to a populated line"
749                 )
750
751         self.append(leaf, preformatted=preformatted)
752
753     @property
754     def is_comment(self) -> bool:
755         """Is this line a standalone comment?"""
756         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
757
758     @property
759     def is_decorator(self) -> bool:
760         """Is this line a decorator?"""
761         return bool(self) and self.leaves[0].type == token.AT
762
763     @property
764     def is_import(self) -> bool:
765         """Is this an import line?"""
766         return bool(self) and is_import(self.leaves[0])
767
768     @property
769     def is_class(self) -> bool:
770         """Is this line a class definition?"""
771         return (
772             bool(self)
773             and self.leaves[0].type == token.NAME
774             and self.leaves[0].value == "class"
775         )
776
777     @property
778     def is_def(self) -> bool:
779         """Is this a function definition? (Also returns True for async defs.)"""
780         try:
781             first_leaf = self.leaves[0]
782         except IndexError:
783             return False
784
785         try:
786             second_leaf: Optional[Leaf] = self.leaves[1]
787         except IndexError:
788             second_leaf = None
789         return (
790             (first_leaf.type == token.NAME and first_leaf.value == "def")
791             or (
792                 first_leaf.type == token.ASYNC
793                 and second_leaf is not None
794                 and second_leaf.type == token.NAME
795                 and second_leaf.value == "def"
796             )
797         )
798
799     @property
800     def is_flow_control(self) -> bool:
801         """Is this line a flow control statement?
802
803         Those are `return`, `raise`, `break`, and `continue`.
804         """
805         return (
806             bool(self)
807             and self.leaves[0].type == token.NAME
808             and self.leaves[0].value in FLOW_CONTROL
809         )
810
811     @property
812     def is_yield(self) -> bool:
813         """Is this line a yield statement?"""
814         return (
815             bool(self)
816             and self.leaves[0].type == token.NAME
817             and self.leaves[0].value == "yield"
818         )
819
820     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
821         """If so, needs to be split before emitting."""
822         for leaf in self.leaves:
823             if leaf.type == STANDALONE_COMMENT:
824                 if leaf.bracket_depth <= depth_limit:
825                     return True
826
827         return False
828
829     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
830         """Remove trailing comma if there is one and it's safe."""
831         if not (
832             self.leaves
833             and self.leaves[-1].type == token.COMMA
834             and closing.type in CLOSING_BRACKETS
835         ):
836             return False
837
838         if closing.type == token.RBRACE:
839             self.remove_trailing_comma()
840             return True
841
842         if closing.type == token.RSQB:
843             comma = self.leaves[-1]
844             if comma.parent and comma.parent.type == syms.listmaker:
845                 self.remove_trailing_comma()
846                 return True
847
848         # For parens let's check if it's safe to remove the comma.  If the
849         # trailing one is the only one, we might mistakenly change a tuple
850         # into a different type by removing the comma.
851         depth = closing.bracket_depth + 1
852         commas = 0
853         opening = closing.opening_bracket
854         for _opening_index, leaf in enumerate(self.leaves):
855             if leaf is opening:
856                 break
857
858         else:
859             return False
860
861         for leaf in self.leaves[_opening_index + 1:]:
862             if leaf is closing:
863                 break
864
865             bracket_depth = leaf.bracket_depth
866             if bracket_depth == depth and leaf.type == token.COMMA:
867                 commas += 1
868                 if leaf.parent and leaf.parent.type == syms.arglist:
869                     commas += 1
870                     break
871
872         if commas > 1:
873             self.remove_trailing_comma()
874             return True
875
876         return False
877
878     def append_comment(self, comment: Leaf) -> bool:
879         """Add an inline or standalone comment to the line."""
880         if (
881             comment.type == STANDALONE_COMMENT
882             and self.bracket_tracker.any_open_brackets()
883         ):
884             comment.prefix = ""
885             return False
886
887         if comment.type != token.COMMENT:
888             return False
889
890         after = len(self.leaves) - 1
891         if after == -1:
892             comment.type = STANDALONE_COMMENT
893             comment.prefix = ""
894             return False
895
896         else:
897             self.comments.append((after, comment))
898             return True
899
900     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
901         """Generate comments that should appear directly after `leaf`."""
902         for _leaf_index, _leaf in enumerate(self.leaves):
903             if leaf is _leaf:
904                 break
905
906         else:
907             return
908
909         for index, comment_after in self.comments:
910             if _leaf_index == index:
911                 yield comment_after
912
913     def remove_trailing_comma(self) -> None:
914         """Remove the trailing comma and moves the comments attached to it."""
915         comma_index = len(self.leaves) - 1
916         for i in range(len(self.comments)):
917             comment_index, comment = self.comments[i]
918             if comment_index == comma_index:
919                 self.comments[i] = (comma_index - 1, comment)
920         self.leaves.pop()
921
922     def __str__(self) -> str:
923         """Render the line."""
924         if not self:
925             return "\n"
926
927         indent = "    " * self.depth
928         leaves = iter(self.leaves)
929         first = next(leaves)
930         res = f"{first.prefix}{indent}{first.value}"
931         for leaf in leaves:
932             res += str(leaf)
933         for _, comment in self.comments:
934             res += str(comment)
935         return res + "\n"
936
937     def __bool__(self) -> bool:
938         """Return True if the line has leaves or comments."""
939         return bool(self.leaves or self.comments)
940
941
942 class UnformattedLines(Line):
943     """Just like :class:`Line` but stores lines which aren't reformatted."""
944
945     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
946         """Just add a new `leaf` to the end of the lines.
947
948         The `preformatted` argument is ignored.
949
950         Keeps track of indentation `depth`, which is useful when the user
951         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
952         """
953         try:
954             list(generate_comments(leaf))
955         except FormatOn as f_on:
956             self.leaves.append(f_on.leaf_from_consumed(leaf))
957             raise
958
959         self.leaves.append(leaf)
960         if leaf.type == token.INDENT:
961             self.depth += 1
962         elif leaf.type == token.DEDENT:
963             self.depth -= 1
964
965     def __str__(self) -> str:
966         """Render unformatted lines from leaves which were added with `append()`.
967
968         `depth` is not used for indentation in this case.
969         """
970         if not self:
971             return "\n"
972
973         res = ""
974         for leaf in self.leaves:
975             res += str(leaf)
976         return res
977
978     def append_comment(self, comment: Leaf) -> bool:
979         """Not implemented in this class. Raises `NotImplementedError`."""
980         raise NotImplementedError("Unformatted lines don't store comments separately.")
981
982     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
983         """Does nothing and returns False."""
984         return False
985
986     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
987         """Does nothing and returns False."""
988         return False
989
990
991 @dataclass
992 class EmptyLineTracker:
993     """Provides a stateful method that returns the number of potential extra
994     empty lines needed before and after the currently processed line.
995
996     Note: this tracker works on lines that haven't been split yet.  It assumes
997     the prefix of the first leaf consists of optional newlines.  Those newlines
998     are consumed by `maybe_empty_lines()` and included in the computation.
999     """
1000     previous_line: Optional[Line] = None
1001     previous_after: int = 0
1002     previous_defs: List[int] = Factory(list)
1003
1004     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1005         """Return the number of extra empty lines before and after the `current_line`.
1006
1007         This is for separating `def`, `async def` and `class` with extra empty
1008         lines (two on module-level), as well as providing an extra empty line
1009         after flow control keywords to make them more prominent.
1010         """
1011         if isinstance(current_line, UnformattedLines):
1012             return 0, 0
1013
1014         before, after = self._maybe_empty_lines(current_line)
1015         before -= self.previous_after
1016         self.previous_after = after
1017         self.previous_line = current_line
1018         return before, after
1019
1020     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1021         max_allowed = 1
1022         if current_line.depth == 0:
1023             max_allowed = 2
1024         if current_line.leaves:
1025             # Consume the first leaf's extra newlines.
1026             first_leaf = current_line.leaves[0]
1027             before = first_leaf.prefix.count("\n")
1028             before = min(before, max_allowed)
1029             first_leaf.prefix = ""
1030         else:
1031             before = 0
1032         depth = current_line.depth
1033         while self.previous_defs and self.previous_defs[-1] >= depth:
1034             self.previous_defs.pop()
1035             before = 1 if depth else 2
1036         is_decorator = current_line.is_decorator
1037         if is_decorator or current_line.is_def or current_line.is_class:
1038             if not is_decorator:
1039                 self.previous_defs.append(depth)
1040             if self.previous_line is None:
1041                 # Don't insert empty lines before the first line in the file.
1042                 return 0, 0
1043
1044             if self.previous_line.is_decorator:
1045                 return 0, 0
1046
1047             if (
1048                 self.previous_line.is_comment
1049                 and self.previous_line.depth == current_line.depth
1050                 and before == 0
1051             ):
1052                 return 0, 0
1053
1054             newlines = 2
1055             if current_line.depth:
1056                 newlines -= 1
1057             return newlines, 0
1058
1059         if (
1060             self.previous_line
1061             and self.previous_line.is_import
1062             and not current_line.is_import
1063             and depth == self.previous_line.depth
1064         ):
1065             return (before or 1), 0
1066
1067         return before, 0
1068
1069
1070 @dataclass
1071 class LineGenerator(Visitor[Line]):
1072     """Generates reformatted Line objects.  Empty lines are not emitted.
1073
1074     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1075     in ways that will no longer stringify to valid Python code on the tree.
1076     """
1077     current_line: Line = Factory(Line)
1078
1079     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1080         """Generate a line.
1081
1082         If the line is empty, only emit if it makes sense.
1083         If the line is too long, split it first and then generate.
1084
1085         If any lines were generated, set up a new current_line.
1086         """
1087         if not self.current_line:
1088             if self.current_line.__class__ == type:
1089                 self.current_line.depth += indent
1090             else:
1091                 self.current_line = type(depth=self.current_line.depth + indent)
1092             return  # Line is empty, don't emit. Creating a new one unnecessary.
1093
1094         complete_line = self.current_line
1095         self.current_line = type(depth=complete_line.depth + indent)
1096         yield complete_line
1097
1098     def visit(self, node: LN) -> Iterator[Line]:
1099         """Main method to visit `node` and its children.
1100
1101         Yields :class:`Line` objects.
1102         """
1103         if isinstance(self.current_line, UnformattedLines):
1104             # File contained `# fmt: off`
1105             yield from self.visit_unformatted(node)
1106
1107         else:
1108             yield from super().visit(node)
1109
1110     def visit_default(self, node: LN) -> Iterator[Line]:
1111         """Default `visit_*()` implementation. Recurses to children of `node`."""
1112         if isinstance(node, Leaf):
1113             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1114             try:
1115                 for comment in generate_comments(node):
1116                     if any_open_brackets:
1117                         # any comment within brackets is subject to splitting
1118                         self.current_line.append(comment)
1119                     elif comment.type == token.COMMENT:
1120                         # regular trailing comment
1121                         self.current_line.append(comment)
1122                         yield from self.line()
1123
1124                     else:
1125                         # regular standalone comment
1126                         yield from self.line()
1127
1128                         self.current_line.append(comment)
1129                         yield from self.line()
1130
1131             except FormatOff as f_off:
1132                 f_off.trim_prefix(node)
1133                 yield from self.line(type=UnformattedLines)
1134                 yield from self.visit(node)
1135
1136             except FormatOn as f_on:
1137                 # This only happens here if somebody says "fmt: on" multiple
1138                 # times in a row.
1139                 f_on.trim_prefix(node)
1140                 yield from self.visit_default(node)
1141
1142             else:
1143                 normalize_prefix(node, inside_brackets=any_open_brackets)
1144                 if node.type == token.STRING:
1145                     normalize_string_quotes(node)
1146                 if node.type not in WHITESPACE:
1147                     self.current_line.append(node)
1148         yield from super().visit_default(node)
1149
1150     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1151         """Increase indentation level, maybe yield a line."""
1152         # In blib2to3 INDENT never holds comments.
1153         yield from self.line(+1)
1154         yield from self.visit_default(node)
1155
1156     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1157         """Decrease indentation level, maybe yield a line."""
1158         # The current line might still wait for trailing comments.  At DEDENT time
1159         # there won't be any (they would be prefixes on the preceding NEWLINE).
1160         # Emit the line then.
1161         yield from self.line()
1162
1163         # While DEDENT has no value, its prefix may contain standalone comments
1164         # that belong to the current indentation level.  Get 'em.
1165         yield from self.visit_default(node)
1166
1167         # Finally, emit the dedent.
1168         yield from self.line(-1)
1169
1170     def visit_stmt(
1171         self, node: Node, keywords: Set[str], parens: Set[str]
1172     ) -> Iterator[Line]:
1173         """Visit a statement.
1174
1175         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1176         `def`, `with`, `class`, and `assert`.
1177
1178         The relevant Python language `keywords` for a given statement will be
1179         NAME leaves within it. This methods puts those on a separate line.
1180
1181         `parens` holds pairs of nodes where invisible parentheses should be put.
1182         Keys hold nodes after which opening parentheses should be put, values
1183         hold nodes before which closing parentheses should be put.
1184         """
1185         normalize_invisible_parens(node, parens_after=parens)
1186         for child in node.children:
1187             if child.type == token.NAME and child.value in keywords:  # type: ignore
1188                 yield from self.line()
1189
1190             yield from self.visit(child)
1191
1192     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1193         """Visit a statement without nested statements."""
1194         is_suite_like = node.parent and node.parent.type in STATEMENT
1195         if is_suite_like:
1196             yield from self.line(+1)
1197             yield from self.visit_default(node)
1198             yield from self.line(-1)
1199
1200         else:
1201             yield from self.line()
1202             yield from self.visit_default(node)
1203
1204     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1205         """Visit `async def`, `async for`, `async with`."""
1206         yield from self.line()
1207
1208         children = iter(node.children)
1209         for child in children:
1210             yield from self.visit(child)
1211
1212             if child.type == token.ASYNC:
1213                 break
1214
1215         internal_stmt = next(children)
1216         for child in internal_stmt.children:
1217             yield from self.visit(child)
1218
1219     def visit_decorators(self, node: Node) -> Iterator[Line]:
1220         """Visit decorators."""
1221         for child in node.children:
1222             yield from self.line()
1223             yield from self.visit(child)
1224
1225     def visit_import_from(self, node: Node) -> Iterator[Line]:
1226         """Visit import_from and maybe put invisible parentheses.
1227
1228         This is separate from `visit_stmt` because import statements don't
1229         support arbitrary atoms and thus handling of parentheses is custom.
1230         """
1231         check_lpar = False
1232         for index, child in enumerate(node.children):
1233             if check_lpar:
1234                 if child.type == token.LPAR:
1235                     # make parentheses invisible
1236                     child.value = ""  # type: ignore
1237                     node.children[-1].value = ""  # type: ignore
1238                 else:
1239                     # insert invisible parentheses
1240                     node.insert_child(index, Leaf(token.LPAR, ""))
1241                     node.append_child(Leaf(token.RPAR, ""))
1242                 break
1243
1244             check_lpar = (
1245                 child.type == token.NAME and child.value == "import"  # type: ignore
1246             )
1247
1248         for child in node.children:
1249             yield from self.visit(child)
1250
1251     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1252         """Remove a semicolon and put the other statement on a separate line."""
1253         yield from self.line()
1254
1255     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1256         """End of file. Process outstanding comments and end with a newline."""
1257         yield from self.visit_default(leaf)
1258         yield from self.line()
1259
1260     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1261         """Used when file contained a `# fmt: off`."""
1262         if isinstance(node, Node):
1263             for child in node.children:
1264                 yield from self.visit(child)
1265
1266         else:
1267             try:
1268                 self.current_line.append(node)
1269             except FormatOn as f_on:
1270                 f_on.trim_prefix(node)
1271                 yield from self.line()
1272                 yield from self.visit(node)
1273
1274             if node.type == token.ENDMARKER:
1275                 # somebody decided not to put a final `# fmt: on`
1276                 yield from self.line()
1277
1278     def __attrs_post_init__(self) -> None:
1279         """You are in a twisty little maze of passages."""
1280         v = self.visit_stmt
1281         Ø: Set[str] = set()
1282         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1283         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1284         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1285         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1286         self.visit_try_stmt = partial(
1287             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1288         )
1289         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1290         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1291         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1292         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1293         self.visit_async_funcdef = self.visit_async_stmt
1294         self.visit_decorated = self.visit_decorators
1295
1296
1297 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1298 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1299 OPENING_BRACKETS = set(BRACKET.keys())
1300 CLOSING_BRACKETS = set(BRACKET.values())
1301 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1302 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1303
1304
1305 def whitespace(leaf: Leaf) -> str:  # noqa C901
1306     """Return whitespace prefix if needed for the given `leaf`."""
1307     NO = ""
1308     SPACE = " "
1309     DOUBLESPACE = "  "
1310     t = leaf.type
1311     p = leaf.parent
1312     v = leaf.value
1313     if t in ALWAYS_NO_SPACE:
1314         return NO
1315
1316     if t == token.COMMENT:
1317         return DOUBLESPACE
1318
1319     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1320     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1321         return NO
1322
1323     prev = leaf.prev_sibling
1324     if not prev:
1325         prevp = preceding_leaf(p)
1326         if not prevp or prevp.type in OPENING_BRACKETS:
1327             return NO
1328
1329         if t == token.COLON:
1330             return SPACE if prevp.type == token.COMMA else NO
1331
1332         if prevp.type == token.EQUAL:
1333             if prevp.parent:
1334                 if prevp.parent.type in {
1335                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1336                 }:
1337                     return NO
1338
1339                 elif prevp.parent.type == syms.typedargslist:
1340                     # A bit hacky: if the equal sign has whitespace, it means we
1341                     # previously found it's a typed argument.  So, we're using
1342                     # that, too.
1343                     return prevp.prefix
1344
1345         elif prevp.type in STARS:
1346             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1347                 return NO
1348
1349         elif prevp.type == token.COLON:
1350             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1351                 return NO
1352
1353         elif (
1354             prevp.parent
1355             and prevp.parent.type == syms.factor
1356             and prevp.type in MATH_OPERATORS
1357         ):
1358             return NO
1359
1360         elif (
1361             prevp.type == token.RIGHTSHIFT
1362             and prevp.parent
1363             and prevp.parent.type == syms.shift_expr
1364             and prevp.prev_sibling
1365             and prevp.prev_sibling.type == token.NAME
1366             and prevp.prev_sibling.value == "print"  # type: ignore
1367         ):
1368             # Python 2 print chevron
1369             return NO
1370
1371     elif prev.type in OPENING_BRACKETS:
1372         return NO
1373
1374     if p.type in {syms.parameters, syms.arglist}:
1375         # untyped function signatures or calls
1376         if not prev or prev.type != token.COMMA:
1377             return NO
1378
1379     elif p.type == syms.varargslist:
1380         # lambdas
1381         if prev and prev.type != token.COMMA:
1382             return NO
1383
1384     elif p.type == syms.typedargslist:
1385         # typed function signatures
1386         if not prev:
1387             return NO
1388
1389         if t == token.EQUAL:
1390             if prev.type != syms.tname:
1391                 return NO
1392
1393         elif prev.type == token.EQUAL:
1394             # A bit hacky: if the equal sign has whitespace, it means we
1395             # previously found it's a typed argument.  So, we're using that, too.
1396             return prev.prefix
1397
1398         elif prev.type != token.COMMA:
1399             return NO
1400
1401     elif p.type == syms.tname:
1402         # type names
1403         if not prev:
1404             prevp = preceding_leaf(p)
1405             if not prevp or prevp.type != token.COMMA:
1406                 return NO
1407
1408     elif p.type == syms.trailer:
1409         # attributes and calls
1410         if t == token.LPAR or t == token.RPAR:
1411             return NO
1412
1413         if not prev:
1414             if t == token.DOT:
1415                 prevp = preceding_leaf(p)
1416                 if not prevp or prevp.type != token.NUMBER:
1417                     return NO
1418
1419             elif t == token.LSQB:
1420                 return NO
1421
1422         elif prev.type != token.COMMA:
1423             return NO
1424
1425     elif p.type == syms.argument:
1426         # single argument
1427         if t == token.EQUAL:
1428             return NO
1429
1430         if not prev:
1431             prevp = preceding_leaf(p)
1432             if not prevp or prevp.type == token.LPAR:
1433                 return NO
1434
1435         elif prev.type in {token.EQUAL} | STARS:
1436             return NO
1437
1438     elif p.type == syms.decorator:
1439         # decorators
1440         return NO
1441
1442     elif p.type == syms.dotted_name:
1443         if prev:
1444             return NO
1445
1446         prevp = preceding_leaf(p)
1447         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1448             return NO
1449
1450     elif p.type == syms.classdef:
1451         if t == token.LPAR:
1452             return NO
1453
1454         if prev and prev.type == token.LPAR:
1455             return NO
1456
1457     elif p.type == syms.subscript:
1458         # indexing
1459         if not prev:
1460             assert p.parent is not None, "subscripts are always parented"
1461             if p.parent.type == syms.subscriptlist:
1462                 return SPACE
1463
1464             return NO
1465
1466         else:
1467             return NO
1468
1469     elif p.type == syms.atom:
1470         if prev and t == token.DOT:
1471             # dots, but not the first one.
1472             return NO
1473
1474     elif p.type == syms.dictsetmaker:
1475         # dict unpacking
1476         if prev and prev.type == token.DOUBLESTAR:
1477             return NO
1478
1479     elif p.type in {syms.factor, syms.star_expr}:
1480         # unary ops
1481         if not prev:
1482             prevp = preceding_leaf(p)
1483             if not prevp or prevp.type in OPENING_BRACKETS:
1484                 return NO
1485
1486             prevp_parent = prevp.parent
1487             assert prevp_parent is not None
1488             if (
1489                 prevp.type == token.COLON
1490                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1491             ):
1492                 return NO
1493
1494             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1495                 return NO
1496
1497         elif t == token.NAME or t == token.NUMBER:
1498             return NO
1499
1500     elif p.type == syms.import_from:
1501         if t == token.DOT:
1502             if prev and prev.type == token.DOT:
1503                 return NO
1504
1505         elif t == token.NAME:
1506             if v == "import":
1507                 return SPACE
1508
1509             if prev and prev.type == token.DOT:
1510                 return NO
1511
1512     elif p.type == syms.sliceop:
1513         return NO
1514
1515     return SPACE
1516
1517
1518 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1519     """Return the first leaf that precedes `node`, if any."""
1520     while node:
1521         res = node.prev_sibling
1522         if res:
1523             if isinstance(res, Leaf):
1524                 return res
1525
1526             try:
1527                 return list(res.leaves())[-1]
1528
1529             except IndexError:
1530                 return None
1531
1532         node = node.parent
1533     return None
1534
1535
1536 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1537     """Return the priority of the `leaf` delimiter, given a line break after it.
1538
1539     The delimiter priorities returned here are from those delimiters that would
1540     cause a line break after themselves.
1541
1542     Higher numbers are higher priority.
1543     """
1544     if leaf.type == token.COMMA:
1545         return COMMA_PRIORITY
1546
1547     return 0
1548
1549
1550 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1551     """Return the priority of the `leaf` delimiter, given a line before after it.
1552
1553     The delimiter priorities returned here are from those delimiters that would
1554     cause a line break before themselves.
1555
1556     Higher numbers are higher priority.
1557     """
1558     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1559         # * and ** might also be MATH_OPERATORS but in this case they are not.
1560         # Don't treat them as a delimiter.
1561         return 0
1562
1563     if (
1564         leaf.type in MATH_OPERATORS
1565         and leaf.parent
1566         and leaf.parent.type not in {syms.factor, syms.star_expr}
1567     ):
1568         return MATH_PRIORITY
1569
1570     if leaf.type in COMPARATORS:
1571         return COMPARATOR_PRIORITY
1572
1573     if (
1574         leaf.type == token.STRING
1575         and previous is not None
1576         and previous.type == token.STRING
1577     ):
1578         return STRING_PRIORITY
1579
1580     if (
1581         leaf.type == token.NAME
1582         and leaf.value == "for"
1583         and leaf.parent
1584         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1585     ):
1586         return COMPREHENSION_PRIORITY
1587
1588     if (
1589         leaf.type == token.NAME
1590         and leaf.value == "if"
1591         and leaf.parent
1592         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1593     ):
1594         return COMPREHENSION_PRIORITY
1595
1596     if (
1597         leaf.type == token.NAME
1598         and leaf.value in {"if", "else"}
1599         and leaf.parent
1600         and leaf.parent.type == syms.test
1601     ):
1602         return TERNARY_PRIORITY
1603
1604     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1605         return LOGIC_PRIORITY
1606
1607     return 0
1608
1609
1610 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1611     """Clean the prefix of the `leaf` and generate comments from it, if any.
1612
1613     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1614     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1615     move because it does away with modifying the grammar to include all the
1616     possible places in which comments can be placed.
1617
1618     The sad consequence for us though is that comments don't "belong" anywhere.
1619     This is why this function generates simple parentless Leaf objects for
1620     comments.  We simply don't know what the correct parent should be.
1621
1622     No matter though, we can live without this.  We really only need to
1623     differentiate between inline and standalone comments.  The latter don't
1624     share the line with any code.
1625
1626     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1627     are emitted with a fake STANDALONE_COMMENT token identifier.
1628     """
1629     p = leaf.prefix
1630     if not p:
1631         return
1632
1633     if "#" not in p:
1634         return
1635
1636     consumed = 0
1637     nlines = 0
1638     for index, line in enumerate(p.split("\n")):
1639         consumed += len(line) + 1  # adding the length of the split '\n'
1640         line = line.lstrip()
1641         if not line:
1642             nlines += 1
1643         if not line.startswith("#"):
1644             continue
1645
1646         if index == 0 and leaf.type != token.ENDMARKER:
1647             comment_type = token.COMMENT  # simple trailing comment
1648         else:
1649             comment_type = STANDALONE_COMMENT
1650         comment = make_comment(line)
1651         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1652
1653         if comment in {"# fmt: on", "# yapf: enable"}:
1654             raise FormatOn(consumed)
1655
1656         if comment in {"# fmt: off", "# yapf: disable"}:
1657             if comment_type == STANDALONE_COMMENT:
1658                 raise FormatOff(consumed)
1659
1660             prev = preceding_leaf(leaf)
1661             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1662                 raise FormatOff(consumed)
1663
1664         nlines = 0
1665
1666
1667 def make_comment(content: str) -> str:
1668     """Return a consistently formatted comment from the given `content` string.
1669
1670     All comments (except for "##", "#!", "#:") should have a single space between
1671     the hash sign and the content.
1672
1673     If `content` didn't start with a hash sign, one is provided.
1674     """
1675     content = content.rstrip()
1676     if not content:
1677         return "#"
1678
1679     if content[0] == "#":
1680         content = content[1:]
1681     if content and content[0] not in " !:#":
1682         content = " " + content
1683     return "#" + content
1684
1685
1686 def split_line(
1687     line: Line, line_length: int, inner: bool = False, py36: bool = False
1688 ) -> Iterator[Line]:
1689     """Split a `line` into potentially many lines.
1690
1691     They should fit in the allotted `line_length` but might not be able to.
1692     `inner` signifies that there were a pair of brackets somewhere around the
1693     current `line`, possibly transitively. This means we can fallback to splitting
1694     by delimiters if the LHS/RHS don't yield any results.
1695
1696     If `py36` is True, splitting may generate syntax that is only compatible
1697     with Python 3.6 and later.
1698     """
1699     if isinstance(line, UnformattedLines) or line.is_comment:
1700         yield line
1701         return
1702
1703     line_str = str(line).strip("\n")
1704     if (
1705         len(line_str) <= line_length
1706         and "\n" not in line_str  # multiline strings
1707         and not line.contains_standalone_comments()
1708     ):
1709         yield line
1710         return
1711
1712     split_funcs: List[SplitFunc]
1713     if line.is_def:
1714         split_funcs = [left_hand_split]
1715     elif line.is_import:
1716         split_funcs = [explode_split]
1717     elif line.inside_brackets:
1718         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1719     else:
1720         split_funcs = [right_hand_split]
1721     for split_func in split_funcs:
1722         # We are accumulating lines in `result` because we might want to abort
1723         # mission and return the original line in the end, or attempt a different
1724         # split altogether.
1725         result: List[Line] = []
1726         try:
1727             for l in split_func(line, py36):
1728                 if str(l).strip("\n") == line_str:
1729                     raise CannotSplit("Split function returned an unchanged result")
1730
1731                 result.extend(
1732                     split_line(l, line_length=line_length, inner=True, py36=py36)
1733                 )
1734         except CannotSplit as cs:
1735             continue
1736
1737         else:
1738             yield from result
1739             break
1740
1741     else:
1742         yield line
1743
1744
1745 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1746     """Split line into many lines, starting with the first matching bracket pair.
1747
1748     Note: this usually looks weird, only use this for function definitions.
1749     Prefer RHS otherwise.
1750     """
1751     head = Line(depth=line.depth)
1752     body = Line(depth=line.depth + 1, inside_brackets=True)
1753     tail = Line(depth=line.depth)
1754     tail_leaves: List[Leaf] = []
1755     body_leaves: List[Leaf] = []
1756     head_leaves: List[Leaf] = []
1757     current_leaves = head_leaves
1758     matching_bracket = None
1759     for leaf in line.leaves:
1760         if (
1761             current_leaves is body_leaves
1762             and leaf.type in CLOSING_BRACKETS
1763             and leaf.opening_bracket is matching_bracket
1764         ):
1765             current_leaves = tail_leaves if body_leaves else head_leaves
1766         current_leaves.append(leaf)
1767         if current_leaves is head_leaves:
1768             if leaf.type in OPENING_BRACKETS:
1769                 matching_bracket = leaf
1770                 current_leaves = body_leaves
1771     # Since body is a new indent level, remove spurious leading whitespace.
1772     if body_leaves:
1773         normalize_prefix(body_leaves[0], inside_brackets=True)
1774     # Build the new lines.
1775     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1776         for leaf in leaves:
1777             result.append(leaf, preformatted=True)
1778             for comment_after in line.comments_after(leaf):
1779                 result.append(comment_after, preformatted=True)
1780     bracket_split_succeeded_or_raise(head, body, tail)
1781     for result in (head, body, tail):
1782         if result:
1783             yield result
1784
1785
1786 def right_hand_split(
1787     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1788 ) -> Iterator[Line]:
1789     """Split line into many lines, starting with the last matching bracket pair."""
1790     head = Line(depth=line.depth)
1791     body = Line(depth=line.depth + 1, inside_brackets=True)
1792     tail = Line(depth=line.depth)
1793     tail_leaves: List[Leaf] = []
1794     body_leaves: List[Leaf] = []
1795     head_leaves: List[Leaf] = []
1796     current_leaves = tail_leaves
1797     opening_bracket = None
1798     closing_bracket = None
1799     for leaf in reversed(line.leaves):
1800         if current_leaves is body_leaves:
1801             if leaf is opening_bracket:
1802                 current_leaves = head_leaves if body_leaves else tail_leaves
1803         current_leaves.append(leaf)
1804         if current_leaves is tail_leaves:
1805             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1806                 opening_bracket = leaf.opening_bracket
1807                 closing_bracket = leaf
1808                 current_leaves = body_leaves
1809     tail_leaves.reverse()
1810     body_leaves.reverse()
1811     head_leaves.reverse()
1812     # Since body is a new indent level, remove spurious leading whitespace.
1813     if body_leaves:
1814         normalize_prefix(body_leaves[0], inside_brackets=True)
1815     elif not head_leaves:
1816         # No `head` and no `body` means the split failed. `tail` has all content.
1817         raise CannotSplit("No brackets found")
1818
1819     # Build the new lines.
1820     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1821         for leaf in leaves:
1822             result.append(leaf, preformatted=True)
1823             for comment_after in line.comments_after(leaf):
1824                 result.append(comment_after, preformatted=True)
1825     bracket_split_succeeded_or_raise(head, body, tail)
1826     assert opening_bracket and closing_bracket
1827     if (
1828         opening_bracket.type == token.LPAR
1829         and not opening_bracket.value
1830         and closing_bracket.type == token.RPAR
1831         and not closing_bracket.value
1832     ):
1833         # These parens were optional. If there aren't any delimiters or standalone
1834         # comments in the body, they were unnecessary and another split without
1835         # them should be attempted.
1836         if not (
1837             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1838         ):
1839             omit = {id(closing_bracket), *omit}
1840             yield from right_hand_split(line, py36=py36, omit=omit)
1841             return
1842
1843     ensure_visible(opening_bracket)
1844     ensure_visible(closing_bracket)
1845     for result in (head, body, tail):
1846         if result:
1847             yield result
1848
1849
1850 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1851     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1852
1853     Do nothing otherwise.
1854
1855     A left- or right-hand split is based on a pair of brackets. Content before
1856     (and including) the opening bracket is left on one line, content inside the
1857     brackets is put on a separate line, and finally content starting with and
1858     following the closing bracket is put on a separate line.
1859
1860     Those are called `head`, `body`, and `tail`, respectively. If the split
1861     produced the same line (all content in `head`) or ended up with an empty `body`
1862     and the `tail` is just the closing bracket, then it's considered failed.
1863     """
1864     tail_len = len(str(tail).strip())
1865     if not body:
1866         if tail_len == 0:
1867             raise CannotSplit("Splitting brackets produced the same line")
1868
1869         elif tail_len < 3:
1870             raise CannotSplit(
1871                 f"Splitting brackets on an empty body to save "
1872                 f"{tail_len} characters is not worth it"
1873             )
1874
1875
1876 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1877     """Normalize prefix of the first leaf in every line returned by `split_func`.
1878
1879     This is a decorator over relevant split functions.
1880     """
1881
1882     @wraps(split_func)
1883     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1884         for l in split_func(line, py36):
1885             normalize_prefix(l.leaves[0], inside_brackets=True)
1886             yield l
1887
1888     return split_wrapper
1889
1890
1891 @dont_increase_indentation
1892 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1893     """Split according to delimiters of the highest priority.
1894
1895     If `py36` is True, the split will add trailing commas also in function
1896     signatures that contain `*` and `**`.
1897     """
1898     try:
1899         last_leaf = line.leaves[-1]
1900     except IndexError:
1901         raise CannotSplit("Line empty")
1902
1903     delimiters = line.bracket_tracker.delimiters
1904     try:
1905         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1906             exclude={id(last_leaf)}
1907         )
1908     except ValueError:
1909         raise CannotSplit("No delimiters found")
1910
1911     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1912     lowest_depth = sys.maxsize
1913     trailing_comma_safe = True
1914
1915     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1916         """Append `leaf` to current line or to new line if appending impossible."""
1917         nonlocal current_line
1918         try:
1919             current_line.append_safe(leaf, preformatted=True)
1920         except ValueError as ve:
1921             yield current_line
1922
1923             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1924             current_line.append(leaf)
1925
1926     for leaf in line.leaves:
1927         yield from append_to_line(leaf)
1928
1929         for comment_after in line.comments_after(leaf):
1930             yield from append_to_line(comment_after)
1931
1932         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1933         if (
1934             leaf.bracket_depth == lowest_depth
1935             and is_vararg(leaf, within=VARARGS_PARENTS)
1936         ):
1937             trailing_comma_safe = trailing_comma_safe and py36
1938         leaf_priority = delimiters.get(id(leaf))
1939         if leaf_priority == delimiter_priority:
1940             yield current_line
1941
1942             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1943     if current_line:
1944         if (
1945             trailing_comma_safe
1946             and delimiter_priority == COMMA_PRIORITY
1947             and current_line.leaves[-1].type != token.COMMA
1948             and current_line.leaves[-1].type != STANDALONE_COMMENT
1949         ):
1950             current_line.append(Leaf(token.COMMA, ","))
1951         yield current_line
1952
1953
1954 @dont_increase_indentation
1955 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1956     """Split standalone comments from the rest of the line."""
1957     if not line.contains_standalone_comments(0):
1958         raise CannotSplit("Line does not have any standalone comments")
1959
1960     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1961
1962     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1963         """Append `leaf` to current line or to new line if appending impossible."""
1964         nonlocal current_line
1965         try:
1966             current_line.append_safe(leaf, preformatted=True)
1967         except ValueError as ve:
1968             yield current_line
1969
1970             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1971             current_line.append(leaf)
1972
1973     for leaf in line.leaves:
1974         yield from append_to_line(leaf)
1975
1976         for comment_after in line.comments_after(leaf):
1977             yield from append_to_line(comment_after)
1978
1979     if current_line:
1980         yield current_line
1981
1982
1983 def explode_split(
1984     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1985 ) -> Iterator[Line]:
1986     """Split by RHS and immediately split contents by a delimiter."""
1987     new_lines = list(right_hand_split(line, py36, omit))
1988     if len(new_lines) != 3:
1989         yield from new_lines
1990         return
1991
1992     yield new_lines[0]
1993     try:
1994         yield from delimiter_split(new_lines[1], py36)
1995     except CannotSplit:
1996         yield new_lines[1]
1997
1998     yield new_lines[2]
1999
2000
2001 def is_import(leaf: Leaf) -> bool:
2002     """Return True if the given leaf starts an import statement."""
2003     p = leaf.parent
2004     t = leaf.type
2005     v = leaf.value
2006     return bool(
2007         t == token.NAME
2008         and (
2009             (v == "import" and p and p.type == syms.import_name)
2010             or (v == "from" and p and p.type == syms.import_from)
2011         )
2012     )
2013
2014
2015 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2016     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2017     else.
2018
2019     Note: don't use backslashes for formatting or you'll lose your voting rights.
2020     """
2021     if not inside_brackets:
2022         spl = leaf.prefix.split("#")
2023         if "\\" not in spl[0]:
2024             nl_count = spl[-1].count("\n")
2025             if len(spl) > 1:
2026                 nl_count -= 1
2027             leaf.prefix = "\n" * nl_count
2028             return
2029
2030     leaf.prefix = ""
2031
2032
2033 def normalize_string_quotes(leaf: Leaf) -> None:
2034     """Prefer double quotes but only if it doesn't cause more escaping.
2035
2036     Adds or removes backslashes as appropriate. Doesn't parse and fix
2037     strings nested in f-strings (yet).
2038
2039     Note: Mutates its argument.
2040     """
2041     value = leaf.value.lstrip("furbFURB")
2042     if value[:3] == '"""':
2043         return
2044
2045     elif value[:3] == "'''":
2046         orig_quote = "'''"
2047         new_quote = '"""'
2048     elif value[0] == '"':
2049         orig_quote = '"'
2050         new_quote = "'"
2051     else:
2052         orig_quote = "'"
2053         new_quote = '"'
2054     first_quote_pos = leaf.value.find(orig_quote)
2055     if first_quote_pos == -1:
2056         return  # There's an internal error
2057
2058     prefix = leaf.value[:first_quote_pos]
2059     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2060     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2061     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2062     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2063     if "r" in prefix.casefold():
2064         if unescaped_new_quote.search(body):
2065             # There's at least one unescaped new_quote in this raw string
2066             # so converting is impossible
2067             return
2068
2069         # Do not introduce or remove backslashes in raw strings
2070         new_body = body
2071     else:
2072         # remove unnecessary quotes
2073         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2074         if body != new_body:
2075             # Consider the string without unnecessary quotes as the original
2076             body = new_body
2077             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2078         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2079         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2080     if new_quote == '"""' and new_body[-1] == '"':
2081         # edge case:
2082         new_body = new_body[:-1] + '\\"'
2083     orig_escape_count = body.count("\\")
2084     new_escape_count = new_body.count("\\")
2085     if new_escape_count > orig_escape_count:
2086         return  # Do not introduce more escaping
2087
2088     if new_escape_count == orig_escape_count and orig_quote == '"':
2089         return  # Prefer double quotes
2090
2091     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2092
2093
2094 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2095     """Make existing optional parentheses invisible or create new ones.
2096
2097     Standardizes on visible parentheses for single-element tuples, and keeps
2098     existing visible parentheses for other tuples and generator expressions.
2099     """
2100     check_lpar = False
2101     for child in list(node.children):
2102         if check_lpar:
2103             if child.type == syms.atom:
2104                 if not (
2105                     is_empty_tuple(child)
2106                     or is_one_tuple(child)
2107                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2108                 ):
2109                     first = child.children[0]
2110                     last = child.children[-1]
2111                     if first.type == token.LPAR and last.type == token.RPAR:
2112                         # make parentheses invisible
2113                         first.value = ""  # type: ignore
2114                         last.value = ""  # type: ignore
2115             elif is_one_tuple(child):
2116                 # wrap child in visible parentheses
2117                 lpar = Leaf(token.LPAR, "(")
2118                 rpar = Leaf(token.RPAR, ")")
2119                 index = child.remove() or 0
2120                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2121             else:
2122                 # wrap child in invisible parentheses
2123                 lpar = Leaf(token.LPAR, "")
2124                 rpar = Leaf(token.RPAR, "")
2125                 index = child.remove() or 0
2126                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2127
2128         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2129
2130
2131 def is_empty_tuple(node: LN) -> bool:
2132     """Return True if `node` holds an empty tuple."""
2133     return (
2134         node.type == syms.atom
2135         and len(node.children) == 2
2136         and node.children[0].type == token.LPAR
2137         and node.children[1].type == token.RPAR
2138     )
2139
2140
2141 def is_one_tuple(node: LN) -> bool:
2142     """Return True if `node` holds a tuple with one element, with or without parens."""
2143     if node.type == syms.atom:
2144         if len(node.children) != 3:
2145             return False
2146
2147         lpar, gexp, rpar = node.children
2148         if not (
2149             lpar.type == token.LPAR
2150             and gexp.type == syms.testlist_gexp
2151             and rpar.type == token.RPAR
2152         ):
2153             return False
2154
2155         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2156
2157     return (
2158         node.type in IMPLICIT_TUPLE
2159         and len(node.children) == 2
2160         and node.children[1].type == token.COMMA
2161     )
2162
2163
2164 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2165     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2166
2167     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2168     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2169     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2170     generalizations (PEP 448).
2171     """
2172     if leaf.type not in STARS or not leaf.parent:
2173         return False
2174
2175     p = leaf.parent
2176     if p.type == syms.star_expr:
2177         # Star expressions are also used as assignment targets in extended
2178         # iterable unpacking (PEP 3132).  See what its parent is instead.
2179         if not p.parent:
2180             return False
2181
2182         p = p.parent
2183
2184     return p.type in within
2185
2186
2187 def max_delimiter_priority_in_atom(node: LN) -> int:
2188     """Return maximum delimiter priority inside `node`.
2189
2190     This is specific to atoms with contents contained in a pair of parentheses.
2191     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2192     """
2193     if node.type != syms.atom:
2194         return 0
2195
2196     first = node.children[0]
2197     last = node.children[-1]
2198     if not (first.type == token.LPAR and last.type == token.RPAR):
2199         return 0
2200
2201     bt = BracketTracker()
2202     for c in node.children[1:-1]:
2203         if isinstance(c, Leaf):
2204             bt.mark(c)
2205         else:
2206             for leaf in c.leaves():
2207                 bt.mark(leaf)
2208     try:
2209         return bt.max_delimiter_priority()
2210
2211     except ValueError:
2212         return 0
2213
2214
2215 def ensure_visible(leaf: Leaf) -> None:
2216     """Make sure parentheses are visible.
2217
2218     They could be invisible as part of some statements (see
2219     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2220     """
2221     if leaf.type == token.LPAR:
2222         leaf.value = "("
2223     elif leaf.type == token.RPAR:
2224         leaf.value = ")"
2225
2226
2227 def is_python36(node: Node) -> bool:
2228     """Return True if the current file is using Python 3.6+ features.
2229
2230     Currently looking for:
2231     - f-strings; and
2232     - trailing commas after * or ** in function signatures.
2233     """
2234     for n in node.pre_order():
2235         if n.type == token.STRING:
2236             value_head = n.value[:2]  # type: ignore
2237             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2238                 return True
2239
2240         elif (
2241             n.type == syms.typedargslist
2242             and n.children
2243             and n.children[-1].type == token.COMMA
2244         ):
2245             for ch in n.children:
2246                 if ch.type in STARS:
2247                     return True
2248
2249     return False
2250
2251
2252 PYTHON_EXTENSIONS = {".py"}
2253 BLACKLISTED_DIRECTORIES = {
2254     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2255 }
2256
2257
2258 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2259     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2260     and have one of the PYTHON_EXTENSIONS.
2261     """
2262     for child in path.iterdir():
2263         if child.is_dir():
2264             if child.name in BLACKLISTED_DIRECTORIES:
2265                 continue
2266
2267             yield from gen_python_files_in_dir(child)
2268
2269         elif child.suffix in PYTHON_EXTENSIONS:
2270             yield child
2271
2272
2273 @dataclass
2274 class Report:
2275     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2276     check: bool = False
2277     quiet: bool = False
2278     change_count: int = 0
2279     same_count: int = 0
2280     failure_count: int = 0
2281
2282     def done(self, src: Path, changed: Changed) -> None:
2283         """Increment the counter for successful reformatting. Write out a message."""
2284         if changed is Changed.YES:
2285             reformatted = "would reformat" if self.check else "reformatted"
2286             if not self.quiet:
2287                 out(f"{reformatted} {src}")
2288             self.change_count += 1
2289         else:
2290             if not self.quiet:
2291                 if changed is Changed.NO:
2292                     msg = f"{src} already well formatted, good job."
2293                 else:
2294                     msg = f"{src} wasn't modified on disk since last run."
2295                 out(msg, bold=False)
2296             self.same_count += 1
2297
2298     def failed(self, src: Path, message: str) -> None:
2299         """Increment the counter for failed reformatting. Write out a message."""
2300         err(f"error: cannot format {src}: {message}")
2301         self.failure_count += 1
2302
2303     @property
2304     def return_code(self) -> int:
2305         """Return the exit code that the app should use.
2306
2307         This considers the current state of changed files and failures:
2308         - if there were any failures, return 123;
2309         - if any files were changed and --check is being used, return 1;
2310         - otherwise return 0.
2311         """
2312         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2313         # 126 we have special returncodes reserved by the shell.
2314         if self.failure_count:
2315             return 123
2316
2317         elif self.change_count and self.check:
2318             return 1
2319
2320         return 0
2321
2322     def __str__(self) -> str:
2323         """Render a color report of the current state.
2324
2325         Use `click.unstyle` to remove colors.
2326         """
2327         if self.check:
2328             reformatted = "would be reformatted"
2329             unchanged = "would be left unchanged"
2330             failed = "would fail to reformat"
2331         else:
2332             reformatted = "reformatted"
2333             unchanged = "left unchanged"
2334             failed = "failed to reformat"
2335         report = []
2336         if self.change_count:
2337             s = "s" if self.change_count > 1 else ""
2338             report.append(
2339                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2340             )
2341         if self.same_count:
2342             s = "s" if self.same_count > 1 else ""
2343             report.append(f"{self.same_count} file{s} {unchanged}")
2344         if self.failure_count:
2345             s = "s" if self.failure_count > 1 else ""
2346             report.append(
2347                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2348             )
2349         return ", ".join(report) + "."
2350
2351
2352 def assert_equivalent(src: str, dst: str) -> None:
2353     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2354
2355     import ast
2356     import traceback
2357
2358     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2359         """Simple visitor generating strings to compare ASTs by content."""
2360         yield f"{'  ' * depth}{node.__class__.__name__}("
2361
2362         for field in sorted(node._fields):
2363             try:
2364                 value = getattr(node, field)
2365             except AttributeError:
2366                 continue
2367
2368             yield f"{'  ' * (depth+1)}{field}="
2369
2370             if isinstance(value, list):
2371                 for item in value:
2372                     if isinstance(item, ast.AST):
2373                         yield from _v(item, depth + 2)
2374
2375             elif isinstance(value, ast.AST):
2376                 yield from _v(value, depth + 2)
2377
2378             else:
2379                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2380
2381         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2382
2383     try:
2384         src_ast = ast.parse(src)
2385     except Exception as exc:
2386         major, minor = sys.version_info[:2]
2387         raise AssertionError(
2388             f"cannot use --safe with this file; failed to parse source file "
2389             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2390             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2391         )
2392
2393     try:
2394         dst_ast = ast.parse(dst)
2395     except Exception as exc:
2396         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2397         raise AssertionError(
2398             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2399             f"Please report a bug on https://github.com/ambv/black/issues.  "
2400             f"This invalid output might be helpful: {log}"
2401         ) from None
2402
2403     src_ast_str = "\n".join(_v(src_ast))
2404     dst_ast_str = "\n".join(_v(dst_ast))
2405     if src_ast_str != dst_ast_str:
2406         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2407         raise AssertionError(
2408             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2409             f"the source.  "
2410             f"Please report a bug on https://github.com/ambv/black/issues.  "
2411             f"This diff might be helpful: {log}"
2412         ) from None
2413
2414
2415 def assert_stable(src: str, dst: str, line_length: int) -> None:
2416     """Raise AssertionError if `dst` reformats differently the second time."""
2417     newdst = format_str(dst, line_length=line_length)
2418     if dst != newdst:
2419         log = dump_to_file(
2420             diff(src, dst, "source", "first pass"),
2421             diff(dst, newdst, "first pass", "second pass"),
2422         )
2423         raise AssertionError(
2424             f"INTERNAL ERROR: Black produced different code on the second pass "
2425             f"of the formatter.  "
2426             f"Please report a bug on https://github.com/ambv/black/issues.  "
2427             f"This diff might be helpful: {log}"
2428         ) from None
2429
2430
2431 def dump_to_file(*output: str) -> str:
2432     """Dump `output` to a temporary file. Return path to the file."""
2433     import tempfile
2434
2435     with tempfile.NamedTemporaryFile(
2436         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2437     ) as f:
2438         for lines in output:
2439             f.write(lines)
2440             if lines and lines[-1] != "\n":
2441                 f.write("\n")
2442     return f.name
2443
2444
2445 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2446     """Return a unified diff string between strings `a` and `b`."""
2447     import difflib
2448
2449     a_lines = [line + "\n" for line in a.split("\n")]
2450     b_lines = [line + "\n" for line in b.split("\n")]
2451     return "".join(
2452         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2453     )
2454
2455
2456 def cancel(tasks: List[asyncio.Task]) -> None:
2457     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2458     err("Aborted!")
2459     for task in tasks:
2460         task.cancel()
2461
2462
2463 def shutdown(loop: BaseEventLoop) -> None:
2464     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2465     try:
2466         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2467         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2468         if not to_cancel:
2469             return
2470
2471         for task in to_cancel:
2472             task.cancel()
2473         loop.run_until_complete(
2474             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2475         )
2476     finally:
2477         # `concurrent.futures.Future` objects cannot be cancelled once they
2478         # are already running. There might be some when the `shutdown()` happened.
2479         # Silence their logger's spew about the event loop being closed.
2480         cf_logger = logging.getLogger("concurrent.futures")
2481         cf_logger.setLevel(logging.CRITICAL)
2482         loop.close()
2483
2484
2485 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2486     """Replace `regex` with `replacement` twice on `original`.
2487
2488     This is used by string normalization to perform replaces on
2489     overlapping matches.
2490     """
2491     return regex.sub(replacement, regex.sub(replacement, original))
2492
2493
2494 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2495
2496
2497 def get_cache_file(line_length: int) -> Path:
2498     return CACHE_DIR / f"cache.{line_length}.pickle"
2499
2500
2501 def read_cache(line_length: int) -> Cache:
2502     """Read the cache if it exists and is well formed.
2503
2504     If it is not well formed, the call to write_cache later should resolve the issue.
2505     """
2506     cache_file = get_cache_file(line_length)
2507     if not cache_file.exists():
2508         return {}
2509
2510     with cache_file.open("rb") as fobj:
2511         try:
2512             cache: Cache = pickle.load(fobj)
2513         except pickle.UnpicklingError:
2514             return {}
2515
2516     return cache
2517
2518
2519 def get_cache_info(path: Path) -> CacheInfo:
2520     """Return the information used to check if a file is already formatted or not."""
2521     stat = path.stat()
2522     return stat.st_mtime, stat.st_size
2523
2524
2525 def filter_cached(
2526     cache: Cache, sources: Iterable[Path]
2527 ) -> Tuple[List[Path], List[Path]]:
2528     """Split a list of paths into two.
2529
2530     The first list contains paths of files that modified on disk or are not in the
2531     cache. The other list contains paths to non-modified files.
2532     """
2533     todo, done = [], []
2534     for src in sources:
2535         src = src.resolve()
2536         if cache.get(src) != get_cache_info(src):
2537             todo.append(src)
2538         else:
2539             done.append(src)
2540     return todo, done
2541
2542
2543 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2544     """Update the cache file."""
2545     cache_file = get_cache_file(line_length)
2546     try:
2547         if not CACHE_DIR.exists():
2548             CACHE_DIR.mkdir(parents=True)
2549         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2550         with cache_file.open("wb") as fobj:
2551             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2552     except OSError:
2553         pass
2554
2555
2556 if __name__ == "__main__":
2557     main()