]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

dd2e2d19d8c8f52b7ced373e161b7ee7f9760dd2
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     report = Report(check=check, quiet=quiet)
195     if len(sources) == 0:
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache()
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back != WriteBack.DIFF and changed is not Changed.NO:
248                 write_cache(cache, [src])
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache()
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back != WriteBack.DIFF and formatted:
315         write_cache(cache, formatted)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src.name}  (original)"
345         dst_name = f"{src.name}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar_no_exec_statement,
432     pygram.python_grammar,
433 ]
434
435
436 def lib2to3_parse(src_txt: str) -> Node:
437     """Given a string with source, return the lib2to3 Node."""
438     grammar = pygram.python_grammar_no_print_statement
439     if src_txt[-1] != "\n":
440         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
441         src_txt += nl
442     for grammar in GRAMMARS:
443         drv = driver.Driver(grammar, pytree.convert)
444         try:
445             result = drv.parse_string(src_txt, True)
446             break
447
448         except ParseError as pe:
449             lineno, column = pe.context[1]
450             lines = src_txt.splitlines()
451             try:
452                 faulty_line = lines[lineno - 1]
453             except IndexError:
454                 faulty_line = "<line number missing in source>"
455             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
456     else:
457         raise exc from None
458
459     if isinstance(result, Leaf):
460         result = Node(syms.file_input, [result])
461     return result
462
463
464 def lib2to3_unparse(node: Node) -> str:
465     """Given a lib2to3 node, return its string representation."""
466     code = str(node)
467     return code
468
469
470 T = TypeVar("T")
471
472
473 class Visitor(Generic[T]):
474     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
475
476     def visit(self, node: LN) -> Iterator[T]:
477         """Main method to visit `node` and its children.
478
479         It tries to find a `visit_*()` method for the given `node.type`, like
480         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
481         If no dedicated `visit_*()` method is found, chooses `visit_default()`
482         instead.
483
484         Then yields objects of type `T` from the selected visitor.
485         """
486         if node.type < 256:
487             name = token.tok_name[node.type]
488         else:
489             name = type_repr(node.type)
490         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
491
492     def visit_default(self, node: LN) -> Iterator[T]:
493         """Default `visit_*()` implementation. Recurses to children of `node`."""
494         if isinstance(node, Node):
495             for child in node.children:
496                 yield from self.visit(child)
497
498
499 @dataclass
500 class DebugVisitor(Visitor[T]):
501     tree_depth: int = 0
502
503     def visit_default(self, node: LN) -> Iterator[T]:
504         indent = " " * (2 * self.tree_depth)
505         if isinstance(node, Node):
506             _type = type_repr(node.type)
507             out(f"{indent}{_type}", fg="yellow")
508             self.tree_depth += 1
509             for child in node.children:
510                 yield from self.visit(child)
511
512             self.tree_depth -= 1
513             out(f"{indent}/{_type}", fg="yellow", bold=False)
514         else:
515             _type = token.tok_name.get(node.type, str(node.type))
516             out(f"{indent}{_type}", fg="blue", nl=False)
517             if node.prefix:
518                 # We don't have to handle prefixes for `Node` objects since
519                 # that delegates to the first child anyway.
520                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
521             out(f" {node.value!r}", fg="blue", bold=False)
522
523     @classmethod
524     def show(cls, code: str) -> None:
525         """Pretty-print the lib2to3 AST of a given string of `code`.
526
527         Convenience method for debugging.
528         """
529         v: DebugVisitor[None] = DebugVisitor()
530         list(v.visit(lib2to3_parse(code)))
531
532
533 KEYWORDS = set(keyword.kwlist)
534 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
535 FLOW_CONTROL = {"return", "raise", "break", "continue"}
536 STATEMENT = {
537     syms.if_stmt,
538     syms.while_stmt,
539     syms.for_stmt,
540     syms.try_stmt,
541     syms.except_clause,
542     syms.with_stmt,
543     syms.funcdef,
544     syms.classdef,
545 }
546 STANDALONE_COMMENT = 153
547 LOGIC_OPERATORS = {"and", "or"}
548 COMPARATORS = {
549     token.LESS,
550     token.GREATER,
551     token.EQEQUAL,
552     token.NOTEQUAL,
553     token.LESSEQUAL,
554     token.GREATEREQUAL,
555 }
556 MATH_OPERATORS = {
557     token.PLUS,
558     token.MINUS,
559     token.STAR,
560     token.SLASH,
561     token.VBAR,
562     token.AMPER,
563     token.PERCENT,
564     token.CIRCUMFLEX,
565     token.TILDE,
566     token.LEFTSHIFT,
567     token.RIGHTSHIFT,
568     token.DOUBLESTAR,
569     token.DOUBLESLASH,
570 }
571 STARS = {token.STAR, token.DOUBLESTAR}
572 VARARGS_PARENTS = {
573     syms.arglist,
574     syms.argument,  # double star in arglist
575     syms.trailer,  # single argument to call
576     syms.typedargslist,
577     syms.varargslist,  # lambdas
578 }
579 UNPACKING_PARENTS = {
580     syms.atom,  # single element of a list or set literal
581     syms.dictsetmaker,
582     syms.listmaker,
583     syms.testlist_gexp,
584 }
585 COMPREHENSION_PRIORITY = 20
586 COMMA_PRIORITY = 10
587 LOGIC_PRIORITY = 5
588 STRING_PRIORITY = 4
589 COMPARATOR_PRIORITY = 3
590 MATH_PRIORITY = 1
591
592
593 @dataclass
594 class BracketTracker:
595     """Keeps track of brackets on a line."""
596
597     depth: int = 0
598     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
599     delimiters: Dict[LeafID, Priority] = Factory(dict)
600     previous: Optional[Leaf] = None
601     _for_loop_variable: bool = False
602     _lambda_arguments: bool = False
603
604     def mark(self, leaf: Leaf) -> None:
605         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
606
607         All leaves receive an int `bracket_depth` field that stores how deep
608         within brackets a given leaf is. 0 means there are no enclosing brackets
609         that started on this line.
610
611         If a leaf is itself a closing bracket, it receives an `opening_bracket`
612         field that it forms a pair with. This is a one-directional link to
613         avoid reference cycles.
614
615         If a leaf is a delimiter (a token on which Black can split the line if
616         needed) and it's on depth 0, its `id()` is stored in the tracker's
617         `delimiters` field.
618         """
619         if leaf.type == token.COMMENT:
620             return
621
622         self.maybe_decrement_after_for_loop_variable(leaf)
623         self.maybe_decrement_after_lambda_arguments(leaf)
624         if leaf.type in CLOSING_BRACKETS:
625             self.depth -= 1
626             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
627             leaf.opening_bracket = opening_bracket
628         leaf.bracket_depth = self.depth
629         if self.depth == 0:
630             delim = is_split_before_delimiter(leaf, self.previous)
631             if delim and self.previous is not None:
632                 self.delimiters[id(self.previous)] = delim
633             else:
634                 delim = is_split_after_delimiter(leaf, self.previous)
635                 if delim:
636                     self.delimiters[id(leaf)] = delim
637         if leaf.type in OPENING_BRACKETS:
638             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
639             self.depth += 1
640         self.previous = leaf
641         self.maybe_increment_lambda_arguments(leaf)
642         self.maybe_increment_for_loop_variable(leaf)
643
644     def any_open_brackets(self) -> bool:
645         """Return True if there is an yet unmatched open bracket on the line."""
646         return bool(self.bracket_match)
647
648     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
649         """Return the highest priority of a delimiter found on the line.
650
651         Values are consistent with what `is_split_*_delimiter()` return.
652         Raises ValueError on no delimiters.
653         """
654         return max(v for k, v in self.delimiters.items() if k not in exclude)
655
656     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
657         """In a for loop, or comprehension, the variables are often unpacks.
658
659         To avoid splitting on the comma in this situation, increase the depth of
660         tokens between `for` and `in`.
661         """
662         if leaf.type == token.NAME and leaf.value == "for":
663             self.depth += 1
664             self._for_loop_variable = True
665             return True
666
667         return False
668
669     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
670         """See `maybe_increment_for_loop_variable` above for explanation."""
671         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
672             self.depth -= 1
673             self._for_loop_variable = False
674             return True
675
676         return False
677
678     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
679         """In a lambda expression, there might be more than one argument.
680
681         To avoid splitting on the comma in this situation, increase the depth of
682         tokens between `lambda` and `:`.
683         """
684         if leaf.type == token.NAME and leaf.value == "lambda":
685             self.depth += 1
686             self._lambda_arguments = True
687             return True
688
689         return False
690
691     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
692         """See `maybe_increment_lambda_arguments` above for explanation."""
693         if self._lambda_arguments and leaf.type == token.COLON:
694             self.depth -= 1
695             self._lambda_arguments = False
696             return True
697
698         return False
699
700
701 @dataclass
702 class Line:
703     """Holds leaves and comments. Can be printed with `str(line)`."""
704
705     depth: int = 0
706     leaves: List[Leaf] = Factory(list)
707     comments: List[Tuple[Index, Leaf]] = Factory(list)
708     bracket_tracker: BracketTracker = Factory(BracketTracker)
709     inside_brackets: bool = False
710
711     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
712         """Add a new `leaf` to the end of the line.
713
714         Unless `preformatted` is True, the `leaf` will receive a new consistent
715         whitespace prefix and metadata applied by :class:`BracketTracker`.
716         Trailing commas are maybe removed, unpacked for loop variables are
717         demoted from being delimiters.
718
719         Inline comments are put aside.
720         """
721         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
722         if not has_value:
723             return
724
725         if self.leaves and not preformatted:
726             # Note: at this point leaf.prefix should be empty except for
727             # imports, for which we only preserve newlines.
728             leaf.prefix += whitespace(leaf)
729         if self.inside_brackets or not preformatted:
730             self.bracket_tracker.mark(leaf)
731             self.maybe_remove_trailing_comma(leaf)
732
733         if not self.append_comment(leaf):
734             self.leaves.append(leaf)
735
736     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
737         """Like :func:`append()` but disallow invalid standalone comment structure.
738
739         Raises ValueError when any `leaf` is appended after a standalone comment
740         or when a standalone comment is not the first leaf on the line.
741         """
742         if self.bracket_tracker.depth == 0:
743             if self.is_comment:
744                 raise ValueError("cannot append to standalone comments")
745
746             if self.leaves and leaf.type == STANDALONE_COMMENT:
747                 raise ValueError(
748                     "cannot append standalone comments to a populated line"
749                 )
750
751         self.append(leaf, preformatted=preformatted)
752
753     @property
754     def is_comment(self) -> bool:
755         """Is this line a standalone comment?"""
756         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
757
758     @property
759     def is_decorator(self) -> bool:
760         """Is this line a decorator?"""
761         return bool(self) and self.leaves[0].type == token.AT
762
763     @property
764     def is_import(self) -> bool:
765         """Is this an import line?"""
766         return bool(self) and is_import(self.leaves[0])
767
768     @property
769     def is_class(self) -> bool:
770         """Is this line a class definition?"""
771         return (
772             bool(self)
773             and self.leaves[0].type == token.NAME
774             and self.leaves[0].value == "class"
775         )
776
777     @property
778     def is_def(self) -> bool:
779         """Is this a function definition? (Also returns True for async defs.)"""
780         try:
781             first_leaf = self.leaves[0]
782         except IndexError:
783             return False
784
785         try:
786             second_leaf: Optional[Leaf] = self.leaves[1]
787         except IndexError:
788             second_leaf = None
789         return (
790             (first_leaf.type == token.NAME and first_leaf.value == "def")
791             or (
792                 first_leaf.type == token.ASYNC
793                 and second_leaf is not None
794                 and second_leaf.type == token.NAME
795                 and second_leaf.value == "def"
796             )
797         )
798
799     @property
800     def is_flow_control(self) -> bool:
801         """Is this line a flow control statement?
802
803         Those are `return`, `raise`, `break`, and `continue`.
804         """
805         return (
806             bool(self)
807             and self.leaves[0].type == token.NAME
808             and self.leaves[0].value in FLOW_CONTROL
809         )
810
811     @property
812     def is_yield(self) -> bool:
813         """Is this line a yield statement?"""
814         return (
815             bool(self)
816             and self.leaves[0].type == token.NAME
817             and self.leaves[0].value == "yield"
818         )
819
820     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
821         """If so, needs to be split before emitting."""
822         for leaf in self.leaves:
823             if leaf.type == STANDALONE_COMMENT:
824                 if leaf.bracket_depth <= depth_limit:
825                     return True
826
827         return False
828
829     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
830         """Remove trailing comma if there is one and it's safe."""
831         if not (
832             self.leaves
833             and self.leaves[-1].type == token.COMMA
834             and closing.type in CLOSING_BRACKETS
835         ):
836             return False
837
838         if closing.type == token.RBRACE:
839             self.remove_trailing_comma()
840             return True
841
842         if closing.type == token.RSQB:
843             comma = self.leaves[-1]
844             if comma.parent and comma.parent.type == syms.listmaker:
845                 self.remove_trailing_comma()
846                 return True
847
848         # For parens let's check if it's safe to remove the comma.  If the
849         # trailing one is the only one, we might mistakenly change a tuple
850         # into a different type by removing the comma.
851         depth = closing.bracket_depth + 1
852         commas = 0
853         opening = closing.opening_bracket
854         for _opening_index, leaf in enumerate(self.leaves):
855             if leaf is opening:
856                 break
857
858         else:
859             return False
860
861         for leaf in self.leaves[_opening_index + 1:]:
862             if leaf is closing:
863                 break
864
865             bracket_depth = leaf.bracket_depth
866             if bracket_depth == depth and leaf.type == token.COMMA:
867                 commas += 1
868                 if leaf.parent and leaf.parent.type == syms.arglist:
869                     commas += 1
870                     break
871
872         if commas > 1:
873             self.remove_trailing_comma()
874             return True
875
876         return False
877
878     def append_comment(self, comment: Leaf) -> bool:
879         """Add an inline or standalone comment to the line."""
880         if (
881             comment.type == STANDALONE_COMMENT
882             and self.bracket_tracker.any_open_brackets()
883         ):
884             comment.prefix = ""
885             return False
886
887         if comment.type != token.COMMENT:
888             return False
889
890         after = len(self.leaves) - 1
891         if after == -1:
892             comment.type = STANDALONE_COMMENT
893             comment.prefix = ""
894             return False
895
896         else:
897             self.comments.append((after, comment))
898             return True
899
900     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
901         """Generate comments that should appear directly after `leaf`."""
902         for _leaf_index, _leaf in enumerate(self.leaves):
903             if leaf is _leaf:
904                 break
905
906         else:
907             return
908
909         for index, comment_after in self.comments:
910             if _leaf_index == index:
911                 yield comment_after
912
913     def remove_trailing_comma(self) -> None:
914         """Remove the trailing comma and moves the comments attached to it."""
915         comma_index = len(self.leaves) - 1
916         for i in range(len(self.comments)):
917             comment_index, comment = self.comments[i]
918             if comment_index == comma_index:
919                 self.comments[i] = (comma_index - 1, comment)
920         self.leaves.pop()
921
922     def __str__(self) -> str:
923         """Render the line."""
924         if not self:
925             return "\n"
926
927         indent = "    " * self.depth
928         leaves = iter(self.leaves)
929         first = next(leaves)
930         res = f"{first.prefix}{indent}{first.value}"
931         for leaf in leaves:
932             res += str(leaf)
933         for _, comment in self.comments:
934             res += str(comment)
935         return res + "\n"
936
937     def __bool__(self) -> bool:
938         """Return True if the line has leaves or comments."""
939         return bool(self.leaves or self.comments)
940
941
942 class UnformattedLines(Line):
943     """Just like :class:`Line` but stores lines which aren't reformatted."""
944
945     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
946         """Just add a new `leaf` to the end of the lines.
947
948         The `preformatted` argument is ignored.
949
950         Keeps track of indentation `depth`, which is useful when the user
951         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
952         """
953         try:
954             list(generate_comments(leaf))
955         except FormatOn as f_on:
956             self.leaves.append(f_on.leaf_from_consumed(leaf))
957             raise
958
959         self.leaves.append(leaf)
960         if leaf.type == token.INDENT:
961             self.depth += 1
962         elif leaf.type == token.DEDENT:
963             self.depth -= 1
964
965     def __str__(self) -> str:
966         """Render unformatted lines from leaves which were added with `append()`.
967
968         `depth` is not used for indentation in this case.
969         """
970         if not self:
971             return "\n"
972
973         res = ""
974         for leaf in self.leaves:
975             res += str(leaf)
976         return res
977
978     def append_comment(self, comment: Leaf) -> bool:
979         """Not implemented in this class. Raises `NotImplementedError`."""
980         raise NotImplementedError("Unformatted lines don't store comments separately.")
981
982     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
983         """Does nothing and returns False."""
984         return False
985
986     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
987         """Does nothing and returns False."""
988         return False
989
990
991 @dataclass
992 class EmptyLineTracker:
993     """Provides a stateful method that returns the number of potential extra
994     empty lines needed before and after the currently processed line.
995
996     Note: this tracker works on lines that haven't been split yet.  It assumes
997     the prefix of the first leaf consists of optional newlines.  Those newlines
998     are consumed by `maybe_empty_lines()` and included in the computation.
999     """
1000     previous_line: Optional[Line] = None
1001     previous_after: int = 0
1002     previous_defs: List[int] = Factory(list)
1003
1004     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1005         """Return the number of extra empty lines before and after the `current_line`.
1006
1007         This is for separating `def`, `async def` and `class` with extra empty
1008         lines (two on module-level), as well as providing an extra empty line
1009         after flow control keywords to make them more prominent.
1010         """
1011         if isinstance(current_line, UnformattedLines):
1012             return 0, 0
1013
1014         before, after = self._maybe_empty_lines(current_line)
1015         before -= self.previous_after
1016         self.previous_after = after
1017         self.previous_line = current_line
1018         return before, after
1019
1020     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1021         max_allowed = 1
1022         if current_line.depth == 0:
1023             max_allowed = 2
1024         if current_line.leaves:
1025             # Consume the first leaf's extra newlines.
1026             first_leaf = current_line.leaves[0]
1027             before = first_leaf.prefix.count("\n")
1028             before = min(before, max_allowed)
1029             first_leaf.prefix = ""
1030         else:
1031             before = 0
1032         depth = current_line.depth
1033         while self.previous_defs and self.previous_defs[-1] >= depth:
1034             self.previous_defs.pop()
1035             before = 1 if depth else 2
1036         is_decorator = current_line.is_decorator
1037         if is_decorator or current_line.is_def or current_line.is_class:
1038             if not is_decorator:
1039                 self.previous_defs.append(depth)
1040             if self.previous_line is None:
1041                 # Don't insert empty lines before the first line in the file.
1042                 return 0, 0
1043
1044             if self.previous_line and self.previous_line.is_decorator:
1045                 # Don't insert empty lines between decorators.
1046                 return 0, 0
1047
1048             newlines = 2
1049             if current_line.depth:
1050                 newlines -= 1
1051             return newlines, 0
1052
1053         if current_line.is_flow_control:
1054             return before, 1
1055
1056         if (
1057             self.previous_line
1058             and self.previous_line.is_import
1059             and not current_line.is_import
1060             and depth == self.previous_line.depth
1061         ):
1062             return (before or 1), 0
1063
1064         if (
1065             self.previous_line
1066             and self.previous_line.is_yield
1067             and (not current_line.is_yield or depth != self.previous_line.depth)
1068         ):
1069             return (before or 1), 0
1070
1071         return before, 0
1072
1073
1074 @dataclass
1075 class LineGenerator(Visitor[Line]):
1076     """Generates reformatted Line objects.  Empty lines are not emitted.
1077
1078     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1079     in ways that will no longer stringify to valid Python code on the tree.
1080     """
1081     current_line: Line = Factory(Line)
1082
1083     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1084         """Generate a line.
1085
1086         If the line is empty, only emit if it makes sense.
1087         If the line is too long, split it first and then generate.
1088
1089         If any lines were generated, set up a new current_line.
1090         """
1091         if not self.current_line:
1092             if self.current_line.__class__ == type:
1093                 self.current_line.depth += indent
1094             else:
1095                 self.current_line = type(depth=self.current_line.depth + indent)
1096             return  # Line is empty, don't emit. Creating a new one unnecessary.
1097
1098         complete_line = self.current_line
1099         self.current_line = type(depth=complete_line.depth + indent)
1100         yield complete_line
1101
1102     def visit(self, node: LN) -> Iterator[Line]:
1103         """Main method to visit `node` and its children.
1104
1105         Yields :class:`Line` objects.
1106         """
1107         if isinstance(self.current_line, UnformattedLines):
1108             # File contained `# fmt: off`
1109             yield from self.visit_unformatted(node)
1110
1111         else:
1112             yield from super().visit(node)
1113
1114     def visit_default(self, node: LN) -> Iterator[Line]:
1115         """Default `visit_*()` implementation. Recurses to children of `node`."""
1116         if isinstance(node, Leaf):
1117             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1118             try:
1119                 for comment in generate_comments(node):
1120                     if any_open_brackets:
1121                         # any comment within brackets is subject to splitting
1122                         self.current_line.append(comment)
1123                     elif comment.type == token.COMMENT:
1124                         # regular trailing comment
1125                         self.current_line.append(comment)
1126                         yield from self.line()
1127
1128                     else:
1129                         # regular standalone comment
1130                         yield from self.line()
1131
1132                         self.current_line.append(comment)
1133                         yield from self.line()
1134
1135             except FormatOff as f_off:
1136                 f_off.trim_prefix(node)
1137                 yield from self.line(type=UnformattedLines)
1138                 yield from self.visit(node)
1139
1140             except FormatOn as f_on:
1141                 # This only happens here if somebody says "fmt: on" multiple
1142                 # times in a row.
1143                 f_on.trim_prefix(node)
1144                 yield from self.visit_default(node)
1145
1146             else:
1147                 normalize_prefix(node, inside_brackets=any_open_brackets)
1148                 if node.type == token.STRING:
1149                     normalize_string_quotes(node)
1150                 if node.type not in WHITESPACE:
1151                     self.current_line.append(node)
1152         yield from super().visit_default(node)
1153
1154     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1155         """Increase indentation level, maybe yield a line."""
1156         # In blib2to3 INDENT never holds comments.
1157         yield from self.line(+1)
1158         yield from self.visit_default(node)
1159
1160     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1161         """Decrease indentation level, maybe yield a line."""
1162         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1163         yield from self.line(-1)
1164
1165     def visit_stmt(
1166         self, node: Node, keywords: Set[str], parens: Set[str]
1167     ) -> Iterator[Line]:
1168         """Visit a statement.
1169
1170         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1171         `def`, `with`, `class`, and `assert`.
1172
1173         The relevant Python language `keywords` for a given statement will be
1174         NAME leaves within it. This methods puts those on a separate line.
1175
1176         `parens` holds pairs of nodes where invisible parentheses should be put.
1177         Keys hold nodes after which opening parentheses should be put, values
1178         hold nodes before which closing parentheses should be put.
1179         """
1180         normalize_invisible_parens(node, parens_after=parens)
1181         for child in node.children:
1182             if child.type == token.NAME and child.value in keywords:  # type: ignore
1183                 yield from self.line()
1184
1185             yield from self.visit(child)
1186
1187     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1188         """Visit a statement without nested statements."""
1189         is_suite_like = node.parent and node.parent.type in STATEMENT
1190         if is_suite_like:
1191             yield from self.line(+1)
1192             yield from self.visit_default(node)
1193             yield from self.line(-1)
1194
1195         else:
1196             yield from self.line()
1197             yield from self.visit_default(node)
1198
1199     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1200         """Visit `async def`, `async for`, `async with`."""
1201         yield from self.line()
1202
1203         children = iter(node.children)
1204         for child in children:
1205             yield from self.visit(child)
1206
1207             if child.type == token.ASYNC:
1208                 break
1209
1210         internal_stmt = next(children)
1211         for child in internal_stmt.children:
1212             yield from self.visit(child)
1213
1214     def visit_decorators(self, node: Node) -> Iterator[Line]:
1215         """Visit decorators."""
1216         for child in node.children:
1217             yield from self.line()
1218             yield from self.visit(child)
1219
1220     def visit_import_from(self, node: Node) -> Iterator[Line]:
1221         """Visit import_from and maybe put invisible parentheses.
1222
1223         This is separate from `visit_stmt` because import statements don't
1224         support arbitrary atoms and thus handling of parentheses is custom.
1225         """
1226         check_lpar = False
1227         for index, child in enumerate(node.children):
1228             if check_lpar:
1229                 if child.type == token.LPAR:
1230                     # make parentheses invisible
1231                     child.value = ""  # type: ignore
1232                     node.children[-1].value = ""  # type: ignore
1233                 else:
1234                     # insert invisible parentheses
1235                     node.insert_child(index, Leaf(token.LPAR, ""))
1236                     node.append_child(Leaf(token.RPAR, ""))
1237                 break
1238
1239             check_lpar = (
1240                 child.type == token.NAME and child.value == "import"  # type: ignore
1241             )
1242
1243         for child in node.children:
1244             yield from self.visit(child)
1245
1246     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1247         """Remove a semicolon and put the other statement on a separate line."""
1248         yield from self.line()
1249
1250     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1251         """End of file. Process outstanding comments and end with a newline."""
1252         yield from self.visit_default(leaf)
1253         yield from self.line()
1254
1255     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1256         """Used when file contained a `# fmt: off`."""
1257         if isinstance(node, Node):
1258             for child in node.children:
1259                 yield from self.visit(child)
1260
1261         else:
1262             try:
1263                 self.current_line.append(node)
1264             except FormatOn as f_on:
1265                 f_on.trim_prefix(node)
1266                 yield from self.line()
1267                 yield from self.visit(node)
1268
1269             if node.type == token.ENDMARKER:
1270                 # somebody decided not to put a final `# fmt: on`
1271                 yield from self.line()
1272
1273     def __attrs_post_init__(self) -> None:
1274         """You are in a twisty little maze of passages."""
1275         v = self.visit_stmt
1276         Ø: Set[str] = set()
1277         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1278         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1279         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1280         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1281         self.visit_try_stmt = partial(
1282             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1283         )
1284         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1285         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1286         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1287         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1288         self.visit_async_funcdef = self.visit_async_stmt
1289         self.visit_decorated = self.visit_decorators
1290
1291
1292 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1293 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1294 OPENING_BRACKETS = set(BRACKET.keys())
1295 CLOSING_BRACKETS = set(BRACKET.values())
1296 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1297 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1298
1299
1300 def whitespace(leaf: Leaf) -> str:  # noqa C901
1301     """Return whitespace prefix if needed for the given `leaf`."""
1302     NO = ""
1303     SPACE = " "
1304     DOUBLESPACE = "  "
1305     t = leaf.type
1306     p = leaf.parent
1307     v = leaf.value
1308     if t in ALWAYS_NO_SPACE:
1309         return NO
1310
1311     if t == token.COMMENT:
1312         return DOUBLESPACE
1313
1314     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1315     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1316         return NO
1317
1318     prev = leaf.prev_sibling
1319     if not prev:
1320         prevp = preceding_leaf(p)
1321         if not prevp or prevp.type in OPENING_BRACKETS:
1322             return NO
1323
1324         if t == token.COLON:
1325             return SPACE if prevp.type == token.COMMA else NO
1326
1327         if prevp.type == token.EQUAL:
1328             if prevp.parent:
1329                 if prevp.parent.type in {
1330                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1331                 }:
1332                     return NO
1333
1334                 elif prevp.parent.type == syms.typedargslist:
1335                     # A bit hacky: if the equal sign has whitespace, it means we
1336                     # previously found it's a typed argument.  So, we're using
1337                     # that, too.
1338                     return prevp.prefix
1339
1340         elif prevp.type in STARS:
1341             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1342                 return NO
1343
1344         elif prevp.type == token.COLON:
1345             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1346                 return NO
1347
1348         elif (
1349             prevp.parent
1350             and prevp.parent.type == syms.factor
1351             and prevp.type in MATH_OPERATORS
1352         ):
1353             return NO
1354
1355         elif (
1356             prevp.type == token.RIGHTSHIFT
1357             and prevp.parent
1358             and prevp.parent.type == syms.shift_expr
1359             and prevp.prev_sibling
1360             and prevp.prev_sibling.type == token.NAME
1361             and prevp.prev_sibling.value == "print"  # type: ignore
1362         ):
1363             # Python 2 print chevron
1364             return NO
1365
1366     elif prev.type in OPENING_BRACKETS:
1367         return NO
1368
1369     if p.type in {syms.parameters, syms.arglist}:
1370         # untyped function signatures or calls
1371         if not prev or prev.type != token.COMMA:
1372             return NO
1373
1374     elif p.type == syms.varargslist:
1375         # lambdas
1376         if prev and prev.type != token.COMMA:
1377             return NO
1378
1379     elif p.type == syms.typedargslist:
1380         # typed function signatures
1381         if not prev:
1382             return NO
1383
1384         if t == token.EQUAL:
1385             if prev.type != syms.tname:
1386                 return NO
1387
1388         elif prev.type == token.EQUAL:
1389             # A bit hacky: if the equal sign has whitespace, it means we
1390             # previously found it's a typed argument.  So, we're using that, too.
1391             return prev.prefix
1392
1393         elif prev.type != token.COMMA:
1394             return NO
1395
1396     elif p.type == syms.tname:
1397         # type names
1398         if not prev:
1399             prevp = preceding_leaf(p)
1400             if not prevp or prevp.type != token.COMMA:
1401                 return NO
1402
1403     elif p.type == syms.trailer:
1404         # attributes and calls
1405         if t == token.LPAR or t == token.RPAR:
1406             return NO
1407
1408         if not prev:
1409             if t == token.DOT:
1410                 prevp = preceding_leaf(p)
1411                 if not prevp or prevp.type != token.NUMBER:
1412                     return NO
1413
1414             elif t == token.LSQB:
1415                 return NO
1416
1417         elif prev.type != token.COMMA:
1418             return NO
1419
1420     elif p.type == syms.argument:
1421         # single argument
1422         if t == token.EQUAL:
1423             return NO
1424
1425         if not prev:
1426             prevp = preceding_leaf(p)
1427             if not prevp or prevp.type == token.LPAR:
1428                 return NO
1429
1430         elif prev.type in {token.EQUAL} | STARS:
1431             return NO
1432
1433     elif p.type == syms.decorator:
1434         # decorators
1435         return NO
1436
1437     elif p.type == syms.dotted_name:
1438         if prev:
1439             return NO
1440
1441         prevp = preceding_leaf(p)
1442         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1443             return NO
1444
1445     elif p.type == syms.classdef:
1446         if t == token.LPAR:
1447             return NO
1448
1449         if prev and prev.type == token.LPAR:
1450             return NO
1451
1452     elif p.type == syms.subscript:
1453         # indexing
1454         if not prev:
1455             assert p.parent is not None, "subscripts are always parented"
1456             if p.parent.type == syms.subscriptlist:
1457                 return SPACE
1458
1459             return NO
1460
1461         else:
1462             return NO
1463
1464     elif p.type == syms.atom:
1465         if prev and t == token.DOT:
1466             # dots, but not the first one.
1467             return NO
1468
1469     elif p.type == syms.dictsetmaker:
1470         # dict unpacking
1471         if prev and prev.type == token.DOUBLESTAR:
1472             return NO
1473
1474     elif p.type in {syms.factor, syms.star_expr}:
1475         # unary ops
1476         if not prev:
1477             prevp = preceding_leaf(p)
1478             if not prevp or prevp.type in OPENING_BRACKETS:
1479                 return NO
1480
1481             prevp_parent = prevp.parent
1482             assert prevp_parent is not None
1483             if (
1484                 prevp.type == token.COLON
1485                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1486             ):
1487                 return NO
1488
1489             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1490                 return NO
1491
1492         elif t == token.NAME or t == token.NUMBER:
1493             return NO
1494
1495     elif p.type == syms.import_from:
1496         if t == token.DOT:
1497             if prev and prev.type == token.DOT:
1498                 return NO
1499
1500         elif t == token.NAME:
1501             if v == "import":
1502                 return SPACE
1503
1504             if prev and prev.type == token.DOT:
1505                 return NO
1506
1507     elif p.type == syms.sliceop:
1508         return NO
1509
1510     return SPACE
1511
1512
1513 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1514     """Return the first leaf that precedes `node`, if any."""
1515     while node:
1516         res = node.prev_sibling
1517         if res:
1518             if isinstance(res, Leaf):
1519                 return res
1520
1521             try:
1522                 return list(res.leaves())[-1]
1523
1524             except IndexError:
1525                 return None
1526
1527         node = node.parent
1528     return None
1529
1530
1531 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1532     """Return the priority of the `leaf` delimiter, given a line break after it.
1533
1534     The delimiter priorities returned here are from those delimiters that would
1535     cause a line break after themselves.
1536
1537     Higher numbers are higher priority.
1538     """
1539     if leaf.type == token.COMMA:
1540         return COMMA_PRIORITY
1541
1542     return 0
1543
1544
1545 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1546     """Return the priority of the `leaf` delimiter, given a line before after it.
1547
1548     The delimiter priorities returned here are from those delimiters that would
1549     cause a line break before themselves.
1550
1551     Higher numbers are higher priority.
1552     """
1553     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1554         # * and ** might also be MATH_OPERATORS but in this case they are not.
1555         # Don't treat them as a delimiter.
1556         return 0
1557
1558     if (
1559         leaf.type in MATH_OPERATORS
1560         and leaf.parent
1561         and leaf.parent.type not in {syms.factor, syms.star_expr}
1562     ):
1563         return MATH_PRIORITY
1564
1565     if leaf.type in COMPARATORS:
1566         return COMPARATOR_PRIORITY
1567
1568     if (
1569         leaf.type == token.STRING
1570         and previous is not None
1571         and previous.type == token.STRING
1572     ):
1573         return STRING_PRIORITY
1574
1575     if (
1576         leaf.type == token.NAME
1577         and leaf.value == "for"
1578         and leaf.parent
1579         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1580     ):
1581         return COMPREHENSION_PRIORITY
1582
1583     if (
1584         leaf.type == token.NAME
1585         and leaf.value == "if"
1586         and leaf.parent
1587         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1588     ):
1589         return COMPREHENSION_PRIORITY
1590
1591     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1592         return LOGIC_PRIORITY
1593
1594     return 0
1595
1596
1597 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1598     """Clean the prefix of the `leaf` and generate comments from it, if any.
1599
1600     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1601     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1602     move because it does away with modifying the grammar to include all the
1603     possible places in which comments can be placed.
1604
1605     The sad consequence for us though is that comments don't "belong" anywhere.
1606     This is why this function generates simple parentless Leaf objects for
1607     comments.  We simply don't know what the correct parent should be.
1608
1609     No matter though, we can live without this.  We really only need to
1610     differentiate between inline and standalone comments.  The latter don't
1611     share the line with any code.
1612
1613     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1614     are emitted with a fake STANDALONE_COMMENT token identifier.
1615     """
1616     p = leaf.prefix
1617     if not p:
1618         return
1619
1620     if "#" not in p:
1621         return
1622
1623     consumed = 0
1624     nlines = 0
1625     for index, line in enumerate(p.split("\n")):
1626         consumed += len(line) + 1  # adding the length of the split '\n'
1627         line = line.lstrip()
1628         if not line:
1629             nlines += 1
1630         if not line.startswith("#"):
1631             continue
1632
1633         if index == 0 and leaf.type != token.ENDMARKER:
1634             comment_type = token.COMMENT  # simple trailing comment
1635         else:
1636             comment_type = STANDALONE_COMMENT
1637         comment = make_comment(line)
1638         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1639
1640         if comment in {"# fmt: on", "# yapf: enable"}:
1641             raise FormatOn(consumed)
1642
1643         if comment in {"# fmt: off", "# yapf: disable"}:
1644             if comment_type == STANDALONE_COMMENT:
1645                 raise FormatOff(consumed)
1646
1647             prev = preceding_leaf(leaf)
1648             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1649                 raise FormatOff(consumed)
1650
1651         nlines = 0
1652
1653
1654 def make_comment(content: str) -> str:
1655     """Return a consistently formatted comment from the given `content` string.
1656
1657     All comments (except for "##", "#!", "#:") should have a single space between
1658     the hash sign and the content.
1659
1660     If `content` didn't start with a hash sign, one is provided.
1661     """
1662     content = content.rstrip()
1663     if not content:
1664         return "#"
1665
1666     if content[0] == "#":
1667         content = content[1:]
1668     if content and content[0] not in " !:#":
1669         content = " " + content
1670     return "#" + content
1671
1672
1673 def split_line(
1674     line: Line, line_length: int, inner: bool = False, py36: bool = False
1675 ) -> Iterator[Line]:
1676     """Split a `line` into potentially many lines.
1677
1678     They should fit in the allotted `line_length` but might not be able to.
1679     `inner` signifies that there were a pair of brackets somewhere around the
1680     current `line`, possibly transitively. This means we can fallback to splitting
1681     by delimiters if the LHS/RHS don't yield any results.
1682
1683     If `py36` is True, splitting may generate syntax that is only compatible
1684     with Python 3.6 and later.
1685     """
1686     if isinstance(line, UnformattedLines) or line.is_comment:
1687         yield line
1688         return
1689
1690     line_str = str(line).strip("\n")
1691     if (
1692         len(line_str) <= line_length
1693         and "\n" not in line_str  # multiline strings
1694         and not line.contains_standalone_comments()
1695     ):
1696         yield line
1697         return
1698
1699     split_funcs: List[SplitFunc]
1700     if line.is_def:
1701         split_funcs = [left_hand_split]
1702     elif line.inside_brackets:
1703         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1704     else:
1705         split_funcs = [right_hand_split]
1706     for split_func in split_funcs:
1707         # We are accumulating lines in `result` because we might want to abort
1708         # mission and return the original line in the end, or attempt a different
1709         # split altogether.
1710         result: List[Line] = []
1711         try:
1712             for l in split_func(line, py36):
1713                 if str(l).strip("\n") == line_str:
1714                     raise CannotSplit("Split function returned an unchanged result")
1715
1716                 result.extend(
1717                     split_line(l, line_length=line_length, inner=True, py36=py36)
1718                 )
1719         except CannotSplit as cs:
1720             continue
1721
1722         else:
1723             yield from result
1724             break
1725
1726     else:
1727         yield line
1728
1729
1730 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1731     """Split line into many lines, starting with the first matching bracket pair.
1732
1733     Note: this usually looks weird, only use this for function definitions.
1734     Prefer RHS otherwise.
1735     """
1736     head = Line(depth=line.depth)
1737     body = Line(depth=line.depth + 1, inside_brackets=True)
1738     tail = Line(depth=line.depth)
1739     tail_leaves: List[Leaf] = []
1740     body_leaves: List[Leaf] = []
1741     head_leaves: List[Leaf] = []
1742     current_leaves = head_leaves
1743     matching_bracket = None
1744     for leaf in line.leaves:
1745         if (
1746             current_leaves is body_leaves
1747             and leaf.type in CLOSING_BRACKETS
1748             and leaf.opening_bracket is matching_bracket
1749         ):
1750             current_leaves = tail_leaves if body_leaves else head_leaves
1751         current_leaves.append(leaf)
1752         if current_leaves is head_leaves:
1753             if leaf.type in OPENING_BRACKETS:
1754                 matching_bracket = leaf
1755                 current_leaves = body_leaves
1756     # Since body is a new indent level, remove spurious leading whitespace.
1757     if body_leaves:
1758         normalize_prefix(body_leaves[0], inside_brackets=True)
1759     # Build the new lines.
1760     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1761         for leaf in leaves:
1762             result.append(leaf, preformatted=True)
1763             for comment_after in line.comments_after(leaf):
1764                 result.append(comment_after, preformatted=True)
1765     bracket_split_succeeded_or_raise(head, body, tail)
1766     for result in (head, body, tail):
1767         if result:
1768             yield result
1769
1770
1771 def right_hand_split(
1772     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1773 ) -> Iterator[Line]:
1774     """Split line into many lines, starting with the last matching bracket pair."""
1775     head = Line(depth=line.depth)
1776     body = Line(depth=line.depth + 1, inside_brackets=True)
1777     tail = Line(depth=line.depth)
1778     tail_leaves: List[Leaf] = []
1779     body_leaves: List[Leaf] = []
1780     head_leaves: List[Leaf] = []
1781     current_leaves = tail_leaves
1782     opening_bracket = None
1783     closing_bracket = None
1784     for leaf in reversed(line.leaves):
1785         if current_leaves is body_leaves:
1786             if leaf is opening_bracket:
1787                 current_leaves = head_leaves if body_leaves else tail_leaves
1788         current_leaves.append(leaf)
1789         if current_leaves is tail_leaves:
1790             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1791                 opening_bracket = leaf.opening_bracket
1792                 closing_bracket = leaf
1793                 current_leaves = body_leaves
1794     tail_leaves.reverse()
1795     body_leaves.reverse()
1796     head_leaves.reverse()
1797     # Since body is a new indent level, remove spurious leading whitespace.
1798     if body_leaves:
1799         normalize_prefix(body_leaves[0], inside_brackets=True)
1800     elif not head_leaves:
1801         # No `head` and no `body` means the split failed. `tail` has all content.
1802         raise CannotSplit("No brackets found")
1803
1804     # Build the new lines.
1805     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1806         for leaf in leaves:
1807             result.append(leaf, preformatted=True)
1808             for comment_after in line.comments_after(leaf):
1809                 result.append(comment_after, preformatted=True)
1810     bracket_split_succeeded_or_raise(head, body, tail)
1811     assert opening_bracket and closing_bracket
1812     if (
1813         opening_bracket.type == token.LPAR
1814         and not opening_bracket.value
1815         and closing_bracket.type == token.RPAR
1816         and not closing_bracket.value
1817     ):
1818         # These parens were optional. If there aren't any delimiters or standalone
1819         # comments in the body, they were unnecessary and another split without
1820         # them should be attempted.
1821         if not (
1822             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1823         ):
1824             omit = {id(closing_bracket), *omit}
1825             yield from right_hand_split(line, py36=py36, omit=omit)
1826             return
1827
1828     ensure_visible(opening_bracket)
1829     ensure_visible(closing_bracket)
1830     for result in (head, body, tail):
1831         if result:
1832             yield result
1833
1834
1835 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1836     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1837
1838     Do nothing otherwise.
1839
1840     A left- or right-hand split is based on a pair of brackets. Content before
1841     (and including) the opening bracket is left on one line, content inside the
1842     brackets is put on a separate line, and finally content starting with and
1843     following the closing bracket is put on a separate line.
1844
1845     Those are called `head`, `body`, and `tail`, respectively. If the split
1846     produced the same line (all content in `head`) or ended up with an empty `body`
1847     and the `tail` is just the closing bracket, then it's considered failed.
1848     """
1849     tail_len = len(str(tail).strip())
1850     if not body:
1851         if tail_len == 0:
1852             raise CannotSplit("Splitting brackets produced the same line")
1853
1854         elif tail_len < 3:
1855             raise CannotSplit(
1856                 f"Splitting brackets on an empty body to save "
1857                 f"{tail_len} characters is not worth it"
1858             )
1859
1860
1861 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1862     """Normalize prefix of the first leaf in every line returned by `split_func`.
1863
1864     This is a decorator over relevant split functions.
1865     """
1866
1867     @wraps(split_func)
1868     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1869         for l in split_func(line, py36):
1870             normalize_prefix(l.leaves[0], inside_brackets=True)
1871             yield l
1872
1873     return split_wrapper
1874
1875
1876 @dont_increase_indentation
1877 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1878     """Split according to delimiters of the highest priority.
1879
1880     If `py36` is True, the split will add trailing commas also in function
1881     signatures that contain `*` and `**`.
1882     """
1883     try:
1884         last_leaf = line.leaves[-1]
1885     except IndexError:
1886         raise CannotSplit("Line empty")
1887
1888     delimiters = line.bracket_tracker.delimiters
1889     try:
1890         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1891             exclude={id(last_leaf)}
1892         )
1893     except ValueError:
1894         raise CannotSplit("No delimiters found")
1895
1896     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1897     lowest_depth = sys.maxsize
1898     trailing_comma_safe = True
1899
1900     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1901         """Append `leaf` to current line or to new line if appending impossible."""
1902         nonlocal current_line
1903         try:
1904             current_line.append_safe(leaf, preformatted=True)
1905         except ValueError as ve:
1906             yield current_line
1907
1908             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1909             current_line.append(leaf)
1910
1911     for leaf in line.leaves:
1912         yield from append_to_line(leaf)
1913
1914         for comment_after in line.comments_after(leaf):
1915             yield from append_to_line(comment_after)
1916
1917         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1918         if (
1919             leaf.bracket_depth == lowest_depth
1920             and is_vararg(leaf, within=VARARGS_PARENTS)
1921         ):
1922             trailing_comma_safe = trailing_comma_safe and py36
1923         leaf_priority = delimiters.get(id(leaf))
1924         if leaf_priority == delimiter_priority:
1925             yield current_line
1926
1927             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1928     if current_line:
1929         if (
1930             trailing_comma_safe
1931             and delimiter_priority == COMMA_PRIORITY
1932             and current_line.leaves[-1].type != token.COMMA
1933             and current_line.leaves[-1].type != STANDALONE_COMMENT
1934         ):
1935             current_line.append(Leaf(token.COMMA, ","))
1936         yield current_line
1937
1938
1939 @dont_increase_indentation
1940 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1941     """Split standalone comments from the rest of the line."""
1942     if not line.contains_standalone_comments(0):
1943         raise CannotSplit("Line does not have any standalone comments")
1944
1945     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1946
1947     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1948         """Append `leaf` to current line or to new line if appending impossible."""
1949         nonlocal current_line
1950         try:
1951             current_line.append_safe(leaf, preformatted=True)
1952         except ValueError as ve:
1953             yield current_line
1954
1955             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1956             current_line.append(leaf)
1957
1958     for leaf in line.leaves:
1959         yield from append_to_line(leaf)
1960
1961         for comment_after in line.comments_after(leaf):
1962             yield from append_to_line(comment_after)
1963
1964     if current_line:
1965         yield current_line
1966
1967
1968 def is_import(leaf: Leaf) -> bool:
1969     """Return True if the given leaf starts an import statement."""
1970     p = leaf.parent
1971     t = leaf.type
1972     v = leaf.value
1973     return bool(
1974         t == token.NAME
1975         and (
1976             (v == "import" and p and p.type == syms.import_name)
1977             or (v == "from" and p and p.type == syms.import_from)
1978         )
1979     )
1980
1981
1982 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1983     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1984     else.
1985
1986     Note: don't use backslashes for formatting or you'll lose your voting rights.
1987     """
1988     if not inside_brackets:
1989         spl = leaf.prefix.split("#")
1990         if "\\" not in spl[0]:
1991             nl_count = spl[-1].count("\n")
1992             if len(spl) > 1:
1993                 nl_count -= 1
1994             leaf.prefix = "\n" * nl_count
1995             return
1996
1997     leaf.prefix = ""
1998
1999
2000 def normalize_string_quotes(leaf: Leaf) -> None:
2001     """Prefer double quotes but only if it doesn't cause more escaping.
2002
2003     Adds or removes backslashes as appropriate. Doesn't parse and fix
2004     strings nested in f-strings (yet).
2005
2006     Note: Mutates its argument.
2007     """
2008     value = leaf.value.lstrip("furbFURB")
2009     if value[:3] == '"""':
2010         return
2011
2012     elif value[:3] == "'''":
2013         orig_quote = "'''"
2014         new_quote = '"""'
2015     elif value[0] == '"':
2016         orig_quote = '"'
2017         new_quote = "'"
2018     else:
2019         orig_quote = "'"
2020         new_quote = '"'
2021     first_quote_pos = leaf.value.find(orig_quote)
2022     if first_quote_pos == -1:
2023         return  # There's an internal error
2024
2025     prefix = leaf.value[:first_quote_pos]
2026     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2027     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2028     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2029     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2030     if "r" in prefix.casefold():
2031         if unescaped_new_quote.search(body):
2032             # There's at least one unescaped new_quote in this raw string
2033             # so converting is impossible
2034             return
2035
2036         # Do not introduce or remove backslashes in raw strings
2037         new_body = body
2038     else:
2039         # remove unnecessary quotes
2040         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2041         if body != new_body:
2042             # Consider the string without unnecessary quotes as the original
2043             body = new_body
2044             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2045         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2046         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2047     if new_quote == '"""' and new_body[-1] == '"':
2048         # edge case:
2049         new_body = new_body[:-1] + '\\"'
2050     orig_escape_count = body.count("\\")
2051     new_escape_count = new_body.count("\\")
2052     if new_escape_count > orig_escape_count:
2053         return  # Do not introduce more escaping
2054
2055     if new_escape_count == orig_escape_count and orig_quote == '"':
2056         return  # Prefer double quotes
2057
2058     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2059
2060
2061 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2062     """Make existing optional parentheses invisible or create new ones.
2063
2064     Standardizes on visible parentheses for single-element tuples, and keeps
2065     existing visible parentheses for other tuples and generator expressions.
2066     """
2067     check_lpar = False
2068     for child in list(node.children):
2069         if check_lpar:
2070             if child.type == syms.atom:
2071                 if not (
2072                     is_empty_tuple(child)
2073                     or is_one_tuple(child)
2074                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2075                 ):
2076                     first = child.children[0]
2077                     last = child.children[-1]
2078                     if first.type == token.LPAR and last.type == token.RPAR:
2079                         # make parentheses invisible
2080                         first.value = ""  # type: ignore
2081                         last.value = ""  # type: ignore
2082             elif is_one_tuple(child):
2083                 # wrap child in visible parentheses
2084                 lpar = Leaf(token.LPAR, "(")
2085                 rpar = Leaf(token.RPAR, ")")
2086                 index = child.remove() or 0
2087                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2088             else:
2089                 # wrap child in invisible parentheses
2090                 lpar = Leaf(token.LPAR, "")
2091                 rpar = Leaf(token.RPAR, "")
2092                 index = child.remove() or 0
2093                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2094
2095         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2096
2097
2098 def is_empty_tuple(node: LN) -> bool:
2099     """Return True if `node` holds an empty tuple."""
2100     return (
2101         node.type == syms.atom
2102         and len(node.children) == 2
2103         and node.children[0].type == token.LPAR
2104         and node.children[1].type == token.RPAR
2105     )
2106
2107
2108 def is_one_tuple(node: LN) -> bool:
2109     """Return True if `node` holds a tuple with one element, with or without parens."""
2110     if node.type == syms.atom:
2111         if len(node.children) != 3:
2112             return False
2113
2114         lpar, gexp, rpar = node.children
2115         if not (
2116             lpar.type == token.LPAR
2117             and gexp.type == syms.testlist_gexp
2118             and rpar.type == token.RPAR
2119         ):
2120             return False
2121
2122         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2123
2124     return (
2125         node.type in IMPLICIT_TUPLE
2126         and len(node.children) == 2
2127         and node.children[1].type == token.COMMA
2128     )
2129
2130
2131 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2132     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2133
2134     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2135     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2136     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2137     generalizations (PEP 448).
2138     """
2139     if leaf.type not in STARS or not leaf.parent:
2140         return False
2141
2142     p = leaf.parent
2143     if p.type == syms.star_expr:
2144         # Star expressions are also used as assignment targets in extended
2145         # iterable unpacking (PEP 3132).  See what its parent is instead.
2146         if not p.parent:
2147             return False
2148
2149         p = p.parent
2150
2151     return p.type in within
2152
2153
2154 def max_delimiter_priority_in_atom(node: LN) -> int:
2155     """Return maximum delimiter priority inside `node`.
2156
2157     This is specific to atoms with contents contained in a pair of parentheses.
2158     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2159     """
2160     if node.type != syms.atom:
2161         return 0
2162
2163     first = node.children[0]
2164     last = node.children[-1]
2165     if not (first.type == token.LPAR and last.type == token.RPAR):
2166         return 0
2167
2168     bt = BracketTracker()
2169     for c in node.children[1:-1]:
2170         if isinstance(c, Leaf):
2171             bt.mark(c)
2172         else:
2173             for leaf in c.leaves():
2174                 bt.mark(leaf)
2175     try:
2176         return bt.max_delimiter_priority()
2177
2178     except ValueError:
2179         return 0
2180
2181
2182 def ensure_visible(leaf: Leaf) -> None:
2183     """Make sure parentheses are visible.
2184
2185     They could be invisible as part of some statements (see
2186     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2187     """
2188     if leaf.type == token.LPAR:
2189         leaf.value = "("
2190     elif leaf.type == token.RPAR:
2191         leaf.value = ")"
2192
2193
2194 def is_python36(node: Node) -> bool:
2195     """Return True if the current file is using Python 3.6+ features.
2196
2197     Currently looking for:
2198     - f-strings; and
2199     - trailing commas after * or ** in function signatures.
2200     """
2201     for n in node.pre_order():
2202         if n.type == token.STRING:
2203             value_head = n.value[:2]  # type: ignore
2204             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2205                 return True
2206
2207         elif (
2208             n.type == syms.typedargslist
2209             and n.children
2210             and n.children[-1].type == token.COMMA
2211         ):
2212             for ch in n.children:
2213                 if ch.type in STARS:
2214                     return True
2215
2216     return False
2217
2218
2219 PYTHON_EXTENSIONS = {".py"}
2220 BLACKLISTED_DIRECTORIES = {
2221     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2222 }
2223
2224
2225 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2226     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2227     and have one of the PYTHON_EXTENSIONS.
2228     """
2229     for child in path.iterdir():
2230         if child.is_dir():
2231             if child.name in BLACKLISTED_DIRECTORIES:
2232                 continue
2233
2234             yield from gen_python_files_in_dir(child)
2235
2236         elif child.suffix in PYTHON_EXTENSIONS:
2237             yield child
2238
2239
2240 @dataclass
2241 class Report:
2242     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2243     check: bool = False
2244     quiet: bool = False
2245     change_count: int = 0
2246     same_count: int = 0
2247     failure_count: int = 0
2248
2249     def done(self, src: Path, changed: Changed) -> None:
2250         """Increment the counter for successful reformatting. Write out a message."""
2251         if changed is Changed.YES:
2252             reformatted = "would reformat" if self.check else "reformatted"
2253             if not self.quiet:
2254                 out(f"{reformatted} {src}")
2255             self.change_count += 1
2256         else:
2257             if not self.quiet:
2258                 if changed is Changed.NO:
2259                     msg = f"{src} already well formatted, good job."
2260                 else:
2261                     msg = f"{src} wasn't modified on disk since last run."
2262                 out(msg, bold=False)
2263             self.same_count += 1
2264
2265     def failed(self, src: Path, message: str) -> None:
2266         """Increment the counter for failed reformatting. Write out a message."""
2267         err(f"error: cannot format {src}: {message}")
2268         self.failure_count += 1
2269
2270     @property
2271     def return_code(self) -> int:
2272         """Return the exit code that the app should use.
2273
2274         This considers the current state of changed files and failures:
2275         - if there were any failures, return 123;
2276         - if any files were changed and --check is being used, return 1;
2277         - otherwise return 0.
2278         """
2279         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2280         # 126 we have special returncodes reserved by the shell.
2281         if self.failure_count:
2282             return 123
2283
2284         elif self.change_count and self.check:
2285             return 1
2286
2287         return 0
2288
2289     def __str__(self) -> str:
2290         """Render a color report of the current state.
2291
2292         Use `click.unstyle` to remove colors.
2293         """
2294         if self.check:
2295             reformatted = "would be reformatted"
2296             unchanged = "would be left unchanged"
2297             failed = "would fail to reformat"
2298         else:
2299             reformatted = "reformatted"
2300             unchanged = "left unchanged"
2301             failed = "failed to reformat"
2302         report = []
2303         if self.change_count:
2304             s = "s" if self.change_count > 1 else ""
2305             report.append(
2306                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2307             )
2308         if self.same_count:
2309             s = "s" if self.same_count > 1 else ""
2310             report.append(f"{self.same_count} file{s} {unchanged}")
2311         if self.failure_count:
2312             s = "s" if self.failure_count > 1 else ""
2313             report.append(
2314                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2315             )
2316         return ", ".join(report) + "."
2317
2318
2319 def assert_equivalent(src: str, dst: str) -> None:
2320     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2321
2322     import ast
2323     import traceback
2324
2325     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2326         """Simple visitor generating strings to compare ASTs by content."""
2327         yield f"{'  ' * depth}{node.__class__.__name__}("
2328
2329         for field in sorted(node._fields):
2330             try:
2331                 value = getattr(node, field)
2332             except AttributeError:
2333                 continue
2334
2335             yield f"{'  ' * (depth+1)}{field}="
2336
2337             if isinstance(value, list):
2338                 for item in value:
2339                     if isinstance(item, ast.AST):
2340                         yield from _v(item, depth + 2)
2341
2342             elif isinstance(value, ast.AST):
2343                 yield from _v(value, depth + 2)
2344
2345             else:
2346                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2347
2348         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2349
2350     try:
2351         src_ast = ast.parse(src)
2352     except Exception as exc:
2353         major, minor = sys.version_info[:2]
2354         raise AssertionError(
2355             f"cannot use --safe with this file; failed to parse source file "
2356             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2357             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2358         )
2359
2360     try:
2361         dst_ast = ast.parse(dst)
2362     except Exception as exc:
2363         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2364         raise AssertionError(
2365             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2366             f"Please report a bug on https://github.com/ambv/black/issues.  "
2367             f"This invalid output might be helpful: {log}"
2368         ) from None
2369
2370     src_ast_str = "\n".join(_v(src_ast))
2371     dst_ast_str = "\n".join(_v(dst_ast))
2372     if src_ast_str != dst_ast_str:
2373         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2374         raise AssertionError(
2375             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2376             f"the source.  "
2377             f"Please report a bug on https://github.com/ambv/black/issues.  "
2378             f"This diff might be helpful: {log}"
2379         ) from None
2380
2381
2382 def assert_stable(src: str, dst: str, line_length: int) -> None:
2383     """Raise AssertionError if `dst` reformats differently the second time."""
2384     newdst = format_str(dst, line_length=line_length)
2385     if dst != newdst:
2386         log = dump_to_file(
2387             diff(src, dst, "source", "first pass"),
2388             diff(dst, newdst, "first pass", "second pass"),
2389         )
2390         raise AssertionError(
2391             f"INTERNAL ERROR: Black produced different code on the second pass "
2392             f"of the formatter.  "
2393             f"Please report a bug on https://github.com/ambv/black/issues.  "
2394             f"This diff might be helpful: {log}"
2395         ) from None
2396
2397
2398 def dump_to_file(*output: str) -> str:
2399     """Dump `output` to a temporary file. Return path to the file."""
2400     import tempfile
2401
2402     with tempfile.NamedTemporaryFile(
2403         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2404     ) as f:
2405         for lines in output:
2406             f.write(lines)
2407             if lines and lines[-1] != "\n":
2408                 f.write("\n")
2409     return f.name
2410
2411
2412 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2413     """Return a unified diff string between strings `a` and `b`."""
2414     import difflib
2415
2416     a_lines = [line + "\n" for line in a.split("\n")]
2417     b_lines = [line + "\n" for line in b.split("\n")]
2418     return "".join(
2419         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2420     )
2421
2422
2423 def cancel(tasks: List[asyncio.Task]) -> None:
2424     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2425     err("Aborted!")
2426     for task in tasks:
2427         task.cancel()
2428
2429
2430 def shutdown(loop: BaseEventLoop) -> None:
2431     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2432     try:
2433         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2434         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2435         if not to_cancel:
2436             return
2437
2438         for task in to_cancel:
2439             task.cancel()
2440         loop.run_until_complete(
2441             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2442         )
2443     finally:
2444         # `concurrent.futures.Future` objects cannot be cancelled once they
2445         # are already running. There might be some when the `shutdown()` happened.
2446         # Silence their logger's spew about the event loop being closed.
2447         cf_logger = logging.getLogger("concurrent.futures")
2448         cf_logger.setLevel(logging.CRITICAL)
2449         loop.close()
2450
2451
2452 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2453     """Replace `regex` with `replacement` twice on `original`.
2454
2455     This is used by string normalization to perform replaces on
2456     overlapping matches.
2457     """
2458     return regex.sub(replacement, regex.sub(replacement, original))
2459
2460
2461 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2462 CACHE_FILE = CACHE_DIR / "cache.pickle"
2463
2464
2465 def read_cache() -> Cache:
2466     """Read the cache if it exists and is well formed.
2467
2468     If it is not well formed, the call to write_cache later should resolve the issue.
2469     """
2470     if not CACHE_FILE.exists():
2471         return {}
2472
2473     with CACHE_FILE.open("rb") as fobj:
2474         try:
2475             cache: Cache = pickle.load(fobj)
2476         except pickle.UnpicklingError:
2477             return {}
2478
2479     return cache
2480
2481
2482 def get_cache_info(path: Path) -> CacheInfo:
2483     """Return the information used to check if a file is already formatted or not."""
2484     stat = path.stat()
2485     return stat.st_mtime, stat.st_size
2486
2487
2488 def filter_cached(
2489     cache: Cache, sources: Iterable[Path]
2490 ) -> Tuple[List[Path], List[Path]]:
2491     """Split a list of paths into two.
2492
2493     The first list contains paths of files that modified on disk or are not in the
2494     cache. The other list contains paths to non-modified files.
2495     """
2496     todo, done = [], []
2497     for src in sources:
2498         src = src.resolve()
2499         if cache.get(src) != get_cache_info(src):
2500             todo.append(src)
2501         else:
2502             done.append(src)
2503     return todo, done
2504
2505
2506 def write_cache(cache: Cache, sources: List[Path]) -> None:
2507     """Update the cache file."""
2508     try:
2509         if not CACHE_DIR.exists():
2510             CACHE_DIR.mkdir(parents=True)
2511         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2512         with CACHE_FILE.open("wb") as fobj:
2513             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2514     except OSError:
2515         pass
2516
2517
2518 if __name__ == "__main__":
2519     main()