]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

5458198b17560621545fdcb0d15a4c32fcdb6522
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a3"
47 DEFAULT_LINE_LENGTH = 88
48
49 # types
50 syms = pygram.python_symbols
51 FileContent = str
52 Encoding = str
53 Depth = int
54 NodeType = int
55 LeafID = int
56 Priority = int
57 Index = int
58 LN = Union[Leaf, Node]
59 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
60 Timestamp = float
61 FileSize = int
62 CacheInfo = Tuple[Timestamp, FileSize]
63 Cache = Dict[Path, CacheInfo]
64 out = partial(click.secho, bold=True, err=True)
65 err = partial(click.secho, fg="red", err=True)
66
67
68 class NothingChanged(UserWarning):
69     """Raised by :func:`format_file` when reformatted code is the same as source."""
70
71
72 class CannotSplit(Exception):
73     """A readable split that fits the allotted line length is impossible.
74
75     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
76     :func:`delimiter_split`.
77     """
78
79
80 class FormatError(Exception):
81     """Base exception for `# fmt: on` and `# fmt: off` handling.
82
83     It holds the number of bytes of the prefix consumed before the format
84     control comment appeared.
85     """
86
87     def __init__(self, consumed: int) -> None:
88         super().__init__(consumed)
89         self.consumed = consumed
90
91     def trim_prefix(self, leaf: Leaf) -> None:
92         leaf.prefix = leaf.prefix[self.consumed:]
93
94     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
95         """Returns a new Leaf from the consumed part of the prefix."""
96         unformatted_prefix = leaf.prefix[:self.consumed]
97         return Leaf(token.NEWLINE, unformatted_prefix)
98
99
100 class FormatOn(FormatError):
101     """Found a comment like `# fmt: on` in the file."""
102
103
104 class FormatOff(FormatError):
105     """Found a comment like `# fmt: off` in the file."""
106
107
108 class WriteBack(Enum):
109     NO = 0
110     YES = 1
111     DIFF = 2
112
113
114 class Changed(Enum):
115     NO = 0
116     CACHED = 1
117     YES = 2
118
119
120 @click.command()
121 @click.option(
122     "-l",
123     "--line-length",
124     type=int,
125     default=DEFAULT_LINE_LENGTH,
126     help="How many character per line to allow.",
127     show_default=True,
128 )
129 @click.option(
130     "--check",
131     is_flag=True,
132     help=(
133         "Don't write the files back, just return the status.  Return code 0 "
134         "means nothing would change.  Return code 1 means some files would be "
135         "reformatted.  Return code 123 means there was an internal error."
136     ),
137 )
138 @click.option(
139     "--diff",
140     is_flag=True,
141     help="Don't write the files back, just output a diff for each file on stdout.",
142 )
143 @click.option(
144     "--fast/--safe",
145     is_flag=True,
146     help="If --fast given, skip temporary sanity checks. [default: --safe]",
147 )
148 @click.option(
149     "-q",
150     "--quiet",
151     is_flag=True,
152     help=(
153         "Don't emit non-error messages to stderr. Errors are still emitted, "
154         "silence those with 2>/dev/null."
155     ),
156 )
157 @click.version_option(version=__version__)
158 @click.argument(
159     "src",
160     nargs=-1,
161     type=click.Path(
162         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
163     ),
164 )
165 @click.pass_context
166 def main(
167     ctx: click.Context,
168     line_length: int,
169     check: bool,
170     diff: bool,
171     fast: bool,
172     quiet: bool,
173     src: List[str],
174 ) -> None:
175     """The uncompromising code formatter."""
176     sources: List[Path] = []
177     for s in src:
178         p = Path(s)
179         if p.is_dir():
180             sources.extend(gen_python_files_in_dir(p))
181         elif p.is_file():
182             # if a file was explicitly given, we don't care about its extension
183             sources.append(p)
184         elif s == "-":
185             sources.append(Path("-"))
186         else:
187             err(f"invalid path: {s}")
188
189     if check and not diff:
190         write_back = WriteBack.NO
191     elif diff:
192         write_back = WriteBack.DIFF
193     else:
194         write_back = WriteBack.YES
195     report = Report(check=check, quiet=quiet)
196     if len(sources) == 0:
197         ctx.exit(0)
198         return
199
200     elif len(sources) == 1:
201         reformat_one(sources[0], line_length, fast, write_back, report)
202     else:
203         loop = asyncio.get_event_loop()
204         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
205         try:
206             loop.run_until_complete(
207                 schedule_formatting(
208                     sources, line_length, fast, write_back, report, loop, executor
209                 )
210             )
211         finally:
212             shutdown(loop)
213         if not quiet:
214             out("All done! ✨ 🍰 ✨")
215             click.echo(str(report))
216     ctx.exit(report.return_code)
217
218
219 def reformat_one(
220     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
221 ) -> None:
222     """Reformat a single file under `src` without spawning child processes.
223
224     If `quiet` is True, non-error messages are not output. `line_length`,
225     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
226     """
227     try:
228         changed = Changed.NO
229         if not src.is_file() and str(src) == "-":
230             if format_stdin_to_stdout(
231                 line_length=line_length, fast=fast, write_back=write_back
232             ):
233                 changed = Changed.YES
234         else:
235             cache: Cache = {}
236             if write_back != WriteBack.DIFF:
237                 cache = read_cache(line_length)
238                 src = src.resolve()
239                 if src in cache and cache[src] == get_cache_info(src):
240                     changed = Changed.CACHED
241             if (
242                 changed is not Changed.CACHED
243                 and format_file_in_place(
244                     src, line_length=line_length, fast=fast, write_back=write_back
245                 )
246             ):
247                 changed = Changed.YES
248             if write_back != WriteBack.DIFF and changed is not Changed.NO:
249                 write_cache(cache, [src], line_length)
250         report.done(src, changed)
251     except Exception as exc:
252         report.failed(src, str(exc))
253
254
255 async def schedule_formatting(
256     sources: List[Path],
257     line_length: int,
258     fast: bool,
259     write_back: WriteBack,
260     report: "Report",
261     loop: BaseEventLoop,
262     executor: Executor,
263 ) -> None:
264     """Run formatting of `sources` in parallel using the provided `executor`.
265
266     (Use ProcessPoolExecutors for actual parallelism.)
267
268     `line_length`, `write_back`, and `fast` options are passed to
269     :func:`format_file_in_place`.
270     """
271     cache: Cache = {}
272     if write_back != WriteBack.DIFF:
273         cache = read_cache(line_length)
274         sources, cached = filter_cached(cache, sources)
275         for src in cached:
276             report.done(src, Changed.CACHED)
277     cancelled = []
278     formatted = []
279     if sources:
280         lock = None
281         if write_back == WriteBack.DIFF:
282             # For diff output, we need locks to ensure we don't interleave output
283             # from different processes.
284             manager = Manager()
285             lock = manager.Lock()
286         tasks = {
287             src: loop.run_in_executor(
288                 executor, format_file_in_place, src, line_length, fast, write_back, lock
289             )
290             for src in sources
291         }
292         _task_values = list(tasks.values())
293         try:
294             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
295             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
296         except NotImplementedError:
297             # There are no good alternatives for these on Windows
298             pass
299         await asyncio.wait(_task_values)
300         for src, task in tasks.items():
301             if not task.done():
302                 report.failed(src, "timed out, cancelling")
303                 task.cancel()
304                 cancelled.append(task)
305             elif task.cancelled():
306                 cancelled.append(task)
307             elif task.exception():
308                 report.failed(src, str(task.exception()))
309             else:
310                 formatted.append(src)
311                 report.done(src, Changed.YES if task.result() else Changed.NO)
312
313     if cancelled:
314         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
315     if write_back != WriteBack.DIFF and formatted:
316         write_cache(cache, formatted, line_length)
317
318
319 def format_file_in_place(
320     src: Path,
321     line_length: int,
322     fast: bool,
323     write_back: WriteBack = WriteBack.NO,
324     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
325 ) -> bool:
326     """Format file under `src` path. Return True if changed.
327
328     If `write_back` is True, write reformatted code back to stdout.
329     `line_length` and `fast` options are passed to :func:`format_file_contents`.
330     """
331
332     with tokenize.open(src) as src_buffer:
333         src_contents = src_buffer.read()
334     try:
335         dst_contents = format_file_contents(
336             src_contents, line_length=line_length, fast=fast
337         )
338     except NothingChanged:
339         return False
340
341     if write_back == write_back.YES:
342         with open(src, "w", encoding=src_buffer.encoding) as f:
343             f.write(dst_contents)
344     elif write_back == write_back.DIFF:
345         src_name = f"{src}  (original)"
346         dst_name = f"{src}  (formatted)"
347         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
348         if lock:
349             lock.acquire()
350         try:
351             sys.stdout.write(diff_contents)
352         finally:
353             if lock:
354                 lock.release()
355     return True
356
357
358 def format_stdin_to_stdout(
359     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
360 ) -> bool:
361     """Format file on stdin. Return True if changed.
362
363     If `write_back` is True, write reformatted code back to stdout.
364     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
365     """
366     src = sys.stdin.read()
367     dst = src
368     try:
369         dst = format_file_contents(src, line_length=line_length, fast=fast)
370         return True
371
372     except NothingChanged:
373         return False
374
375     finally:
376         if write_back == WriteBack.YES:
377             sys.stdout.write(dst)
378         elif write_back == WriteBack.DIFF:
379             src_name = "<stdin>  (original)"
380             dst_name = "<stdin>  (formatted)"
381             sys.stdout.write(diff(src, dst, src_name, dst_name))
382
383
384 def format_file_contents(
385     src_contents: str, line_length: int, fast: bool
386 ) -> FileContent:
387     """Reformat contents a file and return new contents.
388
389     If `fast` is False, additionally confirm that the reformatted code is
390     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
391     `line_length` is passed to :func:`format_str`.
392     """
393     if src_contents.strip() == "":
394         raise NothingChanged
395
396     dst_contents = format_str(src_contents, line_length=line_length)
397     if src_contents == dst_contents:
398         raise NothingChanged
399
400     if not fast:
401         assert_equivalent(src_contents, dst_contents)
402         assert_stable(src_contents, dst_contents, line_length=line_length)
403     return dst_contents
404
405
406 def format_str(src_contents: str, line_length: int) -> FileContent:
407     """Reformat a string and return new contents.
408
409     `line_length` determines how many characters per line are allowed.
410     """
411     src_node = lib2to3_parse(src_contents)
412     dst_contents = ""
413     lines = LineGenerator()
414     elt = EmptyLineTracker()
415     py36 = is_python36(src_node)
416     empty_line = Line()
417     after = 0
418     for current_line in lines.visit(src_node):
419         for _ in range(after):
420             dst_contents += str(empty_line)
421         before, after = elt.maybe_empty_lines(current_line)
422         for _ in range(before):
423             dst_contents += str(empty_line)
424         for line in split_line(current_line, line_length=line_length, py36=py36):
425             dst_contents += str(line)
426     return dst_contents
427
428
429 GRAMMARS = [
430     pygram.python_grammar_no_print_statement_no_exec_statement,
431     pygram.python_grammar_no_print_statement,
432     pygram.python_grammar,
433 ]
434
435
436 def lib2to3_parse(src_txt: str) -> Node:
437     """Given a string with source, return the lib2to3 Node."""
438     grammar = pygram.python_grammar_no_print_statement
439     if src_txt[-1] != "\n":
440         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
441         src_txt += nl
442     for grammar in GRAMMARS:
443         drv = driver.Driver(grammar, pytree.convert)
444         try:
445             result = drv.parse_string(src_txt, True)
446             break
447
448         except ParseError as pe:
449             lineno, column = pe.context[1]
450             lines = src_txt.splitlines()
451             try:
452                 faulty_line = lines[lineno - 1]
453             except IndexError:
454                 faulty_line = "<line number missing in source>"
455             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
456     else:
457         raise exc from None
458
459     if isinstance(result, Leaf):
460         result = Node(syms.file_input, [result])
461     return result
462
463
464 def lib2to3_unparse(node: Node) -> str:
465     """Given a lib2to3 node, return its string representation."""
466     code = str(node)
467     return code
468
469
470 T = TypeVar("T")
471
472
473 class Visitor(Generic[T]):
474     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
475
476     def visit(self, node: LN) -> Iterator[T]:
477         """Main method to visit `node` and its children.
478
479         It tries to find a `visit_*()` method for the given `node.type`, like
480         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
481         If no dedicated `visit_*()` method is found, chooses `visit_default()`
482         instead.
483
484         Then yields objects of type `T` from the selected visitor.
485         """
486         if node.type < 256:
487             name = token.tok_name[node.type]
488         else:
489             name = type_repr(node.type)
490         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
491
492     def visit_default(self, node: LN) -> Iterator[T]:
493         """Default `visit_*()` implementation. Recurses to children of `node`."""
494         if isinstance(node, Node):
495             for child in node.children:
496                 yield from self.visit(child)
497
498
499 @dataclass
500 class DebugVisitor(Visitor[T]):
501     tree_depth: int = 0
502
503     def visit_default(self, node: LN) -> Iterator[T]:
504         indent = " " * (2 * self.tree_depth)
505         if isinstance(node, Node):
506             _type = type_repr(node.type)
507             out(f"{indent}{_type}", fg="yellow")
508             self.tree_depth += 1
509             for child in node.children:
510                 yield from self.visit(child)
511
512             self.tree_depth -= 1
513             out(f"{indent}/{_type}", fg="yellow", bold=False)
514         else:
515             _type = token.tok_name.get(node.type, str(node.type))
516             out(f"{indent}{_type}", fg="blue", nl=False)
517             if node.prefix:
518                 # We don't have to handle prefixes for `Node` objects since
519                 # that delegates to the first child anyway.
520                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
521             out(f" {node.value!r}", fg="blue", bold=False)
522
523     @classmethod
524     def show(cls, code: str) -> None:
525         """Pretty-print the lib2to3 AST of a given string of `code`.
526
527         Convenience method for debugging.
528         """
529         v: DebugVisitor[None] = DebugVisitor()
530         list(v.visit(lib2to3_parse(code)))
531
532
533 KEYWORDS = set(keyword.kwlist)
534 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
535 FLOW_CONTROL = {"return", "raise", "break", "continue"}
536 STATEMENT = {
537     syms.if_stmt,
538     syms.while_stmt,
539     syms.for_stmt,
540     syms.try_stmt,
541     syms.except_clause,
542     syms.with_stmt,
543     syms.funcdef,
544     syms.classdef,
545 }
546 STANDALONE_COMMENT = 153
547 LOGIC_OPERATORS = {"and", "or"}
548 COMPARATORS = {
549     token.LESS,
550     token.GREATER,
551     token.EQEQUAL,
552     token.NOTEQUAL,
553     token.LESSEQUAL,
554     token.GREATEREQUAL,
555 }
556 MATH_OPERATORS = {
557     token.PLUS,
558     token.MINUS,
559     token.STAR,
560     token.SLASH,
561     token.VBAR,
562     token.AMPER,
563     token.PERCENT,
564     token.CIRCUMFLEX,
565     token.TILDE,
566     token.LEFTSHIFT,
567     token.RIGHTSHIFT,
568     token.DOUBLESTAR,
569     token.DOUBLESLASH,
570 }
571 STARS = {token.STAR, token.DOUBLESTAR}
572 VARARGS_PARENTS = {
573     syms.arglist,
574     syms.argument,  # double star in arglist
575     syms.trailer,  # single argument to call
576     syms.typedargslist,
577     syms.varargslist,  # lambdas
578 }
579 UNPACKING_PARENTS = {
580     syms.atom,  # single element of a list or set literal
581     syms.dictsetmaker,
582     syms.listmaker,
583     syms.testlist_gexp,
584 }
585 COMPREHENSION_PRIORITY = 20
586 COMMA_PRIORITY = 10
587 TERNARY_PRIORITY = 7
588 LOGIC_PRIORITY = 5
589 STRING_PRIORITY = 4
590 COMPARATOR_PRIORITY = 3
591 MATH_PRIORITY = 1
592
593
594 @dataclass
595 class BracketTracker:
596     """Keeps track of brackets on a line."""
597
598     depth: int = 0
599     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
600     delimiters: Dict[LeafID, Priority] = Factory(dict)
601     previous: Optional[Leaf] = None
602     _for_loop_variable: bool = False
603     _lambda_arguments: bool = False
604
605     def mark(self, leaf: Leaf) -> None:
606         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
607
608         All leaves receive an int `bracket_depth` field that stores how deep
609         within brackets a given leaf is. 0 means there are no enclosing brackets
610         that started on this line.
611
612         If a leaf is itself a closing bracket, it receives an `opening_bracket`
613         field that it forms a pair with. This is a one-directional link to
614         avoid reference cycles.
615
616         If a leaf is a delimiter (a token on which Black can split the line if
617         needed) and it's on depth 0, its `id()` is stored in the tracker's
618         `delimiters` field.
619         """
620         if leaf.type == token.COMMENT:
621             return
622
623         self.maybe_decrement_after_for_loop_variable(leaf)
624         self.maybe_decrement_after_lambda_arguments(leaf)
625         if leaf.type in CLOSING_BRACKETS:
626             self.depth -= 1
627             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
628             leaf.opening_bracket = opening_bracket
629         leaf.bracket_depth = self.depth
630         if self.depth == 0:
631             delim = is_split_before_delimiter(leaf, self.previous)
632             if delim and self.previous is not None:
633                 self.delimiters[id(self.previous)] = delim
634             else:
635                 delim = is_split_after_delimiter(leaf, self.previous)
636                 if delim:
637                     self.delimiters[id(leaf)] = delim
638         if leaf.type in OPENING_BRACKETS:
639             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
640             self.depth += 1
641         self.previous = leaf
642         self.maybe_increment_lambda_arguments(leaf)
643         self.maybe_increment_for_loop_variable(leaf)
644
645     def any_open_brackets(self) -> bool:
646         """Return True if there is an yet unmatched open bracket on the line."""
647         return bool(self.bracket_match)
648
649     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
650         """Return the highest priority of a delimiter found on the line.
651
652         Values are consistent with what `is_split_*_delimiter()` return.
653         Raises ValueError on no delimiters.
654         """
655         return max(v for k, v in self.delimiters.items() if k not in exclude)
656
657     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
658         """In a for loop, or comprehension, the variables are often unpacks.
659
660         To avoid splitting on the comma in this situation, increase the depth of
661         tokens between `for` and `in`.
662         """
663         if leaf.type == token.NAME and leaf.value == "for":
664             self.depth += 1
665             self._for_loop_variable = True
666             return True
667
668         return False
669
670     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
671         """See `maybe_increment_for_loop_variable` above for explanation."""
672         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
673             self.depth -= 1
674             self._for_loop_variable = False
675             return True
676
677         return False
678
679     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
680         """In a lambda expression, there might be more than one argument.
681
682         To avoid splitting on the comma in this situation, increase the depth of
683         tokens between `lambda` and `:`.
684         """
685         if leaf.type == token.NAME and leaf.value == "lambda":
686             self.depth += 1
687             self._lambda_arguments = True
688             return True
689
690         return False
691
692     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
693         """See `maybe_increment_lambda_arguments` above for explanation."""
694         if self._lambda_arguments and leaf.type == token.COLON:
695             self.depth -= 1
696             self._lambda_arguments = False
697             return True
698
699         return False
700
701
702 @dataclass
703 class Line:
704     """Holds leaves and comments. Can be printed with `str(line)`."""
705
706     depth: int = 0
707     leaves: List[Leaf] = Factory(list)
708     comments: List[Tuple[Index, Leaf]] = Factory(list)
709     bracket_tracker: BracketTracker = Factory(BracketTracker)
710     inside_brackets: bool = False
711
712     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
713         """Add a new `leaf` to the end of the line.
714
715         Unless `preformatted` is True, the `leaf` will receive a new consistent
716         whitespace prefix and metadata applied by :class:`BracketTracker`.
717         Trailing commas are maybe removed, unpacked for loop variables are
718         demoted from being delimiters.
719
720         Inline comments are put aside.
721         """
722         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
723         if not has_value:
724             return
725
726         if self.leaves and not preformatted:
727             # Note: at this point leaf.prefix should be empty except for
728             # imports, for which we only preserve newlines.
729             leaf.prefix += whitespace(leaf)
730         if self.inside_brackets or not preformatted:
731             self.bracket_tracker.mark(leaf)
732             self.maybe_remove_trailing_comma(leaf)
733
734         if not self.append_comment(leaf):
735             self.leaves.append(leaf)
736
737     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
738         """Like :func:`append()` but disallow invalid standalone comment structure.
739
740         Raises ValueError when any `leaf` is appended after a standalone comment
741         or when a standalone comment is not the first leaf on the line.
742         """
743         if self.bracket_tracker.depth == 0:
744             if self.is_comment:
745                 raise ValueError("cannot append to standalone comments")
746
747             if self.leaves and leaf.type == STANDALONE_COMMENT:
748                 raise ValueError(
749                     "cannot append standalone comments to a populated line"
750                 )
751
752         self.append(leaf, preformatted=preformatted)
753
754     @property
755     def is_comment(self) -> bool:
756         """Is this line a standalone comment?"""
757         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
758
759     @property
760     def is_decorator(self) -> bool:
761         """Is this line a decorator?"""
762         return bool(self) and self.leaves[0].type == token.AT
763
764     @property
765     def is_import(self) -> bool:
766         """Is this an import line?"""
767         return bool(self) and is_import(self.leaves[0])
768
769     @property
770     def is_class(self) -> bool:
771         """Is this line a class definition?"""
772         return (
773             bool(self)
774             and self.leaves[0].type == token.NAME
775             and self.leaves[0].value == "class"
776         )
777
778     @property
779     def is_def(self) -> bool:
780         """Is this a function definition? (Also returns True for async defs.)"""
781         try:
782             first_leaf = self.leaves[0]
783         except IndexError:
784             return False
785
786         try:
787             second_leaf: Optional[Leaf] = self.leaves[1]
788         except IndexError:
789             second_leaf = None
790         return (
791             (first_leaf.type == token.NAME and first_leaf.value == "def")
792             or (
793                 first_leaf.type == token.ASYNC
794                 and second_leaf is not None
795                 and second_leaf.type == token.NAME
796                 and second_leaf.value == "def"
797             )
798         )
799
800     @property
801     def is_flow_control(self) -> bool:
802         """Is this line a flow control statement?
803
804         Those are `return`, `raise`, `break`, and `continue`.
805         """
806         return (
807             bool(self)
808             and self.leaves[0].type == token.NAME
809             and self.leaves[0].value in FLOW_CONTROL
810         )
811
812     @property
813     def is_yield(self) -> bool:
814         """Is this line a yield statement?"""
815         return (
816             bool(self)
817             and self.leaves[0].type == token.NAME
818             and self.leaves[0].value == "yield"
819         )
820
821     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
822         """If so, needs to be split before emitting."""
823         for leaf in self.leaves:
824             if leaf.type == STANDALONE_COMMENT:
825                 if leaf.bracket_depth <= depth_limit:
826                     return True
827
828         return False
829
830     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
831         """Remove trailing comma if there is one and it's safe."""
832         if not (
833             self.leaves
834             and self.leaves[-1].type == token.COMMA
835             and closing.type in CLOSING_BRACKETS
836         ):
837             return False
838
839         if closing.type == token.RBRACE:
840             self.remove_trailing_comma()
841             return True
842
843         if closing.type == token.RSQB:
844             comma = self.leaves[-1]
845             if comma.parent and comma.parent.type == syms.listmaker:
846                 self.remove_trailing_comma()
847                 return True
848
849         # For parens let's check if it's safe to remove the comma.  If the
850         # trailing one is the only one, we might mistakenly change a tuple
851         # into a different type by removing the comma.
852         depth = closing.bracket_depth + 1
853         commas = 0
854         opening = closing.opening_bracket
855         for _opening_index, leaf in enumerate(self.leaves):
856             if leaf is opening:
857                 break
858
859         else:
860             return False
861
862         for leaf in self.leaves[_opening_index + 1:]:
863             if leaf is closing:
864                 break
865
866             bracket_depth = leaf.bracket_depth
867             if bracket_depth == depth and leaf.type == token.COMMA:
868                 commas += 1
869                 if leaf.parent and leaf.parent.type == syms.arglist:
870                     commas += 1
871                     break
872
873         if commas > 1:
874             self.remove_trailing_comma()
875             return True
876
877         return False
878
879     def append_comment(self, comment: Leaf) -> bool:
880         """Add an inline or standalone comment to the line."""
881         if (
882             comment.type == STANDALONE_COMMENT
883             and self.bracket_tracker.any_open_brackets()
884         ):
885             comment.prefix = ""
886             return False
887
888         if comment.type != token.COMMENT:
889             return False
890
891         after = len(self.leaves) - 1
892         if after == -1:
893             comment.type = STANDALONE_COMMENT
894             comment.prefix = ""
895             return False
896
897         else:
898             self.comments.append((after, comment))
899             return True
900
901     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
902         """Generate comments that should appear directly after `leaf`."""
903         for _leaf_index, _leaf in enumerate(self.leaves):
904             if leaf is _leaf:
905                 break
906
907         else:
908             return
909
910         for index, comment_after in self.comments:
911             if _leaf_index == index:
912                 yield comment_after
913
914     def remove_trailing_comma(self) -> None:
915         """Remove the trailing comma and moves the comments attached to it."""
916         comma_index = len(self.leaves) - 1
917         for i in range(len(self.comments)):
918             comment_index, comment = self.comments[i]
919             if comment_index == comma_index:
920                 self.comments[i] = (comma_index - 1, comment)
921         self.leaves.pop()
922
923     def __str__(self) -> str:
924         """Render the line."""
925         if not self:
926             return "\n"
927
928         indent = "    " * self.depth
929         leaves = iter(self.leaves)
930         first = next(leaves)
931         res = f"{first.prefix}{indent}{first.value}"
932         for leaf in leaves:
933             res += str(leaf)
934         for _, comment in self.comments:
935             res += str(comment)
936         return res + "\n"
937
938     def __bool__(self) -> bool:
939         """Return True if the line has leaves or comments."""
940         return bool(self.leaves or self.comments)
941
942
943 class UnformattedLines(Line):
944     """Just like :class:`Line` but stores lines which aren't reformatted."""
945
946     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
947         """Just add a new `leaf` to the end of the lines.
948
949         The `preformatted` argument is ignored.
950
951         Keeps track of indentation `depth`, which is useful when the user
952         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
953         """
954         try:
955             list(generate_comments(leaf))
956         except FormatOn as f_on:
957             self.leaves.append(f_on.leaf_from_consumed(leaf))
958             raise
959
960         self.leaves.append(leaf)
961         if leaf.type == token.INDENT:
962             self.depth += 1
963         elif leaf.type == token.DEDENT:
964             self.depth -= 1
965
966     def __str__(self) -> str:
967         """Render unformatted lines from leaves which were added with `append()`.
968
969         `depth` is not used for indentation in this case.
970         """
971         if not self:
972             return "\n"
973
974         res = ""
975         for leaf in self.leaves:
976             res += str(leaf)
977         return res
978
979     def append_comment(self, comment: Leaf) -> bool:
980         """Not implemented in this class. Raises `NotImplementedError`."""
981         raise NotImplementedError("Unformatted lines don't store comments separately.")
982
983     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
984         """Does nothing and returns False."""
985         return False
986
987     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
988         """Does nothing and returns False."""
989         return False
990
991
992 @dataclass
993 class EmptyLineTracker:
994     """Provides a stateful method that returns the number of potential extra
995     empty lines needed before and after the currently processed line.
996
997     Note: this tracker works on lines that haven't been split yet.  It assumes
998     the prefix of the first leaf consists of optional newlines.  Those newlines
999     are consumed by `maybe_empty_lines()` and included in the computation.
1000     """
1001     previous_line: Optional[Line] = None
1002     previous_after: int = 0
1003     previous_defs: List[int] = Factory(list)
1004
1005     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1006         """Return the number of extra empty lines before and after the `current_line`.
1007
1008         This is for separating `def`, `async def` and `class` with extra empty
1009         lines (two on module-level), as well as providing an extra empty line
1010         after flow control keywords to make them more prominent.
1011         """
1012         if isinstance(current_line, UnformattedLines):
1013             return 0, 0
1014
1015         before, after = self._maybe_empty_lines(current_line)
1016         before -= self.previous_after
1017         self.previous_after = after
1018         self.previous_line = current_line
1019         return before, after
1020
1021     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1022         max_allowed = 1
1023         if current_line.depth == 0:
1024             max_allowed = 2
1025         if current_line.leaves:
1026             # Consume the first leaf's extra newlines.
1027             first_leaf = current_line.leaves[0]
1028             before = first_leaf.prefix.count("\n")
1029             before = min(before, max_allowed)
1030             first_leaf.prefix = ""
1031         else:
1032             before = 0
1033         depth = current_line.depth
1034         while self.previous_defs and self.previous_defs[-1] >= depth:
1035             self.previous_defs.pop()
1036             before = 1 if depth else 2
1037         is_decorator = current_line.is_decorator
1038         if is_decorator or current_line.is_def or current_line.is_class:
1039             if not is_decorator:
1040                 self.previous_defs.append(depth)
1041             if self.previous_line is None:
1042                 # Don't insert empty lines before the first line in the file.
1043                 return 0, 0
1044
1045             if self.previous_line.is_decorator:
1046                 return 0, 0
1047
1048             if (
1049                 self.previous_line.is_comment
1050                 and self.previous_line.depth == current_line.depth
1051                 and before == 0
1052             ):
1053                 return 0, 0
1054
1055             newlines = 2
1056             if current_line.depth:
1057                 newlines -= 1
1058             return newlines, 0
1059
1060         if (
1061             self.previous_line
1062             and self.previous_line.is_import
1063             and not current_line.is_import
1064             and depth == self.previous_line.depth
1065         ):
1066             return (before or 1), 0
1067
1068         return before, 0
1069
1070
1071 @dataclass
1072 class LineGenerator(Visitor[Line]):
1073     """Generates reformatted Line objects.  Empty lines are not emitted.
1074
1075     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1076     in ways that will no longer stringify to valid Python code on the tree.
1077     """
1078     current_line: Line = Factory(Line)
1079
1080     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1081         """Generate a line.
1082
1083         If the line is empty, only emit if it makes sense.
1084         If the line is too long, split it first and then generate.
1085
1086         If any lines were generated, set up a new current_line.
1087         """
1088         if not self.current_line:
1089             if self.current_line.__class__ == type:
1090                 self.current_line.depth += indent
1091             else:
1092                 self.current_line = type(depth=self.current_line.depth + indent)
1093             return  # Line is empty, don't emit. Creating a new one unnecessary.
1094
1095         complete_line = self.current_line
1096         self.current_line = type(depth=complete_line.depth + indent)
1097         yield complete_line
1098
1099     def visit(self, node: LN) -> Iterator[Line]:
1100         """Main method to visit `node` and its children.
1101
1102         Yields :class:`Line` objects.
1103         """
1104         if isinstance(self.current_line, UnformattedLines):
1105             # File contained `# fmt: off`
1106             yield from self.visit_unformatted(node)
1107
1108         else:
1109             yield from super().visit(node)
1110
1111     def visit_default(self, node: LN) -> Iterator[Line]:
1112         """Default `visit_*()` implementation. Recurses to children of `node`."""
1113         if isinstance(node, Leaf):
1114             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1115             try:
1116                 for comment in generate_comments(node):
1117                     if any_open_brackets:
1118                         # any comment within brackets is subject to splitting
1119                         self.current_line.append(comment)
1120                     elif comment.type == token.COMMENT:
1121                         # regular trailing comment
1122                         self.current_line.append(comment)
1123                         yield from self.line()
1124
1125                     else:
1126                         # regular standalone comment
1127                         yield from self.line()
1128
1129                         self.current_line.append(comment)
1130                         yield from self.line()
1131
1132             except FormatOff as f_off:
1133                 f_off.trim_prefix(node)
1134                 yield from self.line(type=UnformattedLines)
1135                 yield from self.visit(node)
1136
1137             except FormatOn as f_on:
1138                 # This only happens here if somebody says "fmt: on" multiple
1139                 # times in a row.
1140                 f_on.trim_prefix(node)
1141                 yield from self.visit_default(node)
1142
1143             else:
1144                 normalize_prefix(node, inside_brackets=any_open_brackets)
1145                 if node.type == token.STRING:
1146                     normalize_string_quotes(node)
1147                 if node.type not in WHITESPACE:
1148                     self.current_line.append(node)
1149         yield from super().visit_default(node)
1150
1151     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1152         """Increase indentation level, maybe yield a line."""
1153         # In blib2to3 INDENT never holds comments.
1154         yield from self.line(+1)
1155         yield from self.visit_default(node)
1156
1157     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1158         """Decrease indentation level, maybe yield a line."""
1159         # The current line might still wait for trailing comments.  At DEDENT time
1160         # there won't be any (they would be prefixes on the preceding NEWLINE).
1161         # Emit the line then.
1162         yield from self.line()
1163
1164         # While DEDENT has no value, its prefix may contain standalone comments
1165         # that belong to the current indentation level.  Get 'em.
1166         yield from self.visit_default(node)
1167
1168         # Finally, emit the dedent.
1169         yield from self.line(-1)
1170
1171     def visit_stmt(
1172         self, node: Node, keywords: Set[str], parens: Set[str]
1173     ) -> Iterator[Line]:
1174         """Visit a statement.
1175
1176         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1177         `def`, `with`, `class`, and `assert`.
1178
1179         The relevant Python language `keywords` for a given statement will be
1180         NAME leaves within it. This methods puts those on a separate line.
1181
1182         `parens` holds pairs of nodes where invisible parentheses should be put.
1183         Keys hold nodes after which opening parentheses should be put, values
1184         hold nodes before which closing parentheses should be put.
1185         """
1186         normalize_invisible_parens(node, parens_after=parens)
1187         for child in node.children:
1188             if child.type == token.NAME and child.value in keywords:  # type: ignore
1189                 yield from self.line()
1190
1191             yield from self.visit(child)
1192
1193     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1194         """Visit a statement without nested statements."""
1195         is_suite_like = node.parent and node.parent.type in STATEMENT
1196         if is_suite_like:
1197             yield from self.line(+1)
1198             yield from self.visit_default(node)
1199             yield from self.line(-1)
1200
1201         else:
1202             yield from self.line()
1203             yield from self.visit_default(node)
1204
1205     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1206         """Visit `async def`, `async for`, `async with`."""
1207         yield from self.line()
1208
1209         children = iter(node.children)
1210         for child in children:
1211             yield from self.visit(child)
1212
1213             if child.type == token.ASYNC:
1214                 break
1215
1216         internal_stmt = next(children)
1217         for child in internal_stmt.children:
1218             yield from self.visit(child)
1219
1220     def visit_decorators(self, node: Node) -> Iterator[Line]:
1221         """Visit decorators."""
1222         for child in node.children:
1223             yield from self.line()
1224             yield from self.visit(child)
1225
1226     def visit_import_from(self, node: Node) -> Iterator[Line]:
1227         """Visit import_from and maybe put invisible parentheses.
1228
1229         This is separate from `visit_stmt` because import statements don't
1230         support arbitrary atoms and thus handling of parentheses is custom.
1231         """
1232         check_lpar = False
1233         for index, child in enumerate(node.children):
1234             if check_lpar:
1235                 if child.type == token.LPAR:
1236                     # make parentheses invisible
1237                     child.value = ""  # type: ignore
1238                     node.children[-1].value = ""  # type: ignore
1239                 else:
1240                     # insert invisible parentheses
1241                     node.insert_child(index, Leaf(token.LPAR, ""))
1242                     node.append_child(Leaf(token.RPAR, ""))
1243                 break
1244
1245             check_lpar = (
1246                 child.type == token.NAME and child.value == "import"  # type: ignore
1247             )
1248
1249         for child in node.children:
1250             yield from self.visit(child)
1251
1252     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1253         """Remove a semicolon and put the other statement on a separate line."""
1254         yield from self.line()
1255
1256     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1257         """End of file. Process outstanding comments and end with a newline."""
1258         yield from self.visit_default(leaf)
1259         yield from self.line()
1260
1261     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1262         """Used when file contained a `# fmt: off`."""
1263         if isinstance(node, Node):
1264             for child in node.children:
1265                 yield from self.visit(child)
1266
1267         else:
1268             try:
1269                 self.current_line.append(node)
1270             except FormatOn as f_on:
1271                 f_on.trim_prefix(node)
1272                 yield from self.line()
1273                 yield from self.visit(node)
1274
1275             if node.type == token.ENDMARKER:
1276                 # somebody decided not to put a final `# fmt: on`
1277                 yield from self.line()
1278
1279     def __attrs_post_init__(self) -> None:
1280         """You are in a twisty little maze of passages."""
1281         v = self.visit_stmt
1282         Ø: Set[str] = set()
1283         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1284         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1285         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1286         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1287         self.visit_try_stmt = partial(
1288             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1289         )
1290         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1291         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1292         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1293         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1294         self.visit_async_funcdef = self.visit_async_stmt
1295         self.visit_decorated = self.visit_decorators
1296
1297
1298 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1299 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1300 OPENING_BRACKETS = set(BRACKET.keys())
1301 CLOSING_BRACKETS = set(BRACKET.values())
1302 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1303 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1304
1305
1306 def whitespace(leaf: Leaf) -> str:  # noqa C901
1307     """Return whitespace prefix if needed for the given `leaf`."""
1308     NO = ""
1309     SPACE = " "
1310     DOUBLESPACE = "  "
1311     t = leaf.type
1312     p = leaf.parent
1313     v = leaf.value
1314     if t in ALWAYS_NO_SPACE:
1315         return NO
1316
1317     if t == token.COMMENT:
1318         return DOUBLESPACE
1319
1320     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1321     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1322         return NO
1323
1324     prev = leaf.prev_sibling
1325     if not prev:
1326         prevp = preceding_leaf(p)
1327         if not prevp or prevp.type in OPENING_BRACKETS:
1328             return NO
1329
1330         if t == token.COLON:
1331             return SPACE if prevp.type == token.COMMA else NO
1332
1333         if prevp.type == token.EQUAL:
1334             if prevp.parent:
1335                 if prevp.parent.type in {
1336                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1337                 }:
1338                     return NO
1339
1340                 elif prevp.parent.type == syms.typedargslist:
1341                     # A bit hacky: if the equal sign has whitespace, it means we
1342                     # previously found it's a typed argument.  So, we're using
1343                     # that, too.
1344                     return prevp.prefix
1345
1346         elif prevp.type in STARS:
1347             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1348                 return NO
1349
1350         elif prevp.type == token.COLON:
1351             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1352                 return NO
1353
1354         elif (
1355             prevp.parent
1356             and prevp.parent.type == syms.factor
1357             and prevp.type in MATH_OPERATORS
1358         ):
1359             return NO
1360
1361         elif (
1362             prevp.type == token.RIGHTSHIFT
1363             and prevp.parent
1364             and prevp.parent.type == syms.shift_expr
1365             and prevp.prev_sibling
1366             and prevp.prev_sibling.type == token.NAME
1367             and prevp.prev_sibling.value == "print"  # type: ignore
1368         ):
1369             # Python 2 print chevron
1370             return NO
1371
1372     elif prev.type in OPENING_BRACKETS:
1373         return NO
1374
1375     if p.type in {syms.parameters, syms.arglist}:
1376         # untyped function signatures or calls
1377         if not prev or prev.type != token.COMMA:
1378             return NO
1379
1380     elif p.type == syms.varargslist:
1381         # lambdas
1382         if prev and prev.type != token.COMMA:
1383             return NO
1384
1385     elif p.type == syms.typedargslist:
1386         # typed function signatures
1387         if not prev:
1388             return NO
1389
1390         if t == token.EQUAL:
1391             if prev.type != syms.tname:
1392                 return NO
1393
1394         elif prev.type == token.EQUAL:
1395             # A bit hacky: if the equal sign has whitespace, it means we
1396             # previously found it's a typed argument.  So, we're using that, too.
1397             return prev.prefix
1398
1399         elif prev.type != token.COMMA:
1400             return NO
1401
1402     elif p.type == syms.tname:
1403         # type names
1404         if not prev:
1405             prevp = preceding_leaf(p)
1406             if not prevp or prevp.type != token.COMMA:
1407                 return NO
1408
1409     elif p.type == syms.trailer:
1410         # attributes and calls
1411         if t == token.LPAR or t == token.RPAR:
1412             return NO
1413
1414         if not prev:
1415             if t == token.DOT:
1416                 prevp = preceding_leaf(p)
1417                 if not prevp or prevp.type != token.NUMBER:
1418                     return NO
1419
1420             elif t == token.LSQB:
1421                 return NO
1422
1423         elif prev.type != token.COMMA:
1424             return NO
1425
1426     elif p.type == syms.argument:
1427         # single argument
1428         if t == token.EQUAL:
1429             return NO
1430
1431         if not prev:
1432             prevp = preceding_leaf(p)
1433             if not prevp or prevp.type == token.LPAR:
1434                 return NO
1435
1436         elif prev.type in {token.EQUAL} | STARS:
1437             return NO
1438
1439     elif p.type == syms.decorator:
1440         # decorators
1441         return NO
1442
1443     elif p.type == syms.dotted_name:
1444         if prev:
1445             return NO
1446
1447         prevp = preceding_leaf(p)
1448         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1449             return NO
1450
1451     elif p.type == syms.classdef:
1452         if t == token.LPAR:
1453             return NO
1454
1455         if prev and prev.type == token.LPAR:
1456             return NO
1457
1458     elif p.type == syms.subscript:
1459         # indexing
1460         if not prev:
1461             assert p.parent is not None, "subscripts are always parented"
1462             if p.parent.type == syms.subscriptlist:
1463                 return SPACE
1464
1465             return NO
1466
1467         else:
1468             return NO
1469
1470     elif p.type == syms.atom:
1471         if prev and t == token.DOT:
1472             # dots, but not the first one.
1473             return NO
1474
1475     elif p.type == syms.dictsetmaker:
1476         # dict unpacking
1477         if prev and prev.type == token.DOUBLESTAR:
1478             return NO
1479
1480     elif p.type in {syms.factor, syms.star_expr}:
1481         # unary ops
1482         if not prev:
1483             prevp = preceding_leaf(p)
1484             if not prevp or prevp.type in OPENING_BRACKETS:
1485                 return NO
1486
1487             prevp_parent = prevp.parent
1488             assert prevp_parent is not None
1489             if (
1490                 prevp.type == token.COLON
1491                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1492             ):
1493                 return NO
1494
1495             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1496                 return NO
1497
1498         elif t == token.NAME or t == token.NUMBER:
1499             return NO
1500
1501     elif p.type == syms.import_from:
1502         if t == token.DOT:
1503             if prev and prev.type == token.DOT:
1504                 return NO
1505
1506         elif t == token.NAME:
1507             if v == "import":
1508                 return SPACE
1509
1510             if prev and prev.type == token.DOT:
1511                 return NO
1512
1513     elif p.type == syms.sliceop:
1514         return NO
1515
1516     return SPACE
1517
1518
1519 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1520     """Return the first leaf that precedes `node`, if any."""
1521     while node:
1522         res = node.prev_sibling
1523         if res:
1524             if isinstance(res, Leaf):
1525                 return res
1526
1527             try:
1528                 return list(res.leaves())[-1]
1529
1530             except IndexError:
1531                 return None
1532
1533         node = node.parent
1534     return None
1535
1536
1537 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1538     """Return the priority of the `leaf` delimiter, given a line break after it.
1539
1540     The delimiter priorities returned here are from those delimiters that would
1541     cause a line break after themselves.
1542
1543     Higher numbers are higher priority.
1544     """
1545     if leaf.type == token.COMMA:
1546         return COMMA_PRIORITY
1547
1548     return 0
1549
1550
1551 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1552     """Return the priority of the `leaf` delimiter, given a line before after it.
1553
1554     The delimiter priorities returned here are from those delimiters that would
1555     cause a line break before themselves.
1556
1557     Higher numbers are higher priority.
1558     """
1559     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1560         # * and ** might also be MATH_OPERATORS but in this case they are not.
1561         # Don't treat them as a delimiter.
1562         return 0
1563
1564     if (
1565         leaf.type in MATH_OPERATORS
1566         and leaf.parent
1567         and leaf.parent.type not in {syms.factor, syms.star_expr}
1568     ):
1569         return MATH_PRIORITY
1570
1571     if leaf.type in COMPARATORS:
1572         return COMPARATOR_PRIORITY
1573
1574     if (
1575         leaf.type == token.STRING
1576         and previous is not None
1577         and previous.type == token.STRING
1578     ):
1579         return STRING_PRIORITY
1580
1581     if (
1582         leaf.type == token.NAME
1583         and leaf.value == "for"
1584         and leaf.parent
1585         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1586     ):
1587         return COMPREHENSION_PRIORITY
1588
1589     if (
1590         leaf.type == token.NAME
1591         and leaf.value == "if"
1592         and leaf.parent
1593         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1594     ):
1595         return COMPREHENSION_PRIORITY
1596
1597     if (
1598         leaf.type == token.NAME
1599         and leaf.value in {"if", "else"}
1600         and leaf.parent
1601         and leaf.parent.type == syms.test
1602     ):
1603         return TERNARY_PRIORITY
1604
1605     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1606         return LOGIC_PRIORITY
1607
1608     return 0
1609
1610
1611 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1612     """Clean the prefix of the `leaf` and generate comments from it, if any.
1613
1614     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1615     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1616     move because it does away with modifying the grammar to include all the
1617     possible places in which comments can be placed.
1618
1619     The sad consequence for us though is that comments don't "belong" anywhere.
1620     This is why this function generates simple parentless Leaf objects for
1621     comments.  We simply don't know what the correct parent should be.
1622
1623     No matter though, we can live without this.  We really only need to
1624     differentiate between inline and standalone comments.  The latter don't
1625     share the line with any code.
1626
1627     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1628     are emitted with a fake STANDALONE_COMMENT token identifier.
1629     """
1630     p = leaf.prefix
1631     if not p:
1632         return
1633
1634     if "#" not in p:
1635         return
1636
1637     consumed = 0
1638     nlines = 0
1639     for index, line in enumerate(p.split("\n")):
1640         consumed += len(line) + 1  # adding the length of the split '\n'
1641         line = line.lstrip()
1642         if not line:
1643             nlines += 1
1644         if not line.startswith("#"):
1645             continue
1646
1647         if index == 0 and leaf.type != token.ENDMARKER:
1648             comment_type = token.COMMENT  # simple trailing comment
1649         else:
1650             comment_type = STANDALONE_COMMENT
1651         comment = make_comment(line)
1652         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1653
1654         if comment in {"# fmt: on", "# yapf: enable"}:
1655             raise FormatOn(consumed)
1656
1657         if comment in {"# fmt: off", "# yapf: disable"}:
1658             if comment_type == STANDALONE_COMMENT:
1659                 raise FormatOff(consumed)
1660
1661             prev = preceding_leaf(leaf)
1662             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1663                 raise FormatOff(consumed)
1664
1665         nlines = 0
1666
1667
1668 def make_comment(content: str) -> str:
1669     """Return a consistently formatted comment from the given `content` string.
1670
1671     All comments (except for "##", "#!", "#:") should have a single space between
1672     the hash sign and the content.
1673
1674     If `content` didn't start with a hash sign, one is provided.
1675     """
1676     content = content.rstrip()
1677     if not content:
1678         return "#"
1679
1680     if content[0] == "#":
1681         content = content[1:]
1682     if content and content[0] not in " !:#":
1683         content = " " + content
1684     return "#" + content
1685
1686
1687 def split_line(
1688     line: Line, line_length: int, inner: bool = False, py36: bool = False
1689 ) -> Iterator[Line]:
1690     """Split a `line` into potentially many lines.
1691
1692     They should fit in the allotted `line_length` but might not be able to.
1693     `inner` signifies that there were a pair of brackets somewhere around the
1694     current `line`, possibly transitively. This means we can fallback to splitting
1695     by delimiters if the LHS/RHS don't yield any results.
1696
1697     If `py36` is True, splitting may generate syntax that is only compatible
1698     with Python 3.6 and later.
1699     """
1700     if isinstance(line, UnformattedLines) or line.is_comment:
1701         yield line
1702         return
1703
1704     line_str = str(line).strip("\n")
1705     if (
1706         len(line_str) <= line_length
1707         and "\n" not in line_str  # multiline strings
1708         and not line.contains_standalone_comments()
1709     ):
1710         yield line
1711         return
1712
1713     split_funcs: List[SplitFunc]
1714     if line.is_def:
1715         split_funcs = [left_hand_split]
1716     elif line.is_import:
1717         split_funcs = [explode_split]
1718     elif line.inside_brackets:
1719         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1720     else:
1721         split_funcs = [right_hand_split]
1722     for split_func in split_funcs:
1723         # We are accumulating lines in `result` because we might want to abort
1724         # mission and return the original line in the end, or attempt a different
1725         # split altogether.
1726         result: List[Line] = []
1727         try:
1728             for l in split_func(line, py36):
1729                 if str(l).strip("\n") == line_str:
1730                     raise CannotSplit("Split function returned an unchanged result")
1731
1732                 result.extend(
1733                     split_line(l, line_length=line_length, inner=True, py36=py36)
1734                 )
1735         except CannotSplit as cs:
1736             continue
1737
1738         else:
1739             yield from result
1740             break
1741
1742     else:
1743         yield line
1744
1745
1746 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1747     """Split line into many lines, starting with the first matching bracket pair.
1748
1749     Note: this usually looks weird, only use this for function definitions.
1750     Prefer RHS otherwise.
1751     """
1752     head = Line(depth=line.depth)
1753     body = Line(depth=line.depth + 1, inside_brackets=True)
1754     tail = Line(depth=line.depth)
1755     tail_leaves: List[Leaf] = []
1756     body_leaves: List[Leaf] = []
1757     head_leaves: List[Leaf] = []
1758     current_leaves = head_leaves
1759     matching_bracket = None
1760     for leaf in line.leaves:
1761         if (
1762             current_leaves is body_leaves
1763             and leaf.type in CLOSING_BRACKETS
1764             and leaf.opening_bracket is matching_bracket
1765         ):
1766             current_leaves = tail_leaves if body_leaves else head_leaves
1767         current_leaves.append(leaf)
1768         if current_leaves is head_leaves:
1769             if leaf.type in OPENING_BRACKETS:
1770                 matching_bracket = leaf
1771                 current_leaves = body_leaves
1772     # Since body is a new indent level, remove spurious leading whitespace.
1773     if body_leaves:
1774         normalize_prefix(body_leaves[0], inside_brackets=True)
1775     # Build the new lines.
1776     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1777         for leaf in leaves:
1778             result.append(leaf, preformatted=True)
1779             for comment_after in line.comments_after(leaf):
1780                 result.append(comment_after, preformatted=True)
1781     bracket_split_succeeded_or_raise(head, body, tail)
1782     for result in (head, body, tail):
1783         if result:
1784             yield result
1785
1786
1787 def right_hand_split(
1788     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1789 ) -> Iterator[Line]:
1790     """Split line into many lines, starting with the last matching bracket pair."""
1791     head = Line(depth=line.depth)
1792     body = Line(depth=line.depth + 1, inside_brackets=True)
1793     tail = Line(depth=line.depth)
1794     tail_leaves: List[Leaf] = []
1795     body_leaves: List[Leaf] = []
1796     head_leaves: List[Leaf] = []
1797     current_leaves = tail_leaves
1798     opening_bracket = None
1799     closing_bracket = None
1800     for leaf in reversed(line.leaves):
1801         if current_leaves is body_leaves:
1802             if leaf is opening_bracket:
1803                 current_leaves = head_leaves if body_leaves else tail_leaves
1804         current_leaves.append(leaf)
1805         if current_leaves is tail_leaves:
1806             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1807                 opening_bracket = leaf.opening_bracket
1808                 closing_bracket = leaf
1809                 current_leaves = body_leaves
1810     tail_leaves.reverse()
1811     body_leaves.reverse()
1812     head_leaves.reverse()
1813     # Since body is a new indent level, remove spurious leading whitespace.
1814     if body_leaves:
1815         normalize_prefix(body_leaves[0], inside_brackets=True)
1816     elif not head_leaves:
1817         # No `head` and no `body` means the split failed. `tail` has all content.
1818         raise CannotSplit("No brackets found")
1819
1820     # Build the new lines.
1821     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1822         for leaf in leaves:
1823             result.append(leaf, preformatted=True)
1824             for comment_after in line.comments_after(leaf):
1825                 result.append(comment_after, preformatted=True)
1826     bracket_split_succeeded_or_raise(head, body, tail)
1827     assert opening_bracket and closing_bracket
1828     if (
1829         opening_bracket.type == token.LPAR
1830         and not opening_bracket.value
1831         and closing_bracket.type == token.RPAR
1832         and not closing_bracket.value
1833     ):
1834         # These parens were optional. If there aren't any delimiters or standalone
1835         # comments in the body, they were unnecessary and another split without
1836         # them should be attempted.
1837         if not (
1838             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1839         ):
1840             omit = {id(closing_bracket), *omit}
1841             yield from right_hand_split(line, py36=py36, omit=omit)
1842             return
1843
1844     ensure_visible(opening_bracket)
1845     ensure_visible(closing_bracket)
1846     for result in (head, body, tail):
1847         if result:
1848             yield result
1849
1850
1851 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1852     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1853
1854     Do nothing otherwise.
1855
1856     A left- or right-hand split is based on a pair of brackets. Content before
1857     (and including) the opening bracket is left on one line, content inside the
1858     brackets is put on a separate line, and finally content starting with and
1859     following the closing bracket is put on a separate line.
1860
1861     Those are called `head`, `body`, and `tail`, respectively. If the split
1862     produced the same line (all content in `head`) or ended up with an empty `body`
1863     and the `tail` is just the closing bracket, then it's considered failed.
1864     """
1865     tail_len = len(str(tail).strip())
1866     if not body:
1867         if tail_len == 0:
1868             raise CannotSplit("Splitting brackets produced the same line")
1869
1870         elif tail_len < 3:
1871             raise CannotSplit(
1872                 f"Splitting brackets on an empty body to save "
1873                 f"{tail_len} characters is not worth it"
1874             )
1875
1876
1877 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1878     """Normalize prefix of the first leaf in every line returned by `split_func`.
1879
1880     This is a decorator over relevant split functions.
1881     """
1882
1883     @wraps(split_func)
1884     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1885         for l in split_func(line, py36):
1886             normalize_prefix(l.leaves[0], inside_brackets=True)
1887             yield l
1888
1889     return split_wrapper
1890
1891
1892 @dont_increase_indentation
1893 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1894     """Split according to delimiters of the highest priority.
1895
1896     If `py36` is True, the split will add trailing commas also in function
1897     signatures that contain `*` and `**`.
1898     """
1899     try:
1900         last_leaf = line.leaves[-1]
1901     except IndexError:
1902         raise CannotSplit("Line empty")
1903
1904     delimiters = line.bracket_tracker.delimiters
1905     try:
1906         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1907             exclude={id(last_leaf)}
1908         )
1909     except ValueError:
1910         raise CannotSplit("No delimiters found")
1911
1912     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1913     lowest_depth = sys.maxsize
1914     trailing_comma_safe = True
1915
1916     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1917         """Append `leaf` to current line or to new line if appending impossible."""
1918         nonlocal current_line
1919         try:
1920             current_line.append_safe(leaf, preformatted=True)
1921         except ValueError as ve:
1922             yield current_line
1923
1924             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1925             current_line.append(leaf)
1926
1927     for leaf in line.leaves:
1928         yield from append_to_line(leaf)
1929
1930         for comment_after in line.comments_after(leaf):
1931             yield from append_to_line(comment_after)
1932
1933         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1934         if (
1935             leaf.bracket_depth == lowest_depth
1936             and is_vararg(leaf, within=VARARGS_PARENTS)
1937         ):
1938             trailing_comma_safe = trailing_comma_safe and py36
1939         leaf_priority = delimiters.get(id(leaf))
1940         if leaf_priority == delimiter_priority:
1941             yield current_line
1942
1943             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1944     if current_line:
1945         if (
1946             trailing_comma_safe
1947             and delimiter_priority == COMMA_PRIORITY
1948             and current_line.leaves[-1].type != token.COMMA
1949             and current_line.leaves[-1].type != STANDALONE_COMMENT
1950         ):
1951             current_line.append(Leaf(token.COMMA, ","))
1952         yield current_line
1953
1954
1955 @dont_increase_indentation
1956 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1957     """Split standalone comments from the rest of the line."""
1958     if not line.contains_standalone_comments(0):
1959         raise CannotSplit("Line does not have any standalone comments")
1960
1961     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1962
1963     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1964         """Append `leaf` to current line or to new line if appending impossible."""
1965         nonlocal current_line
1966         try:
1967             current_line.append_safe(leaf, preformatted=True)
1968         except ValueError as ve:
1969             yield current_line
1970
1971             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1972             current_line.append(leaf)
1973
1974     for leaf in line.leaves:
1975         yield from append_to_line(leaf)
1976
1977         for comment_after in line.comments_after(leaf):
1978             yield from append_to_line(comment_after)
1979
1980     if current_line:
1981         yield current_line
1982
1983
1984 def explode_split(
1985     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1986 ) -> Iterator[Line]:
1987     """Split by rightmost bracket and immediately split contents by a delimiter."""
1988     new_lines = list(right_hand_split(line, py36, omit))
1989     if len(new_lines) != 3:
1990         yield from new_lines
1991         return
1992
1993     yield new_lines[0]
1994
1995     try:
1996         yield from delimiter_split(new_lines[1], py36)
1997     except CannotSplit:
1998         yield new_lines[1]
1999
2000     yield new_lines[2]
2001
2002
2003 def is_import(leaf: Leaf) -> bool:
2004     """Return True if the given leaf starts an import statement."""
2005     p = leaf.parent
2006     t = leaf.type
2007     v = leaf.value
2008     return bool(
2009         t == token.NAME
2010         and (
2011             (v == "import" and p and p.type == syms.import_name)
2012             or (v == "from" and p and p.type == syms.import_from)
2013         )
2014     )
2015
2016
2017 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2018     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2019     else.
2020
2021     Note: don't use backslashes for formatting or you'll lose your voting rights.
2022     """
2023     if not inside_brackets:
2024         spl = leaf.prefix.split("#")
2025         if "\\" not in spl[0]:
2026             nl_count = spl[-1].count("\n")
2027             if len(spl) > 1:
2028                 nl_count -= 1
2029             leaf.prefix = "\n" * nl_count
2030             return
2031
2032     leaf.prefix = ""
2033
2034
2035 def normalize_string_quotes(leaf: Leaf) -> None:
2036     """Prefer double quotes but only if it doesn't cause more escaping.
2037
2038     Adds or removes backslashes as appropriate. Doesn't parse and fix
2039     strings nested in f-strings (yet).
2040
2041     Note: Mutates its argument.
2042     """
2043     value = leaf.value.lstrip("furbFURB")
2044     if value[:3] == '"""':
2045         return
2046
2047     elif value[:3] == "'''":
2048         orig_quote = "'''"
2049         new_quote = '"""'
2050     elif value[0] == '"':
2051         orig_quote = '"'
2052         new_quote = "'"
2053     else:
2054         orig_quote = "'"
2055         new_quote = '"'
2056     first_quote_pos = leaf.value.find(orig_quote)
2057     if first_quote_pos == -1:
2058         return  # There's an internal error
2059
2060     prefix = leaf.value[:first_quote_pos]
2061     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2062     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2063     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2064     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2065     if "r" in prefix.casefold():
2066         if unescaped_new_quote.search(body):
2067             # There's at least one unescaped new_quote in this raw string
2068             # so converting is impossible
2069             return
2070
2071         # Do not introduce or remove backslashes in raw strings
2072         new_body = body
2073     else:
2074         # remove unnecessary quotes
2075         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2076         if body != new_body:
2077             # Consider the string without unnecessary quotes as the original
2078             body = new_body
2079             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2080         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2081         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2082     if new_quote == '"""' and new_body[-1] == '"':
2083         # edge case:
2084         new_body = new_body[:-1] + '\\"'
2085     orig_escape_count = body.count("\\")
2086     new_escape_count = new_body.count("\\")
2087     if new_escape_count > orig_escape_count:
2088         return  # Do not introduce more escaping
2089
2090     if new_escape_count == orig_escape_count and orig_quote == '"':
2091         return  # Prefer double quotes
2092
2093     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2094
2095
2096 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2097     """Make existing optional parentheses invisible or create new ones.
2098
2099     Standardizes on visible parentheses for single-element tuples, and keeps
2100     existing visible parentheses for other tuples and generator expressions.
2101     """
2102     check_lpar = False
2103     for child in list(node.children):
2104         if check_lpar:
2105             if child.type == syms.atom:
2106                 if not (
2107                     is_empty_tuple(child)
2108                     or is_one_tuple(child)
2109                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2110                 ):
2111                     first = child.children[0]
2112                     last = child.children[-1]
2113                     if first.type == token.LPAR and last.type == token.RPAR:
2114                         # make parentheses invisible
2115                         first.value = ""  # type: ignore
2116                         last.value = ""  # type: ignore
2117             elif is_one_tuple(child):
2118                 # wrap child in visible parentheses
2119                 lpar = Leaf(token.LPAR, "(")
2120                 rpar = Leaf(token.RPAR, ")")
2121                 index = child.remove() or 0
2122                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2123             else:
2124                 # wrap child in invisible parentheses
2125                 lpar = Leaf(token.LPAR, "")
2126                 rpar = Leaf(token.RPAR, "")
2127                 index = child.remove() or 0
2128                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2129
2130         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2131
2132
2133 def is_empty_tuple(node: LN) -> bool:
2134     """Return True if `node` holds an empty tuple."""
2135     return (
2136         node.type == syms.atom
2137         and len(node.children) == 2
2138         and node.children[0].type == token.LPAR
2139         and node.children[1].type == token.RPAR
2140     )
2141
2142
2143 def is_one_tuple(node: LN) -> bool:
2144     """Return True if `node` holds a tuple with one element, with or without parens."""
2145     if node.type == syms.atom:
2146         if len(node.children) != 3:
2147             return False
2148
2149         lpar, gexp, rpar = node.children
2150         if not (
2151             lpar.type == token.LPAR
2152             and gexp.type == syms.testlist_gexp
2153             and rpar.type == token.RPAR
2154         ):
2155             return False
2156
2157         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2158
2159     return (
2160         node.type in IMPLICIT_TUPLE
2161         and len(node.children) == 2
2162         and node.children[1].type == token.COMMA
2163     )
2164
2165
2166 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2167     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2168
2169     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2170     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2171     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2172     generalizations (PEP 448).
2173     """
2174     if leaf.type not in STARS or not leaf.parent:
2175         return False
2176
2177     p = leaf.parent
2178     if p.type == syms.star_expr:
2179         # Star expressions are also used as assignment targets in extended
2180         # iterable unpacking (PEP 3132).  See what its parent is instead.
2181         if not p.parent:
2182             return False
2183
2184         p = p.parent
2185
2186     return p.type in within
2187
2188
2189 def max_delimiter_priority_in_atom(node: LN) -> int:
2190     """Return maximum delimiter priority inside `node`.
2191
2192     This is specific to atoms with contents contained in a pair of parentheses.
2193     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2194     """
2195     if node.type != syms.atom:
2196         return 0
2197
2198     first = node.children[0]
2199     last = node.children[-1]
2200     if not (first.type == token.LPAR and last.type == token.RPAR):
2201         return 0
2202
2203     bt = BracketTracker()
2204     for c in node.children[1:-1]:
2205         if isinstance(c, Leaf):
2206             bt.mark(c)
2207         else:
2208             for leaf in c.leaves():
2209                 bt.mark(leaf)
2210     try:
2211         return bt.max_delimiter_priority()
2212
2213     except ValueError:
2214         return 0
2215
2216
2217 def ensure_visible(leaf: Leaf) -> None:
2218     """Make sure parentheses are visible.
2219
2220     They could be invisible as part of some statements (see
2221     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2222     """
2223     if leaf.type == token.LPAR:
2224         leaf.value = "("
2225     elif leaf.type == token.RPAR:
2226         leaf.value = ")"
2227
2228
2229 def is_python36(node: Node) -> bool:
2230     """Return True if the current file is using Python 3.6+ features.
2231
2232     Currently looking for:
2233     - f-strings; and
2234     - trailing commas after * or ** in function signatures.
2235     """
2236     for n in node.pre_order():
2237         if n.type == token.STRING:
2238             value_head = n.value[:2]  # type: ignore
2239             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2240                 return True
2241
2242         elif (
2243             n.type == syms.typedargslist
2244             and n.children
2245             and n.children[-1].type == token.COMMA
2246         ):
2247             for ch in n.children:
2248                 if ch.type in STARS:
2249                     return True
2250
2251     return False
2252
2253
2254 PYTHON_EXTENSIONS = {".py"}
2255 BLACKLISTED_DIRECTORIES = {
2256     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2257 }
2258
2259
2260 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2261     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2262     and have one of the PYTHON_EXTENSIONS.
2263     """
2264     for child in path.iterdir():
2265         if child.is_dir():
2266             if child.name in BLACKLISTED_DIRECTORIES:
2267                 continue
2268
2269             yield from gen_python_files_in_dir(child)
2270
2271         elif child.suffix in PYTHON_EXTENSIONS:
2272             yield child
2273
2274
2275 @dataclass
2276 class Report:
2277     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2278     check: bool = False
2279     quiet: bool = False
2280     change_count: int = 0
2281     same_count: int = 0
2282     failure_count: int = 0
2283
2284     def done(self, src: Path, changed: Changed) -> None:
2285         """Increment the counter for successful reformatting. Write out a message."""
2286         if changed is Changed.YES:
2287             reformatted = "would reformat" if self.check else "reformatted"
2288             if not self.quiet:
2289                 out(f"{reformatted} {src}")
2290             self.change_count += 1
2291         else:
2292             if not self.quiet:
2293                 if changed is Changed.NO:
2294                     msg = f"{src} already well formatted, good job."
2295                 else:
2296                     msg = f"{src} wasn't modified on disk since last run."
2297                 out(msg, bold=False)
2298             self.same_count += 1
2299
2300     def failed(self, src: Path, message: str) -> None:
2301         """Increment the counter for failed reformatting. Write out a message."""
2302         err(f"error: cannot format {src}: {message}")
2303         self.failure_count += 1
2304
2305     @property
2306     def return_code(self) -> int:
2307         """Return the exit code that the app should use.
2308
2309         This considers the current state of changed files and failures:
2310         - if there were any failures, return 123;
2311         - if any files were changed and --check is being used, return 1;
2312         - otherwise return 0.
2313         """
2314         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2315         # 126 we have special returncodes reserved by the shell.
2316         if self.failure_count:
2317             return 123
2318
2319         elif self.change_count and self.check:
2320             return 1
2321
2322         return 0
2323
2324     def __str__(self) -> str:
2325         """Render a color report of the current state.
2326
2327         Use `click.unstyle` to remove colors.
2328         """
2329         if self.check:
2330             reformatted = "would be reformatted"
2331             unchanged = "would be left unchanged"
2332             failed = "would fail to reformat"
2333         else:
2334             reformatted = "reformatted"
2335             unchanged = "left unchanged"
2336             failed = "failed to reformat"
2337         report = []
2338         if self.change_count:
2339             s = "s" if self.change_count > 1 else ""
2340             report.append(
2341                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2342             )
2343         if self.same_count:
2344             s = "s" if self.same_count > 1 else ""
2345             report.append(f"{self.same_count} file{s} {unchanged}")
2346         if self.failure_count:
2347             s = "s" if self.failure_count > 1 else ""
2348             report.append(
2349                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2350             )
2351         return ", ".join(report) + "."
2352
2353
2354 def assert_equivalent(src: str, dst: str) -> None:
2355     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2356
2357     import ast
2358     import traceback
2359
2360     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2361         """Simple visitor generating strings to compare ASTs by content."""
2362         yield f"{'  ' * depth}{node.__class__.__name__}("
2363
2364         for field in sorted(node._fields):
2365             try:
2366                 value = getattr(node, field)
2367             except AttributeError:
2368                 continue
2369
2370             yield f"{'  ' * (depth+1)}{field}="
2371
2372             if isinstance(value, list):
2373                 for item in value:
2374                     if isinstance(item, ast.AST):
2375                         yield from _v(item, depth + 2)
2376
2377             elif isinstance(value, ast.AST):
2378                 yield from _v(value, depth + 2)
2379
2380             else:
2381                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2382
2383         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2384
2385     try:
2386         src_ast = ast.parse(src)
2387     except Exception as exc:
2388         major, minor = sys.version_info[:2]
2389         raise AssertionError(
2390             f"cannot use --safe with this file; failed to parse source file "
2391             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2392             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2393         )
2394
2395     try:
2396         dst_ast = ast.parse(dst)
2397     except Exception as exc:
2398         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2399         raise AssertionError(
2400             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2401             f"Please report a bug on https://github.com/ambv/black/issues.  "
2402             f"This invalid output might be helpful: {log}"
2403         ) from None
2404
2405     src_ast_str = "\n".join(_v(src_ast))
2406     dst_ast_str = "\n".join(_v(dst_ast))
2407     if src_ast_str != dst_ast_str:
2408         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2409         raise AssertionError(
2410             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2411             f"the source.  "
2412             f"Please report a bug on https://github.com/ambv/black/issues.  "
2413             f"This diff might be helpful: {log}"
2414         ) from None
2415
2416
2417 def assert_stable(src: str, dst: str, line_length: int) -> None:
2418     """Raise AssertionError if `dst` reformats differently the second time."""
2419     newdst = format_str(dst, line_length=line_length)
2420     if dst != newdst:
2421         log = dump_to_file(
2422             diff(src, dst, "source", "first pass"),
2423             diff(dst, newdst, "first pass", "second pass"),
2424         )
2425         raise AssertionError(
2426             f"INTERNAL ERROR: Black produced different code on the second pass "
2427             f"of the formatter.  "
2428             f"Please report a bug on https://github.com/ambv/black/issues.  "
2429             f"This diff might be helpful: {log}"
2430         ) from None
2431
2432
2433 def dump_to_file(*output: str) -> str:
2434     """Dump `output` to a temporary file. Return path to the file."""
2435     import tempfile
2436
2437     with tempfile.NamedTemporaryFile(
2438         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2439     ) as f:
2440         for lines in output:
2441             f.write(lines)
2442             if lines and lines[-1] != "\n":
2443                 f.write("\n")
2444     return f.name
2445
2446
2447 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2448     """Return a unified diff string between strings `a` and `b`."""
2449     import difflib
2450
2451     a_lines = [line + "\n" for line in a.split("\n")]
2452     b_lines = [line + "\n" for line in b.split("\n")]
2453     return "".join(
2454         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2455     )
2456
2457
2458 def cancel(tasks: List[asyncio.Task]) -> None:
2459     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2460     err("Aborted!")
2461     for task in tasks:
2462         task.cancel()
2463
2464
2465 def shutdown(loop: BaseEventLoop) -> None:
2466     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2467     try:
2468         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2469         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2470         if not to_cancel:
2471             return
2472
2473         for task in to_cancel:
2474             task.cancel()
2475         loop.run_until_complete(
2476             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2477         )
2478     finally:
2479         # `concurrent.futures.Future` objects cannot be cancelled once they
2480         # are already running. There might be some when the `shutdown()` happened.
2481         # Silence their logger's spew about the event loop being closed.
2482         cf_logger = logging.getLogger("concurrent.futures")
2483         cf_logger.setLevel(logging.CRITICAL)
2484         loop.close()
2485
2486
2487 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2488     """Replace `regex` with `replacement` twice on `original`.
2489
2490     This is used by string normalization to perform replaces on
2491     overlapping matches.
2492     """
2493     return regex.sub(replacement, regex.sub(replacement, original))
2494
2495
2496 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2497
2498
2499 def get_cache_file(line_length: int) -> Path:
2500     return CACHE_DIR / f"cache.{line_length}.pickle"
2501
2502
2503 def read_cache(line_length: int) -> Cache:
2504     """Read the cache if it exists and is well formed.
2505
2506     If it is not well formed, the call to write_cache later should resolve the issue.
2507     """
2508     cache_file = get_cache_file(line_length)
2509     if not cache_file.exists():
2510         return {}
2511
2512     with cache_file.open("rb") as fobj:
2513         try:
2514             cache: Cache = pickle.load(fobj)
2515         except pickle.UnpicklingError:
2516             return {}
2517
2518     return cache
2519
2520
2521 def get_cache_info(path: Path) -> CacheInfo:
2522     """Return the information used to check if a file is already formatted or not."""
2523     stat = path.stat()
2524     return stat.st_mtime, stat.st_size
2525
2526
2527 def filter_cached(
2528     cache: Cache, sources: Iterable[Path]
2529 ) -> Tuple[List[Path], List[Path]]:
2530     """Split a list of paths into two.
2531
2532     The first list contains paths of files that modified on disk or are not in the
2533     cache. The other list contains paths to non-modified files.
2534     """
2535     todo, done = [], []
2536     for src in sources:
2537         src = src.resolve()
2538         if cache.get(src) != get_cache_info(src):
2539             todo.append(src)
2540         else:
2541             done.append(src)
2542     return todo, done
2543
2544
2545 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2546     """Update the cache file."""
2547     cache_file = get_cache_file(line_length)
2548     try:
2549         if not CACHE_DIR.exists():
2550             CACHE_DIR.mkdir(parents=True)
2551         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2552         with cache_file.open("wb") as fobj:
2553             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2554     except OSError:
2555         pass
2556
2557
2558 if __name__ == "__main__":
2559     main()