]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

95489f3c0c990a650c8efdc4aacb43159d7f8553
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     report = Report(check=check, quiet=quiet)
195     if len(sources) == 0:
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back != WriteBack.DIFF and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back != WriteBack.DIFF and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 COMPREHENSION_PRIORITY = 20
585 COMMA_PRIORITY = 10
586 TERNARY_PRIORITY = 7
587 LOGIC_PRIORITY = 5
588 STRING_PRIORITY = 4
589 COMPARATOR_PRIORITY = 3
590 MATH_PRIORITY = 1
591
592
593 @dataclass
594 class BracketTracker:
595     """Keeps track of brackets on a line."""
596
597     depth: int = 0
598     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
599     delimiters: Dict[LeafID, Priority] = Factory(dict)
600     previous: Optional[Leaf] = None
601     _for_loop_variable: bool = False
602     _lambda_arguments: bool = False
603
604     def mark(self, leaf: Leaf) -> None:
605         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
606
607         All leaves receive an int `bracket_depth` field that stores how deep
608         within brackets a given leaf is. 0 means there are no enclosing brackets
609         that started on this line.
610
611         If a leaf is itself a closing bracket, it receives an `opening_bracket`
612         field that it forms a pair with. This is a one-directional link to
613         avoid reference cycles.
614
615         If a leaf is a delimiter (a token on which Black can split the line if
616         needed) and it's on depth 0, its `id()` is stored in the tracker's
617         `delimiters` field.
618         """
619         if leaf.type == token.COMMENT:
620             return
621
622         self.maybe_decrement_after_for_loop_variable(leaf)
623         self.maybe_decrement_after_lambda_arguments(leaf)
624         if leaf.type in CLOSING_BRACKETS:
625             self.depth -= 1
626             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
627             leaf.opening_bracket = opening_bracket
628         leaf.bracket_depth = self.depth
629         if self.depth == 0:
630             delim = is_split_before_delimiter(leaf, self.previous)
631             if delim and self.previous is not None:
632                 self.delimiters[id(self.previous)] = delim
633             else:
634                 delim = is_split_after_delimiter(leaf, self.previous)
635                 if delim:
636                     self.delimiters[id(leaf)] = delim
637         if leaf.type in OPENING_BRACKETS:
638             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
639             self.depth += 1
640         self.previous = leaf
641         self.maybe_increment_lambda_arguments(leaf)
642         self.maybe_increment_for_loop_variable(leaf)
643
644     def any_open_brackets(self) -> bool:
645         """Return True if there is an yet unmatched open bracket on the line."""
646         return bool(self.bracket_match)
647
648     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
649         """Return the highest priority of a delimiter found on the line.
650
651         Values are consistent with what `is_split_*_delimiter()` return.
652         Raises ValueError on no delimiters.
653         """
654         return max(v for k, v in self.delimiters.items() if k not in exclude)
655
656     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
657         """In a for loop, or comprehension, the variables are often unpacks.
658
659         To avoid splitting on the comma in this situation, increase the depth of
660         tokens between `for` and `in`.
661         """
662         if leaf.type == token.NAME and leaf.value == "for":
663             self.depth += 1
664             self._for_loop_variable = True
665             return True
666
667         return False
668
669     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
670         """See `maybe_increment_for_loop_variable` above for explanation."""
671         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
672             self.depth -= 1
673             self._for_loop_variable = False
674             return True
675
676         return False
677
678     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
679         """In a lambda expression, there might be more than one argument.
680
681         To avoid splitting on the comma in this situation, increase the depth of
682         tokens between `lambda` and `:`.
683         """
684         if leaf.type == token.NAME and leaf.value == "lambda":
685             self.depth += 1
686             self._lambda_arguments = True
687             return True
688
689         return False
690
691     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
692         """See `maybe_increment_lambda_arguments` above for explanation."""
693         if self._lambda_arguments and leaf.type == token.COLON:
694             self.depth -= 1
695             self._lambda_arguments = False
696             return True
697
698         return False
699
700
701 @dataclass
702 class Line:
703     """Holds leaves and comments. Can be printed with `str(line)`."""
704
705     depth: int = 0
706     leaves: List[Leaf] = Factory(list)
707     comments: List[Tuple[Index, Leaf]] = Factory(list)
708     bracket_tracker: BracketTracker = Factory(BracketTracker)
709     inside_brackets: bool = False
710
711     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
712         """Add a new `leaf` to the end of the line.
713
714         Unless `preformatted` is True, the `leaf` will receive a new consistent
715         whitespace prefix and metadata applied by :class:`BracketTracker`.
716         Trailing commas are maybe removed, unpacked for loop variables are
717         demoted from being delimiters.
718
719         Inline comments are put aside.
720         """
721         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
722         if not has_value:
723             return
724
725         if self.leaves and not preformatted:
726             # Note: at this point leaf.prefix should be empty except for
727             # imports, for which we only preserve newlines.
728             leaf.prefix += whitespace(leaf)
729         if self.inside_brackets or not preformatted:
730             self.bracket_tracker.mark(leaf)
731             self.maybe_remove_trailing_comma(leaf)
732
733         if not self.append_comment(leaf):
734             self.leaves.append(leaf)
735
736     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
737         """Like :func:`append()` but disallow invalid standalone comment structure.
738
739         Raises ValueError when any `leaf` is appended after a standalone comment
740         or when a standalone comment is not the first leaf on the line.
741         """
742         if self.bracket_tracker.depth == 0:
743             if self.is_comment:
744                 raise ValueError("cannot append to standalone comments")
745
746             if self.leaves and leaf.type == STANDALONE_COMMENT:
747                 raise ValueError(
748                     "cannot append standalone comments to a populated line"
749                 )
750
751         self.append(leaf, preformatted=preformatted)
752
753     @property
754     def is_comment(self) -> bool:
755         """Is this line a standalone comment?"""
756         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
757
758     @property
759     def is_decorator(self) -> bool:
760         """Is this line a decorator?"""
761         return bool(self) and self.leaves[0].type == token.AT
762
763     @property
764     def is_import(self) -> bool:
765         """Is this an import line?"""
766         return bool(self) and is_import(self.leaves[0])
767
768     @property
769     def is_class(self) -> bool:
770         """Is this line a class definition?"""
771         return (
772             bool(self)
773             and self.leaves[0].type == token.NAME
774             and self.leaves[0].value == "class"
775         )
776
777     @property
778     def is_def(self) -> bool:
779         """Is this a function definition? (Also returns True for async defs.)"""
780         try:
781             first_leaf = self.leaves[0]
782         except IndexError:
783             return False
784
785         try:
786             second_leaf: Optional[Leaf] = self.leaves[1]
787         except IndexError:
788             second_leaf = None
789         return (
790             (first_leaf.type == token.NAME and first_leaf.value == "def")
791             or (
792                 first_leaf.type == token.ASYNC
793                 and second_leaf is not None
794                 and second_leaf.type == token.NAME
795                 and second_leaf.value == "def"
796             )
797         )
798
799     @property
800     def is_flow_control(self) -> bool:
801         """Is this line a flow control statement?
802
803         Those are `return`, `raise`, `break`, and `continue`.
804         """
805         return (
806             bool(self)
807             and self.leaves[0].type == token.NAME
808             and self.leaves[0].value in FLOW_CONTROL
809         )
810
811     @property
812     def is_yield(self) -> bool:
813         """Is this line a yield statement?"""
814         return (
815             bool(self)
816             and self.leaves[0].type == token.NAME
817             and self.leaves[0].value == "yield"
818         )
819
820     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
821         """If so, needs to be split before emitting."""
822         for leaf in self.leaves:
823             if leaf.type == STANDALONE_COMMENT:
824                 if leaf.bracket_depth <= depth_limit:
825                     return True
826
827         return False
828
829     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
830         """Remove trailing comma if there is one and it's safe."""
831         if not (
832             self.leaves
833             and self.leaves[-1].type == token.COMMA
834             and closing.type in CLOSING_BRACKETS
835         ):
836             return False
837
838         if closing.type == token.RBRACE:
839             self.remove_trailing_comma()
840             return True
841
842         if closing.type == token.RSQB:
843             comma = self.leaves[-1]
844             if comma.parent and comma.parent.type == syms.listmaker:
845                 self.remove_trailing_comma()
846                 return True
847
848         # For parens let's check if it's safe to remove the comma.  If the
849         # trailing one is the only one, we might mistakenly change a tuple
850         # into a different type by removing the comma.
851         depth = closing.bracket_depth + 1
852         commas = 0
853         opening = closing.opening_bracket
854         for _opening_index, leaf in enumerate(self.leaves):
855             if leaf is opening:
856                 break
857
858         else:
859             return False
860
861         for leaf in self.leaves[_opening_index + 1:]:
862             if leaf is closing:
863                 break
864
865             bracket_depth = leaf.bracket_depth
866             if bracket_depth == depth and leaf.type == token.COMMA:
867                 commas += 1
868                 if leaf.parent and leaf.parent.type == syms.arglist:
869                     commas += 1
870                     break
871
872         if commas > 1:
873             self.remove_trailing_comma()
874             return True
875
876         return False
877
878     def append_comment(self, comment: Leaf) -> bool:
879         """Add an inline or standalone comment to the line."""
880         if (
881             comment.type == STANDALONE_COMMENT
882             and self.bracket_tracker.any_open_brackets()
883         ):
884             comment.prefix = ""
885             return False
886
887         if comment.type != token.COMMENT:
888             return False
889
890         after = len(self.leaves) - 1
891         if after == -1:
892             comment.type = STANDALONE_COMMENT
893             comment.prefix = ""
894             return False
895
896         else:
897             self.comments.append((after, comment))
898             return True
899
900     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
901         """Generate comments that should appear directly after `leaf`."""
902         for _leaf_index, _leaf in enumerate(self.leaves):
903             if leaf is _leaf:
904                 break
905
906         else:
907             return
908
909         for index, comment_after in self.comments:
910             if _leaf_index == index:
911                 yield comment_after
912
913     def remove_trailing_comma(self) -> None:
914         """Remove the trailing comma and moves the comments attached to it."""
915         comma_index = len(self.leaves) - 1
916         for i in range(len(self.comments)):
917             comment_index, comment = self.comments[i]
918             if comment_index == comma_index:
919                 self.comments[i] = (comma_index - 1, comment)
920         self.leaves.pop()
921
922     def __str__(self) -> str:
923         """Render the line."""
924         if not self:
925             return "\n"
926
927         indent = "    " * self.depth
928         leaves = iter(self.leaves)
929         first = next(leaves)
930         res = f"{first.prefix}{indent}{first.value}"
931         for leaf in leaves:
932             res += str(leaf)
933         for _, comment in self.comments:
934             res += str(comment)
935         return res + "\n"
936
937     def __bool__(self) -> bool:
938         """Return True if the line has leaves or comments."""
939         return bool(self.leaves or self.comments)
940
941
942 class UnformattedLines(Line):
943     """Just like :class:`Line` but stores lines which aren't reformatted."""
944
945     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
946         """Just add a new `leaf` to the end of the lines.
947
948         The `preformatted` argument is ignored.
949
950         Keeps track of indentation `depth`, which is useful when the user
951         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
952         """
953         try:
954             list(generate_comments(leaf))
955         except FormatOn as f_on:
956             self.leaves.append(f_on.leaf_from_consumed(leaf))
957             raise
958
959         self.leaves.append(leaf)
960         if leaf.type == token.INDENT:
961             self.depth += 1
962         elif leaf.type == token.DEDENT:
963             self.depth -= 1
964
965     def __str__(self) -> str:
966         """Render unformatted lines from leaves which were added with `append()`.
967
968         `depth` is not used for indentation in this case.
969         """
970         if not self:
971             return "\n"
972
973         res = ""
974         for leaf in self.leaves:
975             res += str(leaf)
976         return res
977
978     def append_comment(self, comment: Leaf) -> bool:
979         """Not implemented in this class. Raises `NotImplementedError`."""
980         raise NotImplementedError("Unformatted lines don't store comments separately.")
981
982     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
983         """Does nothing and returns False."""
984         return False
985
986     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
987         """Does nothing and returns False."""
988         return False
989
990
991 @dataclass
992 class EmptyLineTracker:
993     """Provides a stateful method that returns the number of potential extra
994     empty lines needed before and after the currently processed line.
995
996     Note: this tracker works on lines that haven't been split yet.  It assumes
997     the prefix of the first leaf consists of optional newlines.  Those newlines
998     are consumed by `maybe_empty_lines()` and included in the computation.
999     """
1000     previous_line: Optional[Line] = None
1001     previous_after: int = 0
1002     previous_defs: List[int] = Factory(list)
1003
1004     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1005         """Return the number of extra empty lines before and after the `current_line`.
1006
1007         This is for separating `def`, `async def` and `class` with extra empty
1008         lines (two on module-level), as well as providing an extra empty line
1009         after flow control keywords to make them more prominent.
1010         """
1011         if isinstance(current_line, UnformattedLines):
1012             return 0, 0
1013
1014         before, after = self._maybe_empty_lines(current_line)
1015         before -= self.previous_after
1016         self.previous_after = after
1017         self.previous_line = current_line
1018         return before, after
1019
1020     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1021         max_allowed = 1
1022         if current_line.depth == 0:
1023             max_allowed = 2
1024         if current_line.leaves:
1025             # Consume the first leaf's extra newlines.
1026             first_leaf = current_line.leaves[0]
1027             before = first_leaf.prefix.count("\n")
1028             before = min(before, max_allowed)
1029             first_leaf.prefix = ""
1030         else:
1031             before = 0
1032         depth = current_line.depth
1033         while self.previous_defs and self.previous_defs[-1] >= depth:
1034             self.previous_defs.pop()
1035             before = 1 if depth else 2
1036         is_decorator = current_line.is_decorator
1037         if is_decorator or current_line.is_def or current_line.is_class:
1038             if not is_decorator:
1039                 self.previous_defs.append(depth)
1040             if self.previous_line is None:
1041                 # Don't insert empty lines before the first line in the file.
1042                 return 0, 0
1043
1044             if self.previous_line.is_decorator:
1045                 return 0, 0
1046
1047             if (
1048                 self.previous_line.is_comment
1049                 and self.previous_line.depth == current_line.depth
1050                 and before == 0
1051             ):
1052                 return 0, 0
1053
1054             newlines = 2
1055             if current_line.depth:
1056                 newlines -= 1
1057             return newlines, 0
1058
1059         if (
1060             self.previous_line
1061             and self.previous_line.is_import
1062             and not current_line.is_import
1063             and depth == self.previous_line.depth
1064         ):
1065             return (before or 1), 0
1066
1067         return before, 0
1068
1069
1070 @dataclass
1071 class LineGenerator(Visitor[Line]):
1072     """Generates reformatted Line objects.  Empty lines are not emitted.
1073
1074     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1075     in ways that will no longer stringify to valid Python code on the tree.
1076     """
1077     current_line: Line = Factory(Line)
1078
1079     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1080         """Generate a line.
1081
1082         If the line is empty, only emit if it makes sense.
1083         If the line is too long, split it first and then generate.
1084
1085         If any lines were generated, set up a new current_line.
1086         """
1087         if not self.current_line:
1088             if self.current_line.__class__ == type:
1089                 self.current_line.depth += indent
1090             else:
1091                 self.current_line = type(depth=self.current_line.depth + indent)
1092             return  # Line is empty, don't emit. Creating a new one unnecessary.
1093
1094         complete_line = self.current_line
1095         self.current_line = type(depth=complete_line.depth + indent)
1096         yield complete_line
1097
1098     def visit(self, node: LN) -> Iterator[Line]:
1099         """Main method to visit `node` and its children.
1100
1101         Yields :class:`Line` objects.
1102         """
1103         if isinstance(self.current_line, UnformattedLines):
1104             # File contained `# fmt: off`
1105             yield from self.visit_unformatted(node)
1106
1107         else:
1108             yield from super().visit(node)
1109
1110     def visit_default(self, node: LN) -> Iterator[Line]:
1111         """Default `visit_*()` implementation. Recurses to children of `node`."""
1112         if isinstance(node, Leaf):
1113             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1114             try:
1115                 for comment in generate_comments(node):
1116                     if any_open_brackets:
1117                         # any comment within brackets is subject to splitting
1118                         self.current_line.append(comment)
1119                     elif comment.type == token.COMMENT:
1120                         # regular trailing comment
1121                         self.current_line.append(comment)
1122                         yield from self.line()
1123
1124                     else:
1125                         # regular standalone comment
1126                         yield from self.line()
1127
1128                         self.current_line.append(comment)
1129                         yield from self.line()
1130
1131             except FormatOff as f_off:
1132                 f_off.trim_prefix(node)
1133                 yield from self.line(type=UnformattedLines)
1134                 yield from self.visit(node)
1135
1136             except FormatOn as f_on:
1137                 # This only happens here if somebody says "fmt: on" multiple
1138                 # times in a row.
1139                 f_on.trim_prefix(node)
1140                 yield from self.visit_default(node)
1141
1142             else:
1143                 normalize_prefix(node, inside_brackets=any_open_brackets)
1144                 if node.type == token.STRING:
1145                     normalize_string_quotes(node)
1146                 if node.type not in WHITESPACE:
1147                     self.current_line.append(node)
1148         yield from super().visit_default(node)
1149
1150     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1151         """Increase indentation level, maybe yield a line."""
1152         # In blib2to3 INDENT never holds comments.
1153         yield from self.line(+1)
1154         yield from self.visit_default(node)
1155
1156     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1157         """Decrease indentation level, maybe yield a line."""
1158         # The current line might still wait for trailing comments.  At DEDENT time
1159         # there won't be any (they would be prefixes on the preceding NEWLINE).
1160         # Emit the line then.
1161         yield from self.line()
1162
1163         # While DEDENT has no value, its prefix may contain standalone comments
1164         # that belong to the current indentation level.  Get 'em.
1165         yield from self.visit_default(node)
1166
1167         # Finally, emit the dedent.
1168         yield from self.line(-1)
1169
1170     def visit_stmt(
1171         self, node: Node, keywords: Set[str], parens: Set[str]
1172     ) -> Iterator[Line]:
1173         """Visit a statement.
1174
1175         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1176         `def`, `with`, `class`, and `assert`.
1177
1178         The relevant Python language `keywords` for a given statement will be
1179         NAME leaves within it. This methods puts those on a separate line.
1180
1181         `parens` holds pairs of nodes where invisible parentheses should be put.
1182         Keys hold nodes after which opening parentheses should be put, values
1183         hold nodes before which closing parentheses should be put.
1184         """
1185         normalize_invisible_parens(node, parens_after=parens)
1186         for child in node.children:
1187             if child.type == token.NAME and child.value in keywords:  # type: ignore
1188                 yield from self.line()
1189
1190             yield from self.visit(child)
1191
1192     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1193         """Visit a statement without nested statements."""
1194         is_suite_like = node.parent and node.parent.type in STATEMENT
1195         if is_suite_like:
1196             yield from self.line(+1)
1197             yield from self.visit_default(node)
1198             yield from self.line(-1)
1199
1200         else:
1201             yield from self.line()
1202             yield from self.visit_default(node)
1203
1204     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1205         """Visit `async def`, `async for`, `async with`."""
1206         yield from self.line()
1207
1208         children = iter(node.children)
1209         for child in children:
1210             yield from self.visit(child)
1211
1212             if child.type == token.ASYNC:
1213                 break
1214
1215         internal_stmt = next(children)
1216         for child in internal_stmt.children:
1217             yield from self.visit(child)
1218
1219     def visit_decorators(self, node: Node) -> Iterator[Line]:
1220         """Visit decorators."""
1221         for child in node.children:
1222             yield from self.line()
1223             yield from self.visit(child)
1224
1225     def visit_import_from(self, node: Node) -> Iterator[Line]:
1226         """Visit import_from and maybe put invisible parentheses.
1227
1228         This is separate from `visit_stmt` because import statements don't
1229         support arbitrary atoms and thus handling of parentheses is custom.
1230         """
1231         check_lpar = False
1232         for index, child in enumerate(node.children):
1233             if check_lpar:
1234                 if child.type == token.LPAR:
1235                     # make parentheses invisible
1236                     child.value = ""  # type: ignore
1237                     node.children[-1].value = ""  # type: ignore
1238                 else:
1239                     # insert invisible parentheses
1240                     node.insert_child(index, Leaf(token.LPAR, ""))
1241                     node.append_child(Leaf(token.RPAR, ""))
1242                 break
1243
1244             check_lpar = (
1245                 child.type == token.NAME and child.value == "import"  # type: ignore
1246             )
1247
1248         for child in node.children:
1249             yield from self.visit(child)
1250
1251     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1252         """Remove a semicolon and put the other statement on a separate line."""
1253         yield from self.line()
1254
1255     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1256         """End of file. Process outstanding comments and end with a newline."""
1257         yield from self.visit_default(leaf)
1258         yield from self.line()
1259
1260     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1261         """Used when file contained a `# fmt: off`."""
1262         if isinstance(node, Node):
1263             for child in node.children:
1264                 yield from self.visit(child)
1265
1266         else:
1267             try:
1268                 self.current_line.append(node)
1269             except FormatOn as f_on:
1270                 f_on.trim_prefix(node)
1271                 yield from self.line()
1272                 yield from self.visit(node)
1273
1274             if node.type == token.ENDMARKER:
1275                 # somebody decided not to put a final `# fmt: on`
1276                 yield from self.line()
1277
1278     def __attrs_post_init__(self) -> None:
1279         """You are in a twisty little maze of passages."""
1280         v = self.visit_stmt
1281         Ø: Set[str] = set()
1282         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1283         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1284         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1285         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1286         self.visit_try_stmt = partial(
1287             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1288         )
1289         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1290         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1291         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1292         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1293         self.visit_async_funcdef = self.visit_async_stmt
1294         self.visit_decorated = self.visit_decorators
1295
1296
1297 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1298 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1299 OPENING_BRACKETS = set(BRACKET.keys())
1300 CLOSING_BRACKETS = set(BRACKET.values())
1301 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1302 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1303
1304
1305 def whitespace(leaf: Leaf) -> str:  # noqa C901
1306     """Return whitespace prefix if needed for the given `leaf`."""
1307     NO = ""
1308     SPACE = " "
1309     DOUBLESPACE = "  "
1310     t = leaf.type
1311     p = leaf.parent
1312     v = leaf.value
1313     if t in ALWAYS_NO_SPACE:
1314         return NO
1315
1316     if t == token.COMMENT:
1317         return DOUBLESPACE
1318
1319     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1320     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1321         return NO
1322
1323     prev = leaf.prev_sibling
1324     if not prev:
1325         prevp = preceding_leaf(p)
1326         if not prevp or prevp.type in OPENING_BRACKETS:
1327             return NO
1328
1329         if t == token.COLON:
1330             return SPACE if prevp.type == token.COMMA else NO
1331
1332         if prevp.type == token.EQUAL:
1333             if prevp.parent:
1334                 if prevp.parent.type in {
1335                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1336                 }:
1337                     return NO
1338
1339                 elif prevp.parent.type == syms.typedargslist:
1340                     # A bit hacky: if the equal sign has whitespace, it means we
1341                     # previously found it's a typed argument.  So, we're using
1342                     # that, too.
1343                     return prevp.prefix
1344
1345         elif prevp.type in STARS:
1346             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1347                 return NO
1348
1349         elif prevp.type == token.COLON:
1350             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1351                 return NO
1352
1353         elif (
1354             prevp.parent
1355             and prevp.parent.type == syms.factor
1356             and prevp.type in MATH_OPERATORS
1357         ):
1358             return NO
1359
1360         elif (
1361             prevp.type == token.RIGHTSHIFT
1362             and prevp.parent
1363             and prevp.parent.type == syms.shift_expr
1364             and prevp.prev_sibling
1365             and prevp.prev_sibling.type == token.NAME
1366             and prevp.prev_sibling.value == "print"  # type: ignore
1367         ):
1368             # Python 2 print chevron
1369             return NO
1370
1371     elif prev.type in OPENING_BRACKETS:
1372         return NO
1373
1374     if p.type in {syms.parameters, syms.arglist}:
1375         # untyped function signatures or calls
1376         if not prev or prev.type != token.COMMA:
1377             return NO
1378
1379     elif p.type == syms.varargslist:
1380         # lambdas
1381         if prev and prev.type != token.COMMA:
1382             return NO
1383
1384     elif p.type == syms.typedargslist:
1385         # typed function signatures
1386         if not prev:
1387             return NO
1388
1389         if t == token.EQUAL:
1390             if prev.type != syms.tname:
1391                 return NO
1392
1393         elif prev.type == token.EQUAL:
1394             # A bit hacky: if the equal sign has whitespace, it means we
1395             # previously found it's a typed argument.  So, we're using that, too.
1396             return prev.prefix
1397
1398         elif prev.type != token.COMMA:
1399             return NO
1400
1401     elif p.type == syms.tname:
1402         # type names
1403         if not prev:
1404             prevp = preceding_leaf(p)
1405             if not prevp or prevp.type != token.COMMA:
1406                 return NO
1407
1408     elif p.type == syms.trailer:
1409         # attributes and calls
1410         if t == token.LPAR or t == token.RPAR:
1411             return NO
1412
1413         if not prev:
1414             if t == token.DOT:
1415                 prevp = preceding_leaf(p)
1416                 if not prevp or prevp.type != token.NUMBER:
1417                     return NO
1418
1419             elif t == token.LSQB:
1420                 return NO
1421
1422         elif prev.type != token.COMMA:
1423             return NO
1424
1425     elif p.type == syms.argument:
1426         # single argument
1427         if t == token.EQUAL:
1428             return NO
1429
1430         if not prev:
1431             prevp = preceding_leaf(p)
1432             if not prevp or prevp.type == token.LPAR:
1433                 return NO
1434
1435         elif prev.type in {token.EQUAL} | STARS:
1436             return NO
1437
1438     elif p.type == syms.decorator:
1439         # decorators
1440         return NO
1441
1442     elif p.type == syms.dotted_name:
1443         if prev:
1444             return NO
1445
1446         prevp = preceding_leaf(p)
1447         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1448             return NO
1449
1450     elif p.type == syms.classdef:
1451         if t == token.LPAR:
1452             return NO
1453
1454         if prev and prev.type == token.LPAR:
1455             return NO
1456
1457     elif p.type == syms.subscript:
1458         # indexing
1459         if not prev:
1460             assert p.parent is not None, "subscripts are always parented"
1461             if p.parent.type == syms.subscriptlist:
1462                 return SPACE
1463
1464             return NO
1465
1466         else:
1467             return NO
1468
1469     elif p.type == syms.atom:
1470         if prev and t == token.DOT:
1471             # dots, but not the first one.
1472             return NO
1473
1474     elif p.type == syms.dictsetmaker:
1475         # dict unpacking
1476         if prev and prev.type == token.DOUBLESTAR:
1477             return NO
1478
1479     elif p.type in {syms.factor, syms.star_expr}:
1480         # unary ops
1481         if not prev:
1482             prevp = preceding_leaf(p)
1483             if not prevp or prevp.type in OPENING_BRACKETS:
1484                 return NO
1485
1486             prevp_parent = prevp.parent
1487             assert prevp_parent is not None
1488             if (
1489                 prevp.type == token.COLON
1490                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1491             ):
1492                 return NO
1493
1494             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1495                 return NO
1496
1497         elif t == token.NAME or t == token.NUMBER:
1498             return NO
1499
1500     elif p.type == syms.import_from:
1501         if t == token.DOT:
1502             if prev and prev.type == token.DOT:
1503                 return NO
1504
1505         elif t == token.NAME:
1506             if v == "import":
1507                 return SPACE
1508
1509             if prev and prev.type == token.DOT:
1510                 return NO
1511
1512     elif p.type == syms.sliceop:
1513         return NO
1514
1515     return SPACE
1516
1517
1518 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1519     """Return the first leaf that precedes `node`, if any."""
1520     while node:
1521         res = node.prev_sibling
1522         if res:
1523             if isinstance(res, Leaf):
1524                 return res
1525
1526             try:
1527                 return list(res.leaves())[-1]
1528
1529             except IndexError:
1530                 return None
1531
1532         node = node.parent
1533     return None
1534
1535
1536 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1537     """Return the priority of the `leaf` delimiter, given a line break after it.
1538
1539     The delimiter priorities returned here are from those delimiters that would
1540     cause a line break after themselves.
1541
1542     Higher numbers are higher priority.
1543     """
1544     if leaf.type == token.COMMA:
1545         return COMMA_PRIORITY
1546
1547     return 0
1548
1549
1550 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1551     """Return the priority of the `leaf` delimiter, given a line before after it.
1552
1553     The delimiter priorities returned here are from those delimiters that would
1554     cause a line break before themselves.
1555
1556     Higher numbers are higher priority.
1557     """
1558     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1559         # * and ** might also be MATH_OPERATORS but in this case they are not.
1560         # Don't treat them as a delimiter.
1561         return 0
1562
1563     if (
1564         leaf.type in MATH_OPERATORS
1565         and leaf.parent
1566         and leaf.parent.type not in {syms.factor, syms.star_expr}
1567     ):
1568         return MATH_PRIORITY
1569
1570     if leaf.type in COMPARATORS:
1571         return COMPARATOR_PRIORITY
1572
1573     if (
1574         leaf.type == token.STRING
1575         and previous is not None
1576         and previous.type == token.STRING
1577     ):
1578         return STRING_PRIORITY
1579
1580     if (
1581         leaf.type == token.NAME
1582         and leaf.value == "for"
1583         and leaf.parent
1584         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1585     ):
1586         return COMPREHENSION_PRIORITY
1587
1588     if (
1589         leaf.type == token.NAME
1590         and leaf.value == "if"
1591         and leaf.parent
1592         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1593     ):
1594         return COMPREHENSION_PRIORITY
1595
1596     if (
1597         leaf.type == token.NAME
1598         and leaf.value in {"if", "else"}
1599         and leaf.parent
1600         and leaf.parent.type == syms.test
1601     ):
1602         return TERNARY_PRIORITY
1603
1604     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1605         return LOGIC_PRIORITY
1606
1607     return 0
1608
1609
1610 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1611     """Clean the prefix of the `leaf` and generate comments from it, if any.
1612
1613     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1614     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1615     move because it does away with modifying the grammar to include all the
1616     possible places in which comments can be placed.
1617
1618     The sad consequence for us though is that comments don't "belong" anywhere.
1619     This is why this function generates simple parentless Leaf objects for
1620     comments.  We simply don't know what the correct parent should be.
1621
1622     No matter though, we can live without this.  We really only need to
1623     differentiate between inline and standalone comments.  The latter don't
1624     share the line with any code.
1625
1626     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1627     are emitted with a fake STANDALONE_COMMENT token identifier.
1628     """
1629     p = leaf.prefix
1630     if not p:
1631         return
1632
1633     if "#" not in p:
1634         return
1635
1636     consumed = 0
1637     nlines = 0
1638     for index, line in enumerate(p.split("\n")):
1639         consumed += len(line) + 1  # adding the length of the split '\n'
1640         line = line.lstrip()
1641         if not line:
1642             nlines += 1
1643         if not line.startswith("#"):
1644             continue
1645
1646         if index == 0 and leaf.type != token.ENDMARKER:
1647             comment_type = token.COMMENT  # simple trailing comment
1648         else:
1649             comment_type = STANDALONE_COMMENT
1650         comment = make_comment(line)
1651         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1652
1653         if comment in {"# fmt: on", "# yapf: enable"}:
1654             raise FormatOn(consumed)
1655
1656         if comment in {"# fmt: off", "# yapf: disable"}:
1657             if comment_type == STANDALONE_COMMENT:
1658                 raise FormatOff(consumed)
1659
1660             prev = preceding_leaf(leaf)
1661             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1662                 raise FormatOff(consumed)
1663
1664         nlines = 0
1665
1666
1667 def make_comment(content: str) -> str:
1668     """Return a consistently formatted comment from the given `content` string.
1669
1670     All comments (except for "##", "#!", "#:") should have a single space between
1671     the hash sign and the content.
1672
1673     If `content` didn't start with a hash sign, one is provided.
1674     """
1675     content = content.rstrip()
1676     if not content:
1677         return "#"
1678
1679     if content[0] == "#":
1680         content = content[1:]
1681     if content and content[0] not in " !:#":
1682         content = " " + content
1683     return "#" + content
1684
1685
1686 def split_line(
1687     line: Line, line_length: int, inner: bool = False, py36: bool = False
1688 ) -> Iterator[Line]:
1689     """Split a `line` into potentially many lines.
1690
1691     They should fit in the allotted `line_length` but might not be able to.
1692     `inner` signifies that there were a pair of brackets somewhere around the
1693     current `line`, possibly transitively. This means we can fallback to splitting
1694     by delimiters if the LHS/RHS don't yield any results.
1695
1696     If `py36` is True, splitting may generate syntax that is only compatible
1697     with Python 3.6 and later.
1698     """
1699     if isinstance(line, UnformattedLines) or line.is_comment:
1700         yield line
1701         return
1702
1703     line_str = str(line).strip("\n")
1704     if (
1705         len(line_str) <= line_length
1706         and "\n" not in line_str  # multiline strings
1707         and not line.contains_standalone_comments()
1708     ):
1709         yield line
1710         return
1711
1712     split_funcs: List[SplitFunc]
1713     if line.is_def:
1714         split_funcs = [left_hand_split]
1715     elif line.inside_brackets:
1716         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1717     else:
1718         split_funcs = [right_hand_split]
1719     for split_func in split_funcs:
1720         # We are accumulating lines in `result` because we might want to abort
1721         # mission and return the original line in the end, or attempt a different
1722         # split altogether.
1723         result: List[Line] = []
1724         try:
1725             for l in split_func(line, py36):
1726                 if str(l).strip("\n") == line_str:
1727                     raise CannotSplit("Split function returned an unchanged result")
1728
1729                 result.extend(
1730                     split_line(l, line_length=line_length, inner=True, py36=py36)
1731                 )
1732         except CannotSplit as cs:
1733             continue
1734
1735         else:
1736             yield from result
1737             break
1738
1739     else:
1740         yield line
1741
1742
1743 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1744     """Split line into many lines, starting with the first matching bracket pair.
1745
1746     Note: this usually looks weird, only use this for function definitions.
1747     Prefer RHS otherwise.
1748     """
1749     head = Line(depth=line.depth)
1750     body = Line(depth=line.depth + 1, inside_brackets=True)
1751     tail = Line(depth=line.depth)
1752     tail_leaves: List[Leaf] = []
1753     body_leaves: List[Leaf] = []
1754     head_leaves: List[Leaf] = []
1755     current_leaves = head_leaves
1756     matching_bracket = None
1757     for leaf in line.leaves:
1758         if (
1759             current_leaves is body_leaves
1760             and leaf.type in CLOSING_BRACKETS
1761             and leaf.opening_bracket is matching_bracket
1762         ):
1763             current_leaves = tail_leaves if body_leaves else head_leaves
1764         current_leaves.append(leaf)
1765         if current_leaves is head_leaves:
1766             if leaf.type in OPENING_BRACKETS:
1767                 matching_bracket = leaf
1768                 current_leaves = body_leaves
1769     # Since body is a new indent level, remove spurious leading whitespace.
1770     if body_leaves:
1771         normalize_prefix(body_leaves[0], inside_brackets=True)
1772     # Build the new lines.
1773     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1774         for leaf in leaves:
1775             result.append(leaf, preformatted=True)
1776             for comment_after in line.comments_after(leaf):
1777                 result.append(comment_after, preformatted=True)
1778     bracket_split_succeeded_or_raise(head, body, tail)
1779     for result in (head, body, tail):
1780         if result:
1781             yield result
1782
1783
1784 def right_hand_split(
1785     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1786 ) -> Iterator[Line]:
1787     """Split line into many lines, starting with the last matching bracket pair."""
1788     head = Line(depth=line.depth)
1789     body = Line(depth=line.depth + 1, inside_brackets=True)
1790     tail = Line(depth=line.depth)
1791     tail_leaves: List[Leaf] = []
1792     body_leaves: List[Leaf] = []
1793     head_leaves: List[Leaf] = []
1794     current_leaves = tail_leaves
1795     opening_bracket = None
1796     closing_bracket = None
1797     for leaf in reversed(line.leaves):
1798         if current_leaves is body_leaves:
1799             if leaf is opening_bracket:
1800                 current_leaves = head_leaves if body_leaves else tail_leaves
1801         current_leaves.append(leaf)
1802         if current_leaves is tail_leaves:
1803             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1804                 opening_bracket = leaf.opening_bracket
1805                 closing_bracket = leaf
1806                 current_leaves = body_leaves
1807     tail_leaves.reverse()
1808     body_leaves.reverse()
1809     head_leaves.reverse()
1810     # Since body is a new indent level, remove spurious leading whitespace.
1811     if body_leaves:
1812         normalize_prefix(body_leaves[0], inside_brackets=True)
1813     elif not head_leaves:
1814         # No `head` and no `body` means the split failed. `tail` has all content.
1815         raise CannotSplit("No brackets found")
1816
1817     # Build the new lines.
1818     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1819         for leaf in leaves:
1820             result.append(leaf, preformatted=True)
1821             for comment_after in line.comments_after(leaf):
1822                 result.append(comment_after, preformatted=True)
1823     bracket_split_succeeded_or_raise(head, body, tail)
1824     assert opening_bracket and closing_bracket
1825     if (
1826         opening_bracket.type == token.LPAR
1827         and not opening_bracket.value
1828         and closing_bracket.type == token.RPAR
1829         and not closing_bracket.value
1830     ):
1831         # These parens were optional. If there aren't any delimiters or standalone
1832         # comments in the body, they were unnecessary and another split without
1833         # them should be attempted.
1834         if not (
1835             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1836         ):
1837             omit = {id(closing_bracket), *omit}
1838             yield from right_hand_split(line, py36=py36, omit=omit)
1839             return
1840
1841     ensure_visible(opening_bracket)
1842     ensure_visible(closing_bracket)
1843     for result in (head, body, tail):
1844         if result:
1845             yield result
1846
1847
1848 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1849     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1850
1851     Do nothing otherwise.
1852
1853     A left- or right-hand split is based on a pair of brackets. Content before
1854     (and including) the opening bracket is left on one line, content inside the
1855     brackets is put on a separate line, and finally content starting with and
1856     following the closing bracket is put on a separate line.
1857
1858     Those are called `head`, `body`, and `tail`, respectively. If the split
1859     produced the same line (all content in `head`) or ended up with an empty `body`
1860     and the `tail` is just the closing bracket, then it's considered failed.
1861     """
1862     tail_len = len(str(tail).strip())
1863     if not body:
1864         if tail_len == 0:
1865             raise CannotSplit("Splitting brackets produced the same line")
1866
1867         elif tail_len < 3:
1868             raise CannotSplit(
1869                 f"Splitting brackets on an empty body to save "
1870                 f"{tail_len} characters is not worth it"
1871             )
1872
1873
1874 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1875     """Normalize prefix of the first leaf in every line returned by `split_func`.
1876
1877     This is a decorator over relevant split functions.
1878     """
1879
1880     @wraps(split_func)
1881     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1882         for l in split_func(line, py36):
1883             normalize_prefix(l.leaves[0], inside_brackets=True)
1884             yield l
1885
1886     return split_wrapper
1887
1888
1889 @dont_increase_indentation
1890 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1891     """Split according to delimiters of the highest priority.
1892
1893     If `py36` is True, the split will add trailing commas also in function
1894     signatures that contain `*` and `**`.
1895     """
1896     try:
1897         last_leaf = line.leaves[-1]
1898     except IndexError:
1899         raise CannotSplit("Line empty")
1900
1901     delimiters = line.bracket_tracker.delimiters
1902     try:
1903         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1904             exclude={id(last_leaf)}
1905         )
1906     except ValueError:
1907         raise CannotSplit("No delimiters found")
1908
1909     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1910     lowest_depth = sys.maxsize
1911     trailing_comma_safe = True
1912
1913     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1914         """Append `leaf` to current line or to new line if appending impossible."""
1915         nonlocal current_line
1916         try:
1917             current_line.append_safe(leaf, preformatted=True)
1918         except ValueError as ve:
1919             yield current_line
1920
1921             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1922             current_line.append(leaf)
1923
1924     for leaf in line.leaves:
1925         yield from append_to_line(leaf)
1926
1927         for comment_after in line.comments_after(leaf):
1928             yield from append_to_line(comment_after)
1929
1930         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1931         if (
1932             leaf.bracket_depth == lowest_depth
1933             and is_vararg(leaf, within=VARARGS_PARENTS)
1934         ):
1935             trailing_comma_safe = trailing_comma_safe and py36
1936         leaf_priority = delimiters.get(id(leaf))
1937         if leaf_priority == delimiter_priority:
1938             yield current_line
1939
1940             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1941     if current_line:
1942         if (
1943             trailing_comma_safe
1944             and delimiter_priority == COMMA_PRIORITY
1945             and current_line.leaves[-1].type != token.COMMA
1946             and current_line.leaves[-1].type != STANDALONE_COMMENT
1947         ):
1948             current_line.append(Leaf(token.COMMA, ","))
1949         yield current_line
1950
1951
1952 @dont_increase_indentation
1953 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1954     """Split standalone comments from the rest of the line."""
1955     if not line.contains_standalone_comments(0):
1956         raise CannotSplit("Line does not have any standalone comments")
1957
1958     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1959
1960     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1961         """Append `leaf` to current line or to new line if appending impossible."""
1962         nonlocal current_line
1963         try:
1964             current_line.append_safe(leaf, preformatted=True)
1965         except ValueError as ve:
1966             yield current_line
1967
1968             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1969             current_line.append(leaf)
1970
1971     for leaf in line.leaves:
1972         yield from append_to_line(leaf)
1973
1974         for comment_after in line.comments_after(leaf):
1975             yield from append_to_line(comment_after)
1976
1977     if current_line:
1978         yield current_line
1979
1980
1981 def is_import(leaf: Leaf) -> bool:
1982     """Return True if the given leaf starts an import statement."""
1983     p = leaf.parent
1984     t = leaf.type
1985     v = leaf.value
1986     return bool(
1987         t == token.NAME
1988         and (
1989             (v == "import" and p and p.type == syms.import_name)
1990             or (v == "from" and p and p.type == syms.import_from)
1991         )
1992     )
1993
1994
1995 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1996     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1997     else.
1998
1999     Note: don't use backslashes for formatting or you'll lose your voting rights.
2000     """
2001     if not inside_brackets:
2002         spl = leaf.prefix.split("#")
2003         if "\\" not in spl[0]:
2004             nl_count = spl[-1].count("\n")
2005             if len(spl) > 1:
2006                 nl_count -= 1
2007             leaf.prefix = "\n" * nl_count
2008             return
2009
2010     leaf.prefix = ""
2011
2012
2013 def normalize_string_quotes(leaf: Leaf) -> None:
2014     """Prefer double quotes but only if it doesn't cause more escaping.
2015
2016     Adds or removes backslashes as appropriate. Doesn't parse and fix
2017     strings nested in f-strings (yet).
2018
2019     Note: Mutates its argument.
2020     """
2021     value = leaf.value.lstrip("furbFURB")
2022     if value[:3] == '"""':
2023         return
2024
2025     elif value[:3] == "'''":
2026         orig_quote = "'''"
2027         new_quote = '"""'
2028     elif value[0] == '"':
2029         orig_quote = '"'
2030         new_quote = "'"
2031     else:
2032         orig_quote = "'"
2033         new_quote = '"'
2034     first_quote_pos = leaf.value.find(orig_quote)
2035     if first_quote_pos == -1:
2036         return  # There's an internal error
2037
2038     prefix = leaf.value[:first_quote_pos]
2039     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2040     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2041     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2042     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2043     if "r" in prefix.casefold():
2044         if unescaped_new_quote.search(body):
2045             # There's at least one unescaped new_quote in this raw string
2046             # so converting is impossible
2047             return
2048
2049         # Do not introduce or remove backslashes in raw strings
2050         new_body = body
2051     else:
2052         # remove unnecessary quotes
2053         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2054         if body != new_body:
2055             # Consider the string without unnecessary quotes as the original
2056             body = new_body
2057             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2058         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2059         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2060     if new_quote == '"""' and new_body[-1] == '"':
2061         # edge case:
2062         new_body = new_body[:-1] + '\\"'
2063     orig_escape_count = body.count("\\")
2064     new_escape_count = new_body.count("\\")
2065     if new_escape_count > orig_escape_count:
2066         return  # Do not introduce more escaping
2067
2068     if new_escape_count == orig_escape_count and orig_quote == '"':
2069         return  # Prefer double quotes
2070
2071     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2072
2073
2074 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2075     """Make existing optional parentheses invisible or create new ones.
2076
2077     Standardizes on visible parentheses for single-element tuples, and keeps
2078     existing visible parentheses for other tuples and generator expressions.
2079     """
2080     check_lpar = False
2081     for child in list(node.children):
2082         if check_lpar:
2083             if child.type == syms.atom:
2084                 if not (
2085                     is_empty_tuple(child)
2086                     or is_one_tuple(child)
2087                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2088                 ):
2089                     first = child.children[0]
2090                     last = child.children[-1]
2091                     if first.type == token.LPAR and last.type == token.RPAR:
2092                         # make parentheses invisible
2093                         first.value = ""  # type: ignore
2094                         last.value = ""  # type: ignore
2095             elif is_one_tuple(child):
2096                 # wrap child in visible parentheses
2097                 lpar = Leaf(token.LPAR, "(")
2098                 rpar = Leaf(token.RPAR, ")")
2099                 index = child.remove() or 0
2100                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2101             else:
2102                 # wrap child in invisible parentheses
2103                 lpar = Leaf(token.LPAR, "")
2104                 rpar = Leaf(token.RPAR, "")
2105                 index = child.remove() or 0
2106                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2107
2108         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2109
2110
2111 def is_empty_tuple(node: LN) -> bool:
2112     """Return True if `node` holds an empty tuple."""
2113     return (
2114         node.type == syms.atom
2115         and len(node.children) == 2
2116         and node.children[0].type == token.LPAR
2117         and node.children[1].type == token.RPAR
2118     )
2119
2120
2121 def is_one_tuple(node: LN) -> bool:
2122     """Return True if `node` holds a tuple with one element, with or without parens."""
2123     if node.type == syms.atom:
2124         if len(node.children) != 3:
2125             return False
2126
2127         lpar, gexp, rpar = node.children
2128         if not (
2129             lpar.type == token.LPAR
2130             and gexp.type == syms.testlist_gexp
2131             and rpar.type == token.RPAR
2132         ):
2133             return False
2134
2135         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2136
2137     return (
2138         node.type in IMPLICIT_TUPLE
2139         and len(node.children) == 2
2140         and node.children[1].type == token.COMMA
2141     )
2142
2143
2144 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2145     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2146
2147     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2148     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2149     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2150     generalizations (PEP 448).
2151     """
2152     if leaf.type not in STARS or not leaf.parent:
2153         return False
2154
2155     p = leaf.parent
2156     if p.type == syms.star_expr:
2157         # Star expressions are also used as assignment targets in extended
2158         # iterable unpacking (PEP 3132).  See what its parent is instead.
2159         if not p.parent:
2160             return False
2161
2162         p = p.parent
2163
2164     return p.type in within
2165
2166
2167 def max_delimiter_priority_in_atom(node: LN) -> int:
2168     """Return maximum delimiter priority inside `node`.
2169
2170     This is specific to atoms with contents contained in a pair of parentheses.
2171     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2172     """
2173     if node.type != syms.atom:
2174         return 0
2175
2176     first = node.children[0]
2177     last = node.children[-1]
2178     if not (first.type == token.LPAR and last.type == token.RPAR):
2179         return 0
2180
2181     bt = BracketTracker()
2182     for c in node.children[1:-1]:
2183         if isinstance(c, Leaf):
2184             bt.mark(c)
2185         else:
2186             for leaf in c.leaves():
2187                 bt.mark(leaf)
2188     try:
2189         return bt.max_delimiter_priority()
2190
2191     except ValueError:
2192         return 0
2193
2194
2195 def ensure_visible(leaf: Leaf) -> None:
2196     """Make sure parentheses are visible.
2197
2198     They could be invisible as part of some statements (see
2199     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2200     """
2201     if leaf.type == token.LPAR:
2202         leaf.value = "("
2203     elif leaf.type == token.RPAR:
2204         leaf.value = ")"
2205
2206
2207 def is_python36(node: Node) -> bool:
2208     """Return True if the current file is using Python 3.6+ features.
2209
2210     Currently looking for:
2211     - f-strings; and
2212     - trailing commas after * or ** in function signatures.
2213     """
2214     for n in node.pre_order():
2215         if n.type == token.STRING:
2216             value_head = n.value[:2]  # type: ignore
2217             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2218                 return True
2219
2220         elif (
2221             n.type == syms.typedargslist
2222             and n.children
2223             and n.children[-1].type == token.COMMA
2224         ):
2225             for ch in n.children:
2226                 if ch.type in STARS:
2227                     return True
2228
2229     return False
2230
2231
2232 PYTHON_EXTENSIONS = {".py"}
2233 BLACKLISTED_DIRECTORIES = {
2234     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2235 }
2236
2237
2238 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2239     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2240     and have one of the PYTHON_EXTENSIONS.
2241     """
2242     for child in path.iterdir():
2243         if child.is_dir():
2244             if child.name in BLACKLISTED_DIRECTORIES:
2245                 continue
2246
2247             yield from gen_python_files_in_dir(child)
2248
2249         elif child.suffix in PYTHON_EXTENSIONS:
2250             yield child
2251
2252
2253 @dataclass
2254 class Report:
2255     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2256     check: bool = False
2257     quiet: bool = False
2258     change_count: int = 0
2259     same_count: int = 0
2260     failure_count: int = 0
2261
2262     def done(self, src: Path, changed: Changed) -> None:
2263         """Increment the counter for successful reformatting. Write out a message."""
2264         if changed is Changed.YES:
2265             reformatted = "would reformat" if self.check else "reformatted"
2266             if not self.quiet:
2267                 out(f"{reformatted} {src}")
2268             self.change_count += 1
2269         else:
2270             if not self.quiet:
2271                 if changed is Changed.NO:
2272                     msg = f"{src} already well formatted, good job."
2273                 else:
2274                     msg = f"{src} wasn't modified on disk since last run."
2275                 out(msg, bold=False)
2276             self.same_count += 1
2277
2278     def failed(self, src: Path, message: str) -> None:
2279         """Increment the counter for failed reformatting. Write out a message."""
2280         err(f"error: cannot format {src}: {message}")
2281         self.failure_count += 1
2282
2283     @property
2284     def return_code(self) -> int:
2285         """Return the exit code that the app should use.
2286
2287         This considers the current state of changed files and failures:
2288         - if there were any failures, return 123;
2289         - if any files were changed and --check is being used, return 1;
2290         - otherwise return 0.
2291         """
2292         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2293         # 126 we have special returncodes reserved by the shell.
2294         if self.failure_count:
2295             return 123
2296
2297         elif self.change_count and self.check:
2298             return 1
2299
2300         return 0
2301
2302     def __str__(self) -> str:
2303         """Render a color report of the current state.
2304
2305         Use `click.unstyle` to remove colors.
2306         """
2307         if self.check:
2308             reformatted = "would be reformatted"
2309             unchanged = "would be left unchanged"
2310             failed = "would fail to reformat"
2311         else:
2312             reformatted = "reformatted"
2313             unchanged = "left unchanged"
2314             failed = "failed to reformat"
2315         report = []
2316         if self.change_count:
2317             s = "s" if self.change_count > 1 else ""
2318             report.append(
2319                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2320             )
2321         if self.same_count:
2322             s = "s" if self.same_count > 1 else ""
2323             report.append(f"{self.same_count} file{s} {unchanged}")
2324         if self.failure_count:
2325             s = "s" if self.failure_count > 1 else ""
2326             report.append(
2327                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2328             )
2329         return ", ".join(report) + "."
2330
2331
2332 def assert_equivalent(src: str, dst: str) -> None:
2333     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2334
2335     import ast
2336     import traceback
2337
2338     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2339         """Simple visitor generating strings to compare ASTs by content."""
2340         yield f"{'  ' * depth}{node.__class__.__name__}("
2341
2342         for field in sorted(node._fields):
2343             try:
2344                 value = getattr(node, field)
2345             except AttributeError:
2346                 continue
2347
2348             yield f"{'  ' * (depth+1)}{field}="
2349
2350             if isinstance(value, list):
2351                 for item in value:
2352                     if isinstance(item, ast.AST):
2353                         yield from _v(item, depth + 2)
2354
2355             elif isinstance(value, ast.AST):
2356                 yield from _v(value, depth + 2)
2357
2358             else:
2359                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2360
2361         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2362
2363     try:
2364         src_ast = ast.parse(src)
2365     except Exception as exc:
2366         major, minor = sys.version_info[:2]
2367         raise AssertionError(
2368             f"cannot use --safe with this file; failed to parse source file "
2369             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2370             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2371         )
2372
2373     try:
2374         dst_ast = ast.parse(dst)
2375     except Exception as exc:
2376         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2377         raise AssertionError(
2378             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2379             f"Please report a bug on https://github.com/ambv/black/issues.  "
2380             f"This invalid output might be helpful: {log}"
2381         ) from None
2382
2383     src_ast_str = "\n".join(_v(src_ast))
2384     dst_ast_str = "\n".join(_v(dst_ast))
2385     if src_ast_str != dst_ast_str:
2386         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2387         raise AssertionError(
2388             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2389             f"the source.  "
2390             f"Please report a bug on https://github.com/ambv/black/issues.  "
2391             f"This diff might be helpful: {log}"
2392         ) from None
2393
2394
2395 def assert_stable(src: str, dst: str, line_length: int) -> None:
2396     """Raise AssertionError if `dst` reformats differently the second time."""
2397     newdst = format_str(dst, line_length=line_length)
2398     if dst != newdst:
2399         log = dump_to_file(
2400             diff(src, dst, "source", "first pass"),
2401             diff(dst, newdst, "first pass", "second pass"),
2402         )
2403         raise AssertionError(
2404             f"INTERNAL ERROR: Black produced different code on the second pass "
2405             f"of the formatter.  "
2406             f"Please report a bug on https://github.com/ambv/black/issues.  "
2407             f"This diff might be helpful: {log}"
2408         ) from None
2409
2410
2411 def dump_to_file(*output: str) -> str:
2412     """Dump `output` to a temporary file. Return path to the file."""
2413     import tempfile
2414
2415     with tempfile.NamedTemporaryFile(
2416         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2417     ) as f:
2418         for lines in output:
2419             f.write(lines)
2420             if lines and lines[-1] != "\n":
2421                 f.write("\n")
2422     return f.name
2423
2424
2425 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2426     """Return a unified diff string between strings `a` and `b`."""
2427     import difflib
2428
2429     a_lines = [line + "\n" for line in a.split("\n")]
2430     b_lines = [line + "\n" for line in b.split("\n")]
2431     return "".join(
2432         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2433     )
2434
2435
2436 def cancel(tasks: List[asyncio.Task]) -> None:
2437     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2438     err("Aborted!")
2439     for task in tasks:
2440         task.cancel()
2441
2442
2443 def shutdown(loop: BaseEventLoop) -> None:
2444     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2445     try:
2446         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2447         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2448         if not to_cancel:
2449             return
2450
2451         for task in to_cancel:
2452             task.cancel()
2453         loop.run_until_complete(
2454             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2455         )
2456     finally:
2457         # `concurrent.futures.Future` objects cannot be cancelled once they
2458         # are already running. There might be some when the `shutdown()` happened.
2459         # Silence their logger's spew about the event loop being closed.
2460         cf_logger = logging.getLogger("concurrent.futures")
2461         cf_logger.setLevel(logging.CRITICAL)
2462         loop.close()
2463
2464
2465 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2466     """Replace `regex` with `replacement` twice on `original`.
2467
2468     This is used by string normalization to perform replaces on
2469     overlapping matches.
2470     """
2471     return regex.sub(replacement, regex.sub(replacement, original))
2472
2473
2474 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2475
2476
2477 def get_cache_file(line_length: int) -> Path:
2478     return CACHE_DIR / f"cache.{line_length}.pickle"
2479
2480
2481 def read_cache(line_length: int) -> Cache:
2482     """Read the cache if it exists and is well formed.
2483
2484     If it is not well formed, the call to write_cache later should resolve the issue.
2485     """
2486     cache_file = get_cache_file(line_length)
2487     if not cache_file.exists():
2488         return {}
2489
2490     with cache_file.open("rb") as fobj:
2491         try:
2492             cache: Cache = pickle.load(fobj)
2493         except pickle.UnpicklingError:
2494             return {}
2495
2496     return cache
2497
2498
2499 def get_cache_info(path: Path) -> CacheInfo:
2500     """Return the information used to check if a file is already formatted or not."""
2501     stat = path.stat()
2502     return stat.st_mtime, stat.st_size
2503
2504
2505 def filter_cached(
2506     cache: Cache, sources: Iterable[Path]
2507 ) -> Tuple[List[Path], List[Path]]:
2508     """Split a list of paths into two.
2509
2510     The first list contains paths of files that modified on disk or are not in the
2511     cache. The other list contains paths to non-modified files.
2512     """
2513     todo, done = [], []
2514     for src in sources:
2515         src = src.resolve()
2516         if cache.get(src) != get_cache_info(src):
2517             todo.append(src)
2518         else:
2519             done.append(src)
2520     return todo, done
2521
2522
2523 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2524     """Update the cache file."""
2525     cache_file = get_cache_file(line_length)
2526     try:
2527         if not CACHE_DIR.exists():
2528             CACHE_DIR.mkdir(parents=True)
2529         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2530         with cache_file.open("wb") as fobj:
2531             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2532     except OSError:
2533         pass
2534
2535
2536 if __name__ == "__main__":
2537     main()