]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[#154] Handle comments between decorators properly (#166)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     report = Report(check=check, quiet=quiet)
195     if len(sources) == 0:
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache()
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back != WriteBack.DIFF and changed is not Changed.NO:
248                 write_cache(cache, [src])
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache()
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back != WriteBack.DIFF and formatted:
315         write_cache(cache, formatted)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.PLUS,
557     token.MINUS,
558     token.STAR,
559     token.SLASH,
560     token.VBAR,
561     token.AMPER,
562     token.PERCENT,
563     token.CIRCUMFLEX,
564     token.TILDE,
565     token.LEFTSHIFT,
566     token.RIGHTSHIFT,
567     token.DOUBLESTAR,
568     token.DOUBLESLASH,
569 }
570 STARS = {token.STAR, token.DOUBLESTAR}
571 VARARGS_PARENTS = {
572     syms.arglist,
573     syms.argument,  # double star in arglist
574     syms.trailer,  # single argument to call
575     syms.typedargslist,
576     syms.varargslist,  # lambdas
577 }
578 UNPACKING_PARENTS = {
579     syms.atom,  # single element of a list or set literal
580     syms.dictsetmaker,
581     syms.listmaker,
582     syms.testlist_gexp,
583 }
584 COMPREHENSION_PRIORITY = 20
585 COMMA_PRIORITY = 10
586 LOGIC_PRIORITY = 5
587 STRING_PRIORITY = 4
588 COMPARATOR_PRIORITY = 3
589 MATH_PRIORITY = 1
590
591
592 @dataclass
593 class BracketTracker:
594     """Keeps track of brackets on a line."""
595
596     depth: int = 0
597     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
598     delimiters: Dict[LeafID, Priority] = Factory(dict)
599     previous: Optional[Leaf] = None
600     _for_loop_variable: bool = False
601     _lambda_arguments: bool = False
602
603     def mark(self, leaf: Leaf) -> None:
604         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
605
606         All leaves receive an int `bracket_depth` field that stores how deep
607         within brackets a given leaf is. 0 means there are no enclosing brackets
608         that started on this line.
609
610         If a leaf is itself a closing bracket, it receives an `opening_bracket`
611         field that it forms a pair with. This is a one-directional link to
612         avoid reference cycles.
613
614         If a leaf is a delimiter (a token on which Black can split the line if
615         needed) and it's on depth 0, its `id()` is stored in the tracker's
616         `delimiters` field.
617         """
618         if leaf.type == token.COMMENT:
619             return
620
621         self.maybe_decrement_after_for_loop_variable(leaf)
622         self.maybe_decrement_after_lambda_arguments(leaf)
623         if leaf.type in CLOSING_BRACKETS:
624             self.depth -= 1
625             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
626             leaf.opening_bracket = opening_bracket
627         leaf.bracket_depth = self.depth
628         if self.depth == 0:
629             delim = is_split_before_delimiter(leaf, self.previous)
630             if delim and self.previous is not None:
631                 self.delimiters[id(self.previous)] = delim
632             else:
633                 delim = is_split_after_delimiter(leaf, self.previous)
634                 if delim:
635                     self.delimiters[id(leaf)] = delim
636         if leaf.type in OPENING_BRACKETS:
637             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
638             self.depth += 1
639         self.previous = leaf
640         self.maybe_increment_lambda_arguments(leaf)
641         self.maybe_increment_for_loop_variable(leaf)
642
643     def any_open_brackets(self) -> bool:
644         """Return True if there is an yet unmatched open bracket on the line."""
645         return bool(self.bracket_match)
646
647     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
648         """Return the highest priority of a delimiter found on the line.
649
650         Values are consistent with what `is_split_*_delimiter()` return.
651         Raises ValueError on no delimiters.
652         """
653         return max(v for k, v in self.delimiters.items() if k not in exclude)
654
655     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
656         """In a for loop, or comprehension, the variables are often unpacks.
657
658         To avoid splitting on the comma in this situation, increase the depth of
659         tokens between `for` and `in`.
660         """
661         if leaf.type == token.NAME and leaf.value == "for":
662             self.depth += 1
663             self._for_loop_variable = True
664             return True
665
666         return False
667
668     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
669         """See `maybe_increment_for_loop_variable` above for explanation."""
670         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
671             self.depth -= 1
672             self._for_loop_variable = False
673             return True
674
675         return False
676
677     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
678         """In a lambda expression, there might be more than one argument.
679
680         To avoid splitting on the comma in this situation, increase the depth of
681         tokens between `lambda` and `:`.
682         """
683         if leaf.type == token.NAME and leaf.value == "lambda":
684             self.depth += 1
685             self._lambda_arguments = True
686             return True
687
688         return False
689
690     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
691         """See `maybe_increment_lambda_arguments` above for explanation."""
692         if self._lambda_arguments and leaf.type == token.COLON:
693             self.depth -= 1
694             self._lambda_arguments = False
695             return True
696
697         return False
698
699
700 @dataclass
701 class Line:
702     """Holds leaves and comments. Can be printed with `str(line)`."""
703
704     depth: int = 0
705     leaves: List[Leaf] = Factory(list)
706     comments: List[Tuple[Index, Leaf]] = Factory(list)
707     bracket_tracker: BracketTracker = Factory(BracketTracker)
708     inside_brackets: bool = False
709
710     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
711         """Add a new `leaf` to the end of the line.
712
713         Unless `preformatted` is True, the `leaf` will receive a new consistent
714         whitespace prefix and metadata applied by :class:`BracketTracker`.
715         Trailing commas are maybe removed, unpacked for loop variables are
716         demoted from being delimiters.
717
718         Inline comments are put aside.
719         """
720         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
721         if not has_value:
722             return
723
724         if self.leaves and not preformatted:
725             # Note: at this point leaf.prefix should be empty except for
726             # imports, for which we only preserve newlines.
727             leaf.prefix += whitespace(leaf)
728         if self.inside_brackets or not preformatted:
729             self.bracket_tracker.mark(leaf)
730             self.maybe_remove_trailing_comma(leaf)
731
732         if not self.append_comment(leaf):
733             self.leaves.append(leaf)
734
735     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
736         """Like :func:`append()` but disallow invalid standalone comment structure.
737
738         Raises ValueError when any `leaf` is appended after a standalone comment
739         or when a standalone comment is not the first leaf on the line.
740         """
741         if self.bracket_tracker.depth == 0:
742             if self.is_comment:
743                 raise ValueError("cannot append to standalone comments")
744
745             if self.leaves and leaf.type == STANDALONE_COMMENT:
746                 raise ValueError(
747                     "cannot append standalone comments to a populated line"
748                 )
749
750         self.append(leaf, preformatted=preformatted)
751
752     @property
753     def is_comment(self) -> bool:
754         """Is this line a standalone comment?"""
755         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
756
757     @property
758     def is_decorator(self) -> bool:
759         """Is this line a decorator?"""
760         return bool(self) and self.leaves[0].type == token.AT
761
762     @property
763     def is_import(self) -> bool:
764         """Is this an import line?"""
765         return bool(self) and is_import(self.leaves[0])
766
767     @property
768     def is_class(self) -> bool:
769         """Is this line a class definition?"""
770         return (
771             bool(self)
772             and self.leaves[0].type == token.NAME
773             and self.leaves[0].value == "class"
774         )
775
776     @property
777     def is_def(self) -> bool:
778         """Is this a function definition? (Also returns True for async defs.)"""
779         try:
780             first_leaf = self.leaves[0]
781         except IndexError:
782             return False
783
784         try:
785             second_leaf: Optional[Leaf] = self.leaves[1]
786         except IndexError:
787             second_leaf = None
788         return (
789             (first_leaf.type == token.NAME and first_leaf.value == "def")
790             or (
791                 first_leaf.type == token.ASYNC
792                 and second_leaf is not None
793                 and second_leaf.type == token.NAME
794                 and second_leaf.value == "def"
795             )
796         )
797
798     @property
799     def is_flow_control(self) -> bool:
800         """Is this line a flow control statement?
801
802         Those are `return`, `raise`, `break`, and `continue`.
803         """
804         return (
805             bool(self)
806             and self.leaves[0].type == token.NAME
807             and self.leaves[0].value in FLOW_CONTROL
808         )
809
810     @property
811     def is_yield(self) -> bool:
812         """Is this line a yield statement?"""
813         return (
814             bool(self)
815             and self.leaves[0].type == token.NAME
816             and self.leaves[0].value == "yield"
817         )
818
819     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
820         """If so, needs to be split before emitting."""
821         for leaf in self.leaves:
822             if leaf.type == STANDALONE_COMMENT:
823                 if leaf.bracket_depth <= depth_limit:
824                     return True
825
826         return False
827
828     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
829         """Remove trailing comma if there is one and it's safe."""
830         if not (
831             self.leaves
832             and self.leaves[-1].type == token.COMMA
833             and closing.type in CLOSING_BRACKETS
834         ):
835             return False
836
837         if closing.type == token.RBRACE:
838             self.remove_trailing_comma()
839             return True
840
841         if closing.type == token.RSQB:
842             comma = self.leaves[-1]
843             if comma.parent and comma.parent.type == syms.listmaker:
844                 self.remove_trailing_comma()
845                 return True
846
847         # For parens let's check if it's safe to remove the comma.  If the
848         # trailing one is the only one, we might mistakenly change a tuple
849         # into a different type by removing the comma.
850         depth = closing.bracket_depth + 1
851         commas = 0
852         opening = closing.opening_bracket
853         for _opening_index, leaf in enumerate(self.leaves):
854             if leaf is opening:
855                 break
856
857         else:
858             return False
859
860         for leaf in self.leaves[_opening_index + 1:]:
861             if leaf is closing:
862                 break
863
864             bracket_depth = leaf.bracket_depth
865             if bracket_depth == depth and leaf.type == token.COMMA:
866                 commas += 1
867                 if leaf.parent and leaf.parent.type == syms.arglist:
868                     commas += 1
869                     break
870
871         if commas > 1:
872             self.remove_trailing_comma()
873             return True
874
875         return False
876
877     def append_comment(self, comment: Leaf) -> bool:
878         """Add an inline or standalone comment to the line."""
879         if (
880             comment.type == STANDALONE_COMMENT
881             and self.bracket_tracker.any_open_brackets()
882         ):
883             comment.prefix = ""
884             return False
885
886         if comment.type != token.COMMENT:
887             return False
888
889         after = len(self.leaves) - 1
890         if after == -1:
891             comment.type = STANDALONE_COMMENT
892             comment.prefix = ""
893             return False
894
895         else:
896             self.comments.append((after, comment))
897             return True
898
899     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
900         """Generate comments that should appear directly after `leaf`."""
901         for _leaf_index, _leaf in enumerate(self.leaves):
902             if leaf is _leaf:
903                 break
904
905         else:
906             return
907
908         for index, comment_after in self.comments:
909             if _leaf_index == index:
910                 yield comment_after
911
912     def remove_trailing_comma(self) -> None:
913         """Remove the trailing comma and moves the comments attached to it."""
914         comma_index = len(self.leaves) - 1
915         for i in range(len(self.comments)):
916             comment_index, comment = self.comments[i]
917             if comment_index == comma_index:
918                 self.comments[i] = (comma_index - 1, comment)
919         self.leaves.pop()
920
921     def __str__(self) -> str:
922         """Render the line."""
923         if not self:
924             return "\n"
925
926         indent = "    " * self.depth
927         leaves = iter(self.leaves)
928         first = next(leaves)
929         res = f"{first.prefix}{indent}{first.value}"
930         for leaf in leaves:
931             res += str(leaf)
932         for _, comment in self.comments:
933             res += str(comment)
934         return res + "\n"
935
936     def __bool__(self) -> bool:
937         """Return True if the line has leaves or comments."""
938         return bool(self.leaves or self.comments)
939
940
941 class UnformattedLines(Line):
942     """Just like :class:`Line` but stores lines which aren't reformatted."""
943
944     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
945         """Just add a new `leaf` to the end of the lines.
946
947         The `preformatted` argument is ignored.
948
949         Keeps track of indentation `depth`, which is useful when the user
950         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
951         """
952         try:
953             list(generate_comments(leaf))
954         except FormatOn as f_on:
955             self.leaves.append(f_on.leaf_from_consumed(leaf))
956             raise
957
958         self.leaves.append(leaf)
959         if leaf.type == token.INDENT:
960             self.depth += 1
961         elif leaf.type == token.DEDENT:
962             self.depth -= 1
963
964     def __str__(self) -> str:
965         """Render unformatted lines from leaves which were added with `append()`.
966
967         `depth` is not used for indentation in this case.
968         """
969         if not self:
970             return "\n"
971
972         res = ""
973         for leaf in self.leaves:
974             res += str(leaf)
975         return res
976
977     def append_comment(self, comment: Leaf) -> bool:
978         """Not implemented in this class. Raises `NotImplementedError`."""
979         raise NotImplementedError("Unformatted lines don't store comments separately.")
980
981     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
982         """Does nothing and returns False."""
983         return False
984
985     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
986         """Does nothing and returns False."""
987         return False
988
989
990 @dataclass
991 class EmptyLineTracker:
992     """Provides a stateful method that returns the number of potential extra
993     empty lines needed before and after the currently processed line.
994
995     Note: this tracker works on lines that haven't been split yet.  It assumes
996     the prefix of the first leaf consists of optional newlines.  Those newlines
997     are consumed by `maybe_empty_lines()` and included in the computation.
998     """
999     previous_line: Optional[Line] = None
1000     previous_after: int = 0
1001     previous_defs: List[int] = Factory(list)
1002
1003     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1004         """Return the number of extra empty lines before and after the `current_line`.
1005
1006         This is for separating `def`, `async def` and `class` with extra empty
1007         lines (two on module-level), as well as providing an extra empty line
1008         after flow control keywords to make them more prominent.
1009         """
1010         if isinstance(current_line, UnformattedLines):
1011             return 0, 0
1012
1013         before, after = self._maybe_empty_lines(current_line)
1014         before -= self.previous_after
1015         self.previous_after = after
1016         self.previous_line = current_line
1017         return before, after
1018
1019     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1020         max_allowed = 1
1021         if current_line.depth == 0:
1022             max_allowed = 2
1023         if current_line.leaves:
1024             # Consume the first leaf's extra newlines.
1025             first_leaf = current_line.leaves[0]
1026             before = first_leaf.prefix.count("\n")
1027             before = min(before, max_allowed)
1028             first_leaf.prefix = ""
1029         else:
1030             before = 0
1031         depth = current_line.depth
1032         while self.previous_defs and self.previous_defs[-1] >= depth:
1033             self.previous_defs.pop()
1034             before = 1 if depth else 2
1035         is_decorator = current_line.is_decorator
1036         if is_decorator or current_line.is_def or current_line.is_class:
1037             if not is_decorator:
1038                 self.previous_defs.append(depth)
1039             if self.previous_line is None:
1040                 # Don't insert empty lines before the first line in the file.
1041                 return 0, 0
1042
1043             if self.previous_line and self.previous_line.is_decorator:
1044                 # Don't insert empty lines between decorators.
1045                 return 0, 0
1046
1047             if is_decorator and self.previous_line and self.previous_line.is_comment:
1048                 # Don't insert empty lines between decorator comments.
1049                 return 0, 0
1050
1051             newlines = 2
1052             if current_line.depth:
1053                 newlines -= 1
1054             return newlines, 0
1055
1056         if current_line.is_flow_control:
1057             return before, 1
1058
1059         if (
1060             self.previous_line
1061             and self.previous_line.is_import
1062             and not current_line.is_import
1063             and depth == self.previous_line.depth
1064         ):
1065             return (before or 1), 0
1066
1067         if (
1068             self.previous_line
1069             and self.previous_line.is_yield
1070             and (not current_line.is_yield or depth != self.previous_line.depth)
1071         ):
1072             return (before or 1), 0
1073
1074         return before, 0
1075
1076
1077 @dataclass
1078 class LineGenerator(Visitor[Line]):
1079     """Generates reformatted Line objects.  Empty lines are not emitted.
1080
1081     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1082     in ways that will no longer stringify to valid Python code on the tree.
1083     """
1084     current_line: Line = Factory(Line)
1085
1086     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1087         """Generate a line.
1088
1089         If the line is empty, only emit if it makes sense.
1090         If the line is too long, split it first and then generate.
1091
1092         If any lines were generated, set up a new current_line.
1093         """
1094         if not self.current_line:
1095             if self.current_line.__class__ == type:
1096                 self.current_line.depth += indent
1097             else:
1098                 self.current_line = type(depth=self.current_line.depth + indent)
1099             return  # Line is empty, don't emit. Creating a new one unnecessary.
1100
1101         complete_line = self.current_line
1102         self.current_line = type(depth=complete_line.depth + indent)
1103         yield complete_line
1104
1105     def visit(self, node: LN) -> Iterator[Line]:
1106         """Main method to visit `node` and its children.
1107
1108         Yields :class:`Line` objects.
1109         """
1110         if isinstance(self.current_line, UnformattedLines):
1111             # File contained `# fmt: off`
1112             yield from self.visit_unformatted(node)
1113
1114         else:
1115             yield from super().visit(node)
1116
1117     def visit_default(self, node: LN) -> Iterator[Line]:
1118         """Default `visit_*()` implementation. Recurses to children of `node`."""
1119         if isinstance(node, Leaf):
1120             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1121             try:
1122                 for comment in generate_comments(node):
1123                     if any_open_brackets:
1124                         # any comment within brackets is subject to splitting
1125                         self.current_line.append(comment)
1126                     elif comment.type == token.COMMENT:
1127                         # regular trailing comment
1128                         self.current_line.append(comment)
1129                         yield from self.line()
1130
1131                     else:
1132                         # regular standalone comment
1133                         yield from self.line()
1134
1135                         self.current_line.append(comment)
1136                         yield from self.line()
1137
1138             except FormatOff as f_off:
1139                 f_off.trim_prefix(node)
1140                 yield from self.line(type=UnformattedLines)
1141                 yield from self.visit(node)
1142
1143             except FormatOn as f_on:
1144                 # This only happens here if somebody says "fmt: on" multiple
1145                 # times in a row.
1146                 f_on.trim_prefix(node)
1147                 yield from self.visit_default(node)
1148
1149             else:
1150                 normalize_prefix(node, inside_brackets=any_open_brackets)
1151                 if node.type == token.STRING:
1152                     normalize_string_quotes(node)
1153                 if node.type not in WHITESPACE:
1154                     self.current_line.append(node)
1155         yield from super().visit_default(node)
1156
1157     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1158         """Increase indentation level, maybe yield a line."""
1159         # In blib2to3 INDENT never holds comments.
1160         yield from self.line(+1)
1161         yield from self.visit_default(node)
1162
1163     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1164         """Decrease indentation level, maybe yield a line."""
1165         # The current line might still wait for trailing comments.  At DEDENT time
1166         # there won't be any (they would be prefixes on the preceding NEWLINE).
1167         # Emit the line then.
1168         yield from self.line()
1169
1170         # While DEDENT has no value, its prefix may contain standalone comments
1171         # that belong to the current indentation level.  Get 'em.
1172         yield from self.visit_default(node)
1173
1174         # Finally, emit the dedent.
1175         yield from self.line(-1)
1176
1177     def visit_stmt(
1178         self, node: Node, keywords: Set[str], parens: Set[str]
1179     ) -> Iterator[Line]:
1180         """Visit a statement.
1181
1182         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1183         `def`, `with`, `class`, and `assert`.
1184
1185         The relevant Python language `keywords` for a given statement will be
1186         NAME leaves within it. This methods puts those on a separate line.
1187
1188         `parens` holds pairs of nodes where invisible parentheses should be put.
1189         Keys hold nodes after which opening parentheses should be put, values
1190         hold nodes before which closing parentheses should be put.
1191         """
1192         normalize_invisible_parens(node, parens_after=parens)
1193         for child in node.children:
1194             if child.type == token.NAME and child.value in keywords:  # type: ignore
1195                 yield from self.line()
1196
1197             yield from self.visit(child)
1198
1199     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1200         """Visit a statement without nested statements."""
1201         is_suite_like = node.parent and node.parent.type in STATEMENT
1202         if is_suite_like:
1203             yield from self.line(+1)
1204             yield from self.visit_default(node)
1205             yield from self.line(-1)
1206
1207         else:
1208             yield from self.line()
1209             yield from self.visit_default(node)
1210
1211     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1212         """Visit `async def`, `async for`, `async with`."""
1213         yield from self.line()
1214
1215         children = iter(node.children)
1216         for child in children:
1217             yield from self.visit(child)
1218
1219             if child.type == token.ASYNC:
1220                 break
1221
1222         internal_stmt = next(children)
1223         for child in internal_stmt.children:
1224             yield from self.visit(child)
1225
1226     def visit_decorators(self, node: Node) -> Iterator[Line]:
1227         """Visit decorators."""
1228         for child in node.children:
1229             yield from self.line()
1230             yield from self.visit(child)
1231
1232     def visit_import_from(self, node: Node) -> Iterator[Line]:
1233         """Visit import_from and maybe put invisible parentheses.
1234
1235         This is separate from `visit_stmt` because import statements don't
1236         support arbitrary atoms and thus handling of parentheses is custom.
1237         """
1238         check_lpar = False
1239         for index, child in enumerate(node.children):
1240             if check_lpar:
1241                 if child.type == token.LPAR:
1242                     # make parentheses invisible
1243                     child.value = ""  # type: ignore
1244                     node.children[-1].value = ""  # type: ignore
1245                 else:
1246                     # insert invisible parentheses
1247                     node.insert_child(index, Leaf(token.LPAR, ""))
1248                     node.append_child(Leaf(token.RPAR, ""))
1249                 break
1250
1251             check_lpar = (
1252                 child.type == token.NAME and child.value == "import"  # type: ignore
1253             )
1254
1255         for child in node.children:
1256             yield from self.visit(child)
1257
1258     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1259         """Remove a semicolon and put the other statement on a separate line."""
1260         yield from self.line()
1261
1262     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1263         """End of file. Process outstanding comments and end with a newline."""
1264         yield from self.visit_default(leaf)
1265         yield from self.line()
1266
1267     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1268         """Used when file contained a `# fmt: off`."""
1269         if isinstance(node, Node):
1270             for child in node.children:
1271                 yield from self.visit(child)
1272
1273         else:
1274             try:
1275                 self.current_line.append(node)
1276             except FormatOn as f_on:
1277                 f_on.trim_prefix(node)
1278                 yield from self.line()
1279                 yield from self.visit(node)
1280
1281             if node.type == token.ENDMARKER:
1282                 # somebody decided not to put a final `# fmt: on`
1283                 yield from self.line()
1284
1285     def __attrs_post_init__(self) -> None:
1286         """You are in a twisty little maze of passages."""
1287         v = self.visit_stmt
1288         Ø: Set[str] = set()
1289         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1290         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1291         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1292         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1293         self.visit_try_stmt = partial(
1294             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1295         )
1296         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1297         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1298         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1299         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1300         self.visit_async_funcdef = self.visit_async_stmt
1301         self.visit_decorated = self.visit_decorators
1302
1303
1304 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1305 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1306 OPENING_BRACKETS = set(BRACKET.keys())
1307 CLOSING_BRACKETS = set(BRACKET.values())
1308 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1309 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1310
1311
1312 def whitespace(leaf: Leaf) -> str:  # noqa C901
1313     """Return whitespace prefix if needed for the given `leaf`."""
1314     NO = ""
1315     SPACE = " "
1316     DOUBLESPACE = "  "
1317     t = leaf.type
1318     p = leaf.parent
1319     v = leaf.value
1320     if t in ALWAYS_NO_SPACE:
1321         return NO
1322
1323     if t == token.COMMENT:
1324         return DOUBLESPACE
1325
1326     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1327     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1328         return NO
1329
1330     prev = leaf.prev_sibling
1331     if not prev:
1332         prevp = preceding_leaf(p)
1333         if not prevp or prevp.type in OPENING_BRACKETS:
1334             return NO
1335
1336         if t == token.COLON:
1337             return SPACE if prevp.type == token.COMMA else NO
1338
1339         if prevp.type == token.EQUAL:
1340             if prevp.parent:
1341                 if prevp.parent.type in {
1342                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1343                 }:
1344                     return NO
1345
1346                 elif prevp.parent.type == syms.typedargslist:
1347                     # A bit hacky: if the equal sign has whitespace, it means we
1348                     # previously found it's a typed argument.  So, we're using
1349                     # that, too.
1350                     return prevp.prefix
1351
1352         elif prevp.type in STARS:
1353             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1354                 return NO
1355
1356         elif prevp.type == token.COLON:
1357             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1358                 return NO
1359
1360         elif (
1361             prevp.parent
1362             and prevp.parent.type == syms.factor
1363             and prevp.type in MATH_OPERATORS
1364         ):
1365             return NO
1366
1367         elif (
1368             prevp.type == token.RIGHTSHIFT
1369             and prevp.parent
1370             and prevp.parent.type == syms.shift_expr
1371             and prevp.prev_sibling
1372             and prevp.prev_sibling.type == token.NAME
1373             and prevp.prev_sibling.value == "print"  # type: ignore
1374         ):
1375             # Python 2 print chevron
1376             return NO
1377
1378     elif prev.type in OPENING_BRACKETS:
1379         return NO
1380
1381     if p.type in {syms.parameters, syms.arglist}:
1382         # untyped function signatures or calls
1383         if not prev or prev.type != token.COMMA:
1384             return NO
1385
1386     elif p.type == syms.varargslist:
1387         # lambdas
1388         if prev and prev.type != token.COMMA:
1389             return NO
1390
1391     elif p.type == syms.typedargslist:
1392         # typed function signatures
1393         if not prev:
1394             return NO
1395
1396         if t == token.EQUAL:
1397             if prev.type != syms.tname:
1398                 return NO
1399
1400         elif prev.type == token.EQUAL:
1401             # A bit hacky: if the equal sign has whitespace, it means we
1402             # previously found it's a typed argument.  So, we're using that, too.
1403             return prev.prefix
1404
1405         elif prev.type != token.COMMA:
1406             return NO
1407
1408     elif p.type == syms.tname:
1409         # type names
1410         if not prev:
1411             prevp = preceding_leaf(p)
1412             if not prevp or prevp.type != token.COMMA:
1413                 return NO
1414
1415     elif p.type == syms.trailer:
1416         # attributes and calls
1417         if t == token.LPAR or t == token.RPAR:
1418             return NO
1419
1420         if not prev:
1421             if t == token.DOT:
1422                 prevp = preceding_leaf(p)
1423                 if not prevp or prevp.type != token.NUMBER:
1424                     return NO
1425
1426             elif t == token.LSQB:
1427                 return NO
1428
1429         elif prev.type != token.COMMA:
1430             return NO
1431
1432     elif p.type == syms.argument:
1433         # single argument
1434         if t == token.EQUAL:
1435             return NO
1436
1437         if not prev:
1438             prevp = preceding_leaf(p)
1439             if not prevp or prevp.type == token.LPAR:
1440                 return NO
1441
1442         elif prev.type in {token.EQUAL} | STARS:
1443             return NO
1444
1445     elif p.type == syms.decorator:
1446         # decorators
1447         return NO
1448
1449     elif p.type == syms.dotted_name:
1450         if prev:
1451             return NO
1452
1453         prevp = preceding_leaf(p)
1454         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1455             return NO
1456
1457     elif p.type == syms.classdef:
1458         if t == token.LPAR:
1459             return NO
1460
1461         if prev and prev.type == token.LPAR:
1462             return NO
1463
1464     elif p.type == syms.subscript:
1465         # indexing
1466         if not prev:
1467             assert p.parent is not None, "subscripts are always parented"
1468             if p.parent.type == syms.subscriptlist:
1469                 return SPACE
1470
1471             return NO
1472
1473         else:
1474             return NO
1475
1476     elif p.type == syms.atom:
1477         if prev and t == token.DOT:
1478             # dots, but not the first one.
1479             return NO
1480
1481     elif p.type == syms.dictsetmaker:
1482         # dict unpacking
1483         if prev and prev.type == token.DOUBLESTAR:
1484             return NO
1485
1486     elif p.type in {syms.factor, syms.star_expr}:
1487         # unary ops
1488         if not prev:
1489             prevp = preceding_leaf(p)
1490             if not prevp or prevp.type in OPENING_BRACKETS:
1491                 return NO
1492
1493             prevp_parent = prevp.parent
1494             assert prevp_parent is not None
1495             if (
1496                 prevp.type == token.COLON
1497                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1498             ):
1499                 return NO
1500
1501             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1502                 return NO
1503
1504         elif t == token.NAME or t == token.NUMBER:
1505             return NO
1506
1507     elif p.type == syms.import_from:
1508         if t == token.DOT:
1509             if prev and prev.type == token.DOT:
1510                 return NO
1511
1512         elif t == token.NAME:
1513             if v == "import":
1514                 return SPACE
1515
1516             if prev and prev.type == token.DOT:
1517                 return NO
1518
1519     elif p.type == syms.sliceop:
1520         return NO
1521
1522     return SPACE
1523
1524
1525 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1526     """Return the first leaf that precedes `node`, if any."""
1527     while node:
1528         res = node.prev_sibling
1529         if res:
1530             if isinstance(res, Leaf):
1531                 return res
1532
1533             try:
1534                 return list(res.leaves())[-1]
1535
1536             except IndexError:
1537                 return None
1538
1539         node = node.parent
1540     return None
1541
1542
1543 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1544     """Return the priority of the `leaf` delimiter, given a line break after it.
1545
1546     The delimiter priorities returned here are from those delimiters that would
1547     cause a line break after themselves.
1548
1549     Higher numbers are higher priority.
1550     """
1551     if leaf.type == token.COMMA:
1552         return COMMA_PRIORITY
1553
1554     return 0
1555
1556
1557 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1558     """Return the priority of the `leaf` delimiter, given a line before after it.
1559
1560     The delimiter priorities returned here are from those delimiters that would
1561     cause a line break before themselves.
1562
1563     Higher numbers are higher priority.
1564     """
1565     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1566         # * and ** might also be MATH_OPERATORS but in this case they are not.
1567         # Don't treat them as a delimiter.
1568         return 0
1569
1570     if (
1571         leaf.type in MATH_OPERATORS
1572         and leaf.parent
1573         and leaf.parent.type not in {syms.factor, syms.star_expr}
1574     ):
1575         return MATH_PRIORITY
1576
1577     if leaf.type in COMPARATORS:
1578         return COMPARATOR_PRIORITY
1579
1580     if (
1581         leaf.type == token.STRING
1582         and previous is not None
1583         and previous.type == token.STRING
1584     ):
1585         return STRING_PRIORITY
1586
1587     if (
1588         leaf.type == token.NAME
1589         and leaf.value == "for"
1590         and leaf.parent
1591         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1592     ):
1593         return COMPREHENSION_PRIORITY
1594
1595     if (
1596         leaf.type == token.NAME
1597         and leaf.value == "if"
1598         and leaf.parent
1599         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1600     ):
1601         return COMPREHENSION_PRIORITY
1602
1603     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1604         return LOGIC_PRIORITY
1605
1606     return 0
1607
1608
1609 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1610     """Clean the prefix of the `leaf` and generate comments from it, if any.
1611
1612     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1613     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1614     move because it does away with modifying the grammar to include all the
1615     possible places in which comments can be placed.
1616
1617     The sad consequence for us though is that comments don't "belong" anywhere.
1618     This is why this function generates simple parentless Leaf objects for
1619     comments.  We simply don't know what the correct parent should be.
1620
1621     No matter though, we can live without this.  We really only need to
1622     differentiate between inline and standalone comments.  The latter don't
1623     share the line with any code.
1624
1625     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1626     are emitted with a fake STANDALONE_COMMENT token identifier.
1627     """
1628     p = leaf.prefix
1629     if not p:
1630         return
1631
1632     if "#" not in p:
1633         return
1634
1635     consumed = 0
1636     nlines = 0
1637     for index, line in enumerate(p.split("\n")):
1638         consumed += len(line) + 1  # adding the length of the split '\n'
1639         line = line.lstrip()
1640         if not line:
1641             nlines += 1
1642         if not line.startswith("#"):
1643             continue
1644
1645         if index == 0 and leaf.type != token.ENDMARKER:
1646             comment_type = token.COMMENT  # simple trailing comment
1647         else:
1648             comment_type = STANDALONE_COMMENT
1649         comment = make_comment(line)
1650         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1651
1652         if comment in {"# fmt: on", "# yapf: enable"}:
1653             raise FormatOn(consumed)
1654
1655         if comment in {"# fmt: off", "# yapf: disable"}:
1656             if comment_type == STANDALONE_COMMENT:
1657                 raise FormatOff(consumed)
1658
1659             prev = preceding_leaf(leaf)
1660             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1661                 raise FormatOff(consumed)
1662
1663         nlines = 0
1664
1665
1666 def make_comment(content: str) -> str:
1667     """Return a consistently formatted comment from the given `content` string.
1668
1669     All comments (except for "##", "#!", "#:") should have a single space between
1670     the hash sign and the content.
1671
1672     If `content` didn't start with a hash sign, one is provided.
1673     """
1674     content = content.rstrip()
1675     if not content:
1676         return "#"
1677
1678     if content[0] == "#":
1679         content = content[1:]
1680     if content and content[0] not in " !:#":
1681         content = " " + content
1682     return "#" + content
1683
1684
1685 def split_line(
1686     line: Line, line_length: int, inner: bool = False, py36: bool = False
1687 ) -> Iterator[Line]:
1688     """Split a `line` into potentially many lines.
1689
1690     They should fit in the allotted `line_length` but might not be able to.
1691     `inner` signifies that there were a pair of brackets somewhere around the
1692     current `line`, possibly transitively. This means we can fallback to splitting
1693     by delimiters if the LHS/RHS don't yield any results.
1694
1695     If `py36` is True, splitting may generate syntax that is only compatible
1696     with Python 3.6 and later.
1697     """
1698     if isinstance(line, UnformattedLines) or line.is_comment:
1699         yield line
1700         return
1701
1702     line_str = str(line).strip("\n")
1703     if (
1704         len(line_str) <= line_length
1705         and "\n" not in line_str  # multiline strings
1706         and not line.contains_standalone_comments()
1707     ):
1708         yield line
1709         return
1710
1711     split_funcs: List[SplitFunc]
1712     if line.is_def:
1713         split_funcs = [left_hand_split]
1714     elif line.inside_brackets:
1715         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1716     else:
1717         split_funcs = [right_hand_split]
1718     for split_func in split_funcs:
1719         # We are accumulating lines in `result` because we might want to abort
1720         # mission and return the original line in the end, or attempt a different
1721         # split altogether.
1722         result: List[Line] = []
1723         try:
1724             for l in split_func(line, py36):
1725                 if str(l).strip("\n") == line_str:
1726                     raise CannotSplit("Split function returned an unchanged result")
1727
1728                 result.extend(
1729                     split_line(l, line_length=line_length, inner=True, py36=py36)
1730                 )
1731         except CannotSplit as cs:
1732             continue
1733
1734         else:
1735             yield from result
1736             break
1737
1738     else:
1739         yield line
1740
1741
1742 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1743     """Split line into many lines, starting with the first matching bracket pair.
1744
1745     Note: this usually looks weird, only use this for function definitions.
1746     Prefer RHS otherwise.
1747     """
1748     head = Line(depth=line.depth)
1749     body = Line(depth=line.depth + 1, inside_brackets=True)
1750     tail = Line(depth=line.depth)
1751     tail_leaves: List[Leaf] = []
1752     body_leaves: List[Leaf] = []
1753     head_leaves: List[Leaf] = []
1754     current_leaves = head_leaves
1755     matching_bracket = None
1756     for leaf in line.leaves:
1757         if (
1758             current_leaves is body_leaves
1759             and leaf.type in CLOSING_BRACKETS
1760             and leaf.opening_bracket is matching_bracket
1761         ):
1762             current_leaves = tail_leaves if body_leaves else head_leaves
1763         current_leaves.append(leaf)
1764         if current_leaves is head_leaves:
1765             if leaf.type in OPENING_BRACKETS:
1766                 matching_bracket = leaf
1767                 current_leaves = body_leaves
1768     # Since body is a new indent level, remove spurious leading whitespace.
1769     if body_leaves:
1770         normalize_prefix(body_leaves[0], inside_brackets=True)
1771     # Build the new lines.
1772     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1773         for leaf in leaves:
1774             result.append(leaf, preformatted=True)
1775             for comment_after in line.comments_after(leaf):
1776                 result.append(comment_after, preformatted=True)
1777     bracket_split_succeeded_or_raise(head, body, tail)
1778     for result in (head, body, tail):
1779         if result:
1780             yield result
1781
1782
1783 def right_hand_split(
1784     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1785 ) -> Iterator[Line]:
1786     """Split line into many lines, starting with the last matching bracket pair."""
1787     head = Line(depth=line.depth)
1788     body = Line(depth=line.depth + 1, inside_brackets=True)
1789     tail = Line(depth=line.depth)
1790     tail_leaves: List[Leaf] = []
1791     body_leaves: List[Leaf] = []
1792     head_leaves: List[Leaf] = []
1793     current_leaves = tail_leaves
1794     opening_bracket = None
1795     closing_bracket = None
1796     for leaf in reversed(line.leaves):
1797         if current_leaves is body_leaves:
1798             if leaf is opening_bracket:
1799                 current_leaves = head_leaves if body_leaves else tail_leaves
1800         current_leaves.append(leaf)
1801         if current_leaves is tail_leaves:
1802             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1803                 opening_bracket = leaf.opening_bracket
1804                 closing_bracket = leaf
1805                 current_leaves = body_leaves
1806     tail_leaves.reverse()
1807     body_leaves.reverse()
1808     head_leaves.reverse()
1809     # Since body is a new indent level, remove spurious leading whitespace.
1810     if body_leaves:
1811         normalize_prefix(body_leaves[0], inside_brackets=True)
1812     elif not head_leaves:
1813         # No `head` and no `body` means the split failed. `tail` has all content.
1814         raise CannotSplit("No brackets found")
1815
1816     # Build the new lines.
1817     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1818         for leaf in leaves:
1819             result.append(leaf, preformatted=True)
1820             for comment_after in line.comments_after(leaf):
1821                 result.append(comment_after, preformatted=True)
1822     bracket_split_succeeded_or_raise(head, body, tail)
1823     assert opening_bracket and closing_bracket
1824     if (
1825         opening_bracket.type == token.LPAR
1826         and not opening_bracket.value
1827         and closing_bracket.type == token.RPAR
1828         and not closing_bracket.value
1829     ):
1830         # These parens were optional. If there aren't any delimiters or standalone
1831         # comments in the body, they were unnecessary and another split without
1832         # them should be attempted.
1833         if not (
1834             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1835         ):
1836             omit = {id(closing_bracket), *omit}
1837             yield from right_hand_split(line, py36=py36, omit=omit)
1838             return
1839
1840     ensure_visible(opening_bracket)
1841     ensure_visible(closing_bracket)
1842     for result in (head, body, tail):
1843         if result:
1844             yield result
1845
1846
1847 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1848     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1849
1850     Do nothing otherwise.
1851
1852     A left- or right-hand split is based on a pair of brackets. Content before
1853     (and including) the opening bracket is left on one line, content inside the
1854     brackets is put on a separate line, and finally content starting with and
1855     following the closing bracket is put on a separate line.
1856
1857     Those are called `head`, `body`, and `tail`, respectively. If the split
1858     produced the same line (all content in `head`) or ended up with an empty `body`
1859     and the `tail` is just the closing bracket, then it's considered failed.
1860     """
1861     tail_len = len(str(tail).strip())
1862     if not body:
1863         if tail_len == 0:
1864             raise CannotSplit("Splitting brackets produced the same line")
1865
1866         elif tail_len < 3:
1867             raise CannotSplit(
1868                 f"Splitting brackets on an empty body to save "
1869                 f"{tail_len} characters is not worth it"
1870             )
1871
1872
1873 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1874     """Normalize prefix of the first leaf in every line returned by `split_func`.
1875
1876     This is a decorator over relevant split functions.
1877     """
1878
1879     @wraps(split_func)
1880     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1881         for l in split_func(line, py36):
1882             normalize_prefix(l.leaves[0], inside_brackets=True)
1883             yield l
1884
1885     return split_wrapper
1886
1887
1888 @dont_increase_indentation
1889 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1890     """Split according to delimiters of the highest priority.
1891
1892     If `py36` is True, the split will add trailing commas also in function
1893     signatures that contain `*` and `**`.
1894     """
1895     try:
1896         last_leaf = line.leaves[-1]
1897     except IndexError:
1898         raise CannotSplit("Line empty")
1899
1900     delimiters = line.bracket_tracker.delimiters
1901     try:
1902         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1903             exclude={id(last_leaf)}
1904         )
1905     except ValueError:
1906         raise CannotSplit("No delimiters found")
1907
1908     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1909     lowest_depth = sys.maxsize
1910     trailing_comma_safe = True
1911
1912     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1913         """Append `leaf` to current line or to new line if appending impossible."""
1914         nonlocal current_line
1915         try:
1916             current_line.append_safe(leaf, preformatted=True)
1917         except ValueError as ve:
1918             yield current_line
1919
1920             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1921             current_line.append(leaf)
1922
1923     for leaf in line.leaves:
1924         yield from append_to_line(leaf)
1925
1926         for comment_after in line.comments_after(leaf):
1927             yield from append_to_line(comment_after)
1928
1929         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1930         if (
1931             leaf.bracket_depth == lowest_depth
1932             and is_vararg(leaf, within=VARARGS_PARENTS)
1933         ):
1934             trailing_comma_safe = trailing_comma_safe and py36
1935         leaf_priority = delimiters.get(id(leaf))
1936         if leaf_priority == delimiter_priority:
1937             yield current_line
1938
1939             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1940     if current_line:
1941         if (
1942             trailing_comma_safe
1943             and delimiter_priority == COMMA_PRIORITY
1944             and current_line.leaves[-1].type != token.COMMA
1945             and current_line.leaves[-1].type != STANDALONE_COMMENT
1946         ):
1947             current_line.append(Leaf(token.COMMA, ","))
1948         yield current_line
1949
1950
1951 @dont_increase_indentation
1952 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1953     """Split standalone comments from the rest of the line."""
1954     if not line.contains_standalone_comments(0):
1955         raise CannotSplit("Line does not have any standalone comments")
1956
1957     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1958
1959     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1960         """Append `leaf` to current line or to new line if appending impossible."""
1961         nonlocal current_line
1962         try:
1963             current_line.append_safe(leaf, preformatted=True)
1964         except ValueError as ve:
1965             yield current_line
1966
1967             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1968             current_line.append(leaf)
1969
1970     for leaf in line.leaves:
1971         yield from append_to_line(leaf)
1972
1973         for comment_after in line.comments_after(leaf):
1974             yield from append_to_line(comment_after)
1975
1976     if current_line:
1977         yield current_line
1978
1979
1980 def is_import(leaf: Leaf) -> bool:
1981     """Return True if the given leaf starts an import statement."""
1982     p = leaf.parent
1983     t = leaf.type
1984     v = leaf.value
1985     return bool(
1986         t == token.NAME
1987         and (
1988             (v == "import" and p and p.type == syms.import_name)
1989             or (v == "from" and p and p.type == syms.import_from)
1990         )
1991     )
1992
1993
1994 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1995     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1996     else.
1997
1998     Note: don't use backslashes for formatting or you'll lose your voting rights.
1999     """
2000     if not inside_brackets:
2001         spl = leaf.prefix.split("#")
2002         if "\\" not in spl[0]:
2003             nl_count = spl[-1].count("\n")
2004             if len(spl) > 1:
2005                 nl_count -= 1
2006             leaf.prefix = "\n" * nl_count
2007             return
2008
2009     leaf.prefix = ""
2010
2011
2012 def normalize_string_quotes(leaf: Leaf) -> None:
2013     """Prefer double quotes but only if it doesn't cause more escaping.
2014
2015     Adds or removes backslashes as appropriate. Doesn't parse and fix
2016     strings nested in f-strings (yet).
2017
2018     Note: Mutates its argument.
2019     """
2020     value = leaf.value.lstrip("furbFURB")
2021     if value[:3] == '"""':
2022         return
2023
2024     elif value[:3] == "'''":
2025         orig_quote = "'''"
2026         new_quote = '"""'
2027     elif value[0] == '"':
2028         orig_quote = '"'
2029         new_quote = "'"
2030     else:
2031         orig_quote = "'"
2032         new_quote = '"'
2033     first_quote_pos = leaf.value.find(orig_quote)
2034     if first_quote_pos == -1:
2035         return  # There's an internal error
2036
2037     prefix = leaf.value[:first_quote_pos]
2038     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2039     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2040     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2041     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2042     if "r" in prefix.casefold():
2043         if unescaped_new_quote.search(body):
2044             # There's at least one unescaped new_quote in this raw string
2045             # so converting is impossible
2046             return
2047
2048         # Do not introduce or remove backslashes in raw strings
2049         new_body = body
2050     else:
2051         # remove unnecessary quotes
2052         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2053         if body != new_body:
2054             # Consider the string without unnecessary quotes as the original
2055             body = new_body
2056             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2057         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2058         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2059     if new_quote == '"""' and new_body[-1] == '"':
2060         # edge case:
2061         new_body = new_body[:-1] + '\\"'
2062     orig_escape_count = body.count("\\")
2063     new_escape_count = new_body.count("\\")
2064     if new_escape_count > orig_escape_count:
2065         return  # Do not introduce more escaping
2066
2067     if new_escape_count == orig_escape_count and orig_quote == '"':
2068         return  # Prefer double quotes
2069
2070     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2071
2072
2073 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2074     """Make existing optional parentheses invisible or create new ones.
2075
2076     Standardizes on visible parentheses for single-element tuples, and keeps
2077     existing visible parentheses for other tuples and generator expressions.
2078     """
2079     check_lpar = False
2080     for child in list(node.children):
2081         if check_lpar:
2082             if child.type == syms.atom:
2083                 if not (
2084                     is_empty_tuple(child)
2085                     or is_one_tuple(child)
2086                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2087                 ):
2088                     first = child.children[0]
2089                     last = child.children[-1]
2090                     if first.type == token.LPAR and last.type == token.RPAR:
2091                         # make parentheses invisible
2092                         first.value = ""  # type: ignore
2093                         last.value = ""  # type: ignore
2094             elif is_one_tuple(child):
2095                 # wrap child in visible parentheses
2096                 lpar = Leaf(token.LPAR, "(")
2097                 rpar = Leaf(token.RPAR, ")")
2098                 index = child.remove() or 0
2099                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2100             else:
2101                 # wrap child in invisible parentheses
2102                 lpar = Leaf(token.LPAR, "")
2103                 rpar = Leaf(token.RPAR, "")
2104                 index = child.remove() or 0
2105                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2106
2107         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2108
2109
2110 def is_empty_tuple(node: LN) -> bool:
2111     """Return True if `node` holds an empty tuple."""
2112     return (
2113         node.type == syms.atom
2114         and len(node.children) == 2
2115         and node.children[0].type == token.LPAR
2116         and node.children[1].type == token.RPAR
2117     )
2118
2119
2120 def is_one_tuple(node: LN) -> bool:
2121     """Return True if `node` holds a tuple with one element, with or without parens."""
2122     if node.type == syms.atom:
2123         if len(node.children) != 3:
2124             return False
2125
2126         lpar, gexp, rpar = node.children
2127         if not (
2128             lpar.type == token.LPAR
2129             and gexp.type == syms.testlist_gexp
2130             and rpar.type == token.RPAR
2131         ):
2132             return False
2133
2134         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2135
2136     return (
2137         node.type in IMPLICIT_TUPLE
2138         and len(node.children) == 2
2139         and node.children[1].type == token.COMMA
2140     )
2141
2142
2143 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2144     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2145
2146     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2147     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2148     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2149     generalizations (PEP 448).
2150     """
2151     if leaf.type not in STARS or not leaf.parent:
2152         return False
2153
2154     p = leaf.parent
2155     if p.type == syms.star_expr:
2156         # Star expressions are also used as assignment targets in extended
2157         # iterable unpacking (PEP 3132).  See what its parent is instead.
2158         if not p.parent:
2159             return False
2160
2161         p = p.parent
2162
2163     return p.type in within
2164
2165
2166 def max_delimiter_priority_in_atom(node: LN) -> int:
2167     """Return maximum delimiter priority inside `node`.
2168
2169     This is specific to atoms with contents contained in a pair of parentheses.
2170     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2171     """
2172     if node.type != syms.atom:
2173         return 0
2174
2175     first = node.children[0]
2176     last = node.children[-1]
2177     if not (first.type == token.LPAR and last.type == token.RPAR):
2178         return 0
2179
2180     bt = BracketTracker()
2181     for c in node.children[1:-1]:
2182         if isinstance(c, Leaf):
2183             bt.mark(c)
2184         else:
2185             for leaf in c.leaves():
2186                 bt.mark(leaf)
2187     try:
2188         return bt.max_delimiter_priority()
2189
2190     except ValueError:
2191         return 0
2192
2193
2194 def ensure_visible(leaf: Leaf) -> None:
2195     """Make sure parentheses are visible.
2196
2197     They could be invisible as part of some statements (see
2198     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2199     """
2200     if leaf.type == token.LPAR:
2201         leaf.value = "("
2202     elif leaf.type == token.RPAR:
2203         leaf.value = ")"
2204
2205
2206 def is_python36(node: Node) -> bool:
2207     """Return True if the current file is using Python 3.6+ features.
2208
2209     Currently looking for:
2210     - f-strings; and
2211     - trailing commas after * or ** in function signatures.
2212     """
2213     for n in node.pre_order():
2214         if n.type == token.STRING:
2215             value_head = n.value[:2]  # type: ignore
2216             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2217                 return True
2218
2219         elif (
2220             n.type == syms.typedargslist
2221             and n.children
2222             and n.children[-1].type == token.COMMA
2223         ):
2224             for ch in n.children:
2225                 if ch.type in STARS:
2226                     return True
2227
2228     return False
2229
2230
2231 PYTHON_EXTENSIONS = {".py"}
2232 BLACKLISTED_DIRECTORIES = {
2233     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2234 }
2235
2236
2237 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2238     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2239     and have one of the PYTHON_EXTENSIONS.
2240     """
2241     for child in path.iterdir():
2242         if child.is_dir():
2243             if child.name in BLACKLISTED_DIRECTORIES:
2244                 continue
2245
2246             yield from gen_python_files_in_dir(child)
2247
2248         elif child.suffix in PYTHON_EXTENSIONS:
2249             yield child
2250
2251
2252 @dataclass
2253 class Report:
2254     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2255     check: bool = False
2256     quiet: bool = False
2257     change_count: int = 0
2258     same_count: int = 0
2259     failure_count: int = 0
2260
2261     def done(self, src: Path, changed: Changed) -> None:
2262         """Increment the counter for successful reformatting. Write out a message."""
2263         if changed is Changed.YES:
2264             reformatted = "would reformat" if self.check else "reformatted"
2265             if not self.quiet:
2266                 out(f"{reformatted} {src}")
2267             self.change_count += 1
2268         else:
2269             if not self.quiet:
2270                 if changed is Changed.NO:
2271                     msg = f"{src} already well formatted, good job."
2272                 else:
2273                     msg = f"{src} wasn't modified on disk since last run."
2274                 out(msg, bold=False)
2275             self.same_count += 1
2276
2277     def failed(self, src: Path, message: str) -> None:
2278         """Increment the counter for failed reformatting. Write out a message."""
2279         err(f"error: cannot format {src}: {message}")
2280         self.failure_count += 1
2281
2282     @property
2283     def return_code(self) -> int:
2284         """Return the exit code that the app should use.
2285
2286         This considers the current state of changed files and failures:
2287         - if there were any failures, return 123;
2288         - if any files were changed and --check is being used, return 1;
2289         - otherwise return 0.
2290         """
2291         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2292         # 126 we have special returncodes reserved by the shell.
2293         if self.failure_count:
2294             return 123
2295
2296         elif self.change_count and self.check:
2297             return 1
2298
2299         return 0
2300
2301     def __str__(self) -> str:
2302         """Render a color report of the current state.
2303
2304         Use `click.unstyle` to remove colors.
2305         """
2306         if self.check:
2307             reformatted = "would be reformatted"
2308             unchanged = "would be left unchanged"
2309             failed = "would fail to reformat"
2310         else:
2311             reformatted = "reformatted"
2312             unchanged = "left unchanged"
2313             failed = "failed to reformat"
2314         report = []
2315         if self.change_count:
2316             s = "s" if self.change_count > 1 else ""
2317             report.append(
2318                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2319             )
2320         if self.same_count:
2321             s = "s" if self.same_count > 1 else ""
2322             report.append(f"{self.same_count} file{s} {unchanged}")
2323         if self.failure_count:
2324             s = "s" if self.failure_count > 1 else ""
2325             report.append(
2326                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2327             )
2328         return ", ".join(report) + "."
2329
2330
2331 def assert_equivalent(src: str, dst: str) -> None:
2332     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2333
2334     import ast
2335     import traceback
2336
2337     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2338         """Simple visitor generating strings to compare ASTs by content."""
2339         yield f"{'  ' * depth}{node.__class__.__name__}("
2340
2341         for field in sorted(node._fields):
2342             try:
2343                 value = getattr(node, field)
2344             except AttributeError:
2345                 continue
2346
2347             yield f"{'  ' * (depth+1)}{field}="
2348
2349             if isinstance(value, list):
2350                 for item in value:
2351                     if isinstance(item, ast.AST):
2352                         yield from _v(item, depth + 2)
2353
2354             elif isinstance(value, ast.AST):
2355                 yield from _v(value, depth + 2)
2356
2357             else:
2358                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2359
2360         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2361
2362     try:
2363         src_ast = ast.parse(src)
2364     except Exception as exc:
2365         major, minor = sys.version_info[:2]
2366         raise AssertionError(
2367             f"cannot use --safe with this file; failed to parse source file "
2368             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2369             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2370         )
2371
2372     try:
2373         dst_ast = ast.parse(dst)
2374     except Exception as exc:
2375         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2376         raise AssertionError(
2377             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2378             f"Please report a bug on https://github.com/ambv/black/issues.  "
2379             f"This invalid output might be helpful: {log}"
2380         ) from None
2381
2382     src_ast_str = "\n".join(_v(src_ast))
2383     dst_ast_str = "\n".join(_v(dst_ast))
2384     if src_ast_str != dst_ast_str:
2385         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2386         raise AssertionError(
2387             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2388             f"the source.  "
2389             f"Please report a bug on https://github.com/ambv/black/issues.  "
2390             f"This diff might be helpful: {log}"
2391         ) from None
2392
2393
2394 def assert_stable(src: str, dst: str, line_length: int) -> None:
2395     """Raise AssertionError if `dst` reformats differently the second time."""
2396     newdst = format_str(dst, line_length=line_length)
2397     if dst != newdst:
2398         log = dump_to_file(
2399             diff(src, dst, "source", "first pass"),
2400             diff(dst, newdst, "first pass", "second pass"),
2401         )
2402         raise AssertionError(
2403             f"INTERNAL ERROR: Black produced different code on the second pass "
2404             f"of the formatter.  "
2405             f"Please report a bug on https://github.com/ambv/black/issues.  "
2406             f"This diff might be helpful: {log}"
2407         ) from None
2408
2409
2410 def dump_to_file(*output: str) -> str:
2411     """Dump `output` to a temporary file. Return path to the file."""
2412     import tempfile
2413
2414     with tempfile.NamedTemporaryFile(
2415         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2416     ) as f:
2417         for lines in output:
2418             f.write(lines)
2419             if lines and lines[-1] != "\n":
2420                 f.write("\n")
2421     return f.name
2422
2423
2424 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2425     """Return a unified diff string between strings `a` and `b`."""
2426     import difflib
2427
2428     a_lines = [line + "\n" for line in a.split("\n")]
2429     b_lines = [line + "\n" for line in b.split("\n")]
2430     return "".join(
2431         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2432     )
2433
2434
2435 def cancel(tasks: List[asyncio.Task]) -> None:
2436     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2437     err("Aborted!")
2438     for task in tasks:
2439         task.cancel()
2440
2441
2442 def shutdown(loop: BaseEventLoop) -> None:
2443     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2444     try:
2445         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2446         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2447         if not to_cancel:
2448             return
2449
2450         for task in to_cancel:
2451             task.cancel()
2452         loop.run_until_complete(
2453             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2454         )
2455     finally:
2456         # `concurrent.futures.Future` objects cannot be cancelled once they
2457         # are already running. There might be some when the `shutdown()` happened.
2458         # Silence their logger's spew about the event loop being closed.
2459         cf_logger = logging.getLogger("concurrent.futures")
2460         cf_logger.setLevel(logging.CRITICAL)
2461         loop.close()
2462
2463
2464 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2465     """Replace `regex` with `replacement` twice on `original`.
2466
2467     This is used by string normalization to perform replaces on
2468     overlapping matches.
2469     """
2470     return regex.sub(replacement, regex.sub(replacement, original))
2471
2472
2473 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2474 CACHE_FILE = CACHE_DIR / "cache.pickle"
2475
2476
2477 def read_cache() -> Cache:
2478     """Read the cache if it exists and is well formed.
2479
2480     If it is not well formed, the call to write_cache later should resolve the issue.
2481     """
2482     if not CACHE_FILE.exists():
2483         return {}
2484
2485     with CACHE_FILE.open("rb") as fobj:
2486         try:
2487             cache: Cache = pickle.load(fobj)
2488         except pickle.UnpicklingError:
2489             return {}
2490
2491     return cache
2492
2493
2494 def get_cache_info(path: Path) -> CacheInfo:
2495     """Return the information used to check if a file is already formatted or not."""
2496     stat = path.stat()
2497     return stat.st_mtime, stat.st_size
2498
2499
2500 def filter_cached(
2501     cache: Cache, sources: Iterable[Path]
2502 ) -> Tuple[List[Path], List[Path]]:
2503     """Split a list of paths into two.
2504
2505     The first list contains paths of files that modified on disk or are not in the
2506     cache. The other list contains paths to non-modified files.
2507     """
2508     todo, done = [], []
2509     for src in sources:
2510         src = src.resolve()
2511         if cache.get(src) != get_cache_info(src):
2512             todo.append(src)
2513         else:
2514             done.append(src)
2515     return todo, done
2516
2517
2518 def write_cache(cache: Cache, sources: List[Path]) -> None:
2519     """Update the cache file."""
2520     try:
2521         if not CACHE_DIR.exists():
2522             CACHE_DIR.mkdir(parents=True)
2523         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2524         with CACHE_FILE.open("wb") as fobj:
2525             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2526     except OSError:
2527         pass
2528
2529
2530 if __name__ == "__main__":
2531     main()