]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Split ternary expressions
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     report = Report(check=check, quiet=quiet)
195     if len(sources) == 0:
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back != WriteBack.DIFF and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back != WriteBack.DIFF and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 COMPREHENSION_PRIORITY = 20
585 COMMA_PRIORITY = 10
586 TERNARY_PRIORITY = 7
587 LOGIC_PRIORITY = 5
588 STRING_PRIORITY = 4
589 COMPARATOR_PRIORITY = 3
590 MATH_PRIORITY = 1
591
592
593 @dataclass
594 class BracketTracker:
595     """Keeps track of brackets on a line."""
596
597     depth: int = 0
598     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
599     delimiters: Dict[LeafID, Priority] = Factory(dict)
600     previous: Optional[Leaf] = None
601     _for_loop_variable: bool = False
602     _lambda_arguments: bool = False
603
604     def mark(self, leaf: Leaf) -> None:
605         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
606
607         All leaves receive an int `bracket_depth` field that stores how deep
608         within brackets a given leaf is. 0 means there are no enclosing brackets
609         that started on this line.
610
611         If a leaf is itself a closing bracket, it receives an `opening_bracket`
612         field that it forms a pair with. This is a one-directional link to
613         avoid reference cycles.
614
615         If a leaf is a delimiter (a token on which Black can split the line if
616         needed) and it's on depth 0, its `id()` is stored in the tracker's
617         `delimiters` field.
618         """
619         if leaf.type == token.COMMENT:
620             return
621
622         self.maybe_decrement_after_for_loop_variable(leaf)
623         self.maybe_decrement_after_lambda_arguments(leaf)
624         if leaf.type in CLOSING_BRACKETS:
625             self.depth -= 1
626             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
627             leaf.opening_bracket = opening_bracket
628         leaf.bracket_depth = self.depth
629         if self.depth == 0:
630             delim = is_split_before_delimiter(leaf, self.previous)
631             if delim and self.previous is not None:
632                 self.delimiters[id(self.previous)] = delim
633             else:
634                 delim = is_split_after_delimiter(leaf, self.previous)
635                 if delim:
636                     self.delimiters[id(leaf)] = delim
637         if leaf.type in OPENING_BRACKETS:
638             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
639             self.depth += 1
640         self.previous = leaf
641         self.maybe_increment_lambda_arguments(leaf)
642         self.maybe_increment_for_loop_variable(leaf)
643
644     def any_open_brackets(self) -> bool:
645         """Return True if there is an yet unmatched open bracket on the line."""
646         return bool(self.bracket_match)
647
648     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
649         """Return the highest priority of a delimiter found on the line.
650
651         Values are consistent with what `is_split_*_delimiter()` return.
652         Raises ValueError on no delimiters.
653         """
654         return max(v for k, v in self.delimiters.items() if k not in exclude)
655
656     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
657         """In a for loop, or comprehension, the variables are often unpacks.
658
659         To avoid splitting on the comma in this situation, increase the depth of
660         tokens between `for` and `in`.
661         """
662         if leaf.type == token.NAME and leaf.value == "for":
663             self.depth += 1
664             self._for_loop_variable = True
665             return True
666
667         return False
668
669     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
670         """See `maybe_increment_for_loop_variable` above for explanation."""
671         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
672             self.depth -= 1
673             self._for_loop_variable = False
674             return True
675
676         return False
677
678     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
679         """In a lambda expression, there might be more than one argument.
680
681         To avoid splitting on the comma in this situation, increase the depth of
682         tokens between `lambda` and `:`.
683         """
684         if leaf.type == token.NAME and leaf.value == "lambda":
685             self.depth += 1
686             self._lambda_arguments = True
687             return True
688
689         return False
690
691     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
692         """See `maybe_increment_lambda_arguments` above for explanation."""
693         if self._lambda_arguments and leaf.type == token.COLON:
694             self.depth -= 1
695             self._lambda_arguments = False
696             return True
697
698         return False
699
700
701 @dataclass
702 class Line:
703     """Holds leaves and comments. Can be printed with `str(line)`."""
704
705     depth: int = 0
706     leaves: List[Leaf] = Factory(list)
707     comments: List[Tuple[Index, Leaf]] = Factory(list)
708     bracket_tracker: BracketTracker = Factory(BracketTracker)
709     inside_brackets: bool = False
710
711     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
712         """Add a new `leaf` to the end of the line.
713
714         Unless `preformatted` is True, the `leaf` will receive a new consistent
715         whitespace prefix and metadata applied by :class:`BracketTracker`.
716         Trailing commas are maybe removed, unpacked for loop variables are
717         demoted from being delimiters.
718
719         Inline comments are put aside.
720         """
721         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
722         if not has_value:
723             return
724
725         if self.leaves and not preformatted:
726             # Note: at this point leaf.prefix should be empty except for
727             # imports, for which we only preserve newlines.
728             leaf.prefix += whitespace(leaf)
729         if self.inside_brackets or not preformatted:
730             self.bracket_tracker.mark(leaf)
731             self.maybe_remove_trailing_comma(leaf)
732
733         if not self.append_comment(leaf):
734             self.leaves.append(leaf)
735
736     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
737         """Like :func:`append()` but disallow invalid standalone comment structure.
738
739         Raises ValueError when any `leaf` is appended after a standalone comment
740         or when a standalone comment is not the first leaf on the line.
741         """
742         if self.bracket_tracker.depth == 0:
743             if self.is_comment:
744                 raise ValueError("cannot append to standalone comments")
745
746             if self.leaves and leaf.type == STANDALONE_COMMENT:
747                 raise ValueError(
748                     "cannot append standalone comments to a populated line"
749                 )
750
751         self.append(leaf, preformatted=preformatted)
752
753     @property
754     def is_comment(self) -> bool:
755         """Is this line a standalone comment?"""
756         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
757
758     @property
759     def is_decorator(self) -> bool:
760         """Is this line a decorator?"""
761         return bool(self) and self.leaves[0].type == token.AT
762
763     @property
764     def is_import(self) -> bool:
765         """Is this an import line?"""
766         return bool(self) and is_import(self.leaves[0])
767
768     @property
769     def is_class(self) -> bool:
770         """Is this line a class definition?"""
771         return (
772             bool(self)
773             and self.leaves[0].type == token.NAME
774             and self.leaves[0].value == "class"
775         )
776
777     @property
778     def is_def(self) -> bool:
779         """Is this a function definition? (Also returns True for async defs.)"""
780         try:
781             first_leaf = self.leaves[0]
782         except IndexError:
783             return False
784
785         try:
786             second_leaf: Optional[Leaf] = self.leaves[1]
787         except IndexError:
788             second_leaf = None
789         return (
790             (first_leaf.type == token.NAME and first_leaf.value == "def")
791             or (
792                 first_leaf.type == token.ASYNC
793                 and second_leaf is not None
794                 and second_leaf.type == token.NAME
795                 and second_leaf.value == "def"
796             )
797         )
798
799     @property
800     def is_flow_control(self) -> bool:
801         """Is this line a flow control statement?
802
803         Those are `return`, `raise`, `break`, and `continue`.
804         """
805         return (
806             bool(self)
807             and self.leaves[0].type == token.NAME
808             and self.leaves[0].value in FLOW_CONTROL
809         )
810
811     @property
812     def is_yield(self) -> bool:
813         """Is this line a yield statement?"""
814         return (
815             bool(self)
816             and self.leaves[0].type == token.NAME
817             and self.leaves[0].value == "yield"
818         )
819
820     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
821         """If so, needs to be split before emitting."""
822         for leaf in self.leaves:
823             if leaf.type == STANDALONE_COMMENT:
824                 if leaf.bracket_depth <= depth_limit:
825                     return True
826
827         return False
828
829     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
830         """Remove trailing comma if there is one and it's safe."""
831         if not (
832             self.leaves
833             and self.leaves[-1].type == token.COMMA
834             and closing.type in CLOSING_BRACKETS
835         ):
836             return False
837
838         if closing.type == token.RBRACE:
839             self.remove_trailing_comma()
840             return True
841
842         if closing.type == token.RSQB:
843             comma = self.leaves[-1]
844             if comma.parent and comma.parent.type == syms.listmaker:
845                 self.remove_trailing_comma()
846                 return True
847
848         # For parens let's check if it's safe to remove the comma.  If the
849         # trailing one is the only one, we might mistakenly change a tuple
850         # into a different type by removing the comma.
851         depth = closing.bracket_depth + 1
852         commas = 0
853         opening = closing.opening_bracket
854         for _opening_index, leaf in enumerate(self.leaves):
855             if leaf is opening:
856                 break
857
858         else:
859             return False
860
861         for leaf in self.leaves[_opening_index + 1:]:
862             if leaf is closing:
863                 break
864
865             bracket_depth = leaf.bracket_depth
866             if bracket_depth == depth and leaf.type == token.COMMA:
867                 commas += 1
868                 if leaf.parent and leaf.parent.type == syms.arglist:
869                     commas += 1
870                     break
871
872         if commas > 1:
873             self.remove_trailing_comma()
874             return True
875
876         return False
877
878     def append_comment(self, comment: Leaf) -> bool:
879         """Add an inline or standalone comment to the line."""
880         if (
881             comment.type == STANDALONE_COMMENT
882             and self.bracket_tracker.any_open_brackets()
883         ):
884             comment.prefix = ""
885             return False
886
887         if comment.type != token.COMMENT:
888             return False
889
890         after = len(self.leaves) - 1
891         if after == -1:
892             comment.type = STANDALONE_COMMENT
893             comment.prefix = ""
894             return False
895
896         else:
897             self.comments.append((after, comment))
898             return True
899
900     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
901         """Generate comments that should appear directly after `leaf`."""
902         for _leaf_index, _leaf in enumerate(self.leaves):
903             if leaf is _leaf:
904                 break
905
906         else:
907             return
908
909         for index, comment_after in self.comments:
910             if _leaf_index == index:
911                 yield comment_after
912
913     def remove_trailing_comma(self) -> None:
914         """Remove the trailing comma and moves the comments attached to it."""
915         comma_index = len(self.leaves) - 1
916         for i in range(len(self.comments)):
917             comment_index, comment = self.comments[i]
918             if comment_index == comma_index:
919                 self.comments[i] = (comma_index - 1, comment)
920         self.leaves.pop()
921
922     def __str__(self) -> str:
923         """Render the line."""
924         if not self:
925             return "\n"
926
927         indent = "    " * self.depth
928         leaves = iter(self.leaves)
929         first = next(leaves)
930         res = f"{first.prefix}{indent}{first.value}"
931         for leaf in leaves:
932             res += str(leaf)
933         for _, comment in self.comments:
934             res += str(comment)
935         return res + "\n"
936
937     def __bool__(self) -> bool:
938         """Return True if the line has leaves or comments."""
939         return bool(self.leaves or self.comments)
940
941
942 class UnformattedLines(Line):
943     """Just like :class:`Line` but stores lines which aren't reformatted."""
944
945     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
946         """Just add a new `leaf` to the end of the lines.
947
948         The `preformatted` argument is ignored.
949
950         Keeps track of indentation `depth`, which is useful when the user
951         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
952         """
953         try:
954             list(generate_comments(leaf))
955         except FormatOn as f_on:
956             self.leaves.append(f_on.leaf_from_consumed(leaf))
957             raise
958
959         self.leaves.append(leaf)
960         if leaf.type == token.INDENT:
961             self.depth += 1
962         elif leaf.type == token.DEDENT:
963             self.depth -= 1
964
965     def __str__(self) -> str:
966         """Render unformatted lines from leaves which were added with `append()`.
967
968         `depth` is not used for indentation in this case.
969         """
970         if not self:
971             return "\n"
972
973         res = ""
974         for leaf in self.leaves:
975             res += str(leaf)
976         return res
977
978     def append_comment(self, comment: Leaf) -> bool:
979         """Not implemented in this class. Raises `NotImplementedError`."""
980         raise NotImplementedError("Unformatted lines don't store comments separately.")
981
982     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
983         """Does nothing and returns False."""
984         return False
985
986     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
987         """Does nothing and returns False."""
988         return False
989
990
991 @dataclass
992 class EmptyLineTracker:
993     """Provides a stateful method that returns the number of potential extra
994     empty lines needed before and after the currently processed line.
995
996     Note: this tracker works on lines that haven't been split yet.  It assumes
997     the prefix of the first leaf consists of optional newlines.  Those newlines
998     are consumed by `maybe_empty_lines()` and included in the computation.
999     """
1000     previous_line: Optional[Line] = None
1001     previous_after: int = 0
1002     previous_defs: List[int] = Factory(list)
1003
1004     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1005         """Return the number of extra empty lines before and after the `current_line`.
1006
1007         This is for separating `def`, `async def` and `class` with extra empty
1008         lines (two on module-level), as well as providing an extra empty line
1009         after flow control keywords to make them more prominent.
1010         """
1011         if isinstance(current_line, UnformattedLines):
1012             return 0, 0
1013
1014         before, after = self._maybe_empty_lines(current_line)
1015         before -= self.previous_after
1016         self.previous_after = after
1017         self.previous_line = current_line
1018         return before, after
1019
1020     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1021         max_allowed = 1
1022         if current_line.depth == 0:
1023             max_allowed = 2
1024         if current_line.leaves:
1025             # Consume the first leaf's extra newlines.
1026             first_leaf = current_line.leaves[0]
1027             before = first_leaf.prefix.count("\n")
1028             before = min(before, max_allowed)
1029             first_leaf.prefix = ""
1030         else:
1031             before = 0
1032         depth = current_line.depth
1033         while self.previous_defs and self.previous_defs[-1] >= depth:
1034             self.previous_defs.pop()
1035             before = 1 if depth else 2
1036         is_decorator = current_line.is_decorator
1037         if is_decorator or current_line.is_def or current_line.is_class:
1038             if not is_decorator:
1039                 self.previous_defs.append(depth)
1040             if self.previous_line is None:
1041                 # Don't insert empty lines before the first line in the file.
1042                 return 0, 0
1043
1044             if self.previous_line.is_decorator:
1045                 return 0, 0
1046
1047             if (
1048                 self.previous_line.is_comment
1049                 and self.previous_line.depth == current_line.depth
1050                 and before == 0
1051             ):
1052                 return 0, 0
1053
1054             newlines = 2
1055             if current_line.depth:
1056                 newlines -= 1
1057             return newlines, 0
1058
1059         if current_line.is_flow_control:
1060             return before, 1
1061
1062         if (
1063             self.previous_line
1064             and self.previous_line.is_import
1065             and not current_line.is_import
1066             and depth == self.previous_line.depth
1067         ):
1068             return (before or 1), 0
1069
1070         if (
1071             self.previous_line
1072             and self.previous_line.is_yield
1073             and (not current_line.is_yield or depth != self.previous_line.depth)
1074         ):
1075             return (before or 1), 0
1076
1077         return before, 0
1078
1079
1080 @dataclass
1081 class LineGenerator(Visitor[Line]):
1082     """Generates reformatted Line objects.  Empty lines are not emitted.
1083
1084     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1085     in ways that will no longer stringify to valid Python code on the tree.
1086     """
1087     current_line: Line = Factory(Line)
1088
1089     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1090         """Generate a line.
1091
1092         If the line is empty, only emit if it makes sense.
1093         If the line is too long, split it first and then generate.
1094
1095         If any lines were generated, set up a new current_line.
1096         """
1097         if not self.current_line:
1098             if self.current_line.__class__ == type:
1099                 self.current_line.depth += indent
1100             else:
1101                 self.current_line = type(depth=self.current_line.depth + indent)
1102             return  # Line is empty, don't emit. Creating a new one unnecessary.
1103
1104         complete_line = self.current_line
1105         self.current_line = type(depth=complete_line.depth + indent)
1106         yield complete_line
1107
1108     def visit(self, node: LN) -> Iterator[Line]:
1109         """Main method to visit `node` and its children.
1110
1111         Yields :class:`Line` objects.
1112         """
1113         if isinstance(self.current_line, UnformattedLines):
1114             # File contained `# fmt: off`
1115             yield from self.visit_unformatted(node)
1116
1117         else:
1118             yield from super().visit(node)
1119
1120     def visit_default(self, node: LN) -> Iterator[Line]:
1121         """Default `visit_*()` implementation. Recurses to children of `node`."""
1122         if isinstance(node, Leaf):
1123             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1124             try:
1125                 for comment in generate_comments(node):
1126                     if any_open_brackets:
1127                         # any comment within brackets is subject to splitting
1128                         self.current_line.append(comment)
1129                     elif comment.type == token.COMMENT:
1130                         # regular trailing comment
1131                         self.current_line.append(comment)
1132                         yield from self.line()
1133
1134                     else:
1135                         # regular standalone comment
1136                         yield from self.line()
1137
1138                         self.current_line.append(comment)
1139                         yield from self.line()
1140
1141             except FormatOff as f_off:
1142                 f_off.trim_prefix(node)
1143                 yield from self.line(type=UnformattedLines)
1144                 yield from self.visit(node)
1145
1146             except FormatOn as f_on:
1147                 # This only happens here if somebody says "fmt: on" multiple
1148                 # times in a row.
1149                 f_on.trim_prefix(node)
1150                 yield from self.visit_default(node)
1151
1152             else:
1153                 normalize_prefix(node, inside_brackets=any_open_brackets)
1154                 if node.type == token.STRING:
1155                     normalize_string_quotes(node)
1156                 if node.type not in WHITESPACE:
1157                     self.current_line.append(node)
1158         yield from super().visit_default(node)
1159
1160     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1161         """Increase indentation level, maybe yield a line."""
1162         # In blib2to3 INDENT never holds comments.
1163         yield from self.line(+1)
1164         yield from self.visit_default(node)
1165
1166     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1167         """Decrease indentation level, maybe yield a line."""
1168         # The current line might still wait for trailing comments.  At DEDENT time
1169         # there won't be any (they would be prefixes on the preceding NEWLINE).
1170         # Emit the line then.
1171         yield from self.line()
1172
1173         # While DEDENT has no value, its prefix may contain standalone comments
1174         # that belong to the current indentation level.  Get 'em.
1175         yield from self.visit_default(node)
1176
1177         # Finally, emit the dedent.
1178         yield from self.line(-1)
1179
1180     def visit_stmt(
1181         self, node: Node, keywords: Set[str], parens: Set[str]
1182     ) -> Iterator[Line]:
1183         """Visit a statement.
1184
1185         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1186         `def`, `with`, `class`, and `assert`.
1187
1188         The relevant Python language `keywords` for a given statement will be
1189         NAME leaves within it. This methods puts those on a separate line.
1190
1191         `parens` holds pairs of nodes where invisible parentheses should be put.
1192         Keys hold nodes after which opening parentheses should be put, values
1193         hold nodes before which closing parentheses should be put.
1194         """
1195         normalize_invisible_parens(node, parens_after=parens)
1196         for child in node.children:
1197             if child.type == token.NAME and child.value in keywords:  # type: ignore
1198                 yield from self.line()
1199
1200             yield from self.visit(child)
1201
1202     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1203         """Visit a statement without nested statements."""
1204         is_suite_like = node.parent and node.parent.type in STATEMENT
1205         if is_suite_like:
1206             yield from self.line(+1)
1207             yield from self.visit_default(node)
1208             yield from self.line(-1)
1209
1210         else:
1211             yield from self.line()
1212             yield from self.visit_default(node)
1213
1214     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1215         """Visit `async def`, `async for`, `async with`."""
1216         yield from self.line()
1217
1218         children = iter(node.children)
1219         for child in children:
1220             yield from self.visit(child)
1221
1222             if child.type == token.ASYNC:
1223                 break
1224
1225         internal_stmt = next(children)
1226         for child in internal_stmt.children:
1227             yield from self.visit(child)
1228
1229     def visit_decorators(self, node: Node) -> Iterator[Line]:
1230         """Visit decorators."""
1231         for child in node.children:
1232             yield from self.line()
1233             yield from self.visit(child)
1234
1235     def visit_import_from(self, node: Node) -> Iterator[Line]:
1236         """Visit import_from and maybe put invisible parentheses.
1237
1238         This is separate from `visit_stmt` because import statements don't
1239         support arbitrary atoms and thus handling of parentheses is custom.
1240         """
1241         check_lpar = False
1242         for index, child in enumerate(node.children):
1243             if check_lpar:
1244                 if child.type == token.LPAR:
1245                     # make parentheses invisible
1246                     child.value = ""  # type: ignore
1247                     node.children[-1].value = ""  # type: ignore
1248                 else:
1249                     # insert invisible parentheses
1250                     node.insert_child(index, Leaf(token.LPAR, ""))
1251                     node.append_child(Leaf(token.RPAR, ""))
1252                 break
1253
1254             check_lpar = (
1255                 child.type == token.NAME and child.value == "import"  # type: ignore
1256             )
1257
1258         for child in node.children:
1259             yield from self.visit(child)
1260
1261     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1262         """Remove a semicolon and put the other statement on a separate line."""
1263         yield from self.line()
1264
1265     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1266         """End of file. Process outstanding comments and end with a newline."""
1267         yield from self.visit_default(leaf)
1268         yield from self.line()
1269
1270     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1271         """Used when file contained a `# fmt: off`."""
1272         if isinstance(node, Node):
1273             for child in node.children:
1274                 yield from self.visit(child)
1275
1276         else:
1277             try:
1278                 self.current_line.append(node)
1279             except FormatOn as f_on:
1280                 f_on.trim_prefix(node)
1281                 yield from self.line()
1282                 yield from self.visit(node)
1283
1284             if node.type == token.ENDMARKER:
1285                 # somebody decided not to put a final `# fmt: on`
1286                 yield from self.line()
1287
1288     def __attrs_post_init__(self) -> None:
1289         """You are in a twisty little maze of passages."""
1290         v = self.visit_stmt
1291         Ø: Set[str] = set()
1292         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1293         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1294         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1295         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1296         self.visit_try_stmt = partial(
1297             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1298         )
1299         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1300         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1301         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1302         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1303         self.visit_async_funcdef = self.visit_async_stmt
1304         self.visit_decorated = self.visit_decorators
1305
1306
1307 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1308 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1309 OPENING_BRACKETS = set(BRACKET.keys())
1310 CLOSING_BRACKETS = set(BRACKET.values())
1311 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1312 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1313
1314
1315 def whitespace(leaf: Leaf) -> str:  # noqa C901
1316     """Return whitespace prefix if needed for the given `leaf`."""
1317     NO = ""
1318     SPACE = " "
1319     DOUBLESPACE = "  "
1320     t = leaf.type
1321     p = leaf.parent
1322     v = leaf.value
1323     if t in ALWAYS_NO_SPACE:
1324         return NO
1325
1326     if t == token.COMMENT:
1327         return DOUBLESPACE
1328
1329     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1330     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1331         return NO
1332
1333     prev = leaf.prev_sibling
1334     if not prev:
1335         prevp = preceding_leaf(p)
1336         if not prevp or prevp.type in OPENING_BRACKETS:
1337             return NO
1338
1339         if t == token.COLON:
1340             return SPACE if prevp.type == token.COMMA else NO
1341
1342         if prevp.type == token.EQUAL:
1343             if prevp.parent:
1344                 if prevp.parent.type in {
1345                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1346                 }:
1347                     return NO
1348
1349                 elif prevp.parent.type == syms.typedargslist:
1350                     # A bit hacky: if the equal sign has whitespace, it means we
1351                     # previously found it's a typed argument.  So, we're using
1352                     # that, too.
1353                     return prevp.prefix
1354
1355         elif prevp.type in STARS:
1356             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1357                 return NO
1358
1359         elif prevp.type == token.COLON:
1360             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1361                 return NO
1362
1363         elif (
1364             prevp.parent
1365             and prevp.parent.type == syms.factor
1366             and prevp.type in MATH_OPERATORS
1367         ):
1368             return NO
1369
1370         elif (
1371             prevp.type == token.RIGHTSHIFT
1372             and prevp.parent
1373             and prevp.parent.type == syms.shift_expr
1374             and prevp.prev_sibling
1375             and prevp.prev_sibling.type == token.NAME
1376             and prevp.prev_sibling.value == "print"  # type: ignore
1377         ):
1378             # Python 2 print chevron
1379             return NO
1380
1381     elif prev.type in OPENING_BRACKETS:
1382         return NO
1383
1384     if p.type in {syms.parameters, syms.arglist}:
1385         # untyped function signatures or calls
1386         if not prev or prev.type != token.COMMA:
1387             return NO
1388
1389     elif p.type == syms.varargslist:
1390         # lambdas
1391         if prev and prev.type != token.COMMA:
1392             return NO
1393
1394     elif p.type == syms.typedargslist:
1395         # typed function signatures
1396         if not prev:
1397             return NO
1398
1399         if t == token.EQUAL:
1400             if prev.type != syms.tname:
1401                 return NO
1402
1403         elif prev.type == token.EQUAL:
1404             # A bit hacky: if the equal sign has whitespace, it means we
1405             # previously found it's a typed argument.  So, we're using that, too.
1406             return prev.prefix
1407
1408         elif prev.type != token.COMMA:
1409             return NO
1410
1411     elif p.type == syms.tname:
1412         # type names
1413         if not prev:
1414             prevp = preceding_leaf(p)
1415             if not prevp or prevp.type != token.COMMA:
1416                 return NO
1417
1418     elif p.type == syms.trailer:
1419         # attributes and calls
1420         if t == token.LPAR or t == token.RPAR:
1421             return NO
1422
1423         if not prev:
1424             if t == token.DOT:
1425                 prevp = preceding_leaf(p)
1426                 if not prevp or prevp.type != token.NUMBER:
1427                     return NO
1428
1429             elif t == token.LSQB:
1430                 return NO
1431
1432         elif prev.type != token.COMMA:
1433             return NO
1434
1435     elif p.type == syms.argument:
1436         # single argument
1437         if t == token.EQUAL:
1438             return NO
1439
1440         if not prev:
1441             prevp = preceding_leaf(p)
1442             if not prevp or prevp.type == token.LPAR:
1443                 return NO
1444
1445         elif prev.type in {token.EQUAL} | STARS:
1446             return NO
1447
1448     elif p.type == syms.decorator:
1449         # decorators
1450         return NO
1451
1452     elif p.type == syms.dotted_name:
1453         if prev:
1454             return NO
1455
1456         prevp = preceding_leaf(p)
1457         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1458             return NO
1459
1460     elif p.type == syms.classdef:
1461         if t == token.LPAR:
1462             return NO
1463
1464         if prev and prev.type == token.LPAR:
1465             return NO
1466
1467     elif p.type == syms.subscript:
1468         # indexing
1469         if not prev:
1470             assert p.parent is not None, "subscripts are always parented"
1471             if p.parent.type == syms.subscriptlist:
1472                 return SPACE
1473
1474             return NO
1475
1476         else:
1477             return NO
1478
1479     elif p.type == syms.atom:
1480         if prev and t == token.DOT:
1481             # dots, but not the first one.
1482             return NO
1483
1484     elif p.type == syms.dictsetmaker:
1485         # dict unpacking
1486         if prev and prev.type == token.DOUBLESTAR:
1487             return NO
1488
1489     elif p.type in {syms.factor, syms.star_expr}:
1490         # unary ops
1491         if not prev:
1492             prevp = preceding_leaf(p)
1493             if not prevp or prevp.type in OPENING_BRACKETS:
1494                 return NO
1495
1496             prevp_parent = prevp.parent
1497             assert prevp_parent is not None
1498             if (
1499                 prevp.type == token.COLON
1500                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1501             ):
1502                 return NO
1503
1504             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1505                 return NO
1506
1507         elif t == token.NAME or t == token.NUMBER:
1508             return NO
1509
1510     elif p.type == syms.import_from:
1511         if t == token.DOT:
1512             if prev and prev.type == token.DOT:
1513                 return NO
1514
1515         elif t == token.NAME:
1516             if v == "import":
1517                 return SPACE
1518
1519             if prev and prev.type == token.DOT:
1520                 return NO
1521
1522     elif p.type == syms.sliceop:
1523         return NO
1524
1525     return SPACE
1526
1527
1528 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1529     """Return the first leaf that precedes `node`, if any."""
1530     while node:
1531         res = node.prev_sibling
1532         if res:
1533             if isinstance(res, Leaf):
1534                 return res
1535
1536             try:
1537                 return list(res.leaves())[-1]
1538
1539             except IndexError:
1540                 return None
1541
1542         node = node.parent
1543     return None
1544
1545
1546 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1547     """Return the priority of the `leaf` delimiter, given a line break after it.
1548
1549     The delimiter priorities returned here are from those delimiters that would
1550     cause a line break after themselves.
1551
1552     Higher numbers are higher priority.
1553     """
1554     if leaf.type == token.COMMA:
1555         return COMMA_PRIORITY
1556
1557     return 0
1558
1559
1560 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1561     """Return the priority of the `leaf` delimiter, given a line before after it.
1562
1563     The delimiter priorities returned here are from those delimiters that would
1564     cause a line break before themselves.
1565
1566     Higher numbers are higher priority.
1567     """
1568     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1569         # * and ** might also be MATH_OPERATORS but in this case they are not.
1570         # Don't treat them as a delimiter.
1571         return 0
1572
1573     if (
1574         leaf.type in MATH_OPERATORS
1575         and leaf.parent
1576         and leaf.parent.type not in {syms.factor, syms.star_expr}
1577     ):
1578         return MATH_PRIORITY
1579
1580     if leaf.type in COMPARATORS:
1581         return COMPARATOR_PRIORITY
1582
1583     if (
1584         leaf.type == token.STRING
1585         and previous is not None
1586         and previous.type == token.STRING
1587     ):
1588         return STRING_PRIORITY
1589
1590     if (
1591         leaf.type == token.NAME
1592         and leaf.value == "for"
1593         and leaf.parent
1594         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1595     ):
1596         return COMPREHENSION_PRIORITY
1597
1598     if (
1599         leaf.type == token.NAME
1600         and leaf.value == "if"
1601         and leaf.parent
1602         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1603     ):
1604         return COMPREHENSION_PRIORITY
1605
1606     if (
1607         leaf.type == token.NAME
1608         and leaf.value in {"if", "else"}
1609         and leaf.parent
1610         and leaf.parent.type == syms.test
1611     ):
1612         return TERNARY_PRIORITY
1613
1614     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1615         return LOGIC_PRIORITY
1616
1617     return 0
1618
1619
1620 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1621     """Clean the prefix of the `leaf` and generate comments from it, if any.
1622
1623     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1624     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1625     move because it does away with modifying the grammar to include all the
1626     possible places in which comments can be placed.
1627
1628     The sad consequence for us though is that comments don't "belong" anywhere.
1629     This is why this function generates simple parentless Leaf objects for
1630     comments.  We simply don't know what the correct parent should be.
1631
1632     No matter though, we can live without this.  We really only need to
1633     differentiate between inline and standalone comments.  The latter don't
1634     share the line with any code.
1635
1636     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1637     are emitted with a fake STANDALONE_COMMENT token identifier.
1638     """
1639     p = leaf.prefix
1640     if not p:
1641         return
1642
1643     if "#" not in p:
1644         return
1645
1646     consumed = 0
1647     nlines = 0
1648     for index, line in enumerate(p.split("\n")):
1649         consumed += len(line) + 1  # adding the length of the split '\n'
1650         line = line.lstrip()
1651         if not line:
1652             nlines += 1
1653         if not line.startswith("#"):
1654             continue
1655
1656         if index == 0 and leaf.type != token.ENDMARKER:
1657             comment_type = token.COMMENT  # simple trailing comment
1658         else:
1659             comment_type = STANDALONE_COMMENT
1660         comment = make_comment(line)
1661         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1662
1663         if comment in {"# fmt: on", "# yapf: enable"}:
1664             raise FormatOn(consumed)
1665
1666         if comment in {"# fmt: off", "# yapf: disable"}:
1667             if comment_type == STANDALONE_COMMENT:
1668                 raise FormatOff(consumed)
1669
1670             prev = preceding_leaf(leaf)
1671             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1672                 raise FormatOff(consumed)
1673
1674         nlines = 0
1675
1676
1677 def make_comment(content: str) -> str:
1678     """Return a consistently formatted comment from the given `content` string.
1679
1680     All comments (except for "##", "#!", "#:") should have a single space between
1681     the hash sign and the content.
1682
1683     If `content` didn't start with a hash sign, one is provided.
1684     """
1685     content = content.rstrip()
1686     if not content:
1687         return "#"
1688
1689     if content[0] == "#":
1690         content = content[1:]
1691     if content and content[0] not in " !:#":
1692         content = " " + content
1693     return "#" + content
1694
1695
1696 def split_line(
1697     line: Line, line_length: int, inner: bool = False, py36: bool = False
1698 ) -> Iterator[Line]:
1699     """Split a `line` into potentially many lines.
1700
1701     They should fit in the allotted `line_length` but might not be able to.
1702     `inner` signifies that there were a pair of brackets somewhere around the
1703     current `line`, possibly transitively. This means we can fallback to splitting
1704     by delimiters if the LHS/RHS don't yield any results.
1705
1706     If `py36` is True, splitting may generate syntax that is only compatible
1707     with Python 3.6 and later.
1708     """
1709     if isinstance(line, UnformattedLines) or line.is_comment:
1710         yield line
1711         return
1712
1713     line_str = str(line).strip("\n")
1714     if (
1715         len(line_str) <= line_length
1716         and "\n" not in line_str  # multiline strings
1717         and not line.contains_standalone_comments()
1718     ):
1719         yield line
1720         return
1721
1722     split_funcs: List[SplitFunc]
1723     if line.is_def:
1724         split_funcs = [left_hand_split]
1725     elif line.inside_brackets:
1726         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1727     else:
1728         split_funcs = [right_hand_split]
1729     for split_func in split_funcs:
1730         # We are accumulating lines in `result` because we might want to abort
1731         # mission and return the original line in the end, or attempt a different
1732         # split altogether.
1733         result: List[Line] = []
1734         try:
1735             for l in split_func(line, py36):
1736                 if str(l).strip("\n") == line_str:
1737                     raise CannotSplit("Split function returned an unchanged result")
1738
1739                 result.extend(
1740                     split_line(l, line_length=line_length, inner=True, py36=py36)
1741                 )
1742         except CannotSplit as cs:
1743             continue
1744
1745         else:
1746             yield from result
1747             break
1748
1749     else:
1750         yield line
1751
1752
1753 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1754     """Split line into many lines, starting with the first matching bracket pair.
1755
1756     Note: this usually looks weird, only use this for function definitions.
1757     Prefer RHS otherwise.
1758     """
1759     head = Line(depth=line.depth)
1760     body = Line(depth=line.depth + 1, inside_brackets=True)
1761     tail = Line(depth=line.depth)
1762     tail_leaves: List[Leaf] = []
1763     body_leaves: List[Leaf] = []
1764     head_leaves: List[Leaf] = []
1765     current_leaves = head_leaves
1766     matching_bracket = None
1767     for leaf in line.leaves:
1768         if (
1769             current_leaves is body_leaves
1770             and leaf.type in CLOSING_BRACKETS
1771             and leaf.opening_bracket is matching_bracket
1772         ):
1773             current_leaves = tail_leaves if body_leaves else head_leaves
1774         current_leaves.append(leaf)
1775         if current_leaves is head_leaves:
1776             if leaf.type in OPENING_BRACKETS:
1777                 matching_bracket = leaf
1778                 current_leaves = body_leaves
1779     # Since body is a new indent level, remove spurious leading whitespace.
1780     if body_leaves:
1781         normalize_prefix(body_leaves[0], inside_brackets=True)
1782     # Build the new lines.
1783     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1784         for leaf in leaves:
1785             result.append(leaf, preformatted=True)
1786             for comment_after in line.comments_after(leaf):
1787                 result.append(comment_after, preformatted=True)
1788     bracket_split_succeeded_or_raise(head, body, tail)
1789     for result in (head, body, tail):
1790         if result:
1791             yield result
1792
1793
1794 def right_hand_split(
1795     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1796 ) -> Iterator[Line]:
1797     """Split line into many lines, starting with the last matching bracket pair."""
1798     head = Line(depth=line.depth)
1799     body = Line(depth=line.depth + 1, inside_brackets=True)
1800     tail = Line(depth=line.depth)
1801     tail_leaves: List[Leaf] = []
1802     body_leaves: List[Leaf] = []
1803     head_leaves: List[Leaf] = []
1804     current_leaves = tail_leaves
1805     opening_bracket = None
1806     closing_bracket = None
1807     for leaf in reversed(line.leaves):
1808         if current_leaves is body_leaves:
1809             if leaf is opening_bracket:
1810                 current_leaves = head_leaves if body_leaves else tail_leaves
1811         current_leaves.append(leaf)
1812         if current_leaves is tail_leaves:
1813             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1814                 opening_bracket = leaf.opening_bracket
1815                 closing_bracket = leaf
1816                 current_leaves = body_leaves
1817     tail_leaves.reverse()
1818     body_leaves.reverse()
1819     head_leaves.reverse()
1820     # Since body is a new indent level, remove spurious leading whitespace.
1821     if body_leaves:
1822         normalize_prefix(body_leaves[0], inside_brackets=True)
1823     elif not head_leaves:
1824         # No `head` and no `body` means the split failed. `tail` has all content.
1825         raise CannotSplit("No brackets found")
1826
1827     # Build the new lines.
1828     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1829         for leaf in leaves:
1830             result.append(leaf, preformatted=True)
1831             for comment_after in line.comments_after(leaf):
1832                 result.append(comment_after, preformatted=True)
1833     bracket_split_succeeded_or_raise(head, body, tail)
1834     assert opening_bracket and closing_bracket
1835     if (
1836         opening_bracket.type == token.LPAR
1837         and not opening_bracket.value
1838         and closing_bracket.type == token.RPAR
1839         and not closing_bracket.value
1840     ):
1841         # These parens were optional. If there aren't any delimiters or standalone
1842         # comments in the body, they were unnecessary and another split without
1843         # them should be attempted.
1844         if not (
1845             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1846         ):
1847             omit = {id(closing_bracket), *omit}
1848             yield from right_hand_split(line, py36=py36, omit=omit)
1849             return
1850
1851     ensure_visible(opening_bracket)
1852     ensure_visible(closing_bracket)
1853     for result in (head, body, tail):
1854         if result:
1855             yield result
1856
1857
1858 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1859     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1860
1861     Do nothing otherwise.
1862
1863     A left- or right-hand split is based on a pair of brackets. Content before
1864     (and including) the opening bracket is left on one line, content inside the
1865     brackets is put on a separate line, and finally content starting with and
1866     following the closing bracket is put on a separate line.
1867
1868     Those are called `head`, `body`, and `tail`, respectively. If the split
1869     produced the same line (all content in `head`) or ended up with an empty `body`
1870     and the `tail` is just the closing bracket, then it's considered failed.
1871     """
1872     tail_len = len(str(tail).strip())
1873     if not body:
1874         if tail_len == 0:
1875             raise CannotSplit("Splitting brackets produced the same line")
1876
1877         elif tail_len < 3:
1878             raise CannotSplit(
1879                 f"Splitting brackets on an empty body to save "
1880                 f"{tail_len} characters is not worth it"
1881             )
1882
1883
1884 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1885     """Normalize prefix of the first leaf in every line returned by `split_func`.
1886
1887     This is a decorator over relevant split functions.
1888     """
1889
1890     @wraps(split_func)
1891     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1892         for l in split_func(line, py36):
1893             normalize_prefix(l.leaves[0], inside_brackets=True)
1894             yield l
1895
1896     return split_wrapper
1897
1898
1899 @dont_increase_indentation
1900 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1901     """Split according to delimiters of the highest priority.
1902
1903     If `py36` is True, the split will add trailing commas also in function
1904     signatures that contain `*` and `**`.
1905     """
1906     try:
1907         last_leaf = line.leaves[-1]
1908     except IndexError:
1909         raise CannotSplit("Line empty")
1910
1911     delimiters = line.bracket_tracker.delimiters
1912     try:
1913         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1914             exclude={id(last_leaf)}
1915         )
1916     except ValueError:
1917         raise CannotSplit("No delimiters found")
1918
1919     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1920     lowest_depth = sys.maxsize
1921     trailing_comma_safe = True
1922
1923     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1924         """Append `leaf` to current line or to new line if appending impossible."""
1925         nonlocal current_line
1926         try:
1927             current_line.append_safe(leaf, preformatted=True)
1928         except ValueError as ve:
1929             yield current_line
1930
1931             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1932             current_line.append(leaf)
1933
1934     for leaf in line.leaves:
1935         yield from append_to_line(leaf)
1936
1937         for comment_after in line.comments_after(leaf):
1938             yield from append_to_line(comment_after)
1939
1940         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1941         if (
1942             leaf.bracket_depth == lowest_depth
1943             and is_vararg(leaf, within=VARARGS_PARENTS)
1944         ):
1945             trailing_comma_safe = trailing_comma_safe and py36
1946         leaf_priority = delimiters.get(id(leaf))
1947         if leaf_priority == delimiter_priority:
1948             yield current_line
1949
1950             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1951     if current_line:
1952         if (
1953             trailing_comma_safe
1954             and delimiter_priority == COMMA_PRIORITY
1955             and current_line.leaves[-1].type != token.COMMA
1956             and current_line.leaves[-1].type != STANDALONE_COMMENT
1957         ):
1958             current_line.append(Leaf(token.COMMA, ","))
1959         yield current_line
1960
1961
1962 @dont_increase_indentation
1963 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1964     """Split standalone comments from the rest of the line."""
1965     if not line.contains_standalone_comments(0):
1966         raise CannotSplit("Line does not have any standalone comments")
1967
1968     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1969
1970     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1971         """Append `leaf` to current line or to new line if appending impossible."""
1972         nonlocal current_line
1973         try:
1974             current_line.append_safe(leaf, preformatted=True)
1975         except ValueError as ve:
1976             yield current_line
1977
1978             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1979             current_line.append(leaf)
1980
1981     for leaf in line.leaves:
1982         yield from append_to_line(leaf)
1983
1984         for comment_after in line.comments_after(leaf):
1985             yield from append_to_line(comment_after)
1986
1987     if current_line:
1988         yield current_line
1989
1990
1991 def is_import(leaf: Leaf) -> bool:
1992     """Return True if the given leaf starts an import statement."""
1993     p = leaf.parent
1994     t = leaf.type
1995     v = leaf.value
1996     return bool(
1997         t == token.NAME
1998         and (
1999             (v == "import" and p and p.type == syms.import_name)
2000             or (v == "from" and p and p.type == syms.import_from)
2001         )
2002     )
2003
2004
2005 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2006     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2007     else.
2008
2009     Note: don't use backslashes for formatting or you'll lose your voting rights.
2010     """
2011     if not inside_brackets:
2012         spl = leaf.prefix.split("#")
2013         if "\\" not in spl[0]:
2014             nl_count = spl[-1].count("\n")
2015             if len(spl) > 1:
2016                 nl_count -= 1
2017             leaf.prefix = "\n" * nl_count
2018             return
2019
2020     leaf.prefix = ""
2021
2022
2023 def normalize_string_quotes(leaf: Leaf) -> None:
2024     """Prefer double quotes but only if it doesn't cause more escaping.
2025
2026     Adds or removes backslashes as appropriate. Doesn't parse and fix
2027     strings nested in f-strings (yet).
2028
2029     Note: Mutates its argument.
2030     """
2031     value = leaf.value.lstrip("furbFURB")
2032     if value[:3] == '"""':
2033         return
2034
2035     elif value[:3] == "'''":
2036         orig_quote = "'''"
2037         new_quote = '"""'
2038     elif value[0] == '"':
2039         orig_quote = '"'
2040         new_quote = "'"
2041     else:
2042         orig_quote = "'"
2043         new_quote = '"'
2044     first_quote_pos = leaf.value.find(orig_quote)
2045     if first_quote_pos == -1:
2046         return  # There's an internal error
2047
2048     prefix = leaf.value[:first_quote_pos]
2049     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2050     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2051     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2052     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2053     if "r" in prefix.casefold():
2054         if unescaped_new_quote.search(body):
2055             # There's at least one unescaped new_quote in this raw string
2056             # so converting is impossible
2057             return
2058
2059         # Do not introduce or remove backslashes in raw strings
2060         new_body = body
2061     else:
2062         # remove unnecessary quotes
2063         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2064         if body != new_body:
2065             # Consider the string without unnecessary quotes as the original
2066             body = new_body
2067             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2068         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2069         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2070     if new_quote == '"""' and new_body[-1] == '"':
2071         # edge case:
2072         new_body = new_body[:-1] + '\\"'
2073     orig_escape_count = body.count("\\")
2074     new_escape_count = new_body.count("\\")
2075     if new_escape_count > orig_escape_count:
2076         return  # Do not introduce more escaping
2077
2078     if new_escape_count == orig_escape_count and orig_quote == '"':
2079         return  # Prefer double quotes
2080
2081     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2082
2083
2084 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2085     """Make existing optional parentheses invisible or create new ones.
2086
2087     Standardizes on visible parentheses for single-element tuples, and keeps
2088     existing visible parentheses for other tuples and generator expressions.
2089     """
2090     check_lpar = False
2091     for child in list(node.children):
2092         if check_lpar:
2093             if child.type == syms.atom:
2094                 if not (
2095                     is_empty_tuple(child)
2096                     or is_one_tuple(child)
2097                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2098                 ):
2099                     first = child.children[0]
2100                     last = child.children[-1]
2101                     if first.type == token.LPAR and last.type == token.RPAR:
2102                         # make parentheses invisible
2103                         first.value = ""  # type: ignore
2104                         last.value = ""  # type: ignore
2105             elif is_one_tuple(child):
2106                 # wrap child in visible parentheses
2107                 lpar = Leaf(token.LPAR, "(")
2108                 rpar = Leaf(token.RPAR, ")")
2109                 index = child.remove() or 0
2110                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2111             else:
2112                 # wrap child in invisible parentheses
2113                 lpar = Leaf(token.LPAR, "")
2114                 rpar = Leaf(token.RPAR, "")
2115                 index = child.remove() or 0
2116                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2117
2118         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2119
2120
2121 def is_empty_tuple(node: LN) -> bool:
2122     """Return True if `node` holds an empty tuple."""
2123     return (
2124         node.type == syms.atom
2125         and len(node.children) == 2
2126         and node.children[0].type == token.LPAR
2127         and node.children[1].type == token.RPAR
2128     )
2129
2130
2131 def is_one_tuple(node: LN) -> bool:
2132     """Return True if `node` holds a tuple with one element, with or without parens."""
2133     if node.type == syms.atom:
2134         if len(node.children) != 3:
2135             return False
2136
2137         lpar, gexp, rpar = node.children
2138         if not (
2139             lpar.type == token.LPAR
2140             and gexp.type == syms.testlist_gexp
2141             and rpar.type == token.RPAR
2142         ):
2143             return False
2144
2145         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2146
2147     return (
2148         node.type in IMPLICIT_TUPLE
2149         and len(node.children) == 2
2150         and node.children[1].type == token.COMMA
2151     )
2152
2153
2154 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2155     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2156
2157     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2158     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2159     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2160     generalizations (PEP 448).
2161     """
2162     if leaf.type not in STARS or not leaf.parent:
2163         return False
2164
2165     p = leaf.parent
2166     if p.type == syms.star_expr:
2167         # Star expressions are also used as assignment targets in extended
2168         # iterable unpacking (PEP 3132).  See what its parent is instead.
2169         if not p.parent:
2170             return False
2171
2172         p = p.parent
2173
2174     return p.type in within
2175
2176
2177 def max_delimiter_priority_in_atom(node: LN) -> int:
2178     """Return maximum delimiter priority inside `node`.
2179
2180     This is specific to atoms with contents contained in a pair of parentheses.
2181     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2182     """
2183     if node.type != syms.atom:
2184         return 0
2185
2186     first = node.children[0]
2187     last = node.children[-1]
2188     if not (first.type == token.LPAR and last.type == token.RPAR):
2189         return 0
2190
2191     bt = BracketTracker()
2192     for c in node.children[1:-1]:
2193         if isinstance(c, Leaf):
2194             bt.mark(c)
2195         else:
2196             for leaf in c.leaves():
2197                 bt.mark(leaf)
2198     try:
2199         return bt.max_delimiter_priority()
2200
2201     except ValueError:
2202         return 0
2203
2204
2205 def ensure_visible(leaf: Leaf) -> None:
2206     """Make sure parentheses are visible.
2207
2208     They could be invisible as part of some statements (see
2209     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2210     """
2211     if leaf.type == token.LPAR:
2212         leaf.value = "("
2213     elif leaf.type == token.RPAR:
2214         leaf.value = ")"
2215
2216
2217 def is_python36(node: Node) -> bool:
2218     """Return True if the current file is using Python 3.6+ features.
2219
2220     Currently looking for:
2221     - f-strings; and
2222     - trailing commas after * or ** in function signatures.
2223     """
2224     for n in node.pre_order():
2225         if n.type == token.STRING:
2226             value_head = n.value[:2]  # type: ignore
2227             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2228                 return True
2229
2230         elif (
2231             n.type == syms.typedargslist
2232             and n.children
2233             and n.children[-1].type == token.COMMA
2234         ):
2235             for ch in n.children:
2236                 if ch.type in STARS:
2237                     return True
2238
2239     return False
2240
2241
2242 PYTHON_EXTENSIONS = {".py"}
2243 BLACKLISTED_DIRECTORIES = {
2244     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2245 }
2246
2247
2248 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2249     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2250     and have one of the PYTHON_EXTENSIONS.
2251     """
2252     for child in path.iterdir():
2253         if child.is_dir():
2254             if child.name in BLACKLISTED_DIRECTORIES:
2255                 continue
2256
2257             yield from gen_python_files_in_dir(child)
2258
2259         elif child.suffix in PYTHON_EXTENSIONS:
2260             yield child
2261
2262
2263 @dataclass
2264 class Report:
2265     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2266     check: bool = False
2267     quiet: bool = False
2268     change_count: int = 0
2269     same_count: int = 0
2270     failure_count: int = 0
2271
2272     def done(self, src: Path, changed: Changed) -> None:
2273         """Increment the counter for successful reformatting. Write out a message."""
2274         if changed is Changed.YES:
2275             reformatted = "would reformat" if self.check else "reformatted"
2276             if not self.quiet:
2277                 out(f"{reformatted} {src}")
2278             self.change_count += 1
2279         else:
2280             if not self.quiet:
2281                 if changed is Changed.NO:
2282                     msg = f"{src} already well formatted, good job."
2283                 else:
2284                     msg = f"{src} wasn't modified on disk since last run."
2285                 out(msg, bold=False)
2286             self.same_count += 1
2287
2288     def failed(self, src: Path, message: str) -> None:
2289         """Increment the counter for failed reformatting. Write out a message."""
2290         err(f"error: cannot format {src}: {message}")
2291         self.failure_count += 1
2292
2293     @property
2294     def return_code(self) -> int:
2295         """Return the exit code that the app should use.
2296
2297         This considers the current state of changed files and failures:
2298         - if there were any failures, return 123;
2299         - if any files were changed and --check is being used, return 1;
2300         - otherwise return 0.
2301         """
2302         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2303         # 126 we have special returncodes reserved by the shell.
2304         if self.failure_count:
2305             return 123
2306
2307         elif self.change_count and self.check:
2308             return 1
2309
2310         return 0
2311
2312     def __str__(self) -> str:
2313         """Render a color report of the current state.
2314
2315         Use `click.unstyle` to remove colors.
2316         """
2317         if self.check:
2318             reformatted = "would be reformatted"
2319             unchanged = "would be left unchanged"
2320             failed = "would fail to reformat"
2321         else:
2322             reformatted = "reformatted"
2323             unchanged = "left unchanged"
2324             failed = "failed to reformat"
2325         report = []
2326         if self.change_count:
2327             s = "s" if self.change_count > 1 else ""
2328             report.append(
2329                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2330             )
2331         if self.same_count:
2332             s = "s" if self.same_count > 1 else ""
2333             report.append(f"{self.same_count} file{s} {unchanged}")
2334         if self.failure_count:
2335             s = "s" if self.failure_count > 1 else ""
2336             report.append(
2337                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2338             )
2339         return ", ".join(report) + "."
2340
2341
2342 def assert_equivalent(src: str, dst: str) -> None:
2343     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2344
2345     import ast
2346     import traceback
2347
2348     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2349         """Simple visitor generating strings to compare ASTs by content."""
2350         yield f"{'  ' * depth}{node.__class__.__name__}("
2351
2352         for field in sorted(node._fields):
2353             try:
2354                 value = getattr(node, field)
2355             except AttributeError:
2356                 continue
2357
2358             yield f"{'  ' * (depth+1)}{field}="
2359
2360             if isinstance(value, list):
2361                 for item in value:
2362                     if isinstance(item, ast.AST):
2363                         yield from _v(item, depth + 2)
2364
2365             elif isinstance(value, ast.AST):
2366                 yield from _v(value, depth + 2)
2367
2368             else:
2369                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2370
2371         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2372
2373     try:
2374         src_ast = ast.parse(src)
2375     except Exception as exc:
2376         major, minor = sys.version_info[:2]
2377         raise AssertionError(
2378             f"cannot use --safe with this file; failed to parse source file "
2379             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2380             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2381         )
2382
2383     try:
2384         dst_ast = ast.parse(dst)
2385     except Exception as exc:
2386         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2387         raise AssertionError(
2388             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2389             f"Please report a bug on https://github.com/ambv/black/issues.  "
2390             f"This invalid output might be helpful: {log}"
2391         ) from None
2392
2393     src_ast_str = "\n".join(_v(src_ast))
2394     dst_ast_str = "\n".join(_v(dst_ast))
2395     if src_ast_str != dst_ast_str:
2396         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2397         raise AssertionError(
2398             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2399             f"the source.  "
2400             f"Please report a bug on https://github.com/ambv/black/issues.  "
2401             f"This diff might be helpful: {log}"
2402         ) from None
2403
2404
2405 def assert_stable(src: str, dst: str, line_length: int) -> None:
2406     """Raise AssertionError if `dst` reformats differently the second time."""
2407     newdst = format_str(dst, line_length=line_length)
2408     if dst != newdst:
2409         log = dump_to_file(
2410             diff(src, dst, "source", "first pass"),
2411             diff(dst, newdst, "first pass", "second pass"),
2412         )
2413         raise AssertionError(
2414             f"INTERNAL ERROR: Black produced different code on the second pass "
2415             f"of the formatter.  "
2416             f"Please report a bug on https://github.com/ambv/black/issues.  "
2417             f"This diff might be helpful: {log}"
2418         ) from None
2419
2420
2421 def dump_to_file(*output: str) -> str:
2422     """Dump `output` to a temporary file. Return path to the file."""
2423     import tempfile
2424
2425     with tempfile.NamedTemporaryFile(
2426         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2427     ) as f:
2428         for lines in output:
2429             f.write(lines)
2430             if lines and lines[-1] != "\n":
2431                 f.write("\n")
2432     return f.name
2433
2434
2435 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2436     """Return a unified diff string between strings `a` and `b`."""
2437     import difflib
2438
2439     a_lines = [line + "\n" for line in a.split("\n")]
2440     b_lines = [line + "\n" for line in b.split("\n")]
2441     return "".join(
2442         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2443     )
2444
2445
2446 def cancel(tasks: List[asyncio.Task]) -> None:
2447     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2448     err("Aborted!")
2449     for task in tasks:
2450         task.cancel()
2451
2452
2453 def shutdown(loop: BaseEventLoop) -> None:
2454     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2455     try:
2456         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2457         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2458         if not to_cancel:
2459             return
2460
2461         for task in to_cancel:
2462             task.cancel()
2463         loop.run_until_complete(
2464             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2465         )
2466     finally:
2467         # `concurrent.futures.Future` objects cannot be cancelled once they
2468         # are already running. There might be some when the `shutdown()` happened.
2469         # Silence their logger's spew about the event loop being closed.
2470         cf_logger = logging.getLogger("concurrent.futures")
2471         cf_logger.setLevel(logging.CRITICAL)
2472         loop.close()
2473
2474
2475 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2476     """Replace `regex` with `replacement` twice on `original`.
2477
2478     This is used by string normalization to perform replaces on
2479     overlapping matches.
2480     """
2481     return regex.sub(replacement, regex.sub(replacement, original))
2482
2483
2484 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2485
2486
2487 def get_cache_file(line_length: int) -> Path:
2488     return CACHE_DIR / f"cache.{line_length}.pickle"
2489
2490
2491 def read_cache(line_length: int) -> Cache:
2492     """Read the cache if it exists and is well formed.
2493
2494     If it is not well formed, the call to write_cache later should resolve the issue.
2495     """
2496     cache_file = get_cache_file(line_length)
2497     if not cache_file.exists():
2498         return {}
2499
2500     with cache_file.open("rb") as fobj:
2501         try:
2502             cache: Cache = pickle.load(fobj)
2503         except pickle.UnpicklingError:
2504             return {}
2505
2506     return cache
2507
2508
2509 def get_cache_info(path: Path) -> CacheInfo:
2510     """Return the information used to check if a file is already formatted or not."""
2511     stat = path.stat()
2512     return stat.st_mtime, stat.st_size
2513
2514
2515 def filter_cached(
2516     cache: Cache, sources: Iterable[Path]
2517 ) -> Tuple[List[Path], List[Path]]:
2518     """Split a list of paths into two.
2519
2520     The first list contains paths of files that modified on disk or are not in the
2521     cache. The other list contains paths to non-modified files.
2522     """
2523     todo, done = [], []
2524     for src in sources:
2525         src = src.resolve()
2526         if cache.get(src) != get_cache_info(src):
2527             todo.append(src)
2528         else:
2529             done.append(src)
2530     return todo, done
2531
2532
2533 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2534     """Update the cache file."""
2535     cache_file = get_cache_file(line_length)
2536     try:
2537         if not CACHE_DIR.exists():
2538             CACHE_DIR.mkdir(parents=True)
2539         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2540         with cache_file.open("wb") as fobj:
2541             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2542     except OSError:
2543         pass
2544
2545
2546 if __name__ == "__main__":
2547     main()