]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Allow standalone comments to close code blocks
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     report = Report(check=check, quiet=quiet)
195     if len(sources) == 0:
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache()
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back != WriteBack.DIFF and changed is not Changed.NO:
248                 write_cache(cache, [src])
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache()
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back != WriteBack.DIFF and formatted:
315         write_cache(cache, formatted)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 COMPREHENSION_PRIORITY = 20
585 COMMA_PRIORITY = 10
586 LOGIC_PRIORITY = 5
587 STRING_PRIORITY = 4
588 COMPARATOR_PRIORITY = 3
589 MATH_PRIORITY = 1
590
591
592 @dataclass
593 class BracketTracker:
594     """Keeps track of brackets on a line."""
595
596     depth: int = 0
597     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
598     delimiters: Dict[LeafID, Priority] = Factory(dict)
599     previous: Optional[Leaf] = None
600     _for_loop_variable: bool = False
601     _lambda_arguments: bool = False
602
603     def mark(self, leaf: Leaf) -> None:
604         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
605
606         All leaves receive an int `bracket_depth` field that stores how deep
607         within brackets a given leaf is. 0 means there are no enclosing brackets
608         that started on this line.
609
610         If a leaf is itself a closing bracket, it receives an `opening_bracket`
611         field that it forms a pair with. This is a one-directional link to
612         avoid reference cycles.
613
614         If a leaf is a delimiter (a token on which Black can split the line if
615         needed) and it's on depth 0, its `id()` is stored in the tracker's
616         `delimiters` field.
617         """
618         if leaf.type == token.COMMENT:
619             return
620
621         self.maybe_decrement_after_for_loop_variable(leaf)
622         self.maybe_decrement_after_lambda_arguments(leaf)
623         if leaf.type in CLOSING_BRACKETS:
624             self.depth -= 1
625             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
626             leaf.opening_bracket = opening_bracket
627         leaf.bracket_depth = self.depth
628         if self.depth == 0:
629             delim = is_split_before_delimiter(leaf, self.previous)
630             if delim and self.previous is not None:
631                 self.delimiters[id(self.previous)] = delim
632             else:
633                 delim = is_split_after_delimiter(leaf, self.previous)
634                 if delim:
635                     self.delimiters[id(leaf)] = delim
636         if leaf.type in OPENING_BRACKETS:
637             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
638             self.depth += 1
639         self.previous = leaf
640         self.maybe_increment_lambda_arguments(leaf)
641         self.maybe_increment_for_loop_variable(leaf)
642
643     def any_open_brackets(self) -> bool:
644         """Return True if there is an yet unmatched open bracket on the line."""
645         return bool(self.bracket_match)
646
647     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
648         """Return the highest priority of a delimiter found on the line.
649
650         Values are consistent with what `is_split_*_delimiter()` return.
651         Raises ValueError on no delimiters.
652         """
653         return max(v for k, v in self.delimiters.items() if k not in exclude)
654
655     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
656         """In a for loop, or comprehension, the variables are often unpacks.
657
658         To avoid splitting on the comma in this situation, increase the depth of
659         tokens between `for` and `in`.
660         """
661         if leaf.type == token.NAME and leaf.value == "for":
662             self.depth += 1
663             self._for_loop_variable = True
664             return True
665
666         return False
667
668     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
669         """See `maybe_increment_for_loop_variable` above for explanation."""
670         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
671             self.depth -= 1
672             self._for_loop_variable = False
673             return True
674
675         return False
676
677     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
678         """In a lambda expression, there might be more than one argument.
679
680         To avoid splitting on the comma in this situation, increase the depth of
681         tokens between `lambda` and `:`.
682         """
683         if leaf.type == token.NAME and leaf.value == "lambda":
684             self.depth += 1
685             self._lambda_arguments = True
686             return True
687
688         return False
689
690     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
691         """See `maybe_increment_lambda_arguments` above for explanation."""
692         if self._lambda_arguments and leaf.type == token.COLON:
693             self.depth -= 1
694             self._lambda_arguments = False
695             return True
696
697         return False
698
699
700 @dataclass
701 class Line:
702     """Holds leaves and comments. Can be printed with `str(line)`."""
703
704     depth: int = 0
705     leaves: List[Leaf] = Factory(list)
706     comments: List[Tuple[Index, Leaf]] = Factory(list)
707     bracket_tracker: BracketTracker = Factory(BracketTracker)
708     inside_brackets: bool = False
709
710     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
711         """Add a new `leaf` to the end of the line.
712
713         Unless `preformatted` is True, the `leaf` will receive a new consistent
714         whitespace prefix and metadata applied by :class:`BracketTracker`.
715         Trailing commas are maybe removed, unpacked for loop variables are
716         demoted from being delimiters.
717
718         Inline comments are put aside.
719         """
720         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
721         if not has_value:
722             return
723
724         if self.leaves and not preformatted:
725             # Note: at this point leaf.prefix should be empty except for
726             # imports, for which we only preserve newlines.
727             leaf.prefix += whitespace(leaf)
728         if self.inside_brackets or not preformatted:
729             self.bracket_tracker.mark(leaf)
730             self.maybe_remove_trailing_comma(leaf)
731
732         if not self.append_comment(leaf):
733             self.leaves.append(leaf)
734
735     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
736         """Like :func:`append()` but disallow invalid standalone comment structure.
737
738         Raises ValueError when any `leaf` is appended after a standalone comment
739         or when a standalone comment is not the first leaf on the line.
740         """
741         if self.bracket_tracker.depth == 0:
742             if self.is_comment:
743                 raise ValueError("cannot append to standalone comments")
744
745             if self.leaves and leaf.type == STANDALONE_COMMENT:
746                 raise ValueError(
747                     "cannot append standalone comments to a populated line"
748                 )
749
750         self.append(leaf, preformatted=preformatted)
751
752     @property
753     def is_comment(self) -> bool:
754         """Is this line a standalone comment?"""
755         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
756
757     @property
758     def is_decorator(self) -> bool:
759         """Is this line a decorator?"""
760         return bool(self) and self.leaves[0].type == token.AT
761
762     @property
763     def is_import(self) -> bool:
764         """Is this an import line?"""
765         return bool(self) and is_import(self.leaves[0])
766
767     @property
768     def is_class(self) -> bool:
769         """Is this line a class definition?"""
770         return (
771             bool(self)
772             and self.leaves[0].type == token.NAME
773             and self.leaves[0].value == "class"
774         )
775
776     @property
777     def is_def(self) -> bool:
778         """Is this a function definition? (Also returns True for async defs.)"""
779         try:
780             first_leaf = self.leaves[0]
781         except IndexError:
782             return False
783
784         try:
785             second_leaf: Optional[Leaf] = self.leaves[1]
786         except IndexError:
787             second_leaf = None
788         return (
789             (first_leaf.type == token.NAME and first_leaf.value == "def")
790             or (
791                 first_leaf.type == token.ASYNC
792                 and second_leaf is not None
793                 and second_leaf.type == token.NAME
794                 and second_leaf.value == "def"
795             )
796         )
797
798     @property
799     def is_flow_control(self) -> bool:
800         """Is this line a flow control statement?
801
802         Those are `return`, `raise`, `break`, and `continue`.
803         """
804         return (
805             bool(self)
806             and self.leaves[0].type == token.NAME
807             and self.leaves[0].value in FLOW_CONTROL
808         )
809
810     @property
811     def is_yield(self) -> bool:
812         """Is this line a yield statement?"""
813         return (
814             bool(self)
815             and self.leaves[0].type == token.NAME
816             and self.leaves[0].value == "yield"
817         )
818
819     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
820         """If so, needs to be split before emitting."""
821         for leaf in self.leaves:
822             if leaf.type == STANDALONE_COMMENT:
823                 if leaf.bracket_depth <= depth_limit:
824                     return True
825
826         return False
827
828     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
829         """Remove trailing comma if there is one and it's safe."""
830         if not (
831             self.leaves
832             and self.leaves[-1].type == token.COMMA
833             and closing.type in CLOSING_BRACKETS
834         ):
835             return False
836
837         if closing.type == token.RBRACE:
838             self.remove_trailing_comma()
839             return True
840
841         if closing.type == token.RSQB:
842             comma = self.leaves[-1]
843             if comma.parent and comma.parent.type == syms.listmaker:
844                 self.remove_trailing_comma()
845                 return True
846
847         # For parens let's check if it's safe to remove the comma.  If the
848         # trailing one is the only one, we might mistakenly change a tuple
849         # into a different type by removing the comma.
850         depth = closing.bracket_depth + 1
851         commas = 0
852         opening = closing.opening_bracket
853         for _opening_index, leaf in enumerate(self.leaves):
854             if leaf is opening:
855                 break
856
857         else:
858             return False
859
860         for leaf in self.leaves[_opening_index + 1:]:
861             if leaf is closing:
862                 break
863
864             bracket_depth = leaf.bracket_depth
865             if bracket_depth == depth and leaf.type == token.COMMA:
866                 commas += 1
867                 if leaf.parent and leaf.parent.type == syms.arglist:
868                     commas += 1
869                     break
870
871         if commas > 1:
872             self.remove_trailing_comma()
873             return True
874
875         return False
876
877     def append_comment(self, comment: Leaf) -> bool:
878         """Add an inline or standalone comment to the line."""
879         if (
880             comment.type == STANDALONE_COMMENT
881             and self.bracket_tracker.any_open_brackets()
882         ):
883             comment.prefix = ""
884             return False
885
886         if comment.type != token.COMMENT:
887             return False
888
889         after = len(self.leaves) - 1
890         if after == -1:
891             comment.type = STANDALONE_COMMENT
892             comment.prefix = ""
893             return False
894
895         else:
896             self.comments.append((after, comment))
897             return True
898
899     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
900         """Generate comments that should appear directly after `leaf`."""
901         for _leaf_index, _leaf in enumerate(self.leaves):
902             if leaf is _leaf:
903                 break
904
905         else:
906             return
907
908         for index, comment_after in self.comments:
909             if _leaf_index == index:
910                 yield comment_after
911
912     def remove_trailing_comma(self) -> None:
913         """Remove the trailing comma and moves the comments attached to it."""
914         comma_index = len(self.leaves) - 1
915         for i in range(len(self.comments)):
916             comment_index, comment = self.comments[i]
917             if comment_index == comma_index:
918                 self.comments[i] = (comma_index - 1, comment)
919         self.leaves.pop()
920
921     def __str__(self) -> str:
922         """Render the line."""
923         if not self:
924             return "\n"
925
926         indent = "    " * self.depth
927         leaves = iter(self.leaves)
928         first = next(leaves)
929         res = f"{first.prefix}{indent}{first.value}"
930         for leaf in leaves:
931             res += str(leaf)
932         for _, comment in self.comments:
933             res += str(comment)
934         return res + "\n"
935
936     def __bool__(self) -> bool:
937         """Return True if the line has leaves or comments."""
938         return bool(self.leaves or self.comments)
939
940
941 class UnformattedLines(Line):
942     """Just like :class:`Line` but stores lines which aren't reformatted."""
943
944     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
945         """Just add a new `leaf` to the end of the lines.
946
947         The `preformatted` argument is ignored.
948
949         Keeps track of indentation `depth`, which is useful when the user
950         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
951         """
952         try:
953             list(generate_comments(leaf))
954         except FormatOn as f_on:
955             self.leaves.append(f_on.leaf_from_consumed(leaf))
956             raise
957
958         self.leaves.append(leaf)
959         if leaf.type == token.INDENT:
960             self.depth += 1
961         elif leaf.type == token.DEDENT:
962             self.depth -= 1
963
964     def __str__(self) -> str:
965         """Render unformatted lines from leaves which were added with `append()`.
966
967         `depth` is not used for indentation in this case.
968         """
969         if not self:
970             return "\n"
971
972         res = ""
973         for leaf in self.leaves:
974             res += str(leaf)
975         return res
976
977     def append_comment(self, comment: Leaf) -> bool:
978         """Not implemented in this class. Raises `NotImplementedError`."""
979         raise NotImplementedError("Unformatted lines don't store comments separately.")
980
981     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
982         """Does nothing and returns False."""
983         return False
984
985     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
986         """Does nothing and returns False."""
987         return False
988
989
990 @dataclass
991 class EmptyLineTracker:
992     """Provides a stateful method that returns the number of potential extra
993     empty lines needed before and after the currently processed line.
994
995     Note: this tracker works on lines that haven't been split yet.  It assumes
996     the prefix of the first leaf consists of optional newlines.  Those newlines
997     are consumed by `maybe_empty_lines()` and included in the computation.
998     """
999     previous_line: Optional[Line] = None
1000     previous_after: int = 0
1001     previous_defs: List[int] = Factory(list)
1002
1003     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1004         """Return the number of extra empty lines before and after the `current_line`.
1005
1006         This is for separating `def`, `async def` and `class` with extra empty
1007         lines (two on module-level), as well as providing an extra empty line
1008         after flow control keywords to make them more prominent.
1009         """
1010         if isinstance(current_line, UnformattedLines):
1011             return 0, 0
1012
1013         before, after = self._maybe_empty_lines(current_line)
1014         before -= self.previous_after
1015         self.previous_after = after
1016         self.previous_line = current_line
1017         return before, after
1018
1019     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1020         max_allowed = 1
1021         if current_line.depth == 0:
1022             max_allowed = 2
1023         if current_line.leaves:
1024             # Consume the first leaf's extra newlines.
1025             first_leaf = current_line.leaves[0]
1026             before = first_leaf.prefix.count("\n")
1027             before = min(before, max_allowed)
1028             first_leaf.prefix = ""
1029         else:
1030             before = 0
1031         depth = current_line.depth
1032         while self.previous_defs and self.previous_defs[-1] >= depth:
1033             self.previous_defs.pop()
1034             before = 1 if depth else 2
1035         is_decorator = current_line.is_decorator
1036         if is_decorator or current_line.is_def or current_line.is_class:
1037             if not is_decorator:
1038                 self.previous_defs.append(depth)
1039             if self.previous_line is None:
1040                 # Don't insert empty lines before the first line in the file.
1041                 return 0, 0
1042
1043             if self.previous_line and self.previous_line.is_decorator:
1044                 # Don't insert empty lines between decorators.
1045                 return 0, 0
1046
1047             newlines = 2
1048             if current_line.depth:
1049                 newlines -= 1
1050             return newlines, 0
1051
1052         if current_line.is_flow_control:
1053             return before, 1
1054
1055         if (
1056             self.previous_line
1057             and self.previous_line.is_import
1058             and not current_line.is_import
1059             and depth == self.previous_line.depth
1060         ):
1061             return (before or 1), 0
1062
1063         if (
1064             self.previous_line
1065             and self.previous_line.is_yield
1066             and (not current_line.is_yield or depth != self.previous_line.depth)
1067         ):
1068             return (before or 1), 0
1069
1070         return before, 0
1071
1072
1073 @dataclass
1074 class LineGenerator(Visitor[Line]):
1075     """Generates reformatted Line objects.  Empty lines are not emitted.
1076
1077     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1078     in ways that will no longer stringify to valid Python code on the tree.
1079     """
1080     current_line: Line = Factory(Line)
1081
1082     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1083         """Generate a line.
1084
1085         If the line is empty, only emit if it makes sense.
1086         If the line is too long, split it first and then generate.
1087
1088         If any lines were generated, set up a new current_line.
1089         """
1090         if not self.current_line:
1091             if self.current_line.__class__ == type:
1092                 self.current_line.depth += indent
1093             else:
1094                 self.current_line = type(depth=self.current_line.depth + indent)
1095             return  # Line is empty, don't emit. Creating a new one unnecessary.
1096
1097         complete_line = self.current_line
1098         self.current_line = type(depth=complete_line.depth + indent)
1099         yield complete_line
1100
1101     def visit(self, node: LN) -> Iterator[Line]:
1102         """Main method to visit `node` and its children.
1103
1104         Yields :class:`Line` objects.
1105         """
1106         if isinstance(self.current_line, UnformattedLines):
1107             # File contained `# fmt: off`
1108             yield from self.visit_unformatted(node)
1109
1110         else:
1111             yield from super().visit(node)
1112
1113     def visit_default(self, node: LN) -> Iterator[Line]:
1114         """Default `visit_*()` implementation. Recurses to children of `node`."""
1115         if isinstance(node, Leaf):
1116             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1117             try:
1118                 for comment in generate_comments(node):
1119                     if any_open_brackets:
1120                         # any comment within brackets is subject to splitting
1121                         self.current_line.append(comment)
1122                     elif comment.type == token.COMMENT:
1123                         # regular trailing comment
1124                         self.current_line.append(comment)
1125                         yield from self.line()
1126
1127                     else:
1128                         # regular standalone comment
1129                         yield from self.line()
1130
1131                         self.current_line.append(comment)
1132                         yield from self.line()
1133
1134             except FormatOff as f_off:
1135                 f_off.trim_prefix(node)
1136                 yield from self.line(type=UnformattedLines)
1137                 yield from self.visit(node)
1138
1139             except FormatOn as f_on:
1140                 # This only happens here if somebody says "fmt: on" multiple
1141                 # times in a row.
1142                 f_on.trim_prefix(node)
1143                 yield from self.visit_default(node)
1144
1145             else:
1146                 normalize_prefix(node, inside_brackets=any_open_brackets)
1147                 if node.type == token.STRING:
1148                     normalize_string_quotes(node)
1149                 if node.type not in WHITESPACE:
1150                     self.current_line.append(node)
1151         yield from super().visit_default(node)
1152
1153     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1154         """Increase indentation level, maybe yield a line."""
1155         # In blib2to3 INDENT never holds comments.
1156         yield from self.line(+1)
1157         yield from self.visit_default(node)
1158
1159     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1160         """Decrease indentation level, maybe yield a line."""
1161         # The current line might still wait for trailing comments.  At DEDENT time
1162         # there won't be any (they would be prefixes on the preceding NEWLINE).
1163         # Emit the line then.
1164         yield from self.line()
1165
1166         # While DEDENT has no value, its prefix may contain standalone comments
1167         # that belong to the current indentation level.  Get 'em.
1168         yield from self.visit_default(node)
1169
1170         # Finally, emit the dedent.
1171         yield from self.line(-1)
1172
1173     def visit_stmt(
1174         self, node: Node, keywords: Set[str], parens: Set[str]
1175     ) -> Iterator[Line]:
1176         """Visit a statement.
1177
1178         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1179         `def`, `with`, `class`, and `assert`.
1180
1181         The relevant Python language `keywords` for a given statement will be
1182         NAME leaves within it. This methods puts those on a separate line.
1183
1184         `parens` holds pairs of nodes where invisible parentheses should be put.
1185         Keys hold nodes after which opening parentheses should be put, values
1186         hold nodes before which closing parentheses should be put.
1187         """
1188         normalize_invisible_parens(node, parens_after=parens)
1189         for child in node.children:
1190             if child.type == token.NAME and child.value in keywords:  # type: ignore
1191                 yield from self.line()
1192
1193             yield from self.visit(child)
1194
1195     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1196         """Visit a statement without nested statements."""
1197         is_suite_like = node.parent and node.parent.type in STATEMENT
1198         if is_suite_like:
1199             yield from self.line(+1)
1200             yield from self.visit_default(node)
1201             yield from self.line(-1)
1202
1203         else:
1204             yield from self.line()
1205             yield from self.visit_default(node)
1206
1207     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1208         """Visit `async def`, `async for`, `async with`."""
1209         yield from self.line()
1210
1211         children = iter(node.children)
1212         for child in children:
1213             yield from self.visit(child)
1214
1215             if child.type == token.ASYNC:
1216                 break
1217
1218         internal_stmt = next(children)
1219         for child in internal_stmt.children:
1220             yield from self.visit(child)
1221
1222     def visit_decorators(self, node: Node) -> Iterator[Line]:
1223         """Visit decorators."""
1224         for child in node.children:
1225             yield from self.line()
1226             yield from self.visit(child)
1227
1228     def visit_import_from(self, node: Node) -> Iterator[Line]:
1229         """Visit import_from and maybe put invisible parentheses.
1230
1231         This is separate from `visit_stmt` because import statements don't
1232         support arbitrary atoms and thus handling of parentheses is custom.
1233         """
1234         check_lpar = False
1235         for index, child in enumerate(node.children):
1236             if check_lpar:
1237                 if child.type == token.LPAR:
1238                     # make parentheses invisible
1239                     child.value = ""  # type: ignore
1240                     node.children[-1].value = ""  # type: ignore
1241                 else:
1242                     # insert invisible parentheses
1243                     node.insert_child(index, Leaf(token.LPAR, ""))
1244                     node.append_child(Leaf(token.RPAR, ""))
1245                 break
1246
1247             check_lpar = (
1248                 child.type == token.NAME and child.value == "import"  # type: ignore
1249             )
1250
1251         for child in node.children:
1252             yield from self.visit(child)
1253
1254     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1255         """Remove a semicolon and put the other statement on a separate line."""
1256         yield from self.line()
1257
1258     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1259         """End of file. Process outstanding comments and end with a newline."""
1260         yield from self.visit_default(leaf)
1261         yield from self.line()
1262
1263     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1264         """Used when file contained a `# fmt: off`."""
1265         if isinstance(node, Node):
1266             for child in node.children:
1267                 yield from self.visit(child)
1268
1269         else:
1270             try:
1271                 self.current_line.append(node)
1272             except FormatOn as f_on:
1273                 f_on.trim_prefix(node)
1274                 yield from self.line()
1275                 yield from self.visit(node)
1276
1277             if node.type == token.ENDMARKER:
1278                 # somebody decided not to put a final `# fmt: on`
1279                 yield from self.line()
1280
1281     def __attrs_post_init__(self) -> None:
1282         """You are in a twisty little maze of passages."""
1283         v = self.visit_stmt
1284         Ø: Set[str] = set()
1285         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1286         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1287         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1288         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1289         self.visit_try_stmt = partial(
1290             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1291         )
1292         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1293         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1294         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1295         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1296         self.visit_async_funcdef = self.visit_async_stmt
1297         self.visit_decorated = self.visit_decorators
1298
1299
1300 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1301 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1302 OPENING_BRACKETS = set(BRACKET.keys())
1303 CLOSING_BRACKETS = set(BRACKET.values())
1304 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1305 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1306
1307
1308 def whitespace(leaf: Leaf) -> str:  # noqa C901
1309     """Return whitespace prefix if needed for the given `leaf`."""
1310     NO = ""
1311     SPACE = " "
1312     DOUBLESPACE = "  "
1313     t = leaf.type
1314     p = leaf.parent
1315     v = leaf.value
1316     if t in ALWAYS_NO_SPACE:
1317         return NO
1318
1319     if t == token.COMMENT:
1320         return DOUBLESPACE
1321
1322     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1323     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1324         return NO
1325
1326     prev = leaf.prev_sibling
1327     if not prev:
1328         prevp = preceding_leaf(p)
1329         if not prevp or prevp.type in OPENING_BRACKETS:
1330             return NO
1331
1332         if t == token.COLON:
1333             return SPACE if prevp.type == token.COMMA else NO
1334
1335         if prevp.type == token.EQUAL:
1336             if prevp.parent:
1337                 if prevp.parent.type in {
1338                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1339                 }:
1340                     return NO
1341
1342                 elif prevp.parent.type == syms.typedargslist:
1343                     # A bit hacky: if the equal sign has whitespace, it means we
1344                     # previously found it's a typed argument.  So, we're using
1345                     # that, too.
1346                     return prevp.prefix
1347
1348         elif prevp.type in STARS:
1349             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1350                 return NO
1351
1352         elif prevp.type == token.COLON:
1353             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1354                 return NO
1355
1356         elif (
1357             prevp.parent
1358             and prevp.parent.type == syms.factor
1359             and prevp.type in MATH_OPERATORS
1360         ):
1361             return NO
1362
1363         elif (
1364             prevp.type == token.RIGHTSHIFT
1365             and prevp.parent
1366             and prevp.parent.type == syms.shift_expr
1367             and prevp.prev_sibling
1368             and prevp.prev_sibling.type == token.NAME
1369             and prevp.prev_sibling.value == "print"  # type: ignore
1370         ):
1371             # Python 2 print chevron
1372             return NO
1373
1374     elif prev.type in OPENING_BRACKETS:
1375         return NO
1376
1377     if p.type in {syms.parameters, syms.arglist}:
1378         # untyped function signatures or calls
1379         if not prev or prev.type != token.COMMA:
1380             return NO
1381
1382     elif p.type == syms.varargslist:
1383         # lambdas
1384         if prev and prev.type != token.COMMA:
1385             return NO
1386
1387     elif p.type == syms.typedargslist:
1388         # typed function signatures
1389         if not prev:
1390             return NO
1391
1392         if t == token.EQUAL:
1393             if prev.type != syms.tname:
1394                 return NO
1395
1396         elif prev.type == token.EQUAL:
1397             # A bit hacky: if the equal sign has whitespace, it means we
1398             # previously found it's a typed argument.  So, we're using that, too.
1399             return prev.prefix
1400
1401         elif prev.type != token.COMMA:
1402             return NO
1403
1404     elif p.type == syms.tname:
1405         # type names
1406         if not prev:
1407             prevp = preceding_leaf(p)
1408             if not prevp or prevp.type != token.COMMA:
1409                 return NO
1410
1411     elif p.type == syms.trailer:
1412         # attributes and calls
1413         if t == token.LPAR or t == token.RPAR:
1414             return NO
1415
1416         if not prev:
1417             if t == token.DOT:
1418                 prevp = preceding_leaf(p)
1419                 if not prevp or prevp.type != token.NUMBER:
1420                     return NO
1421
1422             elif t == token.LSQB:
1423                 return NO
1424
1425         elif prev.type != token.COMMA:
1426             return NO
1427
1428     elif p.type == syms.argument:
1429         # single argument
1430         if t == token.EQUAL:
1431             return NO
1432
1433         if not prev:
1434             prevp = preceding_leaf(p)
1435             if not prevp or prevp.type == token.LPAR:
1436                 return NO
1437
1438         elif prev.type in {token.EQUAL} | STARS:
1439             return NO
1440
1441     elif p.type == syms.decorator:
1442         # decorators
1443         return NO
1444
1445     elif p.type == syms.dotted_name:
1446         if prev:
1447             return NO
1448
1449         prevp = preceding_leaf(p)
1450         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1451             return NO
1452
1453     elif p.type == syms.classdef:
1454         if t == token.LPAR:
1455             return NO
1456
1457         if prev and prev.type == token.LPAR:
1458             return NO
1459
1460     elif p.type == syms.subscript:
1461         # indexing
1462         if not prev:
1463             assert p.parent is not None, "subscripts are always parented"
1464             if p.parent.type == syms.subscriptlist:
1465                 return SPACE
1466
1467             return NO
1468
1469         else:
1470             return NO
1471
1472     elif p.type == syms.atom:
1473         if prev and t == token.DOT:
1474             # dots, but not the first one.
1475             return NO
1476
1477     elif p.type == syms.dictsetmaker:
1478         # dict unpacking
1479         if prev and prev.type == token.DOUBLESTAR:
1480             return NO
1481
1482     elif p.type in {syms.factor, syms.star_expr}:
1483         # unary ops
1484         if not prev:
1485             prevp = preceding_leaf(p)
1486             if not prevp or prevp.type in OPENING_BRACKETS:
1487                 return NO
1488
1489             prevp_parent = prevp.parent
1490             assert prevp_parent is not None
1491             if (
1492                 prevp.type == token.COLON
1493                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1494             ):
1495                 return NO
1496
1497             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1498                 return NO
1499
1500         elif t == token.NAME or t == token.NUMBER:
1501             return NO
1502
1503     elif p.type == syms.import_from:
1504         if t == token.DOT:
1505             if prev and prev.type == token.DOT:
1506                 return NO
1507
1508         elif t == token.NAME:
1509             if v == "import":
1510                 return SPACE
1511
1512             if prev and prev.type == token.DOT:
1513                 return NO
1514
1515     elif p.type == syms.sliceop:
1516         return NO
1517
1518     return SPACE
1519
1520
1521 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1522     """Return the first leaf that precedes `node`, if any."""
1523     while node:
1524         res = node.prev_sibling
1525         if res:
1526             if isinstance(res, Leaf):
1527                 return res
1528
1529             try:
1530                 return list(res.leaves())[-1]
1531
1532             except IndexError:
1533                 return None
1534
1535         node = node.parent
1536     return None
1537
1538
1539 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1540     """Return the priority of the `leaf` delimiter, given a line break after it.
1541
1542     The delimiter priorities returned here are from those delimiters that would
1543     cause a line break after themselves.
1544
1545     Higher numbers are higher priority.
1546     """
1547     if leaf.type == token.COMMA:
1548         return COMMA_PRIORITY
1549
1550     return 0
1551
1552
1553 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1554     """Return the priority of the `leaf` delimiter, given a line before after it.
1555
1556     The delimiter priorities returned here are from those delimiters that would
1557     cause a line break before themselves.
1558
1559     Higher numbers are higher priority.
1560     """
1561     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1562         # * and ** might also be MATH_OPERATORS but in this case they are not.
1563         # Don't treat them as a delimiter.
1564         return 0
1565
1566     if (
1567         leaf.type in MATH_OPERATORS
1568         and leaf.parent
1569         and leaf.parent.type not in {syms.factor, syms.star_expr}
1570     ):
1571         return MATH_PRIORITY
1572
1573     if leaf.type in COMPARATORS:
1574         return COMPARATOR_PRIORITY
1575
1576     if (
1577         leaf.type == token.STRING
1578         and previous is not None
1579         and previous.type == token.STRING
1580     ):
1581         return STRING_PRIORITY
1582
1583     if (
1584         leaf.type == token.NAME
1585         and leaf.value == "for"
1586         and leaf.parent
1587         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1588     ):
1589         return COMPREHENSION_PRIORITY
1590
1591     if (
1592         leaf.type == token.NAME
1593         and leaf.value == "if"
1594         and leaf.parent
1595         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1596     ):
1597         return COMPREHENSION_PRIORITY
1598
1599     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1600         return LOGIC_PRIORITY
1601
1602     return 0
1603
1604
1605 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1606     """Clean the prefix of the `leaf` and generate comments from it, if any.
1607
1608     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1609     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1610     move because it does away with modifying the grammar to include all the
1611     possible places in which comments can be placed.
1612
1613     The sad consequence for us though is that comments don't "belong" anywhere.
1614     This is why this function generates simple parentless Leaf objects for
1615     comments.  We simply don't know what the correct parent should be.
1616
1617     No matter though, we can live without this.  We really only need to
1618     differentiate between inline and standalone comments.  The latter don't
1619     share the line with any code.
1620
1621     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1622     are emitted with a fake STANDALONE_COMMENT token identifier.
1623     """
1624     p = leaf.prefix
1625     if not p:
1626         return
1627
1628     if "#" not in p:
1629         return
1630
1631     consumed = 0
1632     nlines = 0
1633     for index, line in enumerate(p.split("\n")):
1634         consumed += len(line) + 1  # adding the length of the split '\n'
1635         line = line.lstrip()
1636         if not line:
1637             nlines += 1
1638         if not line.startswith("#"):
1639             continue
1640
1641         if index == 0 and leaf.type != token.ENDMARKER:
1642             comment_type = token.COMMENT  # simple trailing comment
1643         else:
1644             comment_type = STANDALONE_COMMENT
1645         comment = make_comment(line)
1646         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1647
1648         if comment in {"# fmt: on", "# yapf: enable"}:
1649             raise FormatOn(consumed)
1650
1651         if comment in {"# fmt: off", "# yapf: disable"}:
1652             if comment_type == STANDALONE_COMMENT:
1653                 raise FormatOff(consumed)
1654
1655             prev = preceding_leaf(leaf)
1656             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1657                 raise FormatOff(consumed)
1658
1659         nlines = 0
1660
1661
1662 def make_comment(content: str) -> str:
1663     """Return a consistently formatted comment from the given `content` string.
1664
1665     All comments (except for "##", "#!", "#:") should have a single space between
1666     the hash sign and the content.
1667
1668     If `content` didn't start with a hash sign, one is provided.
1669     """
1670     content = content.rstrip()
1671     if not content:
1672         return "#"
1673
1674     if content[0] == "#":
1675         content = content[1:]
1676     if content and content[0] not in " !:#":
1677         content = " " + content
1678     return "#" + content
1679
1680
1681 def split_line(
1682     line: Line, line_length: int, inner: bool = False, py36: bool = False
1683 ) -> Iterator[Line]:
1684     """Split a `line` into potentially many lines.
1685
1686     They should fit in the allotted `line_length` but might not be able to.
1687     `inner` signifies that there were a pair of brackets somewhere around the
1688     current `line`, possibly transitively. This means we can fallback to splitting
1689     by delimiters if the LHS/RHS don't yield any results.
1690
1691     If `py36` is True, splitting may generate syntax that is only compatible
1692     with Python 3.6 and later.
1693     """
1694     if isinstance(line, UnformattedLines) or line.is_comment:
1695         yield line
1696         return
1697
1698     line_str = str(line).strip("\n")
1699     if (
1700         len(line_str) <= line_length
1701         and "\n" not in line_str  # multiline strings
1702         and not line.contains_standalone_comments()
1703     ):
1704         yield line
1705         return
1706
1707     split_funcs: List[SplitFunc]
1708     if line.is_def:
1709         split_funcs = [left_hand_split]
1710     elif line.inside_brackets:
1711         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1712     else:
1713         split_funcs = [right_hand_split]
1714     for split_func in split_funcs:
1715         # We are accumulating lines in `result` because we might want to abort
1716         # mission and return the original line in the end, or attempt a different
1717         # split altogether.
1718         result: List[Line] = []
1719         try:
1720             for l in split_func(line, py36):
1721                 if str(l).strip("\n") == line_str:
1722                     raise CannotSplit("Split function returned an unchanged result")
1723
1724                 result.extend(
1725                     split_line(l, line_length=line_length, inner=True, py36=py36)
1726                 )
1727         except CannotSplit as cs:
1728             continue
1729
1730         else:
1731             yield from result
1732             break
1733
1734     else:
1735         yield line
1736
1737
1738 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1739     """Split line into many lines, starting with the first matching bracket pair.
1740
1741     Note: this usually looks weird, only use this for function definitions.
1742     Prefer RHS otherwise.
1743     """
1744     head = Line(depth=line.depth)
1745     body = Line(depth=line.depth + 1, inside_brackets=True)
1746     tail = Line(depth=line.depth)
1747     tail_leaves: List[Leaf] = []
1748     body_leaves: List[Leaf] = []
1749     head_leaves: List[Leaf] = []
1750     current_leaves = head_leaves
1751     matching_bracket = None
1752     for leaf in line.leaves:
1753         if (
1754             current_leaves is body_leaves
1755             and leaf.type in CLOSING_BRACKETS
1756             and leaf.opening_bracket is matching_bracket
1757         ):
1758             current_leaves = tail_leaves if body_leaves else head_leaves
1759         current_leaves.append(leaf)
1760         if current_leaves is head_leaves:
1761             if leaf.type in OPENING_BRACKETS:
1762                 matching_bracket = leaf
1763                 current_leaves = body_leaves
1764     # Since body is a new indent level, remove spurious leading whitespace.
1765     if body_leaves:
1766         normalize_prefix(body_leaves[0], inside_brackets=True)
1767     # Build the new lines.
1768     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1769         for leaf in leaves:
1770             result.append(leaf, preformatted=True)
1771             for comment_after in line.comments_after(leaf):
1772                 result.append(comment_after, preformatted=True)
1773     bracket_split_succeeded_or_raise(head, body, tail)
1774     for result in (head, body, tail):
1775         if result:
1776             yield result
1777
1778
1779 def right_hand_split(
1780     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1781 ) -> Iterator[Line]:
1782     """Split line into many lines, starting with the last matching bracket pair."""
1783     head = Line(depth=line.depth)
1784     body = Line(depth=line.depth + 1, inside_brackets=True)
1785     tail = Line(depth=line.depth)
1786     tail_leaves: List[Leaf] = []
1787     body_leaves: List[Leaf] = []
1788     head_leaves: List[Leaf] = []
1789     current_leaves = tail_leaves
1790     opening_bracket = None
1791     closing_bracket = None
1792     for leaf in reversed(line.leaves):
1793         if current_leaves is body_leaves:
1794             if leaf is opening_bracket:
1795                 current_leaves = head_leaves if body_leaves else tail_leaves
1796         current_leaves.append(leaf)
1797         if current_leaves is tail_leaves:
1798             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1799                 opening_bracket = leaf.opening_bracket
1800                 closing_bracket = leaf
1801                 current_leaves = body_leaves
1802     tail_leaves.reverse()
1803     body_leaves.reverse()
1804     head_leaves.reverse()
1805     # Since body is a new indent level, remove spurious leading whitespace.
1806     if body_leaves:
1807         normalize_prefix(body_leaves[0], inside_brackets=True)
1808     elif not head_leaves:
1809         # No `head` and no `body` means the split failed. `tail` has all content.
1810         raise CannotSplit("No brackets found")
1811
1812     # Build the new lines.
1813     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1814         for leaf in leaves:
1815             result.append(leaf, preformatted=True)
1816             for comment_after in line.comments_after(leaf):
1817                 result.append(comment_after, preformatted=True)
1818     bracket_split_succeeded_or_raise(head, body, tail)
1819     assert opening_bracket and closing_bracket
1820     if (
1821         opening_bracket.type == token.LPAR
1822         and not opening_bracket.value
1823         and closing_bracket.type == token.RPAR
1824         and not closing_bracket.value
1825     ):
1826         # These parens were optional. If there aren't any delimiters or standalone
1827         # comments in the body, they were unnecessary and another split without
1828         # them should be attempted.
1829         if not (
1830             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1831         ):
1832             omit = {id(closing_bracket), *omit}
1833             yield from right_hand_split(line, py36=py36, omit=omit)
1834             return
1835
1836     ensure_visible(opening_bracket)
1837     ensure_visible(closing_bracket)
1838     for result in (head, body, tail):
1839         if result:
1840             yield result
1841
1842
1843 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1844     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1845
1846     Do nothing otherwise.
1847
1848     A left- or right-hand split is based on a pair of brackets. Content before
1849     (and including) the opening bracket is left on one line, content inside the
1850     brackets is put on a separate line, and finally content starting with and
1851     following the closing bracket is put on a separate line.
1852
1853     Those are called `head`, `body`, and `tail`, respectively. If the split
1854     produced the same line (all content in `head`) or ended up with an empty `body`
1855     and the `tail` is just the closing bracket, then it's considered failed.
1856     """
1857     tail_len = len(str(tail).strip())
1858     if not body:
1859         if tail_len == 0:
1860             raise CannotSplit("Splitting brackets produced the same line")
1861
1862         elif tail_len < 3:
1863             raise CannotSplit(
1864                 f"Splitting brackets on an empty body to save "
1865                 f"{tail_len} characters is not worth it"
1866             )
1867
1868
1869 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1870     """Normalize prefix of the first leaf in every line returned by `split_func`.
1871
1872     This is a decorator over relevant split functions.
1873     """
1874
1875     @wraps(split_func)
1876     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1877         for l in split_func(line, py36):
1878             normalize_prefix(l.leaves[0], inside_brackets=True)
1879             yield l
1880
1881     return split_wrapper
1882
1883
1884 @dont_increase_indentation
1885 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1886     """Split according to delimiters of the highest priority.
1887
1888     If `py36` is True, the split will add trailing commas also in function
1889     signatures that contain `*` and `**`.
1890     """
1891     try:
1892         last_leaf = line.leaves[-1]
1893     except IndexError:
1894         raise CannotSplit("Line empty")
1895
1896     delimiters = line.bracket_tracker.delimiters
1897     try:
1898         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1899             exclude={id(last_leaf)}
1900         )
1901     except ValueError:
1902         raise CannotSplit("No delimiters found")
1903
1904     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1905     lowest_depth = sys.maxsize
1906     trailing_comma_safe = True
1907
1908     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1909         """Append `leaf` to current line or to new line if appending impossible."""
1910         nonlocal current_line
1911         try:
1912             current_line.append_safe(leaf, preformatted=True)
1913         except ValueError as ve:
1914             yield current_line
1915
1916             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1917             current_line.append(leaf)
1918
1919     for leaf in line.leaves:
1920         yield from append_to_line(leaf)
1921
1922         for comment_after in line.comments_after(leaf):
1923             yield from append_to_line(comment_after)
1924
1925         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1926         if (
1927             leaf.bracket_depth == lowest_depth
1928             and is_vararg(leaf, within=VARARGS_PARENTS)
1929         ):
1930             trailing_comma_safe = trailing_comma_safe and py36
1931         leaf_priority = delimiters.get(id(leaf))
1932         if leaf_priority == delimiter_priority:
1933             yield current_line
1934
1935             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1936     if current_line:
1937         if (
1938             trailing_comma_safe
1939             and delimiter_priority == COMMA_PRIORITY
1940             and current_line.leaves[-1].type != token.COMMA
1941             and current_line.leaves[-1].type != STANDALONE_COMMENT
1942         ):
1943             current_line.append(Leaf(token.COMMA, ","))
1944         yield current_line
1945
1946
1947 @dont_increase_indentation
1948 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1949     """Split standalone comments from the rest of the line."""
1950     if not line.contains_standalone_comments(0):
1951         raise CannotSplit("Line does not have any standalone comments")
1952
1953     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1954
1955     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1956         """Append `leaf` to current line or to new line if appending impossible."""
1957         nonlocal current_line
1958         try:
1959             current_line.append_safe(leaf, preformatted=True)
1960         except ValueError as ve:
1961             yield current_line
1962
1963             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1964             current_line.append(leaf)
1965
1966     for leaf in line.leaves:
1967         yield from append_to_line(leaf)
1968
1969         for comment_after in line.comments_after(leaf):
1970             yield from append_to_line(comment_after)
1971
1972     if current_line:
1973         yield current_line
1974
1975
1976 def is_import(leaf: Leaf) -> bool:
1977     """Return True if the given leaf starts an import statement."""
1978     p = leaf.parent
1979     t = leaf.type
1980     v = leaf.value
1981     return bool(
1982         t == token.NAME
1983         and (
1984             (v == "import" and p and p.type == syms.import_name)
1985             or (v == "from" and p and p.type == syms.import_from)
1986         )
1987     )
1988
1989
1990 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1991     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1992     else.
1993
1994     Note: don't use backslashes for formatting or you'll lose your voting rights.
1995     """
1996     if not inside_brackets:
1997         spl = leaf.prefix.split("#")
1998         if "\\" not in spl[0]:
1999             nl_count = spl[-1].count("\n")
2000             if len(spl) > 1:
2001                 nl_count -= 1
2002             leaf.prefix = "\n" * nl_count
2003             return
2004
2005     leaf.prefix = ""
2006
2007
2008 def normalize_string_quotes(leaf: Leaf) -> None:
2009     """Prefer double quotes but only if it doesn't cause more escaping.
2010
2011     Adds or removes backslashes as appropriate. Doesn't parse and fix
2012     strings nested in f-strings (yet).
2013
2014     Note: Mutates its argument.
2015     """
2016     value = leaf.value.lstrip("furbFURB")
2017     if value[:3] == '"""':
2018         return
2019
2020     elif value[:3] == "'''":
2021         orig_quote = "'''"
2022         new_quote = '"""'
2023     elif value[0] == '"':
2024         orig_quote = '"'
2025         new_quote = "'"
2026     else:
2027         orig_quote = "'"
2028         new_quote = '"'
2029     first_quote_pos = leaf.value.find(orig_quote)
2030     if first_quote_pos == -1:
2031         return  # There's an internal error
2032
2033     prefix = leaf.value[:first_quote_pos]
2034     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2035     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2036     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2037     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2038     if "r" in prefix.casefold():
2039         if unescaped_new_quote.search(body):
2040             # There's at least one unescaped new_quote in this raw string
2041             # so converting is impossible
2042             return
2043
2044         # Do not introduce or remove backslashes in raw strings
2045         new_body = body
2046     else:
2047         # remove unnecessary quotes
2048         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2049         if body != new_body:
2050             # Consider the string without unnecessary quotes as the original
2051             body = new_body
2052             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2053         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2054         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2055     if new_quote == '"""' and new_body[-1] == '"':
2056         # edge case:
2057         new_body = new_body[:-1] + '\\"'
2058     orig_escape_count = body.count("\\")
2059     new_escape_count = new_body.count("\\")
2060     if new_escape_count > orig_escape_count:
2061         return  # Do not introduce more escaping
2062
2063     if new_escape_count == orig_escape_count and orig_quote == '"':
2064         return  # Prefer double quotes
2065
2066     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2067
2068
2069 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2070     """Make existing optional parentheses invisible or create new ones.
2071
2072     Standardizes on visible parentheses for single-element tuples, and keeps
2073     existing visible parentheses for other tuples and generator expressions.
2074     """
2075     check_lpar = False
2076     for child in list(node.children):
2077         if check_lpar:
2078             if child.type == syms.atom:
2079                 if not (
2080                     is_empty_tuple(child)
2081                     or is_one_tuple(child)
2082                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2083                 ):
2084                     first = child.children[0]
2085                     last = child.children[-1]
2086                     if first.type == token.LPAR and last.type == token.RPAR:
2087                         # make parentheses invisible
2088                         first.value = ""  # type: ignore
2089                         last.value = ""  # type: ignore
2090             elif is_one_tuple(child):
2091                 # wrap child in visible parentheses
2092                 lpar = Leaf(token.LPAR, "(")
2093                 rpar = Leaf(token.RPAR, ")")
2094                 index = child.remove() or 0
2095                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2096             else:
2097                 # wrap child in invisible parentheses
2098                 lpar = Leaf(token.LPAR, "")
2099                 rpar = Leaf(token.RPAR, "")
2100                 index = child.remove() or 0
2101                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2102
2103         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2104
2105
2106 def is_empty_tuple(node: LN) -> bool:
2107     """Return True if `node` holds an empty tuple."""
2108     return (
2109         node.type == syms.atom
2110         and len(node.children) == 2
2111         and node.children[0].type == token.LPAR
2112         and node.children[1].type == token.RPAR
2113     )
2114
2115
2116 def is_one_tuple(node: LN) -> bool:
2117     """Return True if `node` holds a tuple with one element, with or without parens."""
2118     if node.type == syms.atom:
2119         if len(node.children) != 3:
2120             return False
2121
2122         lpar, gexp, rpar = node.children
2123         if not (
2124             lpar.type == token.LPAR
2125             and gexp.type == syms.testlist_gexp
2126             and rpar.type == token.RPAR
2127         ):
2128             return False
2129
2130         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2131
2132     return (
2133         node.type in IMPLICIT_TUPLE
2134         and len(node.children) == 2
2135         and node.children[1].type == token.COMMA
2136     )
2137
2138
2139 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2140     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2141
2142     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2143     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2144     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2145     generalizations (PEP 448).
2146     """
2147     if leaf.type not in STARS or not leaf.parent:
2148         return False
2149
2150     p = leaf.parent
2151     if p.type == syms.star_expr:
2152         # Star expressions are also used as assignment targets in extended
2153         # iterable unpacking (PEP 3132).  See what its parent is instead.
2154         if not p.parent:
2155             return False
2156
2157         p = p.parent
2158
2159     return p.type in within
2160
2161
2162 def max_delimiter_priority_in_atom(node: LN) -> int:
2163     """Return maximum delimiter priority inside `node`.
2164
2165     This is specific to atoms with contents contained in a pair of parentheses.
2166     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2167     """
2168     if node.type != syms.atom:
2169         return 0
2170
2171     first = node.children[0]
2172     last = node.children[-1]
2173     if not (first.type == token.LPAR and last.type == token.RPAR):
2174         return 0
2175
2176     bt = BracketTracker()
2177     for c in node.children[1:-1]:
2178         if isinstance(c, Leaf):
2179             bt.mark(c)
2180         else:
2181             for leaf in c.leaves():
2182                 bt.mark(leaf)
2183     try:
2184         return bt.max_delimiter_priority()
2185
2186     except ValueError:
2187         return 0
2188
2189
2190 def ensure_visible(leaf: Leaf) -> None:
2191     """Make sure parentheses are visible.
2192
2193     They could be invisible as part of some statements (see
2194     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2195     """
2196     if leaf.type == token.LPAR:
2197         leaf.value = "("
2198     elif leaf.type == token.RPAR:
2199         leaf.value = ")"
2200
2201
2202 def is_python36(node: Node) -> bool:
2203     """Return True if the current file is using Python 3.6+ features.
2204
2205     Currently looking for:
2206     - f-strings; and
2207     - trailing commas after * or ** in function signatures.
2208     """
2209     for n in node.pre_order():
2210         if n.type == token.STRING:
2211             value_head = n.value[:2]  # type: ignore
2212             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2213                 return True
2214
2215         elif (
2216             n.type == syms.typedargslist
2217             and n.children
2218             and n.children[-1].type == token.COMMA
2219         ):
2220             for ch in n.children:
2221                 if ch.type in STARS:
2222                     return True
2223
2224     return False
2225
2226
2227 PYTHON_EXTENSIONS = {".py"}
2228 BLACKLISTED_DIRECTORIES = {
2229     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2230 }
2231
2232
2233 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2234     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2235     and have one of the PYTHON_EXTENSIONS.
2236     """
2237     for child in path.iterdir():
2238         if child.is_dir():
2239             if child.name in BLACKLISTED_DIRECTORIES:
2240                 continue
2241
2242             yield from gen_python_files_in_dir(child)
2243
2244         elif child.suffix in PYTHON_EXTENSIONS:
2245             yield child
2246
2247
2248 @dataclass
2249 class Report:
2250     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2251     check: bool = False
2252     quiet: bool = False
2253     change_count: int = 0
2254     same_count: int = 0
2255     failure_count: int = 0
2256
2257     def done(self, src: Path, changed: Changed) -> None:
2258         """Increment the counter for successful reformatting. Write out a message."""
2259         if changed is Changed.YES:
2260             reformatted = "would reformat" if self.check else "reformatted"
2261             if not self.quiet:
2262                 out(f"{reformatted} {src}")
2263             self.change_count += 1
2264         else:
2265             if not self.quiet:
2266                 if changed is Changed.NO:
2267                     msg = f"{src} already well formatted, good job."
2268                 else:
2269                     msg = f"{src} wasn't modified on disk since last run."
2270                 out(msg, bold=False)
2271             self.same_count += 1
2272
2273     def failed(self, src: Path, message: str) -> None:
2274         """Increment the counter for failed reformatting. Write out a message."""
2275         err(f"error: cannot format {src}: {message}")
2276         self.failure_count += 1
2277
2278     @property
2279     def return_code(self) -> int:
2280         """Return the exit code that the app should use.
2281
2282         This considers the current state of changed files and failures:
2283         - if there were any failures, return 123;
2284         - if any files were changed and --check is being used, return 1;
2285         - otherwise return 0.
2286         """
2287         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2288         # 126 we have special returncodes reserved by the shell.
2289         if self.failure_count:
2290             return 123
2291
2292         elif self.change_count and self.check:
2293             return 1
2294
2295         return 0
2296
2297     def __str__(self) -> str:
2298         """Render a color report of the current state.
2299
2300         Use `click.unstyle` to remove colors.
2301         """
2302         if self.check:
2303             reformatted = "would be reformatted"
2304             unchanged = "would be left unchanged"
2305             failed = "would fail to reformat"
2306         else:
2307             reformatted = "reformatted"
2308             unchanged = "left unchanged"
2309             failed = "failed to reformat"
2310         report = []
2311         if self.change_count:
2312             s = "s" if self.change_count > 1 else ""
2313             report.append(
2314                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2315             )
2316         if self.same_count:
2317             s = "s" if self.same_count > 1 else ""
2318             report.append(f"{self.same_count} file{s} {unchanged}")
2319         if self.failure_count:
2320             s = "s" if self.failure_count > 1 else ""
2321             report.append(
2322                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2323             )
2324         return ", ".join(report) + "."
2325
2326
2327 def assert_equivalent(src: str, dst: str) -> None:
2328     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2329
2330     import ast
2331     import traceback
2332
2333     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2334         """Simple visitor generating strings to compare ASTs by content."""
2335         yield f"{'  ' * depth}{node.__class__.__name__}("
2336
2337         for field in sorted(node._fields):
2338             try:
2339                 value = getattr(node, field)
2340             except AttributeError:
2341                 continue
2342
2343             yield f"{'  ' * (depth+1)}{field}="
2344
2345             if isinstance(value, list):
2346                 for item in value:
2347                     if isinstance(item, ast.AST):
2348                         yield from _v(item, depth + 2)
2349
2350             elif isinstance(value, ast.AST):
2351                 yield from _v(value, depth + 2)
2352
2353             else:
2354                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2355
2356         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2357
2358     try:
2359         src_ast = ast.parse(src)
2360     except Exception as exc:
2361         major, minor = sys.version_info[:2]
2362         raise AssertionError(
2363             f"cannot use --safe with this file; failed to parse source file "
2364             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2365             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2366         )
2367
2368     try:
2369         dst_ast = ast.parse(dst)
2370     except Exception as exc:
2371         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2372         raise AssertionError(
2373             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2374             f"Please report a bug on https://github.com/ambv/black/issues.  "
2375             f"This invalid output might be helpful: {log}"
2376         ) from None
2377
2378     src_ast_str = "\n".join(_v(src_ast))
2379     dst_ast_str = "\n".join(_v(dst_ast))
2380     if src_ast_str != dst_ast_str:
2381         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2382         raise AssertionError(
2383             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2384             f"the source.  "
2385             f"Please report a bug on https://github.com/ambv/black/issues.  "
2386             f"This diff might be helpful: {log}"
2387         ) from None
2388
2389
2390 def assert_stable(src: str, dst: str, line_length: int) -> None:
2391     """Raise AssertionError if `dst` reformats differently the second time."""
2392     newdst = format_str(dst, line_length=line_length)
2393     if dst != newdst:
2394         log = dump_to_file(
2395             diff(src, dst, "source", "first pass"),
2396             diff(dst, newdst, "first pass", "second pass"),
2397         )
2398         raise AssertionError(
2399             f"INTERNAL ERROR: Black produced different code on the second pass "
2400             f"of the formatter.  "
2401             f"Please report a bug on https://github.com/ambv/black/issues.  "
2402             f"This diff might be helpful: {log}"
2403         ) from None
2404
2405
2406 def dump_to_file(*output: str) -> str:
2407     """Dump `output` to a temporary file. Return path to the file."""
2408     import tempfile
2409
2410     with tempfile.NamedTemporaryFile(
2411         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2412     ) as f:
2413         for lines in output:
2414             f.write(lines)
2415             if lines and lines[-1] != "\n":
2416                 f.write("\n")
2417     return f.name
2418
2419
2420 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2421     """Return a unified diff string between strings `a` and `b`."""
2422     import difflib
2423
2424     a_lines = [line + "\n" for line in a.split("\n")]
2425     b_lines = [line + "\n" for line in b.split("\n")]
2426     return "".join(
2427         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2428     )
2429
2430
2431 def cancel(tasks: List[asyncio.Task]) -> None:
2432     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2433     err("Aborted!")
2434     for task in tasks:
2435         task.cancel()
2436
2437
2438 def shutdown(loop: BaseEventLoop) -> None:
2439     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2440     try:
2441         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2442         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2443         if not to_cancel:
2444             return
2445
2446         for task in to_cancel:
2447             task.cancel()
2448         loop.run_until_complete(
2449             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2450         )
2451     finally:
2452         # `concurrent.futures.Future` objects cannot be cancelled once they
2453         # are already running. There might be some when the `shutdown()` happened.
2454         # Silence their logger's spew about the event loop being closed.
2455         cf_logger = logging.getLogger("concurrent.futures")
2456         cf_logger.setLevel(logging.CRITICAL)
2457         loop.close()
2458
2459
2460 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2461     """Replace `regex` with `replacement` twice on `original`.
2462
2463     This is used by string normalization to perform replaces on
2464     overlapping matches.
2465     """
2466     return regex.sub(replacement, regex.sub(replacement, original))
2467
2468
2469 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2470 CACHE_FILE = CACHE_DIR / "cache.pickle"
2471
2472
2473 def read_cache() -> Cache:
2474     """Read the cache if it exists and is well formed.
2475
2476     If it is not well formed, the call to write_cache later should resolve the issue.
2477     """
2478     if not CACHE_FILE.exists():
2479         return {}
2480
2481     with CACHE_FILE.open("rb") as fobj:
2482         try:
2483             cache: Cache = pickle.load(fobj)
2484         except pickle.UnpicklingError:
2485             return {}
2486
2487     return cache
2488
2489
2490 def get_cache_info(path: Path) -> CacheInfo:
2491     """Return the information used to check if a file is already formatted or not."""
2492     stat = path.stat()
2493     return stat.st_mtime, stat.st_size
2494
2495
2496 def filter_cached(
2497     cache: Cache, sources: Iterable[Path]
2498 ) -> Tuple[List[Path], List[Path]]:
2499     """Split a list of paths into two.
2500
2501     The first list contains paths of files that modified on disk or are not in the
2502     cache. The other list contains paths to non-modified files.
2503     """
2504     todo, done = [], []
2505     for src in sources:
2506         src = src.resolve()
2507         if cache.get(src) != get_cache_info(src):
2508             todo.append(src)
2509         else:
2510             done.append(src)
2511     return todo, done
2512
2513
2514 def write_cache(cache: Cache, sources: List[Path]) -> None:
2515     """Update the cache file."""
2516     try:
2517         if not CACHE_DIR.exists():
2518             CACHE_DIR.mkdir(parents=True)
2519         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2520         with CACHE_FILE.open("wb") as fobj:
2521             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2522     except OSError:
2523         pass
2524
2525
2526 if __name__ == "__main__":
2527     main()