]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Automatic management of parentheses in `elif`, too
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a6"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     lines = LineGenerator()
413     elt = EmptyLineTracker()
414     py36 = is_python36(src_node)
415     empty_line = Line()
416     after = 0
417     for current_line in lines.visit(src_node):
418         for _ in range(after):
419             dst_contents += str(empty_line)
420         before, after = elt.maybe_empty_lines(current_line)
421         for _ in range(before):
422             dst_contents += str(empty_line)
423         for line in split_line(current_line, line_length=line_length, py36=py36):
424             dst_contents += str(line)
425     return dst_contents
426
427
428 GRAMMARS = [
429     pygram.python_grammar_no_print_statement_no_exec_statement,
430     pygram.python_grammar_no_print_statement,
431     pygram.python_grammar,
432 ]
433
434
435 def lib2to3_parse(src_txt: str) -> Node:
436     """Given a string with source, return the lib2to3 Node."""
437     grammar = pygram.python_grammar_no_print_statement
438     if src_txt[-1] != "\n":
439         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
440         src_txt += nl
441     for grammar in GRAMMARS:
442         drv = driver.Driver(grammar, pytree.convert)
443         try:
444             result = drv.parse_string(src_txt, True)
445             break
446
447         except ParseError as pe:
448             lineno, column = pe.context[1]
449             lines = src_txt.splitlines()
450             try:
451                 faulty_line = lines[lineno - 1]
452             except IndexError:
453                 faulty_line = "<line number missing in source>"
454             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
455     else:
456         raise exc from None
457
458     if isinstance(result, Leaf):
459         result = Node(syms.file_input, [result])
460     return result
461
462
463 def lib2to3_unparse(node: Node) -> str:
464     """Given a lib2to3 node, return its string representation."""
465     code = str(node)
466     return code
467
468
469 T = TypeVar("T")
470
471
472 class Visitor(Generic[T]):
473     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
474
475     def visit(self, node: LN) -> Iterator[T]:
476         """Main method to visit `node` and its children.
477
478         It tries to find a `visit_*()` method for the given `node.type`, like
479         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
480         If no dedicated `visit_*()` method is found, chooses `visit_default()`
481         instead.
482
483         Then yields objects of type `T` from the selected visitor.
484         """
485         if node.type < 256:
486             name = token.tok_name[node.type]
487         else:
488             name = type_repr(node.type)
489         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
490
491     def visit_default(self, node: LN) -> Iterator[T]:
492         """Default `visit_*()` implementation. Recurses to children of `node`."""
493         if isinstance(node, Node):
494             for child in node.children:
495                 yield from self.visit(child)
496
497
498 @dataclass
499 class DebugVisitor(Visitor[T]):
500     tree_depth: int = 0
501
502     def visit_default(self, node: LN) -> Iterator[T]:
503         indent = " " * (2 * self.tree_depth)
504         if isinstance(node, Node):
505             _type = type_repr(node.type)
506             out(f"{indent}{_type}", fg="yellow")
507             self.tree_depth += 1
508             for child in node.children:
509                 yield from self.visit(child)
510
511             self.tree_depth -= 1
512             out(f"{indent}/{_type}", fg="yellow", bold=False)
513         else:
514             _type = token.tok_name.get(node.type, str(node.type))
515             out(f"{indent}{_type}", fg="blue", nl=False)
516             if node.prefix:
517                 # We don't have to handle prefixes for `Node` objects since
518                 # that delegates to the first child anyway.
519                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
520             out(f" {node.value!r}", fg="blue", bold=False)
521
522     @classmethod
523     def show(cls, code: str) -> None:
524         """Pretty-print the lib2to3 AST of a given string of `code`.
525
526         Convenience method for debugging.
527         """
528         v: DebugVisitor[None] = DebugVisitor()
529         list(v.visit(lib2to3_parse(code)))
530
531
532 KEYWORDS = set(keyword.kwlist)
533 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
534 FLOW_CONTROL = {"return", "raise", "break", "continue"}
535 STATEMENT = {
536     syms.if_stmt,
537     syms.while_stmt,
538     syms.for_stmt,
539     syms.try_stmt,
540     syms.except_clause,
541     syms.with_stmt,
542     syms.funcdef,
543     syms.classdef,
544 }
545 STANDALONE_COMMENT = 153
546 LOGIC_OPERATORS = {"and", "or"}
547 COMPARATORS = {
548     token.LESS,
549     token.GREATER,
550     token.EQEQUAL,
551     token.NOTEQUAL,
552     token.LESSEQUAL,
553     token.GREATEREQUAL,
554 }
555 MATH_OPERATORS = {
556     token.VBAR,
557     token.CIRCUMFLEX,
558     token.AMPER,
559     token.LEFTSHIFT,
560     token.RIGHTSHIFT,
561     token.PLUS,
562     token.MINUS,
563     token.STAR,
564     token.SLASH,
565     token.DOUBLESLASH,
566     token.PERCENT,
567     token.AT,
568     token.TILDE,
569     token.DOUBLESTAR,
570 }
571 STARS = {token.STAR, token.DOUBLESTAR}
572 VARARGS_PARENTS = {
573     syms.arglist,
574     syms.argument,  # double star in arglist
575     syms.trailer,  # single argument to call
576     syms.typedargslist,
577     syms.varargslist,  # lambdas
578 }
579 UNPACKING_PARENTS = {
580     syms.atom,  # single element of a list or set literal
581     syms.dictsetmaker,
582     syms.listmaker,
583     syms.testlist_gexp,
584 }
585 TEST_DESCENDANTS = {
586     syms.test,
587     syms.lambdef,
588     syms.or_test,
589     syms.and_test,
590     syms.not_test,
591     syms.comparison,
592     syms.star_expr,
593     syms.expr,
594     syms.xor_expr,
595     syms.and_expr,
596     syms.shift_expr,
597     syms.arith_expr,
598     syms.trailer,
599     syms.term,
600     syms.power,
601 }
602 ASSIGNMENTS = {
603     "=",
604     "+=",
605     "-=",
606     "*=",
607     "@=",
608     "/=",
609     "%=",
610     "&=",
611     "|=",
612     "^=",
613     "<<=",
614     ">>=",
615     "**=",
616     "//=",
617 }
618 COMPREHENSION_PRIORITY = 20
619 COMMA_PRIORITY = 18
620 TERNARY_PRIORITY = 16
621 LOGIC_PRIORITY = 14
622 STRING_PRIORITY = 12
623 COMPARATOR_PRIORITY = 10
624 MATH_PRIORITIES = {
625     token.VBAR: 8,
626     token.CIRCUMFLEX: 7,
627     token.AMPER: 6,
628     token.LEFTSHIFT: 5,
629     token.RIGHTSHIFT: 5,
630     token.PLUS: 4,
631     token.MINUS: 4,
632     token.STAR: 3,
633     token.SLASH: 3,
634     token.DOUBLESLASH: 3,
635     token.PERCENT: 3,
636     token.AT: 3,
637     token.TILDE: 2,
638     token.DOUBLESTAR: 1,
639 }
640
641
642 @dataclass
643 class BracketTracker:
644     """Keeps track of brackets on a line."""
645
646     depth: int = 0
647     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
648     delimiters: Dict[LeafID, Priority] = Factory(dict)
649     previous: Optional[Leaf] = None
650     _for_loop_variable: int = 0
651     _lambda_arguments: int = 0
652
653     def mark(self, leaf: Leaf) -> None:
654         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
655
656         All leaves receive an int `bracket_depth` field that stores how deep
657         within brackets a given leaf is. 0 means there are no enclosing brackets
658         that started on this line.
659
660         If a leaf is itself a closing bracket, it receives an `opening_bracket`
661         field that it forms a pair with. This is a one-directional link to
662         avoid reference cycles.
663
664         If a leaf is a delimiter (a token on which Black can split the line if
665         needed) and it's on depth 0, its `id()` is stored in the tracker's
666         `delimiters` field.
667         """
668         if leaf.type == token.COMMENT:
669             return
670
671         self.maybe_decrement_after_for_loop_variable(leaf)
672         self.maybe_decrement_after_lambda_arguments(leaf)
673         if leaf.type in CLOSING_BRACKETS:
674             self.depth -= 1
675             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
676             leaf.opening_bracket = opening_bracket
677         leaf.bracket_depth = self.depth
678         if self.depth == 0:
679             delim = is_split_before_delimiter(leaf, self.previous)
680             if delim and self.previous is not None:
681                 self.delimiters[id(self.previous)] = delim
682             else:
683                 delim = is_split_after_delimiter(leaf, self.previous)
684                 if delim:
685                     self.delimiters[id(leaf)] = delim
686         if leaf.type in OPENING_BRACKETS:
687             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
688             self.depth += 1
689         self.previous = leaf
690         self.maybe_increment_lambda_arguments(leaf)
691         self.maybe_increment_for_loop_variable(leaf)
692
693     def any_open_brackets(self) -> bool:
694         """Return True if there is an yet unmatched open bracket on the line."""
695         return bool(self.bracket_match)
696
697     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
698         """Return the highest priority of a delimiter found on the line.
699
700         Values are consistent with what `is_split_*_delimiter()` return.
701         Raises ValueError on no delimiters.
702         """
703         return max(v for k, v in self.delimiters.items() if k not in exclude)
704
705     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
706         """In a for loop, or comprehension, the variables are often unpacks.
707
708         To avoid splitting on the comma in this situation, increase the depth of
709         tokens between `for` and `in`.
710         """
711         if leaf.type == token.NAME and leaf.value == "for":
712             self.depth += 1
713             self._for_loop_variable += 1
714             return True
715
716         return False
717
718     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
719         """See `maybe_increment_for_loop_variable` above for explanation."""
720         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
721             self.depth -= 1
722             self._for_loop_variable -= 1
723             return True
724
725         return False
726
727     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
728         """In a lambda expression, there might be more than one argument.
729
730         To avoid splitting on the comma in this situation, increase the depth of
731         tokens between `lambda` and `:`.
732         """
733         if leaf.type == token.NAME and leaf.value == "lambda":
734             self.depth += 1
735             self._lambda_arguments += 1
736             return True
737
738         return False
739
740     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
741         """See `maybe_increment_lambda_arguments` above for explanation."""
742         if self._lambda_arguments and leaf.type == token.COLON:
743             self.depth -= 1
744             self._lambda_arguments -= 1
745             return True
746
747         return False
748
749     def get_open_lsqb(self) -> Optional[Leaf]:
750         """Return the most recent opening square bracket (if any)."""
751         return self.bracket_match.get((self.depth - 1, token.RSQB))
752
753
754 @dataclass
755 class Line:
756     """Holds leaves and comments. Can be printed with `str(line)`."""
757
758     depth: int = 0
759     leaves: List[Leaf] = Factory(list)
760     comments: List[Tuple[Index, Leaf]] = Factory(list)
761     bracket_tracker: BracketTracker = Factory(BracketTracker)
762     inside_brackets: bool = False
763
764     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
765         """Add a new `leaf` to the end of the line.
766
767         Unless `preformatted` is True, the `leaf` will receive a new consistent
768         whitespace prefix and metadata applied by :class:`BracketTracker`.
769         Trailing commas are maybe removed, unpacked for loop variables are
770         demoted from being delimiters.
771
772         Inline comments are put aside.
773         """
774         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
775         if not has_value:
776             return
777
778         if token.COLON == leaf.type and self.is_class_paren_empty:
779             del self.leaves[-2:]
780         if self.leaves and not preformatted:
781             # Note: at this point leaf.prefix should be empty except for
782             # imports, for which we only preserve newlines.
783             leaf.prefix += whitespace(
784                 leaf, complex_subscript=self.is_complex_subscript(leaf)
785             )
786         if self.inside_brackets or not preformatted:
787             self.bracket_tracker.mark(leaf)
788             self.maybe_remove_trailing_comma(leaf)
789         if not self.append_comment(leaf):
790             self.leaves.append(leaf)
791
792     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
793         """Like :func:`append()` but disallow invalid standalone comment structure.
794
795         Raises ValueError when any `leaf` is appended after a standalone comment
796         or when a standalone comment is not the first leaf on the line.
797         """
798         if self.bracket_tracker.depth == 0:
799             if self.is_comment:
800                 raise ValueError("cannot append to standalone comments")
801
802             if self.leaves and leaf.type == STANDALONE_COMMENT:
803                 raise ValueError(
804                     "cannot append standalone comments to a populated line"
805                 )
806
807         self.append(leaf, preformatted=preformatted)
808
809     @property
810     def is_comment(self) -> bool:
811         """Is this line a standalone comment?"""
812         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
813
814     @property
815     def is_decorator(self) -> bool:
816         """Is this line a decorator?"""
817         return bool(self) and self.leaves[0].type == token.AT
818
819     @property
820     def is_import(self) -> bool:
821         """Is this an import line?"""
822         return bool(self) and is_import(self.leaves[0])
823
824     @property
825     def is_class(self) -> bool:
826         """Is this line a class definition?"""
827         return (
828             bool(self)
829             and self.leaves[0].type == token.NAME
830             and self.leaves[0].value == "class"
831         )
832
833     @property
834     def is_def(self) -> bool:
835         """Is this a function definition? (Also returns True for async defs.)"""
836         try:
837             first_leaf = self.leaves[0]
838         except IndexError:
839             return False
840
841         try:
842             second_leaf: Optional[Leaf] = self.leaves[1]
843         except IndexError:
844             second_leaf = None
845         return (
846             (first_leaf.type == token.NAME and first_leaf.value == "def")
847             or (
848                 first_leaf.type == token.ASYNC
849                 and second_leaf is not None
850                 and second_leaf.type == token.NAME
851                 and second_leaf.value == "def"
852             )
853         )
854
855     @property
856     def is_flow_control(self) -> bool:
857         """Is this line a flow control statement?
858
859         Those are `return`, `raise`, `break`, and `continue`.
860         """
861         return (
862             bool(self)
863             and self.leaves[0].type == token.NAME
864             and self.leaves[0].value in FLOW_CONTROL
865         )
866
867     @property
868     def is_yield(self) -> bool:
869         """Is this line a yield statement?"""
870         return (
871             bool(self)
872             and self.leaves[0].type == token.NAME
873             and self.leaves[0].value == "yield"
874         )
875
876     @property
877     def is_class_paren_empty(self) -> bool:
878         """Is this a class with no base classes but using parentheses?
879
880         Those are unnecessary and should be removed.
881         """
882         return (
883             bool(self)
884             and len(self.leaves) == 4
885             and self.is_class
886             and self.leaves[2].type == token.LPAR
887             and self.leaves[2].value == "("
888             and self.leaves[3].type == token.RPAR
889             and self.leaves[3].value == ")"
890         )
891
892     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
893         """If so, needs to be split before emitting."""
894         for leaf in self.leaves:
895             if leaf.type == STANDALONE_COMMENT:
896                 if leaf.bracket_depth <= depth_limit:
897                     return True
898
899         return False
900
901     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
902         """Remove trailing comma if there is one and it's safe."""
903         if not (
904             self.leaves
905             and self.leaves[-1].type == token.COMMA
906             and closing.type in CLOSING_BRACKETS
907         ):
908             return False
909
910         if closing.type == token.RBRACE:
911             self.remove_trailing_comma()
912             return True
913
914         if closing.type == token.RSQB:
915             comma = self.leaves[-1]
916             if comma.parent and comma.parent.type == syms.listmaker:
917                 self.remove_trailing_comma()
918                 return True
919
920         # For parens let's check if it's safe to remove the comma.
921         # Imports are always safe.
922         if self.is_import:
923             self.remove_trailing_comma()
924             return True
925
926         # Otheriwsse, if the trailing one is the only one, we might mistakenly
927         # change a tuple into a different type by removing the comma.
928         depth = closing.bracket_depth + 1
929         commas = 0
930         opening = closing.opening_bracket
931         for _opening_index, leaf in enumerate(self.leaves):
932             if leaf is opening:
933                 break
934
935         else:
936             return False
937
938         for leaf in self.leaves[_opening_index + 1 :]:
939             if leaf is closing:
940                 break
941
942             bracket_depth = leaf.bracket_depth
943             if bracket_depth == depth and leaf.type == token.COMMA:
944                 commas += 1
945                 if leaf.parent and leaf.parent.type == syms.arglist:
946                     commas += 1
947                     break
948
949         if commas > 1:
950             self.remove_trailing_comma()
951             return True
952
953         return False
954
955     def append_comment(self, comment: Leaf) -> bool:
956         """Add an inline or standalone comment to the line."""
957         if (
958             comment.type == STANDALONE_COMMENT
959             and self.bracket_tracker.any_open_brackets()
960         ):
961             comment.prefix = ""
962             return False
963
964         if comment.type != token.COMMENT:
965             return False
966
967         after = len(self.leaves) - 1
968         if after == -1:
969             comment.type = STANDALONE_COMMENT
970             comment.prefix = ""
971             return False
972
973         else:
974             self.comments.append((after, comment))
975             return True
976
977     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
978         """Generate comments that should appear directly after `leaf`."""
979         for _leaf_index, _leaf in enumerate(self.leaves):
980             if leaf is _leaf:
981                 break
982
983         else:
984             return
985
986         for index, comment_after in self.comments:
987             if _leaf_index == index:
988                 yield comment_after
989
990     def remove_trailing_comma(self) -> None:
991         """Remove the trailing comma and moves the comments attached to it."""
992         comma_index = len(self.leaves) - 1
993         for i in range(len(self.comments)):
994             comment_index, comment = self.comments[i]
995             if comment_index == comma_index:
996                 self.comments[i] = (comma_index - 1, comment)
997         self.leaves.pop()
998
999     def is_complex_subscript(self, leaf: Leaf) -> bool:
1000         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1001         open_lsqb = (
1002             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1003         )
1004         if open_lsqb is None:
1005             return False
1006
1007         subscript_start = open_lsqb.next_sibling
1008         if (
1009             isinstance(subscript_start, Node)
1010             and subscript_start.type == syms.subscriptlist
1011         ):
1012             subscript_start = child_towards(subscript_start, leaf)
1013         return (
1014             subscript_start is not None
1015             and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
1016         )
1017
1018     def __str__(self) -> str:
1019         """Render the line."""
1020         if not self:
1021             return "\n"
1022
1023         indent = "    " * self.depth
1024         leaves = iter(self.leaves)
1025         first = next(leaves)
1026         res = f"{first.prefix}{indent}{first.value}"
1027         for leaf in leaves:
1028             res += str(leaf)
1029         for _, comment in self.comments:
1030             res += str(comment)
1031         return res + "\n"
1032
1033     def __bool__(self) -> bool:
1034         """Return True if the line has leaves or comments."""
1035         return bool(self.leaves or self.comments)
1036
1037
1038 class UnformattedLines(Line):
1039     """Just like :class:`Line` but stores lines which aren't reformatted."""
1040
1041     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1042         """Just add a new `leaf` to the end of the lines.
1043
1044         The `preformatted` argument is ignored.
1045
1046         Keeps track of indentation `depth`, which is useful when the user
1047         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1048         """
1049         try:
1050             list(generate_comments(leaf))
1051         except FormatOn as f_on:
1052             self.leaves.append(f_on.leaf_from_consumed(leaf))
1053             raise
1054
1055         self.leaves.append(leaf)
1056         if leaf.type == token.INDENT:
1057             self.depth += 1
1058         elif leaf.type == token.DEDENT:
1059             self.depth -= 1
1060
1061     def __str__(self) -> str:
1062         """Render unformatted lines from leaves which were added with `append()`.
1063
1064         `depth` is not used for indentation in this case.
1065         """
1066         if not self:
1067             return "\n"
1068
1069         res = ""
1070         for leaf in self.leaves:
1071             res += str(leaf)
1072         return res
1073
1074     def append_comment(self, comment: Leaf) -> bool:
1075         """Not implemented in this class. Raises `NotImplementedError`."""
1076         raise NotImplementedError("Unformatted lines don't store comments separately.")
1077
1078     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1079         """Does nothing and returns False."""
1080         return False
1081
1082     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1083         """Does nothing and returns False."""
1084         return False
1085
1086
1087 @dataclass
1088 class EmptyLineTracker:
1089     """Provides a stateful method that returns the number of potential extra
1090     empty lines needed before and after the currently processed line.
1091
1092     Note: this tracker works on lines that haven't been split yet.  It assumes
1093     the prefix of the first leaf consists of optional newlines.  Those newlines
1094     are consumed by `maybe_empty_lines()` and included in the computation.
1095     """
1096     previous_line: Optional[Line] = None
1097     previous_after: int = 0
1098     previous_defs: List[int] = Factory(list)
1099
1100     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1101         """Return the number of extra empty lines before and after the `current_line`.
1102
1103         This is for separating `def`, `async def` and `class` with extra empty
1104         lines (two on module-level), as well as providing an extra empty line
1105         after flow control keywords to make them more prominent.
1106         """
1107         if isinstance(current_line, UnformattedLines):
1108             return 0, 0
1109
1110         before, after = self._maybe_empty_lines(current_line)
1111         before -= self.previous_after
1112         self.previous_after = after
1113         self.previous_line = current_line
1114         return before, after
1115
1116     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1117         max_allowed = 1
1118         if current_line.depth == 0:
1119             max_allowed = 2
1120         if current_line.leaves:
1121             # Consume the first leaf's extra newlines.
1122             first_leaf = current_line.leaves[0]
1123             before = first_leaf.prefix.count("\n")
1124             before = min(before, max_allowed)
1125             first_leaf.prefix = ""
1126         else:
1127             before = 0
1128         depth = current_line.depth
1129         while self.previous_defs and self.previous_defs[-1] >= depth:
1130             self.previous_defs.pop()
1131             before = 1 if depth else 2
1132         is_decorator = current_line.is_decorator
1133         if is_decorator or current_line.is_def or current_line.is_class:
1134             if not is_decorator:
1135                 self.previous_defs.append(depth)
1136             if self.previous_line is None:
1137                 # Don't insert empty lines before the first line in the file.
1138                 return 0, 0
1139
1140             if self.previous_line.is_decorator:
1141                 return 0, 0
1142
1143             if (
1144                 self.previous_line.is_comment
1145                 and self.previous_line.depth == current_line.depth
1146                 and before == 0
1147             ):
1148                 return 0, 0
1149
1150             newlines = 2
1151             if current_line.depth:
1152                 newlines -= 1
1153             return newlines, 0
1154
1155         if (
1156             self.previous_line
1157             and self.previous_line.is_import
1158             and not current_line.is_import
1159             and depth == self.previous_line.depth
1160         ):
1161             return (before or 1), 0
1162
1163         return before, 0
1164
1165
1166 @dataclass
1167 class LineGenerator(Visitor[Line]):
1168     """Generates reformatted Line objects.  Empty lines are not emitted.
1169
1170     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1171     in ways that will no longer stringify to valid Python code on the tree.
1172     """
1173     current_line: Line = Factory(Line)
1174
1175     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1176         """Generate a line.
1177
1178         If the line is empty, only emit if it makes sense.
1179         If the line is too long, split it first and then generate.
1180
1181         If any lines were generated, set up a new current_line.
1182         """
1183         if not self.current_line:
1184             if self.current_line.__class__ == type:
1185                 self.current_line.depth += indent
1186             else:
1187                 self.current_line = type(depth=self.current_line.depth + indent)
1188             return  # Line is empty, don't emit. Creating a new one unnecessary.
1189
1190         complete_line = self.current_line
1191         self.current_line = type(depth=complete_line.depth + indent)
1192         yield complete_line
1193
1194     def visit(self, node: LN) -> Iterator[Line]:
1195         """Main method to visit `node` and its children.
1196
1197         Yields :class:`Line` objects.
1198         """
1199         if isinstance(self.current_line, UnformattedLines):
1200             # File contained `# fmt: off`
1201             yield from self.visit_unformatted(node)
1202
1203         else:
1204             yield from super().visit(node)
1205
1206     def visit_default(self, node: LN) -> Iterator[Line]:
1207         """Default `visit_*()` implementation. Recurses to children of `node`."""
1208         if isinstance(node, Leaf):
1209             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1210             try:
1211                 for comment in generate_comments(node):
1212                     if any_open_brackets:
1213                         # any comment within brackets is subject to splitting
1214                         self.current_line.append(comment)
1215                     elif comment.type == token.COMMENT:
1216                         # regular trailing comment
1217                         self.current_line.append(comment)
1218                         yield from self.line()
1219
1220                     else:
1221                         # regular standalone comment
1222                         yield from self.line()
1223
1224                         self.current_line.append(comment)
1225                         yield from self.line()
1226
1227             except FormatOff as f_off:
1228                 f_off.trim_prefix(node)
1229                 yield from self.line(type=UnformattedLines)
1230                 yield from self.visit(node)
1231
1232             except FormatOn as f_on:
1233                 # This only happens here if somebody says "fmt: on" multiple
1234                 # times in a row.
1235                 f_on.trim_prefix(node)
1236                 yield from self.visit_default(node)
1237
1238             else:
1239                 normalize_prefix(node, inside_brackets=any_open_brackets)
1240                 if node.type == token.STRING:
1241                     normalize_string_quotes(node)
1242                 if node.type not in WHITESPACE:
1243                     self.current_line.append(node)
1244         yield from super().visit_default(node)
1245
1246     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1247         """Increase indentation level, maybe yield a line."""
1248         # In blib2to3 INDENT never holds comments.
1249         yield from self.line(+1)
1250         yield from self.visit_default(node)
1251
1252     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1253         """Decrease indentation level, maybe yield a line."""
1254         # The current line might still wait for trailing comments.  At DEDENT time
1255         # there won't be any (they would be prefixes on the preceding NEWLINE).
1256         # Emit the line then.
1257         yield from self.line()
1258
1259         # While DEDENT has no value, its prefix may contain standalone comments
1260         # that belong to the current indentation level.  Get 'em.
1261         yield from self.visit_default(node)
1262
1263         # Finally, emit the dedent.
1264         yield from self.line(-1)
1265
1266     def visit_stmt(
1267         self, node: Node, keywords: Set[str], parens: Set[str]
1268     ) -> Iterator[Line]:
1269         """Visit a statement.
1270
1271         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1272         `def`, `with`, `class`, `assert` and assignments.
1273
1274         The relevant Python language `keywords` for a given statement will be
1275         NAME leaves within it. This methods puts those on a separate line.
1276
1277         `parens` holds a set of string leaf values immeditely after which
1278         invisible parens should be put.
1279         """
1280         normalize_invisible_parens(node, parens_after=parens)
1281         for child in node.children:
1282             if child.type == token.NAME and child.value in keywords:  # type: ignore
1283                 yield from self.line()
1284
1285             yield from self.visit(child)
1286
1287     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1288         """Visit a statement without nested statements."""
1289         is_suite_like = node.parent and node.parent.type in STATEMENT
1290         if is_suite_like:
1291             yield from self.line(+1)
1292             yield from self.visit_default(node)
1293             yield from self.line(-1)
1294
1295         else:
1296             yield from self.line()
1297             yield from self.visit_default(node)
1298
1299     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1300         """Visit `async def`, `async for`, `async with`."""
1301         yield from self.line()
1302
1303         children = iter(node.children)
1304         for child in children:
1305             yield from self.visit(child)
1306
1307             if child.type == token.ASYNC:
1308                 break
1309
1310         internal_stmt = next(children)
1311         for child in internal_stmt.children:
1312             yield from self.visit(child)
1313
1314     def visit_decorators(self, node: Node) -> Iterator[Line]:
1315         """Visit decorators."""
1316         for child in node.children:
1317             yield from self.line()
1318             yield from self.visit(child)
1319
1320     def visit_import_from(self, node: Node) -> Iterator[Line]:
1321         """Visit import_from and maybe put invisible parentheses.
1322
1323         This is separate from `visit_stmt` because import statements don't
1324         support arbitrary atoms and thus handling of parentheses is custom.
1325         """
1326         check_lpar = False
1327         for index, child in enumerate(node.children):
1328             if check_lpar:
1329                 if child.type == token.LPAR:
1330                     # make parentheses invisible
1331                     child.value = ""  # type: ignore
1332                     node.children[-1].value = ""  # type: ignore
1333                 else:
1334                     # insert invisible parentheses
1335                     node.insert_child(index, Leaf(token.LPAR, ""))
1336                     node.append_child(Leaf(token.RPAR, ""))
1337                 break
1338
1339             check_lpar = (
1340                 child.type == token.NAME and child.value == "import"  # type: ignore
1341             )
1342
1343         for child in node.children:
1344             yield from self.visit(child)
1345
1346     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1347         """Remove a semicolon and put the other statement on a separate line."""
1348         yield from self.line()
1349
1350     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1351         """End of file. Process outstanding comments and end with a newline."""
1352         yield from self.visit_default(leaf)
1353         yield from self.line()
1354
1355     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1356         """Used when file contained a `# fmt: off`."""
1357         if isinstance(node, Node):
1358             for child in node.children:
1359                 yield from self.visit(child)
1360
1361         else:
1362             try:
1363                 self.current_line.append(node)
1364             except FormatOn as f_on:
1365                 f_on.trim_prefix(node)
1366                 yield from self.line()
1367                 yield from self.visit(node)
1368
1369             if node.type == token.ENDMARKER:
1370                 # somebody decided not to put a final `# fmt: on`
1371                 yield from self.line()
1372
1373     def __attrs_post_init__(self) -> None:
1374         """You are in a twisty little maze of passages."""
1375         v = self.visit_stmt
1376         Ø: Set[str] = set()
1377         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1378         self.visit_if_stmt = partial(
1379             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1380         )
1381         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1382         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1383         self.visit_try_stmt = partial(
1384             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1385         )
1386         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1387         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1388         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1389         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1390         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1391         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1392         self.visit_async_funcdef = self.visit_async_stmt
1393         self.visit_decorated = self.visit_decorators
1394
1395
1396 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1397 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1398 OPENING_BRACKETS = set(BRACKET.keys())
1399 CLOSING_BRACKETS = set(BRACKET.values())
1400 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1401 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1402
1403
1404 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1405     """Return whitespace prefix if needed for the given `leaf`.
1406
1407     `complex_subscript` signals whether the given leaf is part of a subscription
1408     which has non-trivial arguments, like arithmetic expressions or function calls.
1409     """
1410     NO = ""
1411     SPACE = " "
1412     DOUBLESPACE = "  "
1413     t = leaf.type
1414     p = leaf.parent
1415     v = leaf.value
1416     if t in ALWAYS_NO_SPACE:
1417         return NO
1418
1419     if t == token.COMMENT:
1420         return DOUBLESPACE
1421
1422     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1423     if (
1424         t == token.COLON
1425         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1426     ):
1427         return NO
1428
1429     prev = leaf.prev_sibling
1430     if not prev:
1431         prevp = preceding_leaf(p)
1432         if not prevp or prevp.type in OPENING_BRACKETS:
1433             return NO
1434
1435         if t == token.COLON:
1436             if prevp.type == token.COLON:
1437                 return NO
1438
1439             elif prevp.type != token.COMMA and not complex_subscript:
1440                 return NO
1441
1442             return SPACE
1443
1444         if prevp.type == token.EQUAL:
1445             if prevp.parent:
1446                 if prevp.parent.type in {
1447                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1448                 }:
1449                     return NO
1450
1451                 elif prevp.parent.type == syms.typedargslist:
1452                     # A bit hacky: if the equal sign has whitespace, it means we
1453                     # previously found it's a typed argument.  So, we're using
1454                     # that, too.
1455                     return prevp.prefix
1456
1457         elif prevp.type in STARS:
1458             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1459                 return NO
1460
1461         elif prevp.type == token.COLON:
1462             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1463                 return SPACE if complex_subscript else NO
1464
1465         elif (
1466             prevp.parent
1467             and prevp.parent.type == syms.factor
1468             and prevp.type in MATH_OPERATORS
1469         ):
1470             return NO
1471
1472         elif (
1473             prevp.type == token.RIGHTSHIFT
1474             and prevp.parent
1475             and prevp.parent.type == syms.shift_expr
1476             and prevp.prev_sibling
1477             and prevp.prev_sibling.type == token.NAME
1478             and prevp.prev_sibling.value == "print"  # type: ignore
1479         ):
1480             # Python 2 print chevron
1481             return NO
1482
1483     elif prev.type in OPENING_BRACKETS:
1484         return NO
1485
1486     if p.type in {syms.parameters, syms.arglist}:
1487         # untyped function signatures or calls
1488         if not prev or prev.type != token.COMMA:
1489             return NO
1490
1491     elif p.type == syms.varargslist:
1492         # lambdas
1493         if prev and prev.type != token.COMMA:
1494             return NO
1495
1496     elif p.type == syms.typedargslist:
1497         # typed function signatures
1498         if not prev:
1499             return NO
1500
1501         if t == token.EQUAL:
1502             if prev.type != syms.tname:
1503                 return NO
1504
1505         elif prev.type == token.EQUAL:
1506             # A bit hacky: if the equal sign has whitespace, it means we
1507             # previously found it's a typed argument.  So, we're using that, too.
1508             return prev.prefix
1509
1510         elif prev.type != token.COMMA:
1511             return NO
1512
1513     elif p.type == syms.tname:
1514         # type names
1515         if not prev:
1516             prevp = preceding_leaf(p)
1517             if not prevp or prevp.type != token.COMMA:
1518                 return NO
1519
1520     elif p.type == syms.trailer:
1521         # attributes and calls
1522         if t == token.LPAR or t == token.RPAR:
1523             return NO
1524
1525         if not prev:
1526             if t == token.DOT:
1527                 prevp = preceding_leaf(p)
1528                 if not prevp or prevp.type != token.NUMBER:
1529                     return NO
1530
1531             elif t == token.LSQB:
1532                 return NO
1533
1534         elif prev.type != token.COMMA:
1535             return NO
1536
1537     elif p.type == syms.argument:
1538         # single argument
1539         if t == token.EQUAL:
1540             return NO
1541
1542         if not prev:
1543             prevp = preceding_leaf(p)
1544             if not prevp or prevp.type == token.LPAR:
1545                 return NO
1546
1547         elif prev.type in {token.EQUAL} | STARS:
1548             return NO
1549
1550     elif p.type == syms.decorator:
1551         # decorators
1552         return NO
1553
1554     elif p.type == syms.dotted_name:
1555         if prev:
1556             return NO
1557
1558         prevp = preceding_leaf(p)
1559         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1560             return NO
1561
1562     elif p.type == syms.classdef:
1563         if t == token.LPAR:
1564             return NO
1565
1566         if prev and prev.type == token.LPAR:
1567             return NO
1568
1569     elif p.type in {syms.subscript, syms.sliceop}:
1570         # indexing
1571         if not prev:
1572             assert p.parent is not None, "subscripts are always parented"
1573             if p.parent.type == syms.subscriptlist:
1574                 return SPACE
1575
1576             return NO
1577
1578         elif not complex_subscript:
1579             return NO
1580
1581     elif p.type == syms.atom:
1582         if prev and t == token.DOT:
1583             # dots, but not the first one.
1584             return NO
1585
1586     elif p.type == syms.dictsetmaker:
1587         # dict unpacking
1588         if prev and prev.type == token.DOUBLESTAR:
1589             return NO
1590
1591     elif p.type in {syms.factor, syms.star_expr}:
1592         # unary ops
1593         if not prev:
1594             prevp = preceding_leaf(p)
1595             if not prevp or prevp.type in OPENING_BRACKETS:
1596                 return NO
1597
1598             prevp_parent = prevp.parent
1599             assert prevp_parent is not None
1600             if (
1601                 prevp.type == token.COLON
1602                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1603             ):
1604                 return NO
1605
1606             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1607                 return NO
1608
1609         elif t == token.NAME or t == token.NUMBER:
1610             return NO
1611
1612     elif p.type == syms.import_from:
1613         if t == token.DOT:
1614             if prev and prev.type == token.DOT:
1615                 return NO
1616
1617         elif t == token.NAME:
1618             if v == "import":
1619                 return SPACE
1620
1621             if prev and prev.type == token.DOT:
1622                 return NO
1623
1624     elif p.type == syms.sliceop:
1625         return NO
1626
1627     return SPACE
1628
1629
1630 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1631     """Return the first leaf that precedes `node`, if any."""
1632     while node:
1633         res = node.prev_sibling
1634         if res:
1635             if isinstance(res, Leaf):
1636                 return res
1637
1638             try:
1639                 return list(res.leaves())[-1]
1640
1641             except IndexError:
1642                 return None
1643
1644         node = node.parent
1645     return None
1646
1647
1648 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1649     """Return the child of `ancestor` that contains `descendant`."""
1650     node: Optional[LN] = descendant
1651     while node and node.parent != ancestor:
1652         node = node.parent
1653     return node
1654
1655
1656 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1657     """Return the priority of the `leaf` delimiter, given a line break after it.
1658
1659     The delimiter priorities returned here are from those delimiters that would
1660     cause a line break after themselves.
1661
1662     Higher numbers are higher priority.
1663     """
1664     if leaf.type == token.COMMA:
1665         return COMMA_PRIORITY
1666
1667     return 0
1668
1669
1670 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1671     """Return the priority of the `leaf` delimiter, given a line before after it.
1672
1673     The delimiter priorities returned here are from those delimiters that would
1674     cause a line break before themselves.
1675
1676     Higher numbers are higher priority.
1677     """
1678     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1679         # * and ** might also be MATH_OPERATORS but in this case they are not.
1680         # Don't treat them as a delimiter.
1681         return 0
1682
1683     if (
1684         leaf.type in MATH_OPERATORS
1685         and leaf.parent
1686         and leaf.parent.type not in {syms.factor, syms.star_expr}
1687     ):
1688         return MATH_PRIORITIES[leaf.type]
1689
1690     if leaf.type in COMPARATORS:
1691         return COMPARATOR_PRIORITY
1692
1693     if (
1694         leaf.type == token.STRING
1695         and previous is not None
1696         and previous.type == token.STRING
1697     ):
1698         return STRING_PRIORITY
1699
1700     if (
1701         leaf.type == token.NAME
1702         and leaf.value == "for"
1703         and leaf.parent
1704         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1705     ):
1706         return COMPREHENSION_PRIORITY
1707
1708     if (
1709         leaf.type == token.NAME
1710         and leaf.value == "if"
1711         and leaf.parent
1712         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1713     ):
1714         return COMPREHENSION_PRIORITY
1715
1716     if (
1717         leaf.type == token.NAME
1718         and leaf.value in {"if", "else"}
1719         and leaf.parent
1720         and leaf.parent.type == syms.test
1721     ):
1722         return TERNARY_PRIORITY
1723
1724     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1725         return LOGIC_PRIORITY
1726
1727     return 0
1728
1729
1730 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1731     """Clean the prefix of the `leaf` and generate comments from it, if any.
1732
1733     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1734     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1735     move because it does away with modifying the grammar to include all the
1736     possible places in which comments can be placed.
1737
1738     The sad consequence for us though is that comments don't "belong" anywhere.
1739     This is why this function generates simple parentless Leaf objects for
1740     comments.  We simply don't know what the correct parent should be.
1741
1742     No matter though, we can live without this.  We really only need to
1743     differentiate between inline and standalone comments.  The latter don't
1744     share the line with any code.
1745
1746     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1747     are emitted with a fake STANDALONE_COMMENT token identifier.
1748     """
1749     p = leaf.prefix
1750     if not p:
1751         return
1752
1753     if "#" not in p:
1754         return
1755
1756     consumed = 0
1757     nlines = 0
1758     for index, line in enumerate(p.split("\n")):
1759         consumed += len(line) + 1  # adding the length of the split '\n'
1760         line = line.lstrip()
1761         if not line:
1762             nlines += 1
1763         if not line.startswith("#"):
1764             continue
1765
1766         if index == 0 and leaf.type != token.ENDMARKER:
1767             comment_type = token.COMMENT  # simple trailing comment
1768         else:
1769             comment_type = STANDALONE_COMMENT
1770         comment = make_comment(line)
1771         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1772
1773         if comment in {"# fmt: on", "# yapf: enable"}:
1774             raise FormatOn(consumed)
1775
1776         if comment in {"# fmt: off", "# yapf: disable"}:
1777             if comment_type == STANDALONE_COMMENT:
1778                 raise FormatOff(consumed)
1779
1780             prev = preceding_leaf(leaf)
1781             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1782                 raise FormatOff(consumed)
1783
1784         nlines = 0
1785
1786
1787 def make_comment(content: str) -> str:
1788     """Return a consistently formatted comment from the given `content` string.
1789
1790     All comments (except for "##", "#!", "#:") should have a single space between
1791     the hash sign and the content.
1792
1793     If `content` didn't start with a hash sign, one is provided.
1794     """
1795     content = content.rstrip()
1796     if not content:
1797         return "#"
1798
1799     if content[0] == "#":
1800         content = content[1:]
1801     if content and content[0] not in " !:#":
1802         content = " " + content
1803     return "#" + content
1804
1805
1806 def split_line(
1807     line: Line, line_length: int, inner: bool = False, py36: bool = False
1808 ) -> Iterator[Line]:
1809     """Split a `line` into potentially many lines.
1810
1811     They should fit in the allotted `line_length` but might not be able to.
1812     `inner` signifies that there were a pair of brackets somewhere around the
1813     current `line`, possibly transitively. This means we can fallback to splitting
1814     by delimiters if the LHS/RHS don't yield any results.
1815
1816     If `py36` is True, splitting may generate syntax that is only compatible
1817     with Python 3.6 and later.
1818     """
1819     if isinstance(line, UnformattedLines) or line.is_comment:
1820         yield line
1821         return
1822
1823     line_str = str(line).strip("\n")
1824     if (
1825         len(line_str) <= line_length
1826         and "\n" not in line_str  # multiline strings
1827         and not line.contains_standalone_comments()
1828     ):
1829         yield line
1830         return
1831
1832     split_funcs: List[SplitFunc]
1833     if line.is_def:
1834         split_funcs = [left_hand_split]
1835     elif line.is_import:
1836         split_funcs = [explode_split]
1837     elif line.inside_brackets:
1838         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1839     else:
1840         split_funcs = [right_hand_split]
1841     for split_func in split_funcs:
1842         # We are accumulating lines in `result` because we might want to abort
1843         # mission and return the original line in the end, or attempt a different
1844         # split altogether.
1845         result: List[Line] = []
1846         try:
1847             for l in split_func(line, py36):
1848                 if str(l).strip("\n") == line_str:
1849                     raise CannotSplit("Split function returned an unchanged result")
1850
1851                 result.extend(
1852                     split_line(l, line_length=line_length, inner=True, py36=py36)
1853                 )
1854         except CannotSplit as cs:
1855             continue
1856
1857         else:
1858             yield from result
1859             break
1860
1861     else:
1862         yield line
1863
1864
1865 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1866     """Split line into many lines, starting with the first matching bracket pair.
1867
1868     Note: this usually looks weird, only use this for function definitions.
1869     Prefer RHS otherwise.  This is why this function is not symmetrical with
1870     :func:`right_hand_split` which also handles optional parentheses.
1871     """
1872     head = Line(depth=line.depth)
1873     body = Line(depth=line.depth + 1, inside_brackets=True)
1874     tail = Line(depth=line.depth)
1875     tail_leaves: List[Leaf] = []
1876     body_leaves: List[Leaf] = []
1877     head_leaves: List[Leaf] = []
1878     current_leaves = head_leaves
1879     matching_bracket = None
1880     for leaf in line.leaves:
1881         if (
1882             current_leaves is body_leaves
1883             and leaf.type in CLOSING_BRACKETS
1884             and leaf.opening_bracket is matching_bracket
1885         ):
1886             current_leaves = tail_leaves if body_leaves else head_leaves
1887         current_leaves.append(leaf)
1888         if current_leaves is head_leaves:
1889             if leaf.type in OPENING_BRACKETS:
1890                 matching_bracket = leaf
1891                 current_leaves = body_leaves
1892     # Since body is a new indent level, remove spurious leading whitespace.
1893     if body_leaves:
1894         normalize_prefix(body_leaves[0], inside_brackets=True)
1895     # Build the new lines.
1896     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1897         for leaf in leaves:
1898             result.append(leaf, preformatted=True)
1899             for comment_after in line.comments_after(leaf):
1900                 result.append(comment_after, preformatted=True)
1901     bracket_split_succeeded_or_raise(head, body, tail)
1902     for result in (head, body, tail):
1903         if result:
1904             yield result
1905
1906
1907 def right_hand_split(
1908     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1909 ) -> Iterator[Line]:
1910     """Split line into many lines, starting with the last matching bracket pair.
1911
1912     If the split was by optional parentheses, attempt splitting without them, too.
1913     """
1914     head = Line(depth=line.depth)
1915     body = Line(depth=line.depth + 1, inside_brackets=True)
1916     tail = Line(depth=line.depth)
1917     tail_leaves: List[Leaf] = []
1918     body_leaves: List[Leaf] = []
1919     head_leaves: List[Leaf] = []
1920     current_leaves = tail_leaves
1921     opening_bracket = None
1922     closing_bracket = None
1923     for leaf in reversed(line.leaves):
1924         if current_leaves is body_leaves:
1925             if leaf is opening_bracket:
1926                 current_leaves = head_leaves if body_leaves else tail_leaves
1927         current_leaves.append(leaf)
1928         if current_leaves is tail_leaves:
1929             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1930                 opening_bracket = leaf.opening_bracket
1931                 closing_bracket = leaf
1932                 current_leaves = body_leaves
1933     tail_leaves.reverse()
1934     body_leaves.reverse()
1935     head_leaves.reverse()
1936     # Since body is a new indent level, remove spurious leading whitespace.
1937     if body_leaves:
1938         normalize_prefix(body_leaves[0], inside_brackets=True)
1939     elif not head_leaves:
1940         # No `head` and no `body` means the split failed. `tail` has all content.
1941         raise CannotSplit("No brackets found")
1942
1943     # Build the new lines.
1944     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1945         for leaf in leaves:
1946             result.append(leaf, preformatted=True)
1947             for comment_after in line.comments_after(leaf):
1948                 result.append(comment_after, preformatted=True)
1949     bracket_split_succeeded_or_raise(head, body, tail)
1950     assert opening_bracket and closing_bracket
1951     if (
1952         # the opening bracket is an optional paren
1953         opening_bracket.type == token.LPAR
1954         and not opening_bracket.value
1955         # the closing bracket is an optional paren
1956         and closing_bracket.type == token.RPAR
1957         and not closing_bracket.value
1958         # there are no delimiters or standalone comments in the body
1959         and not body.bracket_tracker.delimiters
1960         and not line.contains_standalone_comments(0)
1961         # and it's not an import (optional parens are the only thing we can split
1962         # on in this case; attempting a split without them is a waste of time)
1963         and not line.is_import
1964     ):
1965         omit = {id(closing_bracket), *omit}
1966         try:
1967             yield from right_hand_split(line, py36=py36, omit=omit)
1968             return
1969         except CannotSplit:
1970             pass
1971
1972     ensure_visible(opening_bracket)
1973     ensure_visible(closing_bracket)
1974     for result in (head, body, tail):
1975         if result:
1976             yield result
1977
1978
1979 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1980     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1981
1982     Do nothing otherwise.
1983
1984     A left- or right-hand split is based on a pair of brackets. Content before
1985     (and including) the opening bracket is left on one line, content inside the
1986     brackets is put on a separate line, and finally content starting with and
1987     following the closing bracket is put on a separate line.
1988
1989     Those are called `head`, `body`, and `tail`, respectively. If the split
1990     produced the same line (all content in `head`) or ended up with an empty `body`
1991     and the `tail` is just the closing bracket, then it's considered failed.
1992     """
1993     tail_len = len(str(tail).strip())
1994     if not body:
1995         if tail_len == 0:
1996             raise CannotSplit("Splitting brackets produced the same line")
1997
1998         elif tail_len < 3:
1999             raise CannotSplit(
2000                 f"Splitting brackets on an empty body to save "
2001                 f"{tail_len} characters is not worth it"
2002             )
2003
2004
2005 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2006     """Normalize prefix of the first leaf in every line returned by `split_func`.
2007
2008     This is a decorator over relevant split functions.
2009     """
2010
2011     @wraps(split_func)
2012     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2013         for l in split_func(line, py36):
2014             normalize_prefix(l.leaves[0], inside_brackets=True)
2015             yield l
2016
2017     return split_wrapper
2018
2019
2020 @dont_increase_indentation
2021 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2022     """Split according to delimiters of the highest priority.
2023
2024     If `py36` is True, the split will add trailing commas also in function
2025     signatures that contain `*` and `**`.
2026     """
2027     try:
2028         last_leaf = line.leaves[-1]
2029     except IndexError:
2030         raise CannotSplit("Line empty")
2031
2032     delimiters = line.bracket_tracker.delimiters
2033     try:
2034         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2035             exclude={id(last_leaf)}
2036         )
2037     except ValueError:
2038         raise CannotSplit("No delimiters found")
2039
2040     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2041     lowest_depth = sys.maxsize
2042     trailing_comma_safe = True
2043
2044     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2045         """Append `leaf` to current line or to new line if appending impossible."""
2046         nonlocal current_line
2047         try:
2048             current_line.append_safe(leaf, preformatted=True)
2049         except ValueError as ve:
2050             yield current_line
2051
2052             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2053             current_line.append(leaf)
2054
2055     for leaf in line.leaves:
2056         yield from append_to_line(leaf)
2057
2058         for comment_after in line.comments_after(leaf):
2059             yield from append_to_line(comment_after)
2060
2061         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2062         if (
2063             leaf.bracket_depth == lowest_depth
2064             and is_vararg(leaf, within=VARARGS_PARENTS)
2065         ):
2066             trailing_comma_safe = trailing_comma_safe and py36
2067         leaf_priority = delimiters.get(id(leaf))
2068         if leaf_priority == delimiter_priority:
2069             yield current_line
2070
2071             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2072     if current_line:
2073         if (
2074             trailing_comma_safe
2075             and delimiter_priority == COMMA_PRIORITY
2076             and current_line.leaves[-1].type != token.COMMA
2077             and current_line.leaves[-1].type != STANDALONE_COMMENT
2078         ):
2079             current_line.append(Leaf(token.COMMA, ","))
2080         yield current_line
2081
2082
2083 @dont_increase_indentation
2084 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2085     """Split standalone comments from the rest of the line."""
2086     if not line.contains_standalone_comments(0):
2087         raise CannotSplit("Line does not have any standalone comments")
2088
2089     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2090
2091     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2092         """Append `leaf` to current line or to new line if appending impossible."""
2093         nonlocal current_line
2094         try:
2095             current_line.append_safe(leaf, preformatted=True)
2096         except ValueError as ve:
2097             yield current_line
2098
2099             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2100             current_line.append(leaf)
2101
2102     for leaf in line.leaves:
2103         yield from append_to_line(leaf)
2104
2105         for comment_after in line.comments_after(leaf):
2106             yield from append_to_line(comment_after)
2107
2108     if current_line:
2109         yield current_line
2110
2111
2112 def explode_split(
2113     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2114 ) -> Iterator[Line]:
2115     """Split by rightmost bracket and immediately split contents by a delimiter."""
2116     new_lines = list(right_hand_split(line, py36, omit))
2117     if len(new_lines) != 3:
2118         yield from new_lines
2119         return
2120
2121     yield new_lines[0]
2122
2123     try:
2124         yield from delimiter_split(new_lines[1], py36)
2125
2126     except CannotSplit:
2127         yield new_lines[1]
2128
2129     yield new_lines[2]
2130
2131
2132 def is_import(leaf: Leaf) -> bool:
2133     """Return True if the given leaf starts an import statement."""
2134     p = leaf.parent
2135     t = leaf.type
2136     v = leaf.value
2137     return bool(
2138         t == token.NAME
2139         and (
2140             (v == "import" and p and p.type == syms.import_name)
2141             or (v == "from" and p and p.type == syms.import_from)
2142         )
2143     )
2144
2145
2146 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2147     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2148     else.
2149
2150     Note: don't use backslashes for formatting or you'll lose your voting rights.
2151     """
2152     if not inside_brackets:
2153         spl = leaf.prefix.split("#")
2154         if "\\" not in spl[0]:
2155             nl_count = spl[-1].count("\n")
2156             if len(spl) > 1:
2157                 nl_count -= 1
2158             leaf.prefix = "\n" * nl_count
2159             return
2160
2161     leaf.prefix = ""
2162
2163
2164 def normalize_string_quotes(leaf: Leaf) -> None:
2165     """Prefer double quotes but only if it doesn't cause more escaping.
2166
2167     Adds or removes backslashes as appropriate. Doesn't parse and fix
2168     strings nested in f-strings (yet).
2169
2170     Note: Mutates its argument.
2171     """
2172     value = leaf.value.lstrip("furbFURB")
2173     if value[:3] == '"""':
2174         return
2175
2176     elif value[:3] == "'''":
2177         orig_quote = "'''"
2178         new_quote = '"""'
2179     elif value[0] == '"':
2180         orig_quote = '"'
2181         new_quote = "'"
2182     else:
2183         orig_quote = "'"
2184         new_quote = '"'
2185     first_quote_pos = leaf.value.find(orig_quote)
2186     if first_quote_pos == -1:
2187         return  # There's an internal error
2188
2189     prefix = leaf.value[:first_quote_pos]
2190     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2191     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2192     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2193     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2194     if "r" in prefix.casefold():
2195         if unescaped_new_quote.search(body):
2196             # There's at least one unescaped new_quote in this raw string
2197             # so converting is impossible
2198             return
2199
2200         # Do not introduce or remove backslashes in raw strings
2201         new_body = body
2202     else:
2203         # remove unnecessary quotes
2204         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2205         if body != new_body:
2206             # Consider the string without unnecessary quotes as the original
2207             body = new_body
2208             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2209         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2210         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2211     if new_quote == '"""' and new_body[-1] == '"':
2212         # edge case:
2213         new_body = new_body[:-1] + '\\"'
2214     orig_escape_count = body.count("\\")
2215     new_escape_count = new_body.count("\\")
2216     if new_escape_count > orig_escape_count:
2217         return  # Do not introduce more escaping
2218
2219     if new_escape_count == orig_escape_count and orig_quote == '"':
2220         return  # Prefer double quotes
2221
2222     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2223
2224
2225 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2226     """Make existing optional parentheses invisible or create new ones.
2227
2228     `parens_after` is a set of string leaf values immeditely after which parens
2229     should be put.
2230
2231     Standardizes on visible parentheses for single-element tuples, and keeps
2232     existing visible parentheses for other tuples and generator expressions.
2233     """
2234     check_lpar = False
2235     for child in list(node.children):
2236         if check_lpar:
2237             if child.type == syms.atom:
2238                 maybe_make_parens_invisible_in_atom(child)
2239             elif is_one_tuple(child):
2240                 # wrap child in visible parentheses
2241                 lpar = Leaf(token.LPAR, "(")
2242                 rpar = Leaf(token.RPAR, ")")
2243                 index = child.remove() or 0
2244                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2245             else:
2246                 # wrap child in invisible parentheses
2247                 lpar = Leaf(token.LPAR, "")
2248                 rpar = Leaf(token.RPAR, "")
2249                 index = child.remove() or 0
2250                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2251
2252         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2253
2254
2255 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2256     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2257     if (
2258         node.type != syms.atom
2259         or is_empty_tuple(node)
2260         or is_one_tuple(node)
2261         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2262     ):
2263         return False
2264
2265     first = node.children[0]
2266     last = node.children[-1]
2267     if first.type == token.LPAR and last.type == token.RPAR:
2268         # make parentheses invisible
2269         first.value = ""  # type: ignore
2270         last.value = ""  # type: ignore
2271         if len(node.children) > 1:
2272             maybe_make_parens_invisible_in_atom(node.children[1])
2273         return True
2274
2275     return False
2276
2277
2278 def is_empty_tuple(node: LN) -> bool:
2279     """Return True if `node` holds an empty tuple."""
2280     return (
2281         node.type == syms.atom
2282         and len(node.children) == 2
2283         and node.children[0].type == token.LPAR
2284         and node.children[1].type == token.RPAR
2285     )
2286
2287
2288 def is_one_tuple(node: LN) -> bool:
2289     """Return True if `node` holds a tuple with one element, with or without parens."""
2290     if node.type == syms.atom:
2291         if len(node.children) != 3:
2292             return False
2293
2294         lpar, gexp, rpar = node.children
2295         if not (
2296             lpar.type == token.LPAR
2297             and gexp.type == syms.testlist_gexp
2298             and rpar.type == token.RPAR
2299         ):
2300             return False
2301
2302         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2303
2304     return (
2305         node.type in IMPLICIT_TUPLE
2306         and len(node.children) == 2
2307         and node.children[1].type == token.COMMA
2308     )
2309
2310
2311 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2312     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2313
2314     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2315     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2316     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2317     generalizations (PEP 448).
2318     """
2319     if leaf.type not in STARS or not leaf.parent:
2320         return False
2321
2322     p = leaf.parent
2323     if p.type == syms.star_expr:
2324         # Star expressions are also used as assignment targets in extended
2325         # iterable unpacking (PEP 3132).  See what its parent is instead.
2326         if not p.parent:
2327             return False
2328
2329         p = p.parent
2330
2331     return p.type in within
2332
2333
2334 def max_delimiter_priority_in_atom(node: LN) -> int:
2335     """Return maximum delimiter priority inside `node`.
2336
2337     This is specific to atoms with contents contained in a pair of parentheses.
2338     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2339     """
2340     if node.type != syms.atom:
2341         return 0
2342
2343     first = node.children[0]
2344     last = node.children[-1]
2345     if not (first.type == token.LPAR and last.type == token.RPAR):
2346         return 0
2347
2348     bt = BracketTracker()
2349     for c in node.children[1:-1]:
2350         if isinstance(c, Leaf):
2351             bt.mark(c)
2352         else:
2353             for leaf in c.leaves():
2354                 bt.mark(leaf)
2355     try:
2356         return bt.max_delimiter_priority()
2357
2358     except ValueError:
2359         return 0
2360
2361
2362 def ensure_visible(leaf: Leaf) -> None:
2363     """Make sure parentheses are visible.
2364
2365     They could be invisible as part of some statements (see
2366     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2367     """
2368     if leaf.type == token.LPAR:
2369         leaf.value = "("
2370     elif leaf.type == token.RPAR:
2371         leaf.value = ")"
2372
2373
2374 def is_python36(node: Node) -> bool:
2375     """Return True if the current file is using Python 3.6+ features.
2376
2377     Currently looking for:
2378     - f-strings; and
2379     - trailing commas after * or ** in function signatures and calls.
2380     """
2381     for n in node.pre_order():
2382         if n.type == token.STRING:
2383             value_head = n.value[:2]  # type: ignore
2384             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2385                 return True
2386
2387         elif (
2388             n.type in {syms.typedargslist, syms.arglist}
2389             and n.children
2390             and n.children[-1].type == token.COMMA
2391         ):
2392             for ch in n.children:
2393                 if ch.type in STARS:
2394                     return True
2395
2396                 if ch.type == syms.argument:
2397                     for argch in ch.children:
2398                         if argch.type in STARS:
2399                             return True
2400
2401     return False
2402
2403
2404 PYTHON_EXTENSIONS = {".py"}
2405 BLACKLISTED_DIRECTORIES = {
2406     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2407 }
2408
2409
2410 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2411     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2412     and have one of the PYTHON_EXTENSIONS.
2413     """
2414     for child in path.iterdir():
2415         if child.is_dir():
2416             if child.name in BLACKLISTED_DIRECTORIES:
2417                 continue
2418
2419             yield from gen_python_files_in_dir(child)
2420
2421         elif child.suffix in PYTHON_EXTENSIONS:
2422             yield child
2423
2424
2425 @dataclass
2426 class Report:
2427     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2428     check: bool = False
2429     quiet: bool = False
2430     change_count: int = 0
2431     same_count: int = 0
2432     failure_count: int = 0
2433
2434     def done(self, src: Path, changed: Changed) -> None:
2435         """Increment the counter for successful reformatting. Write out a message."""
2436         if changed is Changed.YES:
2437             reformatted = "would reformat" if self.check else "reformatted"
2438             if not self.quiet:
2439                 out(f"{reformatted} {src}")
2440             self.change_count += 1
2441         else:
2442             if not self.quiet:
2443                 if changed is Changed.NO:
2444                     msg = f"{src} already well formatted, good job."
2445                 else:
2446                     msg = f"{src} wasn't modified on disk since last run."
2447                 out(msg, bold=False)
2448             self.same_count += 1
2449
2450     def failed(self, src: Path, message: str) -> None:
2451         """Increment the counter for failed reformatting. Write out a message."""
2452         err(f"error: cannot format {src}: {message}")
2453         self.failure_count += 1
2454
2455     @property
2456     def return_code(self) -> int:
2457         """Return the exit code that the app should use.
2458
2459         This considers the current state of changed files and failures:
2460         - if there were any failures, return 123;
2461         - if any files were changed and --check is being used, return 1;
2462         - otherwise return 0.
2463         """
2464         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2465         # 126 we have special returncodes reserved by the shell.
2466         if self.failure_count:
2467             return 123
2468
2469         elif self.change_count and self.check:
2470             return 1
2471
2472         return 0
2473
2474     def __str__(self) -> str:
2475         """Render a color report of the current state.
2476
2477         Use `click.unstyle` to remove colors.
2478         """
2479         if self.check:
2480             reformatted = "would be reformatted"
2481             unchanged = "would be left unchanged"
2482             failed = "would fail to reformat"
2483         else:
2484             reformatted = "reformatted"
2485             unchanged = "left unchanged"
2486             failed = "failed to reformat"
2487         report = []
2488         if self.change_count:
2489             s = "s" if self.change_count > 1 else ""
2490             report.append(
2491                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2492             )
2493         if self.same_count:
2494             s = "s" if self.same_count > 1 else ""
2495             report.append(f"{self.same_count} file{s} {unchanged}")
2496         if self.failure_count:
2497             s = "s" if self.failure_count > 1 else ""
2498             report.append(
2499                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2500             )
2501         return ", ".join(report) + "."
2502
2503
2504 def assert_equivalent(src: str, dst: str) -> None:
2505     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2506
2507     import ast
2508     import traceback
2509
2510     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2511         """Simple visitor generating strings to compare ASTs by content."""
2512         yield f"{'  ' * depth}{node.__class__.__name__}("
2513
2514         for field in sorted(node._fields):
2515             try:
2516                 value = getattr(node, field)
2517             except AttributeError:
2518                 continue
2519
2520             yield f"{'  ' * (depth+1)}{field}="
2521
2522             if isinstance(value, list):
2523                 for item in value:
2524                     if isinstance(item, ast.AST):
2525                         yield from _v(item, depth + 2)
2526
2527             elif isinstance(value, ast.AST):
2528                 yield from _v(value, depth + 2)
2529
2530             else:
2531                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2532
2533         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2534
2535     try:
2536         src_ast = ast.parse(src)
2537     except Exception as exc:
2538         major, minor = sys.version_info[:2]
2539         raise AssertionError(
2540             f"cannot use --safe with this file; failed to parse source file "
2541             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2542             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2543         )
2544
2545     try:
2546         dst_ast = ast.parse(dst)
2547     except Exception as exc:
2548         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2549         raise AssertionError(
2550             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2551             f"Please report a bug on https://github.com/ambv/black/issues.  "
2552             f"This invalid output might be helpful: {log}"
2553         ) from None
2554
2555     src_ast_str = "\n".join(_v(src_ast))
2556     dst_ast_str = "\n".join(_v(dst_ast))
2557     if src_ast_str != dst_ast_str:
2558         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2559         raise AssertionError(
2560             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2561             f"the source.  "
2562             f"Please report a bug on https://github.com/ambv/black/issues.  "
2563             f"This diff might be helpful: {log}"
2564         ) from None
2565
2566
2567 def assert_stable(src: str, dst: str, line_length: int) -> None:
2568     """Raise AssertionError if `dst` reformats differently the second time."""
2569     newdst = format_str(dst, line_length=line_length)
2570     if dst != newdst:
2571         log = dump_to_file(
2572             diff(src, dst, "source", "first pass"),
2573             diff(dst, newdst, "first pass", "second pass"),
2574         )
2575         raise AssertionError(
2576             f"INTERNAL ERROR: Black produced different code on the second pass "
2577             f"of the formatter.  "
2578             f"Please report a bug on https://github.com/ambv/black/issues.  "
2579             f"This diff might be helpful: {log}"
2580         ) from None
2581
2582
2583 def dump_to_file(*output: str) -> str:
2584     """Dump `output` to a temporary file. Return path to the file."""
2585     import tempfile
2586
2587     with tempfile.NamedTemporaryFile(
2588         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2589     ) as f:
2590         for lines in output:
2591             f.write(lines)
2592             if lines and lines[-1] != "\n":
2593                 f.write("\n")
2594     return f.name
2595
2596
2597 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2598     """Return a unified diff string between strings `a` and `b`."""
2599     import difflib
2600
2601     a_lines = [line + "\n" for line in a.split("\n")]
2602     b_lines = [line + "\n" for line in b.split("\n")]
2603     return "".join(
2604         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2605     )
2606
2607
2608 def cancel(tasks: List[asyncio.Task]) -> None:
2609     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2610     err("Aborted!")
2611     for task in tasks:
2612         task.cancel()
2613
2614
2615 def shutdown(loop: BaseEventLoop) -> None:
2616     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2617     try:
2618         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2619         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2620         if not to_cancel:
2621             return
2622
2623         for task in to_cancel:
2624             task.cancel()
2625         loop.run_until_complete(
2626             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2627         )
2628     finally:
2629         # `concurrent.futures.Future` objects cannot be cancelled once they
2630         # are already running. There might be some when the `shutdown()` happened.
2631         # Silence their logger's spew about the event loop being closed.
2632         cf_logger = logging.getLogger("concurrent.futures")
2633         cf_logger.setLevel(logging.CRITICAL)
2634         loop.close()
2635
2636
2637 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2638     """Replace `regex` with `replacement` twice on `original`.
2639
2640     This is used by string normalization to perform replaces on
2641     overlapping matches.
2642     """
2643     return regex.sub(replacement, regex.sub(replacement, original))
2644
2645
2646 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2647
2648
2649 def get_cache_file(line_length: int) -> Path:
2650     return CACHE_DIR / f"cache.{line_length}.pickle"
2651
2652
2653 def read_cache(line_length: int) -> Cache:
2654     """Read the cache if it exists and is well formed.
2655
2656     If it is not well formed, the call to write_cache later should resolve the issue.
2657     """
2658     cache_file = get_cache_file(line_length)
2659     if not cache_file.exists():
2660         return {}
2661
2662     with cache_file.open("rb") as fobj:
2663         try:
2664             cache: Cache = pickle.load(fobj)
2665         except pickle.UnpicklingError:
2666             return {}
2667
2668     return cache
2669
2670
2671 def get_cache_info(path: Path) -> CacheInfo:
2672     """Return the information used to check if a file is already formatted or not."""
2673     stat = path.stat()
2674     return stat.st_mtime, stat.st_size
2675
2676
2677 def filter_cached(
2678     cache: Cache, sources: Iterable[Path]
2679 ) -> Tuple[List[Path], List[Path]]:
2680     """Split a list of paths into two.
2681
2682     The first list contains paths of files that modified on disk or are not in the
2683     cache. The other list contains paths to non-modified files.
2684     """
2685     todo, done = [], []
2686     for src in sources:
2687         src = src.resolve()
2688         if cache.get(src) != get_cache_info(src):
2689             todo.append(src)
2690         else:
2691             done.append(src)
2692     return todo, done
2693
2694
2695 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2696     """Update the cache file."""
2697     cache_file = get_cache_file(line_length)
2698     try:
2699         if not CACHE_DIR.exists():
2700             CACHE_DIR.mkdir(parents=True)
2701         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2702         with cache_file.open("wb") as fobj:
2703             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2704     except OSError:
2705         pass
2706
2707
2708 if __name__ == "__main__":
2709     main()