]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Support nested lambdas in BracketTracker
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a6"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.VBAR,
557     token.CIRCUMFLEX,
558     token.AMPER,
559     token.LEFTSHIFT,
560     token.RIGHTSHIFT,
561     token.PLUS,
562     token.MINUS,
563     token.STAR,
564     token.SLASH,
565     token.DOUBLESLASH,
566     token.PERCENT,
567     token.AT,
568     token.TILDE,
569     token.DOUBLESTAR,
570 }
571 STARS = {token.STAR, token.DOUBLESTAR}
572 VARARGS_PARENTS = {
573     syms.arglist,
574     syms.argument,  # double star in arglist
575     syms.trailer,  # single argument to call
576     syms.typedargslist,
577     syms.varargslist,  # lambdas
578 }
579 UNPACKING_PARENTS = {
580     syms.atom,  # single element of a list or set literal
581     syms.dictsetmaker,
582     syms.listmaker,
583     syms.testlist_gexp,
584 }
585 TEST_DESCENDANTS = {
586     syms.test,
587     syms.lambdef,
588     syms.or_test,
589     syms.and_test,
590     syms.not_test,
591     syms.comparison,
592     syms.star_expr,
593     syms.expr,
594     syms.xor_expr,
595     syms.and_expr,
596     syms.shift_expr,
597     syms.arith_expr,
598     syms.trailer,
599     syms.term,
600     syms.power,
601 }
602 ASSIGNMENTS = {
603     "=",
604     "+=",
605     "-=",
606     "*=",
607     "@=",
608     "/=",
609     "%=",
610     "&=",
611     "|=",
612     "^=",
613     "<<=",
614     ">>=",
615     "**=",
616     "//=",
617 }
618 COMPREHENSION_PRIORITY = 20
619 COMMA_PRIORITY = 18
620 TERNARY_PRIORITY = 16
621 LOGIC_PRIORITY = 14
622 STRING_PRIORITY = 12
623 COMPARATOR_PRIORITY = 10
624 MATH_PRIORITIES = {
625     token.VBAR: 8,
626     token.CIRCUMFLEX: 7,
627     token.AMPER: 6,
628     token.LEFTSHIFT: 5,
629     token.RIGHTSHIFT: 5,
630     token.PLUS: 4,
631     token.MINUS: 4,
632     token.STAR: 3,
633     token.SLASH: 3,
634     token.DOUBLESLASH: 3,
635     token.PERCENT: 3,
636     token.AT: 3,
637     token.TILDE: 2,
638     token.DOUBLESTAR: 1,
639 }
640
641
642 @dataclass
643 class BracketTracker:
644     """Keeps track of brackets on a line."""
645
646     depth: int = 0
647     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
648     delimiters: Dict[LeafID, Priority] = Factory(dict)
649     previous: Optional[Leaf] = None
650     _for_loop_variable: int = 0
651     _lambda_arguments: int = 0
652
653     def mark(self, leaf: Leaf) -> None:
654         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
655
656         All leaves receive an int `bracket_depth` field that stores how deep
657         within brackets a given leaf is. 0 means there are no enclosing brackets
658         that started on this line.
659
660         If a leaf is itself a closing bracket, it receives an `opening_bracket`
661         field that it forms a pair with. This is a one-directional link to
662         avoid reference cycles.
663
664         If a leaf is a delimiter (a token on which Black can split the line if
665         needed) and it's on depth 0, its `id()` is stored in the tracker's
666         `delimiters` field.
667         """
668         if leaf.type == token.COMMENT:
669             return
670
671         self.maybe_decrement_after_for_loop_variable(leaf)
672         self.maybe_decrement_after_lambda_arguments(leaf)
673         if leaf.type in CLOSING_BRACKETS:
674             self.depth -= 1
675             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
676             leaf.opening_bracket = opening_bracket
677         leaf.bracket_depth = self.depth
678         if self.depth == 0:
679             delim = is_split_before_delimiter(leaf, self.previous)
680             if delim and self.previous is not None:
681                 self.delimiters[id(self.previous)] = delim
682             else:
683                 delim = is_split_after_delimiter(leaf, self.previous)
684                 if delim:
685                     self.delimiters[id(leaf)] = delim
686         if leaf.type in OPENING_BRACKETS:
687             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
688             self.depth += 1
689         self.previous = leaf
690         self.maybe_increment_lambda_arguments(leaf)
691         self.maybe_increment_for_loop_variable(leaf)
692
693     def any_open_brackets(self) -> bool:
694         """Return True if there is an yet unmatched open bracket on the line."""
695         return bool(self.bracket_match)
696
697     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
698         """Return the highest priority of a delimiter found on the line.
699
700         Values are consistent with what `is_split_*_delimiter()` return.
701         Raises ValueError on no delimiters.
702         """
703         return max(v for k, v in self.delimiters.items() if k not in exclude)
704
705     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
706         """In a for loop, or comprehension, the variables are often unpacks.
707
708         To avoid splitting on the comma in this situation, increase the depth of
709         tokens between `for` and `in`.
710         """
711         if leaf.type == token.NAME and leaf.value == "for":
712             self.depth += 1
713             self._for_loop_variable += 1
714             return True
715
716         return False
717
718     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
719         """See `maybe_increment_for_loop_variable` above for explanation."""
720         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
721             self.depth -= 1
722             self._for_loop_variable -= 1
723             return True
724
725         return False
726
727     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
728         """In a lambda expression, there might be more than one argument.
729
730         To avoid splitting on the comma in this situation, increase the depth of
731         tokens between `lambda` and `:`.
732         """
733         if leaf.type == token.NAME and leaf.value == "lambda":
734             self.depth += 1
735             self._lambda_arguments += 1
736             return True
737
738         return False
739
740     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
741         """See `maybe_increment_lambda_arguments` above for explanation."""
742         if self._lambda_arguments and leaf.type == token.COLON:
743             self.depth -= 1
744             self._lambda_arguments -= 1
745             return True
746
747         return False
748
749     def get_open_lsqb(self) -> Optional[Leaf]:
750         """Return the most recent opening square bracket (if any)."""
751         return self.bracket_match.get((self.depth - 1, token.RSQB))
752
753
754 @dataclass
755 class Line:
756     """Holds leaves and comments. Can be printed with `str(line)`."""
757
758     depth: int = 0
759     leaves: List[Leaf] = Factory(list)
760     comments: List[Tuple[Index, Leaf]] = Factory(list)
761     bracket_tracker: BracketTracker = Factory(BracketTracker)
762     inside_brackets: bool = False
763
764     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
765         """Add a new `leaf` to the end of the line.
766
767         Unless `preformatted` is True, the `leaf` will receive a new consistent
768         whitespace prefix and metadata applied by :class:`BracketTracker`.
769         Trailing commas are maybe removed, unpacked for loop variables are
770         demoted from being delimiters.
771
772         Inline comments are put aside.
773         """
774         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
775         if not has_value:
776             return
777
778         if token.COLON == leaf.type and self.is_class_paren_empty:
779             del self.leaves[-2:]
780         if self.leaves and not preformatted:
781             # Note: at this point leaf.prefix should be empty except for
782             # imports, for which we only preserve newlines.
783             leaf.prefix += whitespace(
784                 leaf, complex_subscript=self.is_complex_subscript(leaf)
785             )
786         if self.inside_brackets or not preformatted:
787             self.bracket_tracker.mark(leaf)
788             self.maybe_remove_trailing_comma(leaf)
789         if not self.append_comment(leaf):
790             self.leaves.append(leaf)
791
792     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
793         """Like :func:`append()` but disallow invalid standalone comment structure.
794
795         Raises ValueError when any `leaf` is appended after a standalone comment
796         or when a standalone comment is not the first leaf on the line.
797         """
798         if self.bracket_tracker.depth == 0:
799             if self.is_comment:
800                 raise ValueError("cannot append to standalone comments")
801
802             if self.leaves and leaf.type == STANDALONE_COMMENT:
803                 raise ValueError(
804                     "cannot append standalone comments to a populated line"
805                 )
806
807         self.append(leaf, preformatted=preformatted)
808
809     @property
810     def is_comment(self) -> bool:
811         """Is this line a standalone comment?"""
812         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
813
814     @property
815     def is_decorator(self) -> bool:
816         """Is this line a decorator?"""
817         return bool(self) and self.leaves[0].type == token.AT
818
819     @property
820     def is_import(self) -> bool:
821         """Is this an import line?"""
822         return bool(self) and is_import(self.leaves[0])
823
824     @property
825     def is_class(self) -> bool:
826         """Is this line a class definition?"""
827         return (
828             bool(self)
829             and self.leaves[0].type == token.NAME
830             and self.leaves[0].value == "class"
831         )
832
833     @property
834     def is_def(self) -> bool:
835         """Is this a function definition? (Also returns True for async defs.)"""
836         try:
837             first_leaf = self.leaves[0]
838         except IndexError:
839             return False
840
841         try:
842             second_leaf: Optional[Leaf] = self.leaves[1]
843         except IndexError:
844             second_leaf = None
845         return (
846             (first_leaf.type == token.NAME and first_leaf.value == "def")
847             or (
848                 first_leaf.type == token.ASYNC
849                 and second_leaf is not None
850                 and second_leaf.type == token.NAME
851                 and second_leaf.value == "def"
852             )
853         )
854
855     @property
856     def is_flow_control(self) -> bool:
857         """Is this line a flow control statement?
858
859         Those are `return`, `raise`, `break`, and `continue`.
860         """
861         return (
862             bool(self)
863             and self.leaves[0].type == token.NAME
864             and self.leaves[0].value in FLOW_CONTROL
865         )
866
867     @property
868     def is_yield(self) -> bool:
869         """Is this line a yield statement?"""
870         return (
871             bool(self)
872             and self.leaves[0].type == token.NAME
873             and self.leaves[0].value == "yield"
874         )
875
876     @property
877     def is_class_paren_empty(self) -> bool:
878         """Is this a class with no base classes but using parentheses?
879
880         Those are unnecessary and should be removed.
881         """
882         return (
883             bool(self)
884             and len(self.leaves) == 4
885             and self.is_class
886             and self.leaves[2].type == token.LPAR
887             and self.leaves[2].value == "("
888             and self.leaves[3].type == token.RPAR
889             and self.leaves[3].value == ")"
890         )
891
892     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
893         """If so, needs to be split before emitting."""
894         for leaf in self.leaves:
895             if leaf.type == STANDALONE_COMMENT:
896                 if leaf.bracket_depth <= depth_limit:
897                     return True
898
899         return False
900
901     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
902         """Remove trailing comma if there is one and it's safe."""
903         if not (
904             self.leaves
905             and self.leaves[-1].type == token.COMMA
906             and closing.type in CLOSING_BRACKETS
907         ):
908             return False
909
910         if closing.type == token.RBRACE:
911             self.remove_trailing_comma()
912             return True
913
914         if closing.type == token.RSQB:
915             comma = self.leaves[-1]
916             if comma.parent and comma.parent.type == syms.listmaker:
917                 self.remove_trailing_comma()
918                 return True
919
920         # For parens let's check if it's safe to remove the comma.
921         # Imports are always safe.
922         if self.is_import:
923             self.remove_trailing_comma()
924             return True
925
926         # Otheriwsse, if the trailing one is the only one, we might mistakenly
927         # change a tuple into a different type by removing the comma.
928         depth = closing.bracket_depth + 1
929         commas = 0
930         opening = closing.opening_bracket
931         for _opening_index, leaf in enumerate(self.leaves):
932             if leaf is opening:
933                 break
934
935         else:
936             return False
937
938         for leaf in self.leaves[_opening_index + 1 :]:
939             if leaf is closing:
940                 break
941
942             bracket_depth = leaf.bracket_depth
943             if bracket_depth == depth and leaf.type == token.COMMA:
944                 commas += 1
945                 if leaf.parent and leaf.parent.type == syms.arglist:
946                     commas += 1
947                     break
948
949         if commas > 1:
950             self.remove_trailing_comma()
951             return True
952
953         return False
954
955     def append_comment(self, comment: Leaf) -> bool:
956         """Add an inline or standalone comment to the line."""
957         if (
958             comment.type == STANDALONE_COMMENT
959             and self.bracket_tracker.any_open_brackets()
960         ):
961             comment.prefix = ""
962             return False
963
964         if comment.type != token.COMMENT:
965             return False
966
967         after = len(self.leaves) - 1
968         if after == -1:
969             comment.type = STANDALONE_COMMENT
970             comment.prefix = ""
971             return False
972
973         else:
974             self.comments.append((after, comment))
975             return True
976
977     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
978         """Generate comments that should appear directly after `leaf`."""
979         for _leaf_index, _leaf in enumerate(self.leaves):
980             if leaf is _leaf:
981                 break
982
983         else:
984             return
985
986         for index, comment_after in self.comments:
987             if _leaf_index == index:
988                 yield comment_after
989
990     def remove_trailing_comma(self) -> None:
991         """Remove the trailing comma and moves the comments attached to it."""
992         comma_index = len(self.leaves) - 1
993         for i in range(len(self.comments)):
994             comment_index, comment = self.comments[i]
995             if comment_index == comma_index:
996                 self.comments[i] = (comma_index - 1, comment)
997         self.leaves.pop()
998
999     def is_complex_subscript(self, leaf: Leaf) -> bool:
1000         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1001         open_lsqb = (
1002             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1003         )
1004         if open_lsqb is None:
1005             return False
1006
1007         subscript_start = open_lsqb.next_sibling
1008         if (
1009             isinstance(subscript_start, Node)
1010             and subscript_start.type == syms.subscriptlist
1011         ):
1012             subscript_start = child_towards(subscript_start, leaf)
1013         return (
1014             subscript_start is not None
1015             and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
1016         )
1017
1018     def __str__(self) -> str:
1019         """Render the line."""
1020         if not self:
1021             return "\n"
1022
1023         indent = "    " * self.depth
1024         leaves = iter(self.leaves)
1025         first = next(leaves)
1026         res = f"{first.prefix}{indent}{first.value}"
1027         for leaf in leaves:
1028             res += str(leaf)
1029         for _, comment in self.comments:
1030             res += str(comment)
1031         return res + "\n"
1032
1033     def __bool__(self) -> bool:
1034         """Return True if the line has leaves or comments."""
1035         return bool(self.leaves or self.comments)
1036
1037
1038 class UnformattedLines(Line):
1039     """Just like :class:`Line` but stores lines which aren't reformatted."""
1040
1041     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1042         """Just add a new `leaf` to the end of the lines.
1043
1044         The `preformatted` argument is ignored.
1045
1046         Keeps track of indentation `depth`, which is useful when the user
1047         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1048         """
1049         try:
1050             list(generate_comments(leaf))
1051         except FormatOn as f_on:
1052             self.leaves.append(f_on.leaf_from_consumed(leaf))
1053             raise
1054
1055         self.leaves.append(leaf)
1056         if leaf.type == token.INDENT:
1057             self.depth += 1
1058         elif leaf.type == token.DEDENT:
1059             self.depth -= 1
1060
1061     def __str__(self) -> str:
1062         """Render unformatted lines from leaves which were added with `append()`.
1063
1064         `depth` is not used for indentation in this case.
1065         """
1066         if not self:
1067             return "\n"
1068
1069         res = ""
1070         for leaf in self.leaves:
1071             res += str(leaf)
1072         return res
1073
1074     def append_comment(self, comment: Leaf) -> bool:
1075         """Not implemented in this class. Raises `NotImplementedError`."""
1076         raise NotImplementedError("Unformatted lines don't store comments separately.")
1077
1078     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1079         """Does nothing and returns False."""
1080         return False
1081
1082     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1083         """Does nothing and returns False."""
1084         return False
1085
1086
1087 @dataclass
1088 class EmptyLineTracker:
1089     """Provides a stateful method that returns the number of potential extra
1090     empty lines needed before and after the currently processed line.
1091
1092     Note: this tracker works on lines that haven't been split yet.  It assumes
1093     the prefix of the first leaf consists of optional newlines.  Those newlines
1094     are consumed by `maybe_empty_lines()` and included in the computation.
1095     """
1096     previous_line: Optional[Line] = None
1097     previous_after: int = 0
1098     previous_defs: List[int] = Factory(list)
1099
1100     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1101         """Return the number of extra empty lines before and after the `current_line`.
1102
1103         This is for separating `def`, `async def` and `class` with extra empty
1104         lines (two on module-level), as well as providing an extra empty line
1105         after flow control keywords to make them more prominent.
1106         """
1107         if isinstance(current_line, UnformattedLines):
1108             return 0, 0
1109
1110         before, after = self._maybe_empty_lines(current_line)
1111         before -= self.previous_after
1112         self.previous_after = after
1113         self.previous_line = current_line
1114         return before, after
1115
1116     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1117         max_allowed = 1
1118         if current_line.depth == 0:
1119             max_allowed = 2
1120         if current_line.leaves:
1121             # Consume the first leaf's extra newlines.
1122             first_leaf = current_line.leaves[0]
1123             before = first_leaf.prefix.count("\n")
1124             before = min(before, max_allowed)
1125             first_leaf.prefix = ""
1126         else:
1127             before = 0
1128         depth = current_line.depth
1129         while self.previous_defs and self.previous_defs[-1] >= depth:
1130             self.previous_defs.pop()
1131             before = 1 if depth else 2
1132         is_decorator = current_line.is_decorator
1133         if is_decorator or current_line.is_def or current_line.is_class:
1134             if not is_decorator:
1135                 self.previous_defs.append(depth)
1136             if self.previous_line is None:
1137                 # Don't insert empty lines before the first line in the file.
1138                 return 0, 0
1139
1140             if self.previous_line.is_decorator:
1141                 return 0, 0
1142
1143             if (
1144                 self.previous_line.is_comment
1145                 and self.previous_line.depth == current_line.depth
1146                 and before == 0
1147             ):
1148                 return 0, 0
1149
1150             newlines = 2
1151             if current_line.depth:
1152                 newlines -= 1
1153             return newlines, 0
1154
1155         if (
1156             self.previous_line
1157             and self.previous_line.is_import
1158             and not current_line.is_import
1159             and depth == self.previous_line.depth
1160         ):
1161             return (before or 1), 0
1162
1163         return before, 0
1164
1165
1166 @dataclass
1167 class LineGenerator(Visitor[Line]):
1168     """Generates reformatted Line objects.  Empty lines are not emitted.
1169
1170     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1171     in ways that will no longer stringify to valid Python code on the tree.
1172     """
1173     current_line: Line = Factory(Line)
1174
1175     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1176         """Generate a line.
1177
1178         If the line is empty, only emit if it makes sense.
1179         If the line is too long, split it first and then generate.
1180
1181         If any lines were generated, set up a new current_line.
1182         """
1183         if not self.current_line:
1184             if self.current_line.__class__ == type:
1185                 self.current_line.depth += indent
1186             else:
1187                 self.current_line = type(depth=self.current_line.depth + indent)
1188             return  # Line is empty, don't emit. Creating a new one unnecessary.
1189
1190         complete_line = self.current_line
1191         self.current_line = type(depth=complete_line.depth + indent)
1192         yield complete_line
1193
1194     def visit(self, node: LN) -> Iterator[Line]:
1195         """Main method to visit `node` and its children.
1196
1197         Yields :class:`Line` objects.
1198         """
1199         if isinstance(self.current_line, UnformattedLines):
1200             # File contained `# fmt: off`
1201             yield from self.visit_unformatted(node)
1202
1203         else:
1204             yield from super().visit(node)
1205
1206     def visit_default(self, node: LN) -> Iterator[Line]:
1207         """Default `visit_*()` implementation. Recurses to children of `node`."""
1208         if isinstance(node, Leaf):
1209             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1210             try:
1211                 for comment in generate_comments(node):
1212                     if any_open_brackets:
1213                         # any comment within brackets is subject to splitting
1214                         self.current_line.append(comment)
1215                     elif comment.type == token.COMMENT:
1216                         # regular trailing comment
1217                         self.current_line.append(comment)
1218                         yield from self.line()
1219
1220                     else:
1221                         # regular standalone comment
1222                         yield from self.line()
1223
1224                         self.current_line.append(comment)
1225                         yield from self.line()
1226
1227             except FormatOff as f_off:
1228                 f_off.trim_prefix(node)
1229                 yield from self.line(type=UnformattedLines)
1230                 yield from self.visit(node)
1231
1232             except FormatOn as f_on:
1233                 # This only happens here if somebody says "fmt: on" multiple
1234                 # times in a row.
1235                 f_on.trim_prefix(node)
1236                 yield from self.visit_default(node)
1237
1238             else:
1239                 normalize_prefix(node, inside_brackets=any_open_brackets)
1240                 if node.type == token.STRING:
1241                     normalize_string_quotes(node)
1242                 if node.type not in WHITESPACE:
1243                     self.current_line.append(node)
1244         yield from super().visit_default(node)
1245
1246     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1247         """Increase indentation level, maybe yield a line."""
1248         # In blib2to3 INDENT never holds comments.
1249         yield from self.line(+1)
1250         yield from self.visit_default(node)
1251
1252     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1253         """Decrease indentation level, maybe yield a line."""
1254         # The current line might still wait for trailing comments.  At DEDENT time
1255         # there won't be any (they would be prefixes on the preceding NEWLINE).
1256         # Emit the line then.
1257         yield from self.line()
1258
1259         # While DEDENT has no value, its prefix may contain standalone comments
1260         # that belong to the current indentation level.  Get 'em.
1261         yield from self.visit_default(node)
1262
1263         # Finally, emit the dedent.
1264         yield from self.line(-1)
1265
1266     def visit_stmt(
1267         self, node: Node, keywords: Set[str], parens: Set[str]
1268     ) -> Iterator[Line]:
1269         """Visit a statement.
1270
1271         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1272         `def`, `with`, `class`, `assert` and assignments.
1273
1274         The relevant Python language `keywords` for a given statement will be
1275         NAME leaves within it. This methods puts those on a separate line.
1276
1277         `parens` holds a set of string leaf values immeditely after which
1278         invisible parens should be put.
1279         """
1280         normalize_invisible_parens(node, parens_after=parens)
1281         for child in node.children:
1282             if child.type == token.NAME and child.value in keywords:  # type: ignore
1283                 yield from self.line()
1284
1285             yield from self.visit(child)
1286
1287     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1288         """Visit a statement without nested statements."""
1289         is_suite_like = node.parent and node.parent.type in STATEMENT
1290         if is_suite_like:
1291             yield from self.line(+1)
1292             yield from self.visit_default(node)
1293             yield from self.line(-1)
1294
1295         else:
1296             yield from self.line()
1297             yield from self.visit_default(node)
1298
1299     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1300         """Visit `async def`, `async for`, `async with`."""
1301         yield from self.line()
1302
1303         children = iter(node.children)
1304         for child in children:
1305             yield from self.visit(child)
1306
1307             if child.type == token.ASYNC:
1308                 break
1309
1310         internal_stmt = next(children)
1311         for child in internal_stmt.children:
1312             yield from self.visit(child)
1313
1314     def visit_decorators(self, node: Node) -> Iterator[Line]:
1315         """Visit decorators."""
1316         for child in node.children:
1317             yield from self.line()
1318             yield from self.visit(child)
1319
1320     def visit_import_from(self, node: Node) -> Iterator[Line]:
1321         """Visit import_from and maybe put invisible parentheses.
1322
1323         This is separate from `visit_stmt` because import statements don't
1324         support arbitrary atoms and thus handling of parentheses is custom.
1325         """
1326         check_lpar = False
1327         for index, child in enumerate(node.children):
1328             if check_lpar:
1329                 if child.type == token.LPAR:
1330                     # make parentheses invisible
1331                     child.value = ""  # type: ignore
1332                     node.children[-1].value = ""  # type: ignore
1333                 else:
1334                     # insert invisible parentheses
1335                     node.insert_child(index, Leaf(token.LPAR, ""))
1336                     node.append_child(Leaf(token.RPAR, ""))
1337                 break
1338
1339             check_lpar = (
1340                 child.type == token.NAME and child.value == "import"  # type: ignore
1341             )
1342
1343         for child in node.children:
1344             yield from self.visit(child)
1345
1346     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1347         """Remove a semicolon and put the other statement on a separate line."""
1348         yield from self.line()
1349
1350     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1351         """End of file. Process outstanding comments and end with a newline."""
1352         yield from self.visit_default(leaf)
1353         yield from self.line()
1354
1355     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1356         """Used when file contained a `# fmt: off`."""
1357         if isinstance(node, Node):
1358             for child in node.children:
1359                 yield from self.visit(child)
1360
1361         else:
1362             try:
1363                 self.current_line.append(node)
1364             except FormatOn as f_on:
1365                 f_on.trim_prefix(node)
1366                 yield from self.line()
1367                 yield from self.visit(node)
1368
1369             if node.type == token.ENDMARKER:
1370                 # somebody decided not to put a final `# fmt: on`
1371                 yield from self.line()
1372
1373     def __attrs_post_init__(self) -> None:
1374         """You are in a twisty little maze of passages."""
1375         v = self.visit_stmt
1376         Ø: Set[str] = set()
1377         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1378         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1379         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1380         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1381         self.visit_try_stmt = partial(
1382             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1383         )
1384         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1385         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1386         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1387         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1388         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1389         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1390         self.visit_async_funcdef = self.visit_async_stmt
1391         self.visit_decorated = self.visit_decorators
1392
1393
1394 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1395 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1396 OPENING_BRACKETS = set(BRACKET.keys())
1397 CLOSING_BRACKETS = set(BRACKET.values())
1398 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1399 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1400
1401
1402 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1403     """Return whitespace prefix if needed for the given `leaf`.
1404
1405     `complex_subscript` signals whether the given leaf is part of a subscription
1406     which has non-trivial arguments, like arithmetic expressions or function calls.
1407     """
1408     NO = ""
1409     SPACE = " "
1410     DOUBLESPACE = "  "
1411     t = leaf.type
1412     p = leaf.parent
1413     v = leaf.value
1414     if t in ALWAYS_NO_SPACE:
1415         return NO
1416
1417     if t == token.COMMENT:
1418         return DOUBLESPACE
1419
1420     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1421     if (
1422         t == token.COLON
1423         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1424     ):
1425         return NO
1426
1427     prev = leaf.prev_sibling
1428     if not prev:
1429         prevp = preceding_leaf(p)
1430         if not prevp or prevp.type in OPENING_BRACKETS:
1431             return NO
1432
1433         if t == token.COLON:
1434             if prevp.type == token.COLON:
1435                 return NO
1436
1437             elif prevp.type != token.COMMA and not complex_subscript:
1438                 return NO
1439
1440             return SPACE
1441
1442         if prevp.type == token.EQUAL:
1443             if prevp.parent:
1444                 if prevp.parent.type in {
1445                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1446                 }:
1447                     return NO
1448
1449                 elif prevp.parent.type == syms.typedargslist:
1450                     # A bit hacky: if the equal sign has whitespace, it means we
1451                     # previously found it's a typed argument.  So, we're using
1452                     # that, too.
1453                     return prevp.prefix
1454
1455         elif prevp.type in STARS:
1456             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1457                 return NO
1458
1459         elif prevp.type == token.COLON:
1460             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1461                 return SPACE if complex_subscript else NO
1462
1463         elif (
1464             prevp.parent
1465             and prevp.parent.type == syms.factor
1466             and prevp.type in MATH_OPERATORS
1467         ):
1468             return NO
1469
1470         elif (
1471             prevp.type == token.RIGHTSHIFT
1472             and prevp.parent
1473             and prevp.parent.type == syms.shift_expr
1474             and prevp.prev_sibling
1475             and prevp.prev_sibling.type == token.NAME
1476             and prevp.prev_sibling.value == "print"  # type: ignore
1477         ):
1478             # Python 2 print chevron
1479             return NO
1480
1481     elif prev.type in OPENING_BRACKETS:
1482         return NO
1483
1484     if p.type in {syms.parameters, syms.arglist}:
1485         # untyped function signatures or calls
1486         if not prev or prev.type != token.COMMA:
1487             return NO
1488
1489     elif p.type == syms.varargslist:
1490         # lambdas
1491         if prev and prev.type != token.COMMA:
1492             return NO
1493
1494     elif p.type == syms.typedargslist:
1495         # typed function signatures
1496         if not prev:
1497             return NO
1498
1499         if t == token.EQUAL:
1500             if prev.type != syms.tname:
1501                 return NO
1502
1503         elif prev.type == token.EQUAL:
1504             # A bit hacky: if the equal sign has whitespace, it means we
1505             # previously found it's a typed argument.  So, we're using that, too.
1506             return prev.prefix
1507
1508         elif prev.type != token.COMMA:
1509             return NO
1510
1511     elif p.type == syms.tname:
1512         # type names
1513         if not prev:
1514             prevp = preceding_leaf(p)
1515             if not prevp or prevp.type != token.COMMA:
1516                 return NO
1517
1518     elif p.type == syms.trailer:
1519         # attributes and calls
1520         if t == token.LPAR or t == token.RPAR:
1521             return NO
1522
1523         if not prev:
1524             if t == token.DOT:
1525                 prevp = preceding_leaf(p)
1526                 if not prevp or prevp.type != token.NUMBER:
1527                     return NO
1528
1529             elif t == token.LSQB:
1530                 return NO
1531
1532         elif prev.type != token.COMMA:
1533             return NO
1534
1535     elif p.type == syms.argument:
1536         # single argument
1537         if t == token.EQUAL:
1538             return NO
1539
1540         if not prev:
1541             prevp = preceding_leaf(p)
1542             if not prevp or prevp.type == token.LPAR:
1543                 return NO
1544
1545         elif prev.type in {token.EQUAL} | STARS:
1546             return NO
1547
1548     elif p.type == syms.decorator:
1549         # decorators
1550         return NO
1551
1552     elif p.type == syms.dotted_name:
1553         if prev:
1554             return NO
1555
1556         prevp = preceding_leaf(p)
1557         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1558             return NO
1559
1560     elif p.type == syms.classdef:
1561         if t == token.LPAR:
1562             return NO
1563
1564         if prev and prev.type == token.LPAR:
1565             return NO
1566
1567     elif p.type in {syms.subscript, syms.sliceop}:
1568         # indexing
1569         if not prev:
1570             assert p.parent is not None, "subscripts are always parented"
1571             if p.parent.type == syms.subscriptlist:
1572                 return SPACE
1573
1574             return NO
1575
1576         elif not complex_subscript:
1577             return NO
1578
1579     elif p.type == syms.atom:
1580         if prev and t == token.DOT:
1581             # dots, but not the first one.
1582             return NO
1583
1584     elif p.type == syms.dictsetmaker:
1585         # dict unpacking
1586         if prev and prev.type == token.DOUBLESTAR:
1587             return NO
1588
1589     elif p.type in {syms.factor, syms.star_expr}:
1590         # unary ops
1591         if not prev:
1592             prevp = preceding_leaf(p)
1593             if not prevp or prevp.type in OPENING_BRACKETS:
1594                 return NO
1595
1596             prevp_parent = prevp.parent
1597             assert prevp_parent is not None
1598             if (
1599                 prevp.type == token.COLON
1600                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1601             ):
1602                 return NO
1603
1604             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1605                 return NO
1606
1607         elif t == token.NAME or t == token.NUMBER:
1608             return NO
1609
1610     elif p.type == syms.import_from:
1611         if t == token.DOT:
1612             if prev and prev.type == token.DOT:
1613                 return NO
1614
1615         elif t == token.NAME:
1616             if v == "import":
1617                 return SPACE
1618
1619             if prev and prev.type == token.DOT:
1620                 return NO
1621
1622     elif p.type == syms.sliceop:
1623         return NO
1624
1625     return SPACE
1626
1627
1628 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1629     """Return the first leaf that precedes `node`, if any."""
1630     while node:
1631         res = node.prev_sibling
1632         if res:
1633             if isinstance(res, Leaf):
1634                 return res
1635
1636             try:
1637                 return list(res.leaves())[-1]
1638
1639             except IndexError:
1640                 return None
1641
1642         node = node.parent
1643     return None
1644
1645
1646 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1647     """Return the child of `ancestor` that contains `descendant`."""
1648     node: Optional[LN] = descendant
1649     while node and node.parent != ancestor:
1650         node = node.parent
1651     return node
1652
1653
1654 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1655     """Return the priority of the `leaf` delimiter, given a line break after it.
1656
1657     The delimiter priorities returned here are from those delimiters that would
1658     cause a line break after themselves.
1659
1660     Higher numbers are higher priority.
1661     """
1662     if leaf.type == token.COMMA:
1663         return COMMA_PRIORITY
1664
1665     return 0
1666
1667
1668 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1669     """Return the priority of the `leaf` delimiter, given a line before after it.
1670
1671     The delimiter priorities returned here are from those delimiters that would
1672     cause a line break before themselves.
1673
1674     Higher numbers are higher priority.
1675     """
1676     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1677         # * and ** might also be MATH_OPERATORS but in this case they are not.
1678         # Don't treat them as a delimiter.
1679         return 0
1680
1681     if (
1682         leaf.type in MATH_OPERATORS
1683         and leaf.parent
1684         and leaf.parent.type not in {syms.factor, syms.star_expr}
1685     ):
1686         return MATH_PRIORITIES[leaf.type]
1687
1688     if leaf.type in COMPARATORS:
1689         return COMPARATOR_PRIORITY
1690
1691     if (
1692         leaf.type == token.STRING
1693         and previous is not None
1694         and previous.type == token.STRING
1695     ):
1696         return STRING_PRIORITY
1697
1698     if (
1699         leaf.type == token.NAME
1700         and leaf.value == "for"
1701         and leaf.parent
1702         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1703     ):
1704         return COMPREHENSION_PRIORITY
1705
1706     if (
1707         leaf.type == token.NAME
1708         and leaf.value == "if"
1709         and leaf.parent
1710         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1711     ):
1712         return COMPREHENSION_PRIORITY
1713
1714     if (
1715         leaf.type == token.NAME
1716         and leaf.value in {"if", "else"}
1717         and leaf.parent
1718         and leaf.parent.type == syms.test
1719     ):
1720         return TERNARY_PRIORITY
1721
1722     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1723         return LOGIC_PRIORITY
1724
1725     return 0
1726
1727
1728 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1729     """Clean the prefix of the `leaf` and generate comments from it, if any.
1730
1731     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1732     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1733     move because it does away with modifying the grammar to include all the
1734     possible places in which comments can be placed.
1735
1736     The sad consequence for us though is that comments don't "belong" anywhere.
1737     This is why this function generates simple parentless Leaf objects for
1738     comments.  We simply don't know what the correct parent should be.
1739
1740     No matter though, we can live without this.  We really only need to
1741     differentiate between inline and standalone comments.  The latter don't
1742     share the line with any code.
1743
1744     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1745     are emitted with a fake STANDALONE_COMMENT token identifier.
1746     """
1747     p = leaf.prefix
1748     if not p:
1749         return
1750
1751     if "#" not in p:
1752         return
1753
1754     consumed = 0
1755     nlines = 0
1756     for index, line in enumerate(p.split("\n")):
1757         consumed += len(line) + 1  # adding the length of the split '\n'
1758         line = line.lstrip()
1759         if not line:
1760             nlines += 1
1761         if not line.startswith("#"):
1762             continue
1763
1764         if index == 0 and leaf.type != token.ENDMARKER:
1765             comment_type = token.COMMENT  # simple trailing comment
1766         else:
1767             comment_type = STANDALONE_COMMENT
1768         comment = make_comment(line)
1769         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1770
1771         if comment in {"# fmt: on", "# yapf: enable"}:
1772             raise FormatOn(consumed)
1773
1774         if comment in {"# fmt: off", "# yapf: disable"}:
1775             if comment_type == STANDALONE_COMMENT:
1776                 raise FormatOff(consumed)
1777
1778             prev = preceding_leaf(leaf)
1779             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1780                 raise FormatOff(consumed)
1781
1782         nlines = 0
1783
1784
1785 def make_comment(content: str) -> str:
1786     """Return a consistently formatted comment from the given `content` string.
1787
1788     All comments (except for "##", "#!", "#:") should have a single space between
1789     the hash sign and the content.
1790
1791     If `content` didn't start with a hash sign, one is provided.
1792     """
1793     content = content.rstrip()
1794     if not content:
1795         return "#"
1796
1797     if content[0] == "#":
1798         content = content[1:]
1799     if content and content[0] not in " !:#":
1800         content = " " + content
1801     return "#" + content
1802
1803
1804 def split_line(
1805     line: Line, line_length: int, inner: bool = False, py36: bool = False
1806 ) -> Iterator[Line]:
1807     """Split a `line` into potentially many lines.
1808
1809     They should fit in the allotted `line_length` but might not be able to.
1810     `inner` signifies that there were a pair of brackets somewhere around the
1811     current `line`, possibly transitively. This means we can fallback to splitting
1812     by delimiters if the LHS/RHS don't yield any results.
1813
1814     If `py36` is True, splitting may generate syntax that is only compatible
1815     with Python 3.6 and later.
1816     """
1817     if isinstance(line, UnformattedLines) or line.is_comment:
1818         yield line
1819         return
1820
1821     line_str = str(line).strip("\n")
1822     if (
1823         len(line_str) <= line_length
1824         and "\n" not in line_str  # multiline strings
1825         and not line.contains_standalone_comments()
1826     ):
1827         yield line
1828         return
1829
1830     split_funcs: List[SplitFunc]
1831     if line.is_def:
1832         split_funcs = [left_hand_split]
1833     elif line.is_import:
1834         split_funcs = [explode_split]
1835     elif line.inside_brackets:
1836         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1837     else:
1838         split_funcs = [right_hand_split]
1839     for split_func in split_funcs:
1840         # We are accumulating lines in `result` because we might want to abort
1841         # mission and return the original line in the end, or attempt a different
1842         # split altogether.
1843         result: List[Line] = []
1844         try:
1845             for l in split_func(line, py36):
1846                 if str(l).strip("\n") == line_str:
1847                     raise CannotSplit("Split function returned an unchanged result")
1848
1849                 result.extend(
1850                     split_line(l, line_length=line_length, inner=True, py36=py36)
1851                 )
1852         except CannotSplit as cs:
1853             continue
1854
1855         else:
1856             yield from result
1857             break
1858
1859     else:
1860         yield line
1861
1862
1863 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1864     """Split line into many lines, starting with the first matching bracket pair.
1865
1866     Note: this usually looks weird, only use this for function definitions.
1867     Prefer RHS otherwise.  This is why this function is not symmetrical with
1868     :func:`right_hand_split` which also handles optional parentheses.
1869     """
1870     head = Line(depth=line.depth)
1871     body = Line(depth=line.depth + 1, inside_brackets=True)
1872     tail = Line(depth=line.depth)
1873     tail_leaves: List[Leaf] = []
1874     body_leaves: List[Leaf] = []
1875     head_leaves: List[Leaf] = []
1876     current_leaves = head_leaves
1877     matching_bracket = None
1878     for leaf in line.leaves:
1879         if (
1880             current_leaves is body_leaves
1881             and leaf.type in CLOSING_BRACKETS
1882             and leaf.opening_bracket is matching_bracket
1883         ):
1884             current_leaves = tail_leaves if body_leaves else head_leaves
1885         current_leaves.append(leaf)
1886         if current_leaves is head_leaves:
1887             if leaf.type in OPENING_BRACKETS:
1888                 matching_bracket = leaf
1889                 current_leaves = body_leaves
1890     # Since body is a new indent level, remove spurious leading whitespace.
1891     if body_leaves:
1892         normalize_prefix(body_leaves[0], inside_brackets=True)
1893     # Build the new lines.
1894     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1895         for leaf in leaves:
1896             result.append(leaf, preformatted=True)
1897             for comment_after in line.comments_after(leaf):
1898                 result.append(comment_after, preformatted=True)
1899     bracket_split_succeeded_or_raise(head, body, tail)
1900     for result in (head, body, tail):
1901         if result:
1902             yield result
1903
1904
1905 def right_hand_split(
1906     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1907 ) -> Iterator[Line]:
1908     """Split line into many lines, starting with the last matching bracket pair.
1909
1910     If the split was by optional parentheses, attempt splitting without them, too.
1911     """
1912     head = Line(depth=line.depth)
1913     body = Line(depth=line.depth + 1, inside_brackets=True)
1914     tail = Line(depth=line.depth)
1915     tail_leaves: List[Leaf] = []
1916     body_leaves: List[Leaf] = []
1917     head_leaves: List[Leaf] = []
1918     current_leaves = tail_leaves
1919     opening_bracket = None
1920     closing_bracket = None
1921     for leaf in reversed(line.leaves):
1922         if current_leaves is body_leaves:
1923             if leaf is opening_bracket:
1924                 current_leaves = head_leaves if body_leaves else tail_leaves
1925         current_leaves.append(leaf)
1926         if current_leaves is tail_leaves:
1927             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1928                 opening_bracket = leaf.opening_bracket
1929                 closing_bracket = leaf
1930                 current_leaves = body_leaves
1931     tail_leaves.reverse()
1932     body_leaves.reverse()
1933     head_leaves.reverse()
1934     # Since body is a new indent level, remove spurious leading whitespace.
1935     if body_leaves:
1936         normalize_prefix(body_leaves[0], inside_brackets=True)
1937     elif not head_leaves:
1938         # No `head` and no `body` means the split failed. `tail` has all content.
1939         raise CannotSplit("No brackets found")
1940
1941     # Build the new lines.
1942     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1943         for leaf in leaves:
1944             result.append(leaf, preformatted=True)
1945             for comment_after in line.comments_after(leaf):
1946                 result.append(comment_after, preformatted=True)
1947     bracket_split_succeeded_or_raise(head, body, tail)
1948     assert opening_bracket and closing_bracket
1949     if (
1950         # the opening bracket is an optional paren
1951         opening_bracket.type == token.LPAR
1952         and not opening_bracket.value
1953         # the closing bracket is an optional paren
1954         and closing_bracket.type == token.RPAR
1955         and not closing_bracket.value
1956         # there are no delimiters or standalone comments in the body
1957         and not body.bracket_tracker.delimiters
1958         and not line.contains_standalone_comments(0)
1959         # and it's not an import (optional parens are the only thing we can split
1960         # on in this case; attempting a split without them is a waste of time)
1961         and not line.is_import
1962     ):
1963         omit = {id(closing_bracket), *omit}
1964         try:
1965             yield from right_hand_split(line, py36=py36, omit=omit)
1966             return
1967         except CannotSplit:
1968             pass
1969
1970     ensure_visible(opening_bracket)
1971     ensure_visible(closing_bracket)
1972     for result in (head, body, tail):
1973         if result:
1974             yield result
1975
1976
1977 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1978     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1979
1980     Do nothing otherwise.
1981
1982     A left- or right-hand split is based on a pair of brackets. Content before
1983     (and including) the opening bracket is left on one line, content inside the
1984     brackets is put on a separate line, and finally content starting with and
1985     following the closing bracket is put on a separate line.
1986
1987     Those are called `head`, `body`, and `tail`, respectively. If the split
1988     produced the same line (all content in `head`) or ended up with an empty `body`
1989     and the `tail` is just the closing bracket, then it's considered failed.
1990     """
1991     tail_len = len(str(tail).strip())
1992     if not body:
1993         if tail_len == 0:
1994             raise CannotSplit("Splitting brackets produced the same line")
1995
1996         elif tail_len < 3:
1997             raise CannotSplit(
1998                 f"Splitting brackets on an empty body to save "
1999                 f"{tail_len} characters is not worth it"
2000             )
2001
2002
2003 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2004     """Normalize prefix of the first leaf in every line returned by `split_func`.
2005
2006     This is a decorator over relevant split functions.
2007     """
2008
2009     @wraps(split_func)
2010     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2011         for l in split_func(line, py36):
2012             normalize_prefix(l.leaves[0], inside_brackets=True)
2013             yield l
2014
2015     return split_wrapper
2016
2017
2018 @dont_increase_indentation
2019 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2020     """Split according to delimiters of the highest priority.
2021
2022     If `py36` is True, the split will add trailing commas also in function
2023     signatures that contain `*` and `**`.
2024     """
2025     try:
2026         last_leaf = line.leaves[-1]
2027     except IndexError:
2028         raise CannotSplit("Line empty")
2029
2030     delimiters = line.bracket_tracker.delimiters
2031     try:
2032         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2033             exclude={id(last_leaf)}
2034         )
2035     except ValueError:
2036         raise CannotSplit("No delimiters found")
2037
2038     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2039     lowest_depth = sys.maxsize
2040     trailing_comma_safe = True
2041
2042     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2043         """Append `leaf` to current line or to new line if appending impossible."""
2044         nonlocal current_line
2045         try:
2046             current_line.append_safe(leaf, preformatted=True)
2047         except ValueError as ve:
2048             yield current_line
2049
2050             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2051             current_line.append(leaf)
2052
2053     for leaf in line.leaves:
2054         yield from append_to_line(leaf)
2055
2056         for comment_after in line.comments_after(leaf):
2057             yield from append_to_line(comment_after)
2058
2059         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2060         if (
2061             leaf.bracket_depth == lowest_depth
2062             and is_vararg(leaf, within=VARARGS_PARENTS)
2063         ):
2064             trailing_comma_safe = trailing_comma_safe and py36
2065         leaf_priority = delimiters.get(id(leaf))
2066         if leaf_priority == delimiter_priority:
2067             yield current_line
2068
2069             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2070     if current_line:
2071         if (
2072             trailing_comma_safe
2073             and delimiter_priority == COMMA_PRIORITY
2074             and current_line.leaves[-1].type != token.COMMA
2075             and current_line.leaves[-1].type != STANDALONE_COMMENT
2076         ):
2077             current_line.append(Leaf(token.COMMA, ","))
2078         yield current_line
2079
2080
2081 @dont_increase_indentation
2082 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2083     """Split standalone comments from the rest of the line."""
2084     if not line.contains_standalone_comments(0):
2085         raise CannotSplit("Line does not have any standalone comments")
2086
2087     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2088
2089     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2090         """Append `leaf` to current line or to new line if appending impossible."""
2091         nonlocal current_line
2092         try:
2093             current_line.append_safe(leaf, preformatted=True)
2094         except ValueError as ve:
2095             yield current_line
2096
2097             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2098             current_line.append(leaf)
2099
2100     for leaf in line.leaves:
2101         yield from append_to_line(leaf)
2102
2103         for comment_after in line.comments_after(leaf):
2104             yield from append_to_line(comment_after)
2105
2106     if current_line:
2107         yield current_line
2108
2109
2110 def explode_split(
2111     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2112 ) -> Iterator[Line]:
2113     """Split by rightmost bracket and immediately split contents by a delimiter."""
2114     new_lines = list(right_hand_split(line, py36, omit))
2115     if len(new_lines) != 3:
2116         yield from new_lines
2117         return
2118
2119     yield new_lines[0]
2120
2121     try:
2122         yield from delimiter_split(new_lines[1], py36)
2123
2124     except CannotSplit:
2125         yield new_lines[1]
2126
2127     yield new_lines[2]
2128
2129
2130 def is_import(leaf: Leaf) -> bool:
2131     """Return True if the given leaf starts an import statement."""
2132     p = leaf.parent
2133     t = leaf.type
2134     v = leaf.value
2135     return bool(
2136         t == token.NAME
2137         and (
2138             (v == "import" and p and p.type == syms.import_name)
2139             or (v == "from" and p and p.type == syms.import_from)
2140         )
2141     )
2142
2143
2144 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2145     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2146     else.
2147
2148     Note: don't use backslashes for formatting or you'll lose your voting rights.
2149     """
2150     if not inside_brackets:
2151         spl = leaf.prefix.split("#")
2152         if "\\" not in spl[0]:
2153             nl_count = spl[-1].count("\n")
2154             if len(spl) > 1:
2155                 nl_count -= 1
2156             leaf.prefix = "\n" * nl_count
2157             return
2158
2159     leaf.prefix = ""
2160
2161
2162 def normalize_string_quotes(leaf: Leaf) -> None:
2163     """Prefer double quotes but only if it doesn't cause more escaping.
2164
2165     Adds or removes backslashes as appropriate. Doesn't parse and fix
2166     strings nested in f-strings (yet).
2167
2168     Note: Mutates its argument.
2169     """
2170     value = leaf.value.lstrip("furbFURB")
2171     if value[:3] == '"""':
2172         return
2173
2174     elif value[:3] == "'''":
2175         orig_quote = "'''"
2176         new_quote = '"""'
2177     elif value[0] == '"':
2178         orig_quote = '"'
2179         new_quote = "'"
2180     else:
2181         orig_quote = "'"
2182         new_quote = '"'
2183     first_quote_pos = leaf.value.find(orig_quote)
2184     if first_quote_pos == -1:
2185         return  # There's an internal error
2186
2187     prefix = leaf.value[:first_quote_pos]
2188     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2189     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2190     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2191     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2192     if "r" in prefix.casefold():
2193         if unescaped_new_quote.search(body):
2194             # There's at least one unescaped new_quote in this raw string
2195             # so converting is impossible
2196             return
2197
2198         # Do not introduce or remove backslashes in raw strings
2199         new_body = body
2200     else:
2201         # remove unnecessary quotes
2202         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2203         if body != new_body:
2204             # Consider the string without unnecessary quotes as the original
2205             body = new_body
2206             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2207         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2208         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2209     if new_quote == '"""' and new_body[-1] == '"':
2210         # edge case:
2211         new_body = new_body[:-1] + '\\"'
2212     orig_escape_count = body.count("\\")
2213     new_escape_count = new_body.count("\\")
2214     if new_escape_count > orig_escape_count:
2215         return  # Do not introduce more escaping
2216
2217     if new_escape_count == orig_escape_count and orig_quote == '"':
2218         return  # Prefer double quotes
2219
2220     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2221
2222
2223 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2224     """Make existing optional parentheses invisible or create new ones.
2225
2226     `parens_after` is a set of string leaf values immeditely after which parens
2227     should be put.
2228
2229     Standardizes on visible parentheses for single-element tuples, and keeps
2230     existing visible parentheses for other tuples and generator expressions.
2231     """
2232     check_lpar = False
2233     for child in list(node.children):
2234         if check_lpar:
2235             if child.type == syms.atom:
2236                 maybe_make_parens_invisible_in_atom(child)
2237             elif is_one_tuple(child):
2238                 # wrap child in visible parentheses
2239                 lpar = Leaf(token.LPAR, "(")
2240                 rpar = Leaf(token.RPAR, ")")
2241                 index = child.remove() or 0
2242                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2243             else:
2244                 # wrap child in invisible parentheses
2245                 lpar = Leaf(token.LPAR, "")
2246                 rpar = Leaf(token.RPAR, "")
2247                 index = child.remove() or 0
2248                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2249
2250         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2251
2252
2253 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2254     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2255     if (
2256         node.type != syms.atom
2257         or is_empty_tuple(node)
2258         or is_one_tuple(node)
2259         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2260     ):
2261         return False
2262
2263     first = node.children[0]
2264     last = node.children[-1]
2265     if first.type == token.LPAR and last.type == token.RPAR:
2266         # make parentheses invisible
2267         first.value = ""  # type: ignore
2268         last.value = ""  # type: ignore
2269         if len(node.children) > 1:
2270             maybe_make_parens_invisible_in_atom(node.children[1])
2271         return True
2272
2273     return False
2274
2275
2276 def is_empty_tuple(node: LN) -> bool:
2277     """Return True if `node` holds an empty tuple."""
2278     return (
2279         node.type == syms.atom
2280         and len(node.children) == 2
2281         and node.children[0].type == token.LPAR
2282         and node.children[1].type == token.RPAR
2283     )
2284
2285
2286 def is_one_tuple(node: LN) -> bool:
2287     """Return True if `node` holds a tuple with one element, with or without parens."""
2288     if node.type == syms.atom:
2289         if len(node.children) != 3:
2290             return False
2291
2292         lpar, gexp, rpar = node.children
2293         if not (
2294             lpar.type == token.LPAR
2295             and gexp.type == syms.testlist_gexp
2296             and rpar.type == token.RPAR
2297         ):
2298             return False
2299
2300         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2301
2302     return (
2303         node.type in IMPLICIT_TUPLE
2304         and len(node.children) == 2
2305         and node.children[1].type == token.COMMA
2306     )
2307
2308
2309 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2310     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2311
2312     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2313     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2314     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2315     generalizations (PEP 448).
2316     """
2317     if leaf.type not in STARS or not leaf.parent:
2318         return False
2319
2320     p = leaf.parent
2321     if p.type == syms.star_expr:
2322         # Star expressions are also used as assignment targets in extended
2323         # iterable unpacking (PEP 3132).  See what its parent is instead.
2324         if not p.parent:
2325             return False
2326
2327         p = p.parent
2328
2329     return p.type in within
2330
2331
2332 def max_delimiter_priority_in_atom(node: LN) -> int:
2333     """Return maximum delimiter priority inside `node`.
2334
2335     This is specific to atoms with contents contained in a pair of parentheses.
2336     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2337     """
2338     if node.type != syms.atom:
2339         return 0
2340
2341     first = node.children[0]
2342     last = node.children[-1]
2343     if not (first.type == token.LPAR and last.type == token.RPAR):
2344         return 0
2345
2346     bt = BracketTracker()
2347     for c in node.children[1:-1]:
2348         if isinstance(c, Leaf):
2349             bt.mark(c)
2350         else:
2351             for leaf in c.leaves():
2352                 bt.mark(leaf)
2353     try:
2354         return bt.max_delimiter_priority()
2355
2356     except ValueError:
2357         return 0
2358
2359
2360 def ensure_visible(leaf: Leaf) -> None:
2361     """Make sure parentheses are visible.
2362
2363     They could be invisible as part of some statements (see
2364     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2365     """
2366     if leaf.type == token.LPAR:
2367         leaf.value = "("
2368     elif leaf.type == token.RPAR:
2369         leaf.value = ")"
2370
2371
2372 def is_python36(node: Node) -> bool:
2373     """Return True if the current file is using Python 3.6+ features.
2374
2375     Currently looking for:
2376     - f-strings; and
2377     - trailing commas after * or ** in function signatures and calls.
2378     """
2379     for n in node.pre_order():
2380         if n.type == token.STRING:
2381             value_head = n.value[:2]  # type: ignore
2382             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2383                 return True
2384
2385         elif (
2386             n.type in {syms.typedargslist, syms.arglist}
2387             and n.children
2388             and n.children[-1].type == token.COMMA
2389         ):
2390             for ch in n.children:
2391                 if ch.type in STARS:
2392                     return True
2393
2394                 if ch.type == syms.argument:
2395                     for argch in ch.children:
2396                         if argch.type in STARS:
2397                             return True
2398
2399     return False
2400
2401
2402 PYTHON_EXTENSIONS = {".py"}
2403 BLACKLISTED_DIRECTORIES = {
2404     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2405 }
2406
2407
2408 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2409     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2410     and have one of the PYTHON_EXTENSIONS.
2411     """
2412     for child in path.iterdir():
2413         if child.is_dir():
2414             if child.name in BLACKLISTED_DIRECTORIES:
2415                 continue
2416
2417             yield from gen_python_files_in_dir(child)
2418
2419         elif child.suffix in PYTHON_EXTENSIONS:
2420             yield child
2421
2422
2423 @dataclass
2424 class Report:
2425     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2426     check: bool = False
2427     quiet: bool = False
2428     change_count: int = 0
2429     same_count: int = 0
2430     failure_count: int = 0
2431
2432     def done(self, src: Path, changed: Changed) -> None:
2433         """Increment the counter for successful reformatting. Write out a message."""
2434         if changed is Changed.YES:
2435             reformatted = "would reformat" if self.check else "reformatted"
2436             if not self.quiet:
2437                 out(f"{reformatted} {src}")
2438             self.change_count += 1
2439         else:
2440             if not self.quiet:
2441                 if changed is Changed.NO:
2442                     msg = f"{src} already well formatted, good job."
2443                 else:
2444                     msg = f"{src} wasn't modified on disk since last run."
2445                 out(msg, bold=False)
2446             self.same_count += 1
2447
2448     def failed(self, src: Path, message: str) -> None:
2449         """Increment the counter for failed reformatting. Write out a message."""
2450         err(f"error: cannot format {src}: {message}")
2451         self.failure_count += 1
2452
2453     @property
2454     def return_code(self) -> int:
2455         """Return the exit code that the app should use.
2456
2457         This considers the current state of changed files and failures:
2458         - if there were any failures, return 123;
2459         - if any files were changed and --check is being used, return 1;
2460         - otherwise return 0.
2461         """
2462         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2463         # 126 we have special returncodes reserved by the shell.
2464         if self.failure_count:
2465             return 123
2466
2467         elif self.change_count and self.check:
2468             return 1
2469
2470         return 0
2471
2472     def __str__(self) -> str:
2473         """Render a color report of the current state.
2474
2475         Use `click.unstyle` to remove colors.
2476         """
2477         if self.check:
2478             reformatted = "would be reformatted"
2479             unchanged = "would be left unchanged"
2480             failed = "would fail to reformat"
2481         else:
2482             reformatted = "reformatted"
2483             unchanged = "left unchanged"
2484             failed = "failed to reformat"
2485         report = []
2486         if self.change_count:
2487             s = "s" if self.change_count > 1 else ""
2488             report.append(
2489                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2490             )
2491         if self.same_count:
2492             s = "s" if self.same_count > 1 else ""
2493             report.append(f"{self.same_count} file{s} {unchanged}")
2494         if self.failure_count:
2495             s = "s" if self.failure_count > 1 else ""
2496             report.append(
2497                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2498             )
2499         return ", ".join(report) + "."
2500
2501
2502 def assert_equivalent(src: str, dst: str) -> None:
2503     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2504
2505     import ast
2506     import traceback
2507
2508     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2509         """Simple visitor generating strings to compare ASTs by content."""
2510         yield f"{'  ' * depth}{node.__class__.__name__}("
2511
2512         for field in sorted(node._fields):
2513             try:
2514                 value = getattr(node, field)
2515             except AttributeError:
2516                 continue
2517
2518             yield f"{'  ' * (depth+1)}{field}="
2519
2520             if isinstance(value, list):
2521                 for item in value:
2522                     if isinstance(item, ast.AST):
2523                         yield from _v(item, depth + 2)
2524
2525             elif isinstance(value, ast.AST):
2526                 yield from _v(value, depth + 2)
2527
2528             else:
2529                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2530
2531         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2532
2533     try:
2534         src_ast = ast.parse(src)
2535     except Exception as exc:
2536         major, minor = sys.version_info[:2]
2537         raise AssertionError(
2538             f"cannot use --safe with this file; failed to parse source file "
2539             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2540             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2541         )
2542
2543     try:
2544         dst_ast = ast.parse(dst)
2545     except Exception as exc:
2546         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2547         raise AssertionError(
2548             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2549             f"Please report a bug on https://github.com/ambv/black/issues.  "
2550             f"This invalid output might be helpful: {log}"
2551         ) from None
2552
2553     src_ast_str = "\n".join(_v(src_ast))
2554     dst_ast_str = "\n".join(_v(dst_ast))
2555     if src_ast_str != dst_ast_str:
2556         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2557         raise AssertionError(
2558             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2559             f"the source.  "
2560             f"Please report a bug on https://github.com/ambv/black/issues.  "
2561             f"This diff might be helpful: {log}"
2562         ) from None
2563
2564
2565 def assert_stable(src: str, dst: str, line_length: int) -> None:
2566     """Raise AssertionError if `dst` reformats differently the second time."""
2567     newdst = format_str(dst, line_length=line_length)
2568     if dst != newdst:
2569         log = dump_to_file(
2570             diff(src, dst, "source", "first pass"),
2571             diff(dst, newdst, "first pass", "second pass"),
2572         )
2573         raise AssertionError(
2574             f"INTERNAL ERROR: Black produced different code on the second pass "
2575             f"of the formatter.  "
2576             f"Please report a bug on https://github.com/ambv/black/issues.  "
2577             f"This diff might be helpful: {log}"
2578         ) from None
2579
2580
2581 def dump_to_file(*output: str) -> str:
2582     """Dump `output` to a temporary file. Return path to the file."""
2583     import tempfile
2584
2585     with tempfile.NamedTemporaryFile(
2586         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2587     ) as f:
2588         for lines in output:
2589             f.write(lines)
2590             if lines and lines[-1] != "\n":
2591                 f.write("\n")
2592     return f.name
2593
2594
2595 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2596     """Return a unified diff string between strings `a` and `b`."""
2597     import difflib
2598
2599     a_lines = [line + "\n" for line in a.split("\n")]
2600     b_lines = [line + "\n" for line in b.split("\n")]
2601     return "".join(
2602         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2603     )
2604
2605
2606 def cancel(tasks: List[asyncio.Task]) -> None:
2607     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2608     err("Aborted!")
2609     for task in tasks:
2610         task.cancel()
2611
2612
2613 def shutdown(loop: BaseEventLoop) -> None:
2614     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2615     try:
2616         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2617         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2618         if not to_cancel:
2619             return
2620
2621         for task in to_cancel:
2622             task.cancel()
2623         loop.run_until_complete(
2624             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2625         )
2626     finally:
2627         # `concurrent.futures.Future` objects cannot be cancelled once they
2628         # are already running. There might be some when the `shutdown()` happened.
2629         # Silence their logger's spew about the event loop being closed.
2630         cf_logger = logging.getLogger("concurrent.futures")
2631         cf_logger.setLevel(logging.CRITICAL)
2632         loop.close()
2633
2634
2635 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2636     """Replace `regex` with `replacement` twice on `original`.
2637
2638     This is used by string normalization to perform replaces on
2639     overlapping matches.
2640     """
2641     return regex.sub(replacement, regex.sub(replacement, original))
2642
2643
2644 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2645
2646
2647 def get_cache_file(line_length: int) -> Path:
2648     return CACHE_DIR / f"cache.{line_length}.pickle"
2649
2650
2651 def read_cache(line_length: int) -> Cache:
2652     """Read the cache if it exists and is well formed.
2653
2654     If it is not well formed, the call to write_cache later should resolve the issue.
2655     """
2656     cache_file = get_cache_file(line_length)
2657     if not cache_file.exists():
2658         return {}
2659
2660     with cache_file.open("rb") as fobj:
2661         try:
2662             cache: Cache = pickle.load(fobj)
2663         except pickle.UnpicklingError:
2664             return {}
2665
2666     return cache
2667
2668
2669 def get_cache_info(path: Path) -> CacheInfo:
2670     """Return the information used to check if a file is already formatted or not."""
2671     stat = path.stat()
2672     return stat.st_mtime, stat.st_size
2673
2674
2675 def filter_cached(
2676     cache: Cache, sources: Iterable[Path]
2677 ) -> Tuple[List[Path], List[Path]]:
2678     """Split a list of paths into two.
2679
2680     The first list contains paths of files that modified on disk or are not in the
2681     cache. The other list contains paths to non-modified files.
2682     """
2683     todo, done = [], []
2684     for src in sources:
2685         src = src.resolve()
2686         if cache.get(src) != get_cache_info(src):
2687             todo.append(src)
2688         else:
2689             done.append(src)
2690     return todo, done
2691
2692
2693 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2694     """Update the cache file."""
2695     cache_file = get_cache_file(line_length)
2696     try:
2697         if not CACHE_DIR.exists():
2698             CACHE_DIR.mkdir(parents=True)
2699         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2700         with cache_file.open("wb") as fobj:
2701             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2702     except OSError:
2703         pass
2704
2705
2706 if __name__ == "__main__":
2707     main()