]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[#149] Make check and diff not mutually exclusive (#161)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187
188     if check and not diff:
189         write_back = WriteBack.NO
190     elif diff:
191         write_back = WriteBack.DIFF
192     else:
193         write_back = WriteBack.YES
194     if len(sources) == 0:
195         ctx.exit(0)
196         return
197
198     elif len(sources) == 1:
199         return_code = reformat_one(
200             sources[0], line_length, fast, quiet, write_back, check
201         )
202     else:
203         loop = asyncio.get_event_loop()
204         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
205         return_code = 1
206         try:
207             return_code = loop.run_until_complete(
208                 schedule_formatting(
209                     sources, line_length, write_back, fast, quiet, loop, executor, check
210                 )
211             )
212         finally:
213             shutdown(loop)
214     ctx.exit(return_code)
215
216
217 def reformat_one(
218     src: Path,
219     line_length: int,
220     fast: bool,
221     quiet: bool,
222     write_back: WriteBack,
223     check: bool,
224 ) -> int:
225     """Reformat a single file under `src` without spawning child processes.
226
227     If `quiet` is True, non-error messages are not output. `line_length`,
228     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
229     """
230     report = Report(check=check, quiet=quiet)
231     try:
232         changed = Changed.NO
233         if not src.is_file() and str(src) == "-":
234             if format_stdin_to_stdout(
235                 line_length=line_length, fast=fast, write_back=write_back
236             ):
237                 changed = Changed.YES
238         else:
239             cache: Cache = {}
240             if write_back != WriteBack.DIFF:
241                 cache = read_cache()
242                 src = src.resolve()
243                 if src in cache and cache[src] == get_cache_info(src):
244                     changed = Changed.CACHED
245             if (
246                 changed is not Changed.CACHED
247                 and format_file_in_place(
248                     src, line_length=line_length, fast=fast, write_back=write_back
249                 )
250             ):
251                 changed = Changed.YES
252             if write_back != WriteBack.DIFF and changed is not Changed.NO:
253                 write_cache(cache, [src])
254         report.done(src, changed)
255     except Exception as exc:
256         report.failed(src, str(exc))
257     return report.return_code
258
259
260 async def schedule_formatting(
261     sources: List[Path],
262     line_length: int,
263     write_back: WriteBack,
264     fast: bool,
265     quiet: bool,
266     loop: BaseEventLoop,
267     executor: Executor,
268     check: bool,
269 ) -> int:
270     """Run formatting of `sources` in parallel using the provided `executor`.
271
272     (Use ProcessPoolExecutors for actual parallelism.)
273
274     `line_length`, `write_back`, and `fast` options are passed to
275     :func:`format_file_in_place`.
276     """
277     report = Report(check=check, quiet=quiet)
278     cache: Cache = {}
279     if write_back != WriteBack.DIFF:
280         cache = read_cache()
281         sources, cached = filter_cached(cache, sources)
282         for src in cached:
283             report.done(src, Changed.CACHED)
284     cancelled = []
285     formatted = []
286     if sources:
287         lock = None
288         if write_back == WriteBack.DIFF:
289             # For diff output, we need locks to ensure we don't interleave output
290             # from different processes.
291             manager = Manager()
292             lock = manager.Lock()
293         tasks = {
294             src: loop.run_in_executor(
295                 executor, format_file_in_place, src, line_length, fast, write_back, lock
296             )
297             for src in sources
298         }
299         _task_values = list(tasks.values())
300         try:
301             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
302             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
303         except NotImplementedError:
304             # There are no good alternatives for these on Windows
305             pass
306         await asyncio.wait(_task_values)
307         for src, task in tasks.items():
308             if not task.done():
309                 report.failed(src, "timed out, cancelling")
310                 task.cancel()
311                 cancelled.append(task)
312             elif task.cancelled():
313                 cancelled.append(task)
314             elif task.exception():
315                 report.failed(src, str(task.exception()))
316             else:
317                 formatted.append(src)
318                 report.done(src, Changed.YES if task.result() else Changed.NO)
319
320     if cancelled:
321         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
322     elif not quiet:
323         out("All done! ✨ 🍰 ✨")
324     if not quiet:
325         click.echo(str(report))
326
327     if write_back != WriteBack.DIFF and formatted:
328         write_cache(cache, formatted)
329
330     return report.return_code
331
332
333 def format_file_in_place(
334     src: Path,
335     line_length: int,
336     fast: bool,
337     write_back: WriteBack = WriteBack.NO,
338     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
339 ) -> bool:
340     """Format file under `src` path. Return True if changed.
341
342     If `write_back` is True, write reformatted code back to stdout.
343     `line_length` and `fast` options are passed to :func:`format_file_contents`.
344     """
345
346     with tokenize.open(src) as src_buffer:
347         src_contents = src_buffer.read()
348     try:
349         dst_contents = format_file_contents(
350             src_contents, line_length=line_length, fast=fast
351         )
352     except NothingChanged:
353         return False
354
355     if write_back == write_back.YES:
356         with open(src, "w", encoding=src_buffer.encoding) as f:
357             f.write(dst_contents)
358     elif write_back == write_back.DIFF:
359         src_name = f"{src.name}  (original)"
360         dst_name = f"{src.name}  (formatted)"
361         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
362         if lock:
363             lock.acquire()
364         try:
365             sys.stdout.write(diff_contents)
366         finally:
367             if lock:
368                 lock.release()
369     return True
370
371
372 def format_stdin_to_stdout(
373     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
374 ) -> bool:
375     """Format file on stdin. Return True if changed.
376
377     If `write_back` is True, write reformatted code back to stdout.
378     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
379     """
380     src = sys.stdin.read()
381     dst = src
382     try:
383         dst = format_file_contents(src, line_length=line_length, fast=fast)
384         return True
385
386     except NothingChanged:
387         return False
388
389     finally:
390         if write_back == WriteBack.YES:
391             sys.stdout.write(dst)
392         elif write_back == WriteBack.DIFF:
393             src_name = "<stdin>  (original)"
394             dst_name = "<stdin>  (formatted)"
395             sys.stdout.write(diff(src, dst, src_name, dst_name))
396
397
398 def format_file_contents(
399     src_contents: str, line_length: int, fast: bool
400 ) -> FileContent:
401     """Reformat contents a file and return new contents.
402
403     If `fast` is False, additionally confirm that the reformatted code is
404     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
405     `line_length` is passed to :func:`format_str`.
406     """
407     if src_contents.strip() == "":
408         raise NothingChanged
409
410     dst_contents = format_str(src_contents, line_length=line_length)
411     if src_contents == dst_contents:
412         raise NothingChanged
413
414     if not fast:
415         assert_equivalent(src_contents, dst_contents)
416         assert_stable(src_contents, dst_contents, line_length=line_length)
417     return dst_contents
418
419
420 def format_str(src_contents: str, line_length: int) -> FileContent:
421     """Reformat a string and return new contents.
422
423     `line_length` determines how many characters per line are allowed.
424     """
425     src_node = lib2to3_parse(src_contents)
426     dst_contents = ""
427     lines = LineGenerator()
428     elt = EmptyLineTracker()
429     py36 = is_python36(src_node)
430     empty_line = Line()
431     after = 0
432     for current_line in lines.visit(src_node):
433         for _ in range(after):
434             dst_contents += str(empty_line)
435         before, after = elt.maybe_empty_lines(current_line)
436         for _ in range(before):
437             dst_contents += str(empty_line)
438         for line in split_line(current_line, line_length=line_length, py36=py36):
439             dst_contents += str(line)
440     return dst_contents
441
442
443 GRAMMARS = [
444     pygram.python_grammar_no_print_statement_no_exec_statement,
445     pygram.python_grammar_no_print_statement,
446     pygram.python_grammar_no_exec_statement,
447     pygram.python_grammar,
448 ]
449
450
451 def lib2to3_parse(src_txt: str) -> Node:
452     """Given a string with source, return the lib2to3 Node."""
453     grammar = pygram.python_grammar_no_print_statement
454     if src_txt[-1] != "\n":
455         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
456         src_txt += nl
457     for grammar in GRAMMARS:
458         drv = driver.Driver(grammar, pytree.convert)
459         try:
460             result = drv.parse_string(src_txt, True)
461             break
462
463         except ParseError as pe:
464             lineno, column = pe.context[1]
465             lines = src_txt.splitlines()
466             try:
467                 faulty_line = lines[lineno - 1]
468             except IndexError:
469                 faulty_line = "<line number missing in source>"
470             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
471     else:
472         raise exc from None
473
474     if isinstance(result, Leaf):
475         result = Node(syms.file_input, [result])
476     return result
477
478
479 def lib2to3_unparse(node: Node) -> str:
480     """Given a lib2to3 node, return its string representation."""
481     code = str(node)
482     return code
483
484
485 T = TypeVar("T")
486
487
488 class Visitor(Generic[T]):
489     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
490
491     def visit(self, node: LN) -> Iterator[T]:
492         """Main method to visit `node` and its children.
493
494         It tries to find a `visit_*()` method for the given `node.type`, like
495         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
496         If no dedicated `visit_*()` method is found, chooses `visit_default()`
497         instead.
498
499         Then yields objects of type `T` from the selected visitor.
500         """
501         if node.type < 256:
502             name = token.tok_name[node.type]
503         else:
504             name = type_repr(node.type)
505         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
506
507     def visit_default(self, node: LN) -> Iterator[T]:
508         """Default `visit_*()` implementation. Recurses to children of `node`."""
509         if isinstance(node, Node):
510             for child in node.children:
511                 yield from self.visit(child)
512
513
514 @dataclass
515 class DebugVisitor(Visitor[T]):
516     tree_depth: int = 0
517
518     def visit_default(self, node: LN) -> Iterator[T]:
519         indent = " " * (2 * self.tree_depth)
520         if isinstance(node, Node):
521             _type = type_repr(node.type)
522             out(f"{indent}{_type}", fg="yellow")
523             self.tree_depth += 1
524             for child in node.children:
525                 yield from self.visit(child)
526
527             self.tree_depth -= 1
528             out(f"{indent}/{_type}", fg="yellow", bold=False)
529         else:
530             _type = token.tok_name.get(node.type, str(node.type))
531             out(f"{indent}{_type}", fg="blue", nl=False)
532             if node.prefix:
533                 # We don't have to handle prefixes for `Node` objects since
534                 # that delegates to the first child anyway.
535                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
536             out(f" {node.value!r}", fg="blue", bold=False)
537
538     @classmethod
539     def show(cls, code: str) -> None:
540         """Pretty-print the lib2to3 AST of a given string of `code`.
541
542         Convenience method for debugging.
543         """
544         v: DebugVisitor[None] = DebugVisitor()
545         list(v.visit(lib2to3_parse(code)))
546
547
548 KEYWORDS = set(keyword.kwlist)
549 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
550 FLOW_CONTROL = {"return", "raise", "break", "continue"}
551 STATEMENT = {
552     syms.if_stmt,
553     syms.while_stmt,
554     syms.for_stmt,
555     syms.try_stmt,
556     syms.except_clause,
557     syms.with_stmt,
558     syms.funcdef,
559     syms.classdef,
560 }
561 STANDALONE_COMMENT = 153
562 LOGIC_OPERATORS = {"and", "or"}
563 COMPARATORS = {
564     token.LESS,
565     token.GREATER,
566     token.EQEQUAL,
567     token.NOTEQUAL,
568     token.LESSEQUAL,
569     token.GREATEREQUAL,
570 }
571 MATH_OPERATORS = {
572     token.PLUS,
573     token.MINUS,
574     token.STAR,
575     token.SLASH,
576     token.VBAR,
577     token.AMPER,
578     token.PERCENT,
579     token.CIRCUMFLEX,
580     token.TILDE,
581     token.LEFTSHIFT,
582     token.RIGHTSHIFT,
583     token.DOUBLESTAR,
584     token.DOUBLESLASH,
585 }
586 STARS = {token.STAR, token.DOUBLESTAR}
587 VARARGS_PARENTS = {
588     syms.arglist,
589     syms.argument,  # double star in arglist
590     syms.trailer,  # single argument to call
591     syms.typedargslist,
592     syms.varargslist,  # lambdas
593 }
594 UNPACKING_PARENTS = {
595     syms.atom,  # single element of a list or set literal
596     syms.dictsetmaker,
597     syms.listmaker,
598     syms.testlist_gexp,
599 }
600 COMPREHENSION_PRIORITY = 20
601 COMMA_PRIORITY = 10
602 LOGIC_PRIORITY = 5
603 STRING_PRIORITY = 4
604 COMPARATOR_PRIORITY = 3
605 MATH_PRIORITY = 1
606
607
608 @dataclass
609 class BracketTracker:
610     """Keeps track of brackets on a line."""
611
612     depth: int = 0
613     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
614     delimiters: Dict[LeafID, Priority] = Factory(dict)
615     previous: Optional[Leaf] = None
616     _for_loop_variable: bool = False
617     _lambda_arguments: bool = False
618
619     def mark(self, leaf: Leaf) -> None:
620         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
621
622         All leaves receive an int `bracket_depth` field that stores how deep
623         within brackets a given leaf is. 0 means there are no enclosing brackets
624         that started on this line.
625
626         If a leaf is itself a closing bracket, it receives an `opening_bracket`
627         field that it forms a pair with. This is a one-directional link to
628         avoid reference cycles.
629
630         If a leaf is a delimiter (a token on which Black can split the line if
631         needed) and it's on depth 0, its `id()` is stored in the tracker's
632         `delimiters` field.
633         """
634         if leaf.type == token.COMMENT:
635             return
636
637         self.maybe_decrement_after_for_loop_variable(leaf)
638         self.maybe_decrement_after_lambda_arguments(leaf)
639         if leaf.type in CLOSING_BRACKETS:
640             self.depth -= 1
641             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
642             leaf.opening_bracket = opening_bracket
643         leaf.bracket_depth = self.depth
644         if self.depth == 0:
645             delim = is_split_before_delimiter(leaf, self.previous)
646             if delim and self.previous is not None:
647                 self.delimiters[id(self.previous)] = delim
648             else:
649                 delim = is_split_after_delimiter(leaf, self.previous)
650                 if delim:
651                     self.delimiters[id(leaf)] = delim
652         if leaf.type in OPENING_BRACKETS:
653             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
654             self.depth += 1
655         self.previous = leaf
656         self.maybe_increment_lambda_arguments(leaf)
657         self.maybe_increment_for_loop_variable(leaf)
658
659     def any_open_brackets(self) -> bool:
660         """Return True if there is an yet unmatched open bracket on the line."""
661         return bool(self.bracket_match)
662
663     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
664         """Return the highest priority of a delimiter found on the line.
665
666         Values are consistent with what `is_split_*_delimiter()` return.
667         Raises ValueError on no delimiters.
668         """
669         return max(v for k, v in self.delimiters.items() if k not in exclude)
670
671     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
672         """In a for loop, or comprehension, the variables are often unpacks.
673
674         To avoid splitting on the comma in this situation, increase the depth of
675         tokens between `for` and `in`.
676         """
677         if leaf.type == token.NAME and leaf.value == "for":
678             self.depth += 1
679             self._for_loop_variable = True
680             return True
681
682         return False
683
684     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
685         """See `maybe_increment_for_loop_variable` above for explanation."""
686         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
687             self.depth -= 1
688             self._for_loop_variable = False
689             return True
690
691         return False
692
693     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
694         """In a lambda expression, there might be more than one argument.
695
696         To avoid splitting on the comma in this situation, increase the depth of
697         tokens between `lambda` and `:`.
698         """
699         if leaf.type == token.NAME and leaf.value == "lambda":
700             self.depth += 1
701             self._lambda_arguments = True
702             return True
703
704         return False
705
706     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
707         """See `maybe_increment_lambda_arguments` above for explanation."""
708         if self._lambda_arguments and leaf.type == token.COLON:
709             self.depth -= 1
710             self._lambda_arguments = False
711             return True
712
713         return False
714
715
716 @dataclass
717 class Line:
718     """Holds leaves and comments. Can be printed with `str(line)`."""
719
720     depth: int = 0
721     leaves: List[Leaf] = Factory(list)
722     comments: List[Tuple[Index, Leaf]] = Factory(list)
723     bracket_tracker: BracketTracker = Factory(BracketTracker)
724     inside_brackets: bool = False
725
726     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
727         """Add a new `leaf` to the end of the line.
728
729         Unless `preformatted` is True, the `leaf` will receive a new consistent
730         whitespace prefix and metadata applied by :class:`BracketTracker`.
731         Trailing commas are maybe removed, unpacked for loop variables are
732         demoted from being delimiters.
733
734         Inline comments are put aside.
735         """
736         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
737         if not has_value:
738             return
739
740         if self.leaves and not preformatted:
741             # Note: at this point leaf.prefix should be empty except for
742             # imports, for which we only preserve newlines.
743             leaf.prefix += whitespace(leaf)
744         if self.inside_brackets or not preformatted:
745             self.bracket_tracker.mark(leaf)
746             self.maybe_remove_trailing_comma(leaf)
747
748         if not self.append_comment(leaf):
749             self.leaves.append(leaf)
750
751     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
752         """Like :func:`append()` but disallow invalid standalone comment structure.
753
754         Raises ValueError when any `leaf` is appended after a standalone comment
755         or when a standalone comment is not the first leaf on the line.
756         """
757         if self.bracket_tracker.depth == 0:
758             if self.is_comment:
759                 raise ValueError("cannot append to standalone comments")
760
761             if self.leaves and leaf.type == STANDALONE_COMMENT:
762                 raise ValueError(
763                     "cannot append standalone comments to a populated line"
764                 )
765
766         self.append(leaf, preformatted=preformatted)
767
768     @property
769     def is_comment(self) -> bool:
770         """Is this line a standalone comment?"""
771         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
772
773     @property
774     def is_decorator(self) -> bool:
775         """Is this line a decorator?"""
776         return bool(self) and self.leaves[0].type == token.AT
777
778     @property
779     def is_import(self) -> bool:
780         """Is this an import line?"""
781         return bool(self) and is_import(self.leaves[0])
782
783     @property
784     def is_class(self) -> bool:
785         """Is this line a class definition?"""
786         return (
787             bool(self)
788             and self.leaves[0].type == token.NAME
789             and self.leaves[0].value == "class"
790         )
791
792     @property
793     def is_def(self) -> bool:
794         """Is this a function definition? (Also returns True for async defs.)"""
795         try:
796             first_leaf = self.leaves[0]
797         except IndexError:
798             return False
799
800         try:
801             second_leaf: Optional[Leaf] = self.leaves[1]
802         except IndexError:
803             second_leaf = None
804         return (
805             (first_leaf.type == token.NAME and first_leaf.value == "def")
806             or (
807                 first_leaf.type == token.ASYNC
808                 and second_leaf is not None
809                 and second_leaf.type == token.NAME
810                 and second_leaf.value == "def"
811             )
812         )
813
814     @property
815     def is_flow_control(self) -> bool:
816         """Is this line a flow control statement?
817
818         Those are `return`, `raise`, `break`, and `continue`.
819         """
820         return (
821             bool(self)
822             and self.leaves[0].type == token.NAME
823             and self.leaves[0].value in FLOW_CONTROL
824         )
825
826     @property
827     def is_yield(self) -> bool:
828         """Is this line a yield statement?"""
829         return (
830             bool(self)
831             and self.leaves[0].type == token.NAME
832             and self.leaves[0].value == "yield"
833         )
834
835     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
836         """If so, needs to be split before emitting."""
837         for leaf in self.leaves:
838             if leaf.type == STANDALONE_COMMENT:
839                 if leaf.bracket_depth <= depth_limit:
840                     return True
841
842         return False
843
844     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
845         """Remove trailing comma if there is one and it's safe."""
846         if not (
847             self.leaves
848             and self.leaves[-1].type == token.COMMA
849             and closing.type in CLOSING_BRACKETS
850         ):
851             return False
852
853         if closing.type == token.RBRACE:
854             self.remove_trailing_comma()
855             return True
856
857         if closing.type == token.RSQB:
858             comma = self.leaves[-1]
859             if comma.parent and comma.parent.type == syms.listmaker:
860                 self.remove_trailing_comma()
861                 return True
862
863         # For parens let's check if it's safe to remove the comma.  If the
864         # trailing one is the only one, we might mistakenly change a tuple
865         # into a different type by removing the comma.
866         depth = closing.bracket_depth + 1
867         commas = 0
868         opening = closing.opening_bracket
869         for _opening_index, leaf in enumerate(self.leaves):
870             if leaf is opening:
871                 break
872
873         else:
874             return False
875
876         for leaf in self.leaves[_opening_index + 1:]:
877             if leaf is closing:
878                 break
879
880             bracket_depth = leaf.bracket_depth
881             if bracket_depth == depth and leaf.type == token.COMMA:
882                 commas += 1
883                 if leaf.parent and leaf.parent.type == syms.arglist:
884                     commas += 1
885                     break
886
887         if commas > 1:
888             self.remove_trailing_comma()
889             return True
890
891         return False
892
893     def append_comment(self, comment: Leaf) -> bool:
894         """Add an inline or standalone comment to the line."""
895         if (
896             comment.type == STANDALONE_COMMENT
897             and self.bracket_tracker.any_open_brackets()
898         ):
899             comment.prefix = ""
900             return False
901
902         if comment.type != token.COMMENT:
903             return False
904
905         after = len(self.leaves) - 1
906         if after == -1:
907             comment.type = STANDALONE_COMMENT
908             comment.prefix = ""
909             return False
910
911         else:
912             self.comments.append((after, comment))
913             return True
914
915     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
916         """Generate comments that should appear directly after `leaf`."""
917         for _leaf_index, _leaf in enumerate(self.leaves):
918             if leaf is _leaf:
919                 break
920
921         else:
922             return
923
924         for index, comment_after in self.comments:
925             if _leaf_index == index:
926                 yield comment_after
927
928     def remove_trailing_comma(self) -> None:
929         """Remove the trailing comma and moves the comments attached to it."""
930         comma_index = len(self.leaves) - 1
931         for i in range(len(self.comments)):
932             comment_index, comment = self.comments[i]
933             if comment_index == comma_index:
934                 self.comments[i] = (comma_index - 1, comment)
935         self.leaves.pop()
936
937     def __str__(self) -> str:
938         """Render the line."""
939         if not self:
940             return "\n"
941
942         indent = "    " * self.depth
943         leaves = iter(self.leaves)
944         first = next(leaves)
945         res = f"{first.prefix}{indent}{first.value}"
946         for leaf in leaves:
947             res += str(leaf)
948         for _, comment in self.comments:
949             res += str(comment)
950         return res + "\n"
951
952     def __bool__(self) -> bool:
953         """Return True if the line has leaves or comments."""
954         return bool(self.leaves or self.comments)
955
956
957 class UnformattedLines(Line):
958     """Just like :class:`Line` but stores lines which aren't reformatted."""
959
960     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
961         """Just add a new `leaf` to the end of the lines.
962
963         The `preformatted` argument is ignored.
964
965         Keeps track of indentation `depth`, which is useful when the user
966         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
967         """
968         try:
969             list(generate_comments(leaf))
970         except FormatOn as f_on:
971             self.leaves.append(f_on.leaf_from_consumed(leaf))
972             raise
973
974         self.leaves.append(leaf)
975         if leaf.type == token.INDENT:
976             self.depth += 1
977         elif leaf.type == token.DEDENT:
978             self.depth -= 1
979
980     def __str__(self) -> str:
981         """Render unformatted lines from leaves which were added with `append()`.
982
983         `depth` is not used for indentation in this case.
984         """
985         if not self:
986             return "\n"
987
988         res = ""
989         for leaf in self.leaves:
990             res += str(leaf)
991         return res
992
993     def append_comment(self, comment: Leaf) -> bool:
994         """Not implemented in this class. Raises `NotImplementedError`."""
995         raise NotImplementedError("Unformatted lines don't store comments separately.")
996
997     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
998         """Does nothing and returns False."""
999         return False
1000
1001     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1002         """Does nothing and returns False."""
1003         return False
1004
1005
1006 @dataclass
1007 class EmptyLineTracker:
1008     """Provides a stateful method that returns the number of potential extra
1009     empty lines needed before and after the currently processed line.
1010
1011     Note: this tracker works on lines that haven't been split yet.  It assumes
1012     the prefix of the first leaf consists of optional newlines.  Those newlines
1013     are consumed by `maybe_empty_lines()` and included in the computation.
1014     """
1015     previous_line: Optional[Line] = None
1016     previous_after: int = 0
1017     previous_defs: List[int] = Factory(list)
1018
1019     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1020         """Return the number of extra empty lines before and after the `current_line`.
1021
1022         This is for separating `def`, `async def` and `class` with extra empty
1023         lines (two on module-level), as well as providing an extra empty line
1024         after flow control keywords to make them more prominent.
1025         """
1026         if isinstance(current_line, UnformattedLines):
1027             return 0, 0
1028
1029         before, after = self._maybe_empty_lines(current_line)
1030         before -= self.previous_after
1031         self.previous_after = after
1032         self.previous_line = current_line
1033         return before, after
1034
1035     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1036         max_allowed = 1
1037         if current_line.depth == 0:
1038             max_allowed = 2
1039         if current_line.leaves:
1040             # Consume the first leaf's extra newlines.
1041             first_leaf = current_line.leaves[0]
1042             before = first_leaf.prefix.count("\n")
1043             before = min(before, max_allowed)
1044             first_leaf.prefix = ""
1045         else:
1046             before = 0
1047         depth = current_line.depth
1048         while self.previous_defs and self.previous_defs[-1] >= depth:
1049             self.previous_defs.pop()
1050             before = 1 if depth else 2
1051         is_decorator = current_line.is_decorator
1052         if is_decorator or current_line.is_def or current_line.is_class:
1053             if not is_decorator:
1054                 self.previous_defs.append(depth)
1055             if self.previous_line is None:
1056                 # Don't insert empty lines before the first line in the file.
1057                 return 0, 0
1058
1059             if self.previous_line and self.previous_line.is_decorator:
1060                 # Don't insert empty lines between decorators.
1061                 return 0, 0
1062
1063             newlines = 2
1064             if current_line.depth:
1065                 newlines -= 1
1066             return newlines, 0
1067
1068         if current_line.is_flow_control:
1069             return before, 1
1070
1071         if (
1072             self.previous_line
1073             and self.previous_line.is_import
1074             and not current_line.is_import
1075             and depth == self.previous_line.depth
1076         ):
1077             return (before or 1), 0
1078
1079         if (
1080             self.previous_line
1081             and self.previous_line.is_yield
1082             and (not current_line.is_yield or depth != self.previous_line.depth)
1083         ):
1084             return (before or 1), 0
1085
1086         return before, 0
1087
1088
1089 @dataclass
1090 class LineGenerator(Visitor[Line]):
1091     """Generates reformatted Line objects.  Empty lines are not emitted.
1092
1093     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1094     in ways that will no longer stringify to valid Python code on the tree.
1095     """
1096     current_line: Line = Factory(Line)
1097
1098     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1099         """Generate a line.
1100
1101         If the line is empty, only emit if it makes sense.
1102         If the line is too long, split it first and then generate.
1103
1104         If any lines were generated, set up a new current_line.
1105         """
1106         if not self.current_line:
1107             if self.current_line.__class__ == type:
1108                 self.current_line.depth += indent
1109             else:
1110                 self.current_line = type(depth=self.current_line.depth + indent)
1111             return  # Line is empty, don't emit. Creating a new one unnecessary.
1112
1113         complete_line = self.current_line
1114         self.current_line = type(depth=complete_line.depth + indent)
1115         yield complete_line
1116
1117     def visit(self, node: LN) -> Iterator[Line]:
1118         """Main method to visit `node` and its children.
1119
1120         Yields :class:`Line` objects.
1121         """
1122         if isinstance(self.current_line, UnformattedLines):
1123             # File contained `# fmt: off`
1124             yield from self.visit_unformatted(node)
1125
1126         else:
1127             yield from super().visit(node)
1128
1129     def visit_default(self, node: LN) -> Iterator[Line]:
1130         """Default `visit_*()` implementation. Recurses to children of `node`."""
1131         if isinstance(node, Leaf):
1132             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1133             try:
1134                 for comment in generate_comments(node):
1135                     if any_open_brackets:
1136                         # any comment within brackets is subject to splitting
1137                         self.current_line.append(comment)
1138                     elif comment.type == token.COMMENT:
1139                         # regular trailing comment
1140                         self.current_line.append(comment)
1141                         yield from self.line()
1142
1143                     else:
1144                         # regular standalone comment
1145                         yield from self.line()
1146
1147                         self.current_line.append(comment)
1148                         yield from self.line()
1149
1150             except FormatOff as f_off:
1151                 f_off.trim_prefix(node)
1152                 yield from self.line(type=UnformattedLines)
1153                 yield from self.visit(node)
1154
1155             except FormatOn as f_on:
1156                 # This only happens here if somebody says "fmt: on" multiple
1157                 # times in a row.
1158                 f_on.trim_prefix(node)
1159                 yield from self.visit_default(node)
1160
1161             else:
1162                 normalize_prefix(node, inside_brackets=any_open_brackets)
1163                 if node.type == token.STRING:
1164                     normalize_string_quotes(node)
1165                 if node.type not in WHITESPACE:
1166                     self.current_line.append(node)
1167         yield from super().visit_default(node)
1168
1169     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1170         """Increase indentation level, maybe yield a line."""
1171         # In blib2to3 INDENT never holds comments.
1172         yield from self.line(+1)
1173         yield from self.visit_default(node)
1174
1175     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1176         """Decrease indentation level, maybe yield a line."""
1177         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1178         yield from self.line(-1)
1179
1180     def visit_stmt(
1181         self, node: Node, keywords: Set[str], parens: Set[str]
1182     ) -> Iterator[Line]:
1183         """Visit a statement.
1184
1185         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1186         `def`, `with`, `class`, and `assert`.
1187
1188         The relevant Python language `keywords` for a given statement will be
1189         NAME leaves within it. This methods puts those on a separate line.
1190
1191         `parens` holds pairs of nodes where invisible parentheses should be put.
1192         Keys hold nodes after which opening parentheses should be put, values
1193         hold nodes before which closing parentheses should be put.
1194         """
1195         normalize_invisible_parens(node, parens_after=parens)
1196         for child in node.children:
1197             if child.type == token.NAME and child.value in keywords:  # type: ignore
1198                 yield from self.line()
1199
1200             yield from self.visit(child)
1201
1202     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1203         """Visit a statement without nested statements."""
1204         is_suite_like = node.parent and node.parent.type in STATEMENT
1205         if is_suite_like:
1206             yield from self.line(+1)
1207             yield from self.visit_default(node)
1208             yield from self.line(-1)
1209
1210         else:
1211             yield from self.line()
1212             yield from self.visit_default(node)
1213
1214     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1215         """Visit `async def`, `async for`, `async with`."""
1216         yield from self.line()
1217
1218         children = iter(node.children)
1219         for child in children:
1220             yield from self.visit(child)
1221
1222             if child.type == token.ASYNC:
1223                 break
1224
1225         internal_stmt = next(children)
1226         for child in internal_stmt.children:
1227             yield from self.visit(child)
1228
1229     def visit_decorators(self, node: Node) -> Iterator[Line]:
1230         """Visit decorators."""
1231         for child in node.children:
1232             yield from self.line()
1233             yield from self.visit(child)
1234
1235     def visit_import_from(self, node: Node) -> Iterator[Line]:
1236         """Visit import_from and maybe put invisible parentheses.
1237
1238         This is separate from `visit_stmt` because import statements don't
1239         support arbitrary atoms and thus handling of parentheses is custom.
1240         """
1241         check_lpar = False
1242         for index, child in enumerate(node.children):
1243             if check_lpar:
1244                 if child.type == token.LPAR:
1245                     # make parentheses invisible
1246                     child.value = ""  # type: ignore
1247                     node.children[-1].value = ""  # type: ignore
1248                 else:
1249                     # insert invisible parentheses
1250                     node.insert_child(index, Leaf(token.LPAR, ""))
1251                     node.append_child(Leaf(token.RPAR, ""))
1252                 break
1253
1254             check_lpar = (
1255                 child.type == token.NAME and child.value == "import"  # type: ignore
1256             )
1257
1258         for child in node.children:
1259             yield from self.visit(child)
1260
1261     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1262         """Remove a semicolon and put the other statement on a separate line."""
1263         yield from self.line()
1264
1265     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1266         """End of file. Process outstanding comments and end with a newline."""
1267         yield from self.visit_default(leaf)
1268         yield from self.line()
1269
1270     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1271         """Used when file contained a `# fmt: off`."""
1272         if isinstance(node, Node):
1273             for child in node.children:
1274                 yield from self.visit(child)
1275
1276         else:
1277             try:
1278                 self.current_line.append(node)
1279             except FormatOn as f_on:
1280                 f_on.trim_prefix(node)
1281                 yield from self.line()
1282                 yield from self.visit(node)
1283
1284             if node.type == token.ENDMARKER:
1285                 # somebody decided not to put a final `# fmt: on`
1286                 yield from self.line()
1287
1288     def __attrs_post_init__(self) -> None:
1289         """You are in a twisty little maze of passages."""
1290         v = self.visit_stmt
1291         Ø: Set[str] = set()
1292         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1293         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1294         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1295         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1296         self.visit_try_stmt = partial(
1297             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1298         )
1299         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1300         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1301         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1302         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1303         self.visit_async_funcdef = self.visit_async_stmt
1304         self.visit_decorated = self.visit_decorators
1305
1306
1307 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1308 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1309 OPENING_BRACKETS = set(BRACKET.keys())
1310 CLOSING_BRACKETS = set(BRACKET.values())
1311 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1312 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1313
1314
1315 def whitespace(leaf: Leaf) -> str:  # noqa C901
1316     """Return whitespace prefix if needed for the given `leaf`."""
1317     NO = ""
1318     SPACE = " "
1319     DOUBLESPACE = "  "
1320     t = leaf.type
1321     p = leaf.parent
1322     v = leaf.value
1323     if t in ALWAYS_NO_SPACE:
1324         return NO
1325
1326     if t == token.COMMENT:
1327         return DOUBLESPACE
1328
1329     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1330     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1331         return NO
1332
1333     prev = leaf.prev_sibling
1334     if not prev:
1335         prevp = preceding_leaf(p)
1336         if not prevp or prevp.type in OPENING_BRACKETS:
1337             return NO
1338
1339         if t == token.COLON:
1340             return SPACE if prevp.type == token.COMMA else NO
1341
1342         if prevp.type == token.EQUAL:
1343             if prevp.parent:
1344                 if prevp.parent.type in {
1345                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1346                 }:
1347                     return NO
1348
1349                 elif prevp.parent.type == syms.typedargslist:
1350                     # A bit hacky: if the equal sign has whitespace, it means we
1351                     # previously found it's a typed argument.  So, we're using
1352                     # that, too.
1353                     return prevp.prefix
1354
1355         elif prevp.type in STARS:
1356             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1357                 return NO
1358
1359         elif prevp.type == token.COLON:
1360             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1361                 return NO
1362
1363         elif (
1364             prevp.parent
1365             and prevp.parent.type == syms.factor
1366             and prevp.type in MATH_OPERATORS
1367         ):
1368             return NO
1369
1370         elif (
1371             prevp.type == token.RIGHTSHIFT
1372             and prevp.parent
1373             and prevp.parent.type == syms.shift_expr
1374             and prevp.prev_sibling
1375             and prevp.prev_sibling.type == token.NAME
1376             and prevp.prev_sibling.value == "print"  # type: ignore
1377         ):
1378             # Python 2 print chevron
1379             return NO
1380
1381     elif prev.type in OPENING_BRACKETS:
1382         return NO
1383
1384     if p.type in {syms.parameters, syms.arglist}:
1385         # untyped function signatures or calls
1386         if not prev or prev.type != token.COMMA:
1387             return NO
1388
1389     elif p.type == syms.varargslist:
1390         # lambdas
1391         if prev and prev.type != token.COMMA:
1392             return NO
1393
1394     elif p.type == syms.typedargslist:
1395         # typed function signatures
1396         if not prev:
1397             return NO
1398
1399         if t == token.EQUAL:
1400             if prev.type != syms.tname:
1401                 return NO
1402
1403         elif prev.type == token.EQUAL:
1404             # A bit hacky: if the equal sign has whitespace, it means we
1405             # previously found it's a typed argument.  So, we're using that, too.
1406             return prev.prefix
1407
1408         elif prev.type != token.COMMA:
1409             return NO
1410
1411     elif p.type == syms.tname:
1412         # type names
1413         if not prev:
1414             prevp = preceding_leaf(p)
1415             if not prevp or prevp.type != token.COMMA:
1416                 return NO
1417
1418     elif p.type == syms.trailer:
1419         # attributes and calls
1420         if t == token.LPAR or t == token.RPAR:
1421             return NO
1422
1423         if not prev:
1424             if t == token.DOT:
1425                 prevp = preceding_leaf(p)
1426                 if not prevp or prevp.type != token.NUMBER:
1427                     return NO
1428
1429             elif t == token.LSQB:
1430                 return NO
1431
1432         elif prev.type != token.COMMA:
1433             return NO
1434
1435     elif p.type == syms.argument:
1436         # single argument
1437         if t == token.EQUAL:
1438             return NO
1439
1440         if not prev:
1441             prevp = preceding_leaf(p)
1442             if not prevp or prevp.type == token.LPAR:
1443                 return NO
1444
1445         elif prev.type in {token.EQUAL} | STARS:
1446             return NO
1447
1448     elif p.type == syms.decorator:
1449         # decorators
1450         return NO
1451
1452     elif p.type == syms.dotted_name:
1453         if prev:
1454             return NO
1455
1456         prevp = preceding_leaf(p)
1457         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1458             return NO
1459
1460     elif p.type == syms.classdef:
1461         if t == token.LPAR:
1462             return NO
1463
1464         if prev and prev.type == token.LPAR:
1465             return NO
1466
1467     elif p.type == syms.subscript:
1468         # indexing
1469         if not prev:
1470             assert p.parent is not None, "subscripts are always parented"
1471             if p.parent.type == syms.subscriptlist:
1472                 return SPACE
1473
1474             return NO
1475
1476         else:
1477             return NO
1478
1479     elif p.type == syms.atom:
1480         if prev and t == token.DOT:
1481             # dots, but not the first one.
1482             return NO
1483
1484     elif p.type == syms.dictsetmaker:
1485         # dict unpacking
1486         if prev and prev.type == token.DOUBLESTAR:
1487             return NO
1488
1489     elif p.type in {syms.factor, syms.star_expr}:
1490         # unary ops
1491         if not prev:
1492             prevp = preceding_leaf(p)
1493             if not prevp or prevp.type in OPENING_BRACKETS:
1494                 return NO
1495
1496             prevp_parent = prevp.parent
1497             assert prevp_parent is not None
1498             if (
1499                 prevp.type == token.COLON
1500                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1501             ):
1502                 return NO
1503
1504             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1505                 return NO
1506
1507         elif t == token.NAME or t == token.NUMBER:
1508             return NO
1509
1510     elif p.type == syms.import_from:
1511         if t == token.DOT:
1512             if prev and prev.type == token.DOT:
1513                 return NO
1514
1515         elif t == token.NAME:
1516             if v == "import":
1517                 return SPACE
1518
1519             if prev and prev.type == token.DOT:
1520                 return NO
1521
1522     elif p.type == syms.sliceop:
1523         return NO
1524
1525     return SPACE
1526
1527
1528 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1529     """Return the first leaf that precedes `node`, if any."""
1530     while node:
1531         res = node.prev_sibling
1532         if res:
1533             if isinstance(res, Leaf):
1534                 return res
1535
1536             try:
1537                 return list(res.leaves())[-1]
1538
1539             except IndexError:
1540                 return None
1541
1542         node = node.parent
1543     return None
1544
1545
1546 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1547     """Return the priority of the `leaf` delimiter, given a line break after it.
1548
1549     The delimiter priorities returned here are from those delimiters that would
1550     cause a line break after themselves.
1551
1552     Higher numbers are higher priority.
1553     """
1554     if leaf.type == token.COMMA:
1555         return COMMA_PRIORITY
1556
1557     return 0
1558
1559
1560 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1561     """Return the priority of the `leaf` delimiter, given a line before after it.
1562
1563     The delimiter priorities returned here are from those delimiters that would
1564     cause a line break before themselves.
1565
1566     Higher numbers are higher priority.
1567     """
1568     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1569         # * and ** might also be MATH_OPERATORS but in this case they are not.
1570         # Don't treat them as a delimiter.
1571         return 0
1572
1573     if (
1574         leaf.type in MATH_OPERATORS
1575         and leaf.parent
1576         and leaf.parent.type not in {syms.factor, syms.star_expr}
1577     ):
1578         return MATH_PRIORITY
1579
1580     if leaf.type in COMPARATORS:
1581         return COMPARATOR_PRIORITY
1582
1583     if (
1584         leaf.type == token.STRING
1585         and previous is not None
1586         and previous.type == token.STRING
1587     ):
1588         return STRING_PRIORITY
1589
1590     if (
1591         leaf.type == token.NAME
1592         and leaf.value == "for"
1593         and leaf.parent
1594         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1595     ):
1596         return COMPREHENSION_PRIORITY
1597
1598     if (
1599         leaf.type == token.NAME
1600         and leaf.value == "if"
1601         and leaf.parent
1602         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1603     ):
1604         return COMPREHENSION_PRIORITY
1605
1606     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1607         return LOGIC_PRIORITY
1608
1609     return 0
1610
1611
1612 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1613     """Clean the prefix of the `leaf` and generate comments from it, if any.
1614
1615     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1616     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1617     move because it does away with modifying the grammar to include all the
1618     possible places in which comments can be placed.
1619
1620     The sad consequence for us though is that comments don't "belong" anywhere.
1621     This is why this function generates simple parentless Leaf objects for
1622     comments.  We simply don't know what the correct parent should be.
1623
1624     No matter though, we can live without this.  We really only need to
1625     differentiate between inline and standalone comments.  The latter don't
1626     share the line with any code.
1627
1628     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1629     are emitted with a fake STANDALONE_COMMENT token identifier.
1630     """
1631     p = leaf.prefix
1632     if not p:
1633         return
1634
1635     if "#" not in p:
1636         return
1637
1638     consumed = 0
1639     nlines = 0
1640     for index, line in enumerate(p.split("\n")):
1641         consumed += len(line) + 1  # adding the length of the split '\n'
1642         line = line.lstrip()
1643         if not line:
1644             nlines += 1
1645         if not line.startswith("#"):
1646             continue
1647
1648         if index == 0 and leaf.type != token.ENDMARKER:
1649             comment_type = token.COMMENT  # simple trailing comment
1650         else:
1651             comment_type = STANDALONE_COMMENT
1652         comment = make_comment(line)
1653         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1654
1655         if comment in {"# fmt: on", "# yapf: enable"}:
1656             raise FormatOn(consumed)
1657
1658         if comment in {"# fmt: off", "# yapf: disable"}:
1659             if comment_type == STANDALONE_COMMENT:
1660                 raise FormatOff(consumed)
1661
1662             prev = preceding_leaf(leaf)
1663             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1664                 raise FormatOff(consumed)
1665
1666         nlines = 0
1667
1668
1669 def make_comment(content: str) -> str:
1670     """Return a consistently formatted comment from the given `content` string.
1671
1672     All comments (except for "##", "#!", "#:") should have a single space between
1673     the hash sign and the content.
1674
1675     If `content` didn't start with a hash sign, one is provided.
1676     """
1677     content = content.rstrip()
1678     if not content:
1679         return "#"
1680
1681     if content[0] == "#":
1682         content = content[1:]
1683     if content and content[0] not in " !:#":
1684         content = " " + content
1685     return "#" + content
1686
1687
1688 def split_line(
1689     line: Line, line_length: int, inner: bool = False, py36: bool = False
1690 ) -> Iterator[Line]:
1691     """Split a `line` into potentially many lines.
1692
1693     They should fit in the allotted `line_length` but might not be able to.
1694     `inner` signifies that there were a pair of brackets somewhere around the
1695     current `line`, possibly transitively. This means we can fallback to splitting
1696     by delimiters if the LHS/RHS don't yield any results.
1697
1698     If `py36` is True, splitting may generate syntax that is only compatible
1699     with Python 3.6 and later.
1700     """
1701     if isinstance(line, UnformattedLines) or line.is_comment:
1702         yield line
1703         return
1704
1705     line_str = str(line).strip("\n")
1706     if (
1707         len(line_str) <= line_length
1708         and "\n" not in line_str  # multiline strings
1709         and not line.contains_standalone_comments()
1710     ):
1711         yield line
1712         return
1713
1714     split_funcs: List[SplitFunc]
1715     if line.is_def:
1716         split_funcs = [left_hand_split]
1717     elif line.inside_brackets:
1718         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1719     else:
1720         split_funcs = [right_hand_split]
1721     for split_func in split_funcs:
1722         # We are accumulating lines in `result` because we might want to abort
1723         # mission and return the original line in the end, or attempt a different
1724         # split altogether.
1725         result: List[Line] = []
1726         try:
1727             for l in split_func(line, py36):
1728                 if str(l).strip("\n") == line_str:
1729                     raise CannotSplit("Split function returned an unchanged result")
1730
1731                 result.extend(
1732                     split_line(l, line_length=line_length, inner=True, py36=py36)
1733                 )
1734         except CannotSplit as cs:
1735             continue
1736
1737         else:
1738             yield from result
1739             break
1740
1741     else:
1742         yield line
1743
1744
1745 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1746     """Split line into many lines, starting with the first matching bracket pair.
1747
1748     Note: this usually looks weird, only use this for function definitions.
1749     Prefer RHS otherwise.
1750     """
1751     head = Line(depth=line.depth)
1752     body = Line(depth=line.depth + 1, inside_brackets=True)
1753     tail = Line(depth=line.depth)
1754     tail_leaves: List[Leaf] = []
1755     body_leaves: List[Leaf] = []
1756     head_leaves: List[Leaf] = []
1757     current_leaves = head_leaves
1758     matching_bracket = None
1759     for leaf in line.leaves:
1760         if (
1761             current_leaves is body_leaves
1762             and leaf.type in CLOSING_BRACKETS
1763             and leaf.opening_bracket is matching_bracket
1764         ):
1765             current_leaves = tail_leaves if body_leaves else head_leaves
1766         current_leaves.append(leaf)
1767         if current_leaves is head_leaves:
1768             if leaf.type in OPENING_BRACKETS:
1769                 matching_bracket = leaf
1770                 current_leaves = body_leaves
1771     # Since body is a new indent level, remove spurious leading whitespace.
1772     if body_leaves:
1773         normalize_prefix(body_leaves[0], inside_brackets=True)
1774     # Build the new lines.
1775     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1776         for leaf in leaves:
1777             result.append(leaf, preformatted=True)
1778             for comment_after in line.comments_after(leaf):
1779                 result.append(comment_after, preformatted=True)
1780     bracket_split_succeeded_or_raise(head, body, tail)
1781     for result in (head, body, tail):
1782         if result:
1783             yield result
1784
1785
1786 def right_hand_split(
1787     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1788 ) -> Iterator[Line]:
1789     """Split line into many lines, starting with the last matching bracket pair."""
1790     head = Line(depth=line.depth)
1791     body = Line(depth=line.depth + 1, inside_brackets=True)
1792     tail = Line(depth=line.depth)
1793     tail_leaves: List[Leaf] = []
1794     body_leaves: List[Leaf] = []
1795     head_leaves: List[Leaf] = []
1796     current_leaves = tail_leaves
1797     opening_bracket = None
1798     closing_bracket = None
1799     for leaf in reversed(line.leaves):
1800         if current_leaves is body_leaves:
1801             if leaf is opening_bracket:
1802                 current_leaves = head_leaves if body_leaves else tail_leaves
1803         current_leaves.append(leaf)
1804         if current_leaves is tail_leaves:
1805             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1806                 opening_bracket = leaf.opening_bracket
1807                 closing_bracket = leaf
1808                 current_leaves = body_leaves
1809     tail_leaves.reverse()
1810     body_leaves.reverse()
1811     head_leaves.reverse()
1812     # Since body is a new indent level, remove spurious leading whitespace.
1813     if body_leaves:
1814         normalize_prefix(body_leaves[0], inside_brackets=True)
1815     elif not head_leaves:
1816         # No `head` and no `body` means the split failed. `tail` has all content.
1817         raise CannotSplit("No brackets found")
1818
1819     # Build the new lines.
1820     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1821         for leaf in leaves:
1822             result.append(leaf, preformatted=True)
1823             for comment_after in line.comments_after(leaf):
1824                 result.append(comment_after, preformatted=True)
1825     bracket_split_succeeded_or_raise(head, body, tail)
1826     assert opening_bracket and closing_bracket
1827     if (
1828         opening_bracket.type == token.LPAR
1829         and not opening_bracket.value
1830         and closing_bracket.type == token.RPAR
1831         and not closing_bracket.value
1832     ):
1833         # These parens were optional. If there aren't any delimiters or standalone
1834         # comments in the body, they were unnecessary and another split without
1835         # them should be attempted.
1836         if not (
1837             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1838         ):
1839             omit = {id(closing_bracket), *omit}
1840             yield from right_hand_split(line, py36=py36, omit=omit)
1841             return
1842
1843     ensure_visible(opening_bracket)
1844     ensure_visible(closing_bracket)
1845     for result in (head, body, tail):
1846         if result:
1847             yield result
1848
1849
1850 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1851     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1852
1853     Do nothing otherwise.
1854
1855     A left- or right-hand split is based on a pair of brackets. Content before
1856     (and including) the opening bracket is left on one line, content inside the
1857     brackets is put on a separate line, and finally content starting with and
1858     following the closing bracket is put on a separate line.
1859
1860     Those are called `head`, `body`, and `tail`, respectively. If the split
1861     produced the same line (all content in `head`) or ended up with an empty `body`
1862     and the `tail` is just the closing bracket, then it's considered failed.
1863     """
1864     tail_len = len(str(tail).strip())
1865     if not body:
1866         if tail_len == 0:
1867             raise CannotSplit("Splitting brackets produced the same line")
1868
1869         elif tail_len < 3:
1870             raise CannotSplit(
1871                 f"Splitting brackets on an empty body to save "
1872                 f"{tail_len} characters is not worth it"
1873             )
1874
1875
1876 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1877     """Normalize prefix of the first leaf in every line returned by `split_func`.
1878
1879     This is a decorator over relevant split functions.
1880     """
1881
1882     @wraps(split_func)
1883     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1884         for l in split_func(line, py36):
1885             normalize_prefix(l.leaves[0], inside_brackets=True)
1886             yield l
1887
1888     return split_wrapper
1889
1890
1891 @dont_increase_indentation
1892 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1893     """Split according to delimiters of the highest priority.
1894
1895     If `py36` is True, the split will add trailing commas also in function
1896     signatures that contain `*` and `**`.
1897     """
1898     try:
1899         last_leaf = line.leaves[-1]
1900     except IndexError:
1901         raise CannotSplit("Line empty")
1902
1903     delimiters = line.bracket_tracker.delimiters
1904     try:
1905         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1906             exclude={id(last_leaf)}
1907         )
1908     except ValueError:
1909         raise CannotSplit("No delimiters found")
1910
1911     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1912     lowest_depth = sys.maxsize
1913     trailing_comma_safe = True
1914
1915     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1916         """Append `leaf` to current line or to new line if appending impossible."""
1917         nonlocal current_line
1918         try:
1919             current_line.append_safe(leaf, preformatted=True)
1920         except ValueError as ve:
1921             yield current_line
1922
1923             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1924             current_line.append(leaf)
1925
1926     for leaf in line.leaves:
1927         yield from append_to_line(leaf)
1928
1929         for comment_after in line.comments_after(leaf):
1930             yield from append_to_line(comment_after)
1931
1932         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1933         if (
1934             leaf.bracket_depth == lowest_depth
1935             and is_vararg(leaf, within=VARARGS_PARENTS)
1936         ):
1937             trailing_comma_safe = trailing_comma_safe and py36
1938         leaf_priority = delimiters.get(id(leaf))
1939         if leaf_priority == delimiter_priority:
1940             yield current_line
1941
1942             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1943     if current_line:
1944         if (
1945             trailing_comma_safe
1946             and delimiter_priority == COMMA_PRIORITY
1947             and current_line.leaves[-1].type != token.COMMA
1948             and current_line.leaves[-1].type != STANDALONE_COMMENT
1949         ):
1950             current_line.append(Leaf(token.COMMA, ","))
1951         yield current_line
1952
1953
1954 @dont_increase_indentation
1955 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1956     """Split standalone comments from the rest of the line."""
1957     if not line.contains_standalone_comments(0):
1958         raise CannotSplit("Line does not have any standalone comments")
1959
1960     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1961
1962     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1963         """Append `leaf` to current line or to new line if appending impossible."""
1964         nonlocal current_line
1965         try:
1966             current_line.append_safe(leaf, preformatted=True)
1967         except ValueError as ve:
1968             yield current_line
1969
1970             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1971             current_line.append(leaf)
1972
1973     for leaf in line.leaves:
1974         yield from append_to_line(leaf)
1975
1976         for comment_after in line.comments_after(leaf):
1977             yield from append_to_line(comment_after)
1978
1979     if current_line:
1980         yield current_line
1981
1982
1983 def is_import(leaf: Leaf) -> bool:
1984     """Return True if the given leaf starts an import statement."""
1985     p = leaf.parent
1986     t = leaf.type
1987     v = leaf.value
1988     return bool(
1989         t == token.NAME
1990         and (
1991             (v == "import" and p and p.type == syms.import_name)
1992             or (v == "from" and p and p.type == syms.import_from)
1993         )
1994     )
1995
1996
1997 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1998     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1999     else.
2000
2001     Note: don't use backslashes for formatting or you'll lose your voting rights.
2002     """
2003     if not inside_brackets:
2004         spl = leaf.prefix.split("#")
2005         if "\\" not in spl[0]:
2006             nl_count = spl[-1].count("\n")
2007             if len(spl) > 1:
2008                 nl_count -= 1
2009             leaf.prefix = "\n" * nl_count
2010             return
2011
2012     leaf.prefix = ""
2013
2014
2015 def normalize_string_quotes(leaf: Leaf) -> None:
2016     """Prefer double quotes but only if it doesn't cause more escaping.
2017
2018     Adds or removes backslashes as appropriate. Doesn't parse and fix
2019     strings nested in f-strings (yet).
2020
2021     Note: Mutates its argument.
2022     """
2023     value = leaf.value.lstrip("furbFURB")
2024     if value[:3] == '"""':
2025         return
2026
2027     elif value[:3] == "'''":
2028         orig_quote = "'''"
2029         new_quote = '"""'
2030     elif value[0] == '"':
2031         orig_quote = '"'
2032         new_quote = "'"
2033     else:
2034         orig_quote = "'"
2035         new_quote = '"'
2036     first_quote_pos = leaf.value.find(orig_quote)
2037     if first_quote_pos == -1:
2038         return  # There's an internal error
2039
2040     prefix = leaf.value[:first_quote_pos]
2041     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2042     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2043     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2044     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2045     if "r" in prefix.casefold():
2046         if unescaped_new_quote.search(body):
2047             # There's at least one unescaped new_quote in this raw string
2048             # so converting is impossible
2049             return
2050
2051         # Do not introduce or remove backslashes in raw strings
2052         new_body = body
2053     else:
2054         # remove unnecessary quotes
2055         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2056         if body != new_body:
2057             # Consider the string without unnecessary quotes as the original
2058             body = new_body
2059             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2060         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2061         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2062     if new_quote == '"""' and new_body[-1] == '"':
2063         # edge case:
2064         new_body = new_body[:-1] + '\\"'
2065     orig_escape_count = body.count("\\")
2066     new_escape_count = new_body.count("\\")
2067     if new_escape_count > orig_escape_count:
2068         return  # Do not introduce more escaping
2069
2070     if new_escape_count == orig_escape_count and orig_quote == '"':
2071         return  # Prefer double quotes
2072
2073     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2074
2075
2076 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2077     """Make existing optional parentheses invisible or create new ones.
2078
2079     Standardizes on visible parentheses for single-element tuples, and keeps
2080     existing visible parentheses for other tuples and generator expressions.
2081     """
2082     check_lpar = False
2083     for child in list(node.children):
2084         if check_lpar:
2085             if child.type == syms.atom:
2086                 if not (
2087                     is_empty_tuple(child)
2088                     or is_one_tuple(child)
2089                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2090                 ):
2091                     first = child.children[0]
2092                     last = child.children[-1]
2093                     if first.type == token.LPAR and last.type == token.RPAR:
2094                         # make parentheses invisible
2095                         first.value = ""  # type: ignore
2096                         last.value = ""  # type: ignore
2097             elif is_one_tuple(child):
2098                 # wrap child in visible parentheses
2099                 lpar = Leaf(token.LPAR, "(")
2100                 rpar = Leaf(token.RPAR, ")")
2101                 index = child.remove() or 0
2102                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2103             else:
2104                 # wrap child in invisible parentheses
2105                 lpar = Leaf(token.LPAR, "")
2106                 rpar = Leaf(token.RPAR, "")
2107                 index = child.remove() or 0
2108                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2109
2110         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2111
2112
2113 def is_empty_tuple(node: LN) -> bool:
2114     """Return True if `node` holds an empty tuple."""
2115     return (
2116         node.type == syms.atom
2117         and len(node.children) == 2
2118         and node.children[0].type == token.LPAR
2119         and node.children[1].type == token.RPAR
2120     )
2121
2122
2123 def is_one_tuple(node: LN) -> bool:
2124     """Return True if `node` holds a tuple with one element, with or without parens."""
2125     if node.type == syms.atom:
2126         if len(node.children) != 3:
2127             return False
2128
2129         lpar, gexp, rpar = node.children
2130         if not (
2131             lpar.type == token.LPAR
2132             and gexp.type == syms.testlist_gexp
2133             and rpar.type == token.RPAR
2134         ):
2135             return False
2136
2137         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2138
2139     return (
2140         node.type in IMPLICIT_TUPLE
2141         and len(node.children) == 2
2142         and node.children[1].type == token.COMMA
2143     )
2144
2145
2146 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2147     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2148
2149     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2150     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2151     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2152     generalizations (PEP 448).
2153     """
2154     if leaf.type not in STARS or not leaf.parent:
2155         return False
2156
2157     p = leaf.parent
2158     if p.type == syms.star_expr:
2159         # Star expressions are also used as assignment targets in extended
2160         # iterable unpacking (PEP 3132).  See what its parent is instead.
2161         if not p.parent:
2162             return False
2163
2164         p = p.parent
2165
2166     return p.type in within
2167
2168
2169 def max_delimiter_priority_in_atom(node: LN) -> int:
2170     """Return maximum delimiter priority inside `node`.
2171
2172     This is specific to atoms with contents contained in a pair of parentheses.
2173     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2174     """
2175     if node.type != syms.atom:
2176         return 0
2177
2178     first = node.children[0]
2179     last = node.children[-1]
2180     if not (first.type == token.LPAR and last.type == token.RPAR):
2181         return 0
2182
2183     bt = BracketTracker()
2184     for c in node.children[1:-1]:
2185         if isinstance(c, Leaf):
2186             bt.mark(c)
2187         else:
2188             for leaf in c.leaves():
2189                 bt.mark(leaf)
2190     try:
2191         return bt.max_delimiter_priority()
2192
2193     except ValueError:
2194         return 0
2195
2196
2197 def ensure_visible(leaf: Leaf) -> None:
2198     """Make sure parentheses are visible.
2199
2200     They could be invisible as part of some statements (see
2201     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2202     """
2203     if leaf.type == token.LPAR:
2204         leaf.value = "("
2205     elif leaf.type == token.RPAR:
2206         leaf.value = ")"
2207
2208
2209 def is_python36(node: Node) -> bool:
2210     """Return True if the current file is using Python 3.6+ features.
2211
2212     Currently looking for:
2213     - f-strings; and
2214     - trailing commas after * or ** in function signatures.
2215     """
2216     for n in node.pre_order():
2217         if n.type == token.STRING:
2218             value_head = n.value[:2]  # type: ignore
2219             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2220                 return True
2221
2222         elif (
2223             n.type == syms.typedargslist
2224             and n.children
2225             and n.children[-1].type == token.COMMA
2226         ):
2227             for ch in n.children:
2228                 if ch.type in STARS:
2229                     return True
2230
2231     return False
2232
2233
2234 PYTHON_EXTENSIONS = {".py"}
2235 BLACKLISTED_DIRECTORIES = {
2236     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2237 }
2238
2239
2240 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2241     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2242     and have one of the PYTHON_EXTENSIONS.
2243     """
2244     for child in path.iterdir():
2245         if child.is_dir():
2246             if child.name in BLACKLISTED_DIRECTORIES:
2247                 continue
2248
2249             yield from gen_python_files_in_dir(child)
2250
2251         elif child.suffix in PYTHON_EXTENSIONS:
2252             yield child
2253
2254
2255 @dataclass
2256 class Report:
2257     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2258     check: bool = False
2259     quiet: bool = False
2260     change_count: int = 0
2261     same_count: int = 0
2262     failure_count: int = 0
2263
2264     def done(self, src: Path, changed: Changed) -> None:
2265         """Increment the counter for successful reformatting. Write out a message."""
2266         if changed is Changed.YES:
2267             reformatted = "would reformat" if self.check else "reformatted"
2268             if not self.quiet:
2269                 out(f"{reformatted} {src}")
2270             self.change_count += 1
2271         else:
2272             if not self.quiet:
2273                 if changed is Changed.NO:
2274                     msg = f"{src} already well formatted, good job."
2275                 else:
2276                     msg = f"{src} wasn't modified on disk since last run."
2277                 out(msg, bold=False)
2278             self.same_count += 1
2279
2280     def failed(self, src: Path, message: str) -> None:
2281         """Increment the counter for failed reformatting. Write out a message."""
2282         err(f"error: cannot format {src}: {message}")
2283         self.failure_count += 1
2284
2285     @property
2286     def return_code(self) -> int:
2287         """Return the exit code that the app should use.
2288
2289         This considers the current state of changed files and failures:
2290         - if there were any failures, return 123;
2291         - if any files were changed and --check is being used, return 1;
2292         - otherwise return 0.
2293         """
2294         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2295         # 126 we have special returncodes reserved by the shell.
2296         if self.failure_count:
2297             return 123
2298
2299         elif self.change_count and self.check:
2300             return 1
2301
2302         return 0
2303
2304     def __str__(self) -> str:
2305         """Render a color report of the current state.
2306
2307         Use `click.unstyle` to remove colors.
2308         """
2309         if self.check:
2310             reformatted = "would be reformatted"
2311             unchanged = "would be left unchanged"
2312             failed = "would fail to reformat"
2313         else:
2314             reformatted = "reformatted"
2315             unchanged = "left unchanged"
2316             failed = "failed to reformat"
2317         report = []
2318         if self.change_count:
2319             s = "s" if self.change_count > 1 else ""
2320             report.append(
2321                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2322             )
2323         if self.same_count:
2324             s = "s" if self.same_count > 1 else ""
2325             report.append(f"{self.same_count} file{s} {unchanged}")
2326         if self.failure_count:
2327             s = "s" if self.failure_count > 1 else ""
2328             report.append(
2329                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2330             )
2331         return ", ".join(report) + "."
2332
2333
2334 def assert_equivalent(src: str, dst: str) -> None:
2335     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2336
2337     import ast
2338     import traceback
2339
2340     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2341         """Simple visitor generating strings to compare ASTs by content."""
2342         yield f"{'  ' * depth}{node.__class__.__name__}("
2343
2344         for field in sorted(node._fields):
2345             try:
2346                 value = getattr(node, field)
2347             except AttributeError:
2348                 continue
2349
2350             yield f"{'  ' * (depth+1)}{field}="
2351
2352             if isinstance(value, list):
2353                 for item in value:
2354                     if isinstance(item, ast.AST):
2355                         yield from _v(item, depth + 2)
2356
2357             elif isinstance(value, ast.AST):
2358                 yield from _v(value, depth + 2)
2359
2360             else:
2361                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2362
2363         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2364
2365     try:
2366         src_ast = ast.parse(src)
2367     except Exception as exc:
2368         major, minor = sys.version_info[:2]
2369         raise AssertionError(
2370             f"cannot use --safe with this file; failed to parse source file "
2371             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2372             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2373         )
2374
2375     try:
2376         dst_ast = ast.parse(dst)
2377     except Exception as exc:
2378         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2379         raise AssertionError(
2380             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2381             f"Please report a bug on https://github.com/ambv/black/issues.  "
2382             f"This invalid output might be helpful: {log}"
2383         ) from None
2384
2385     src_ast_str = "\n".join(_v(src_ast))
2386     dst_ast_str = "\n".join(_v(dst_ast))
2387     if src_ast_str != dst_ast_str:
2388         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2389         raise AssertionError(
2390             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2391             f"the source.  "
2392             f"Please report a bug on https://github.com/ambv/black/issues.  "
2393             f"This diff might be helpful: {log}"
2394         ) from None
2395
2396
2397 def assert_stable(src: str, dst: str, line_length: int) -> None:
2398     """Raise AssertionError if `dst` reformats differently the second time."""
2399     newdst = format_str(dst, line_length=line_length)
2400     if dst != newdst:
2401         log = dump_to_file(
2402             diff(src, dst, "source", "first pass"),
2403             diff(dst, newdst, "first pass", "second pass"),
2404         )
2405         raise AssertionError(
2406             f"INTERNAL ERROR: Black produced different code on the second pass "
2407             f"of the formatter.  "
2408             f"Please report a bug on https://github.com/ambv/black/issues.  "
2409             f"This diff might be helpful: {log}"
2410         ) from None
2411
2412
2413 def dump_to_file(*output: str) -> str:
2414     """Dump `output` to a temporary file. Return path to the file."""
2415     import tempfile
2416
2417     with tempfile.NamedTemporaryFile(
2418         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2419     ) as f:
2420         for lines in output:
2421             f.write(lines)
2422             if lines and lines[-1] != "\n":
2423                 f.write("\n")
2424     return f.name
2425
2426
2427 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2428     """Return a unified diff string between strings `a` and `b`."""
2429     import difflib
2430
2431     a_lines = [line + "\n" for line in a.split("\n")]
2432     b_lines = [line + "\n" for line in b.split("\n")]
2433     return "".join(
2434         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2435     )
2436
2437
2438 def cancel(tasks: List[asyncio.Task]) -> None:
2439     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2440     err("Aborted!")
2441     for task in tasks:
2442         task.cancel()
2443
2444
2445 def shutdown(loop: BaseEventLoop) -> None:
2446     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2447     try:
2448         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2449         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2450         if not to_cancel:
2451             return
2452
2453         for task in to_cancel:
2454             task.cancel()
2455         loop.run_until_complete(
2456             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2457         )
2458     finally:
2459         # `concurrent.futures.Future` objects cannot be cancelled once they
2460         # are already running. There might be some when the `shutdown()` happened.
2461         # Silence their logger's spew about the event loop being closed.
2462         cf_logger = logging.getLogger("concurrent.futures")
2463         cf_logger.setLevel(logging.CRITICAL)
2464         loop.close()
2465
2466
2467 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2468     """Replace `regex` with `replacement` twice on `original`.
2469
2470     This is used by string normalization to perform replaces on
2471     overlapping matches.
2472     """
2473     return regex.sub(replacement, regex.sub(replacement, original))
2474
2475
2476 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2477 CACHE_FILE = CACHE_DIR / "cache.pickle"
2478
2479
2480 def read_cache() -> Cache:
2481     """Read the cache if it exists and is well formed.
2482
2483     If it is not well formed, the call to write_cache later should resolve the issue.
2484     """
2485     if not CACHE_FILE.exists():
2486         return {}
2487
2488     with CACHE_FILE.open("rb") as fobj:
2489         try:
2490             cache: Cache = pickle.load(fobj)
2491         except pickle.UnpicklingError:
2492             return {}
2493
2494     return cache
2495
2496
2497 def get_cache_info(path: Path) -> CacheInfo:
2498     """Return the information used to check if a file is already formatted or not."""
2499     stat = path.stat()
2500     return stat.st_mtime, stat.st_size
2501
2502
2503 def filter_cached(
2504     cache: Cache, sources: Iterable[Path]
2505 ) -> Tuple[List[Path], List[Path]]:
2506     """Split a list of paths into two.
2507
2508     The first list contains paths of files that modified on disk or are not in the
2509     cache. The other list contains paths to non-modified files.
2510     """
2511     todo, done = [], []
2512     for src in sources:
2513         src = src.resolve()
2514         if cache.get(src) != get_cache_info(src):
2515             todo.append(src)
2516         else:
2517             done.append(src)
2518     return todo, done
2519
2520
2521 def write_cache(cache: Cache, sources: List[Path]) -> None:
2522     """Update the cache file."""
2523     try:
2524         if not CACHE_DIR.exists():
2525             CACHE_DIR.mkdir(parents=True)
2526         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2527         with CACHE_FILE.open("wb") as fobj:
2528             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2529     except OSError:
2530         pass
2531
2532
2533 if __name__ == "__main__":
2534     main()