]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

43b9bd17f6f2dcd3b52648908e54239620ddd4b3
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a6"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     future_imports = get_future_imports(src_node)
413     py36 = is_python36(src_node)
414     lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
415     elt = EmptyLineTracker()
416     empty_line = Line()
417     after = 0
418     for current_line in lines.visit(src_node):
419         for _ in range(after):
420             dst_contents += str(empty_line)
421         before, after = elt.maybe_empty_lines(current_line)
422         for _ in range(before):
423             dst_contents += str(empty_line)
424         for line in split_line(current_line, line_length=line_length, py36=py36):
425             dst_contents += str(line)
426     return dst_contents
427
428
429 GRAMMARS = [
430     pygram.python_grammar_no_print_statement_no_exec_statement,
431     pygram.python_grammar_no_print_statement,
432     pygram.python_grammar,
433 ]
434
435
436 def lib2to3_parse(src_txt: str) -> Node:
437     """Given a string with source, return the lib2to3 Node."""
438     grammar = pygram.python_grammar_no_print_statement
439     if src_txt[-1] != "\n":
440         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
441         src_txt += nl
442     for grammar in GRAMMARS:
443         drv = driver.Driver(grammar, pytree.convert)
444         try:
445             result = drv.parse_string(src_txt, True)
446             break
447
448         except ParseError as pe:
449             lineno, column = pe.context[1]
450             lines = src_txt.splitlines()
451             try:
452                 faulty_line = lines[lineno - 1]
453             except IndexError:
454                 faulty_line = "<line number missing in source>"
455             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
456     else:
457         raise exc from None
458
459     if isinstance(result, Leaf):
460         result = Node(syms.file_input, [result])
461     return result
462
463
464 def lib2to3_unparse(node: Node) -> str:
465     """Given a lib2to3 node, return its string representation."""
466     code = str(node)
467     return code
468
469
470 T = TypeVar("T")
471
472
473 class Visitor(Generic[T]):
474     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
475
476     def visit(self, node: LN) -> Iterator[T]:
477         """Main method to visit `node` and its children.
478
479         It tries to find a `visit_*()` method for the given `node.type`, like
480         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
481         If no dedicated `visit_*()` method is found, chooses `visit_default()`
482         instead.
483
484         Then yields objects of type `T` from the selected visitor.
485         """
486         if node.type < 256:
487             name = token.tok_name[node.type]
488         else:
489             name = type_repr(node.type)
490         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
491
492     def visit_default(self, node: LN) -> Iterator[T]:
493         """Default `visit_*()` implementation. Recurses to children of `node`."""
494         if isinstance(node, Node):
495             for child in node.children:
496                 yield from self.visit(child)
497
498
499 @dataclass
500 class DebugVisitor(Visitor[T]):
501     tree_depth: int = 0
502
503     def visit_default(self, node: LN) -> Iterator[T]:
504         indent = " " * (2 * self.tree_depth)
505         if isinstance(node, Node):
506             _type = type_repr(node.type)
507             out(f"{indent}{_type}", fg="yellow")
508             self.tree_depth += 1
509             for child in node.children:
510                 yield from self.visit(child)
511
512             self.tree_depth -= 1
513             out(f"{indent}/{_type}", fg="yellow", bold=False)
514         else:
515             _type = token.tok_name.get(node.type, str(node.type))
516             out(f"{indent}{_type}", fg="blue", nl=False)
517             if node.prefix:
518                 # We don't have to handle prefixes for `Node` objects since
519                 # that delegates to the first child anyway.
520                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
521             out(f" {node.value!r}", fg="blue", bold=False)
522
523     @classmethod
524     def show(cls, code: str) -> None:
525         """Pretty-print the lib2to3 AST of a given string of `code`.
526
527         Convenience method for debugging.
528         """
529         v: DebugVisitor[None] = DebugVisitor()
530         list(v.visit(lib2to3_parse(code)))
531
532
533 KEYWORDS = set(keyword.kwlist)
534 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
535 FLOW_CONTROL = {"return", "raise", "break", "continue"}
536 STATEMENT = {
537     syms.if_stmt,
538     syms.while_stmt,
539     syms.for_stmt,
540     syms.try_stmt,
541     syms.except_clause,
542     syms.with_stmt,
543     syms.funcdef,
544     syms.classdef,
545 }
546 STANDALONE_COMMENT = 153
547 LOGIC_OPERATORS = {"and", "or"}
548 COMPARATORS = {
549     token.LESS,
550     token.GREATER,
551     token.EQEQUAL,
552     token.NOTEQUAL,
553     token.LESSEQUAL,
554     token.GREATEREQUAL,
555 }
556 MATH_OPERATORS = {
557     token.VBAR,
558     token.CIRCUMFLEX,
559     token.AMPER,
560     token.LEFTSHIFT,
561     token.RIGHTSHIFT,
562     token.PLUS,
563     token.MINUS,
564     token.STAR,
565     token.SLASH,
566     token.DOUBLESLASH,
567     token.PERCENT,
568     token.AT,
569     token.TILDE,
570     token.DOUBLESTAR,
571 }
572 STARS = {token.STAR, token.DOUBLESTAR}
573 VARARGS_PARENTS = {
574     syms.arglist,
575     syms.argument,  # double star in arglist
576     syms.trailer,  # single argument to call
577     syms.typedargslist,
578     syms.varargslist,  # lambdas
579 }
580 UNPACKING_PARENTS = {
581     syms.atom,  # single element of a list or set literal
582     syms.dictsetmaker,
583     syms.listmaker,
584     syms.testlist_gexp,
585 }
586 TEST_DESCENDANTS = {
587     syms.test,
588     syms.lambdef,
589     syms.or_test,
590     syms.and_test,
591     syms.not_test,
592     syms.comparison,
593     syms.star_expr,
594     syms.expr,
595     syms.xor_expr,
596     syms.and_expr,
597     syms.shift_expr,
598     syms.arith_expr,
599     syms.trailer,
600     syms.term,
601     syms.power,
602 }
603 ASSIGNMENTS = {
604     "=",
605     "+=",
606     "-=",
607     "*=",
608     "@=",
609     "/=",
610     "%=",
611     "&=",
612     "|=",
613     "^=",
614     "<<=",
615     ">>=",
616     "**=",
617     "//=",
618 }
619 COMPREHENSION_PRIORITY = 20
620 COMMA_PRIORITY = 18
621 TERNARY_PRIORITY = 16
622 LOGIC_PRIORITY = 14
623 STRING_PRIORITY = 12
624 COMPARATOR_PRIORITY = 10
625 MATH_PRIORITIES = {
626     token.VBAR: 8,
627     token.CIRCUMFLEX: 7,
628     token.AMPER: 6,
629     token.LEFTSHIFT: 5,
630     token.RIGHTSHIFT: 5,
631     token.PLUS: 4,
632     token.MINUS: 4,
633     token.STAR: 3,
634     token.SLASH: 3,
635     token.DOUBLESLASH: 3,
636     token.PERCENT: 3,
637     token.AT: 3,
638     token.TILDE: 2,
639     token.DOUBLESTAR: 1,
640 }
641
642
643 @dataclass
644 class BracketTracker:
645     """Keeps track of brackets on a line."""
646
647     depth: int = 0
648     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
649     delimiters: Dict[LeafID, Priority] = Factory(dict)
650     previous: Optional[Leaf] = None
651     _for_loop_variable: int = 0
652     _lambda_arguments: int = 0
653
654     def mark(self, leaf: Leaf) -> None:
655         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
656
657         All leaves receive an int `bracket_depth` field that stores how deep
658         within brackets a given leaf is. 0 means there are no enclosing brackets
659         that started on this line.
660
661         If a leaf is itself a closing bracket, it receives an `opening_bracket`
662         field that it forms a pair with. This is a one-directional link to
663         avoid reference cycles.
664
665         If a leaf is a delimiter (a token on which Black can split the line if
666         needed) and it's on depth 0, its `id()` is stored in the tracker's
667         `delimiters` field.
668         """
669         if leaf.type == token.COMMENT:
670             return
671
672         self.maybe_decrement_after_for_loop_variable(leaf)
673         self.maybe_decrement_after_lambda_arguments(leaf)
674         if leaf.type in CLOSING_BRACKETS:
675             self.depth -= 1
676             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
677             leaf.opening_bracket = opening_bracket
678         leaf.bracket_depth = self.depth
679         if self.depth == 0:
680             delim = is_split_before_delimiter(leaf, self.previous)
681             if delim and self.previous is not None:
682                 self.delimiters[id(self.previous)] = delim
683             else:
684                 delim = is_split_after_delimiter(leaf, self.previous)
685                 if delim:
686                     self.delimiters[id(leaf)] = delim
687         if leaf.type in OPENING_BRACKETS:
688             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
689             self.depth += 1
690         self.previous = leaf
691         self.maybe_increment_lambda_arguments(leaf)
692         self.maybe_increment_for_loop_variable(leaf)
693
694     def any_open_brackets(self) -> bool:
695         """Return True if there is an yet unmatched open bracket on the line."""
696         return bool(self.bracket_match)
697
698     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
699         """Return the highest priority of a delimiter found on the line.
700
701         Values are consistent with what `is_split_*_delimiter()` return.
702         Raises ValueError on no delimiters.
703         """
704         return max(v for k, v in self.delimiters.items() if k not in exclude)
705
706     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
707         """In a for loop, or comprehension, the variables are often unpacks.
708
709         To avoid splitting on the comma in this situation, increase the depth of
710         tokens between `for` and `in`.
711         """
712         if leaf.type == token.NAME and leaf.value == "for":
713             self.depth += 1
714             self._for_loop_variable += 1
715             return True
716
717         return False
718
719     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
720         """See `maybe_increment_for_loop_variable` above for explanation."""
721         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
722             self.depth -= 1
723             self._for_loop_variable -= 1
724             return True
725
726         return False
727
728     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
729         """In a lambda expression, there might be more than one argument.
730
731         To avoid splitting on the comma in this situation, increase the depth of
732         tokens between `lambda` and `:`.
733         """
734         if leaf.type == token.NAME and leaf.value == "lambda":
735             self.depth += 1
736             self._lambda_arguments += 1
737             return True
738
739         return False
740
741     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
742         """See `maybe_increment_lambda_arguments` above for explanation."""
743         if self._lambda_arguments and leaf.type == token.COLON:
744             self.depth -= 1
745             self._lambda_arguments -= 1
746             return True
747
748         return False
749
750     def get_open_lsqb(self) -> Optional[Leaf]:
751         """Return the most recent opening square bracket (if any)."""
752         return self.bracket_match.get((self.depth - 1, token.RSQB))
753
754
755 @dataclass
756 class Line:
757     """Holds leaves and comments. Can be printed with `str(line)`."""
758
759     depth: int = 0
760     leaves: List[Leaf] = Factory(list)
761     comments: List[Tuple[Index, Leaf]] = Factory(list)
762     bracket_tracker: BracketTracker = Factory(BracketTracker)
763     inside_brackets: bool = False
764
765     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
766         """Add a new `leaf` to the end of the line.
767
768         Unless `preformatted` is True, the `leaf` will receive a new consistent
769         whitespace prefix and metadata applied by :class:`BracketTracker`.
770         Trailing commas are maybe removed, unpacked for loop variables are
771         demoted from being delimiters.
772
773         Inline comments are put aside.
774         """
775         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
776         if not has_value:
777             return
778
779         if token.COLON == leaf.type and self.is_class_paren_empty:
780             del self.leaves[-2:]
781         if self.leaves and not preformatted:
782             # Note: at this point leaf.prefix should be empty except for
783             # imports, for which we only preserve newlines.
784             leaf.prefix += whitespace(
785                 leaf, complex_subscript=self.is_complex_subscript(leaf)
786             )
787         if self.inside_brackets or not preformatted:
788             self.bracket_tracker.mark(leaf)
789             self.maybe_remove_trailing_comma(leaf)
790         if not self.append_comment(leaf):
791             self.leaves.append(leaf)
792
793     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
794         """Like :func:`append()` but disallow invalid standalone comment structure.
795
796         Raises ValueError when any `leaf` is appended after a standalone comment
797         or when a standalone comment is not the first leaf on the line.
798         """
799         if self.bracket_tracker.depth == 0:
800             if self.is_comment:
801                 raise ValueError("cannot append to standalone comments")
802
803             if self.leaves and leaf.type == STANDALONE_COMMENT:
804                 raise ValueError(
805                     "cannot append standalone comments to a populated line"
806                 )
807
808         self.append(leaf, preformatted=preformatted)
809
810     @property
811     def is_comment(self) -> bool:
812         """Is this line a standalone comment?"""
813         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
814
815     @property
816     def is_decorator(self) -> bool:
817         """Is this line a decorator?"""
818         return bool(self) and self.leaves[0].type == token.AT
819
820     @property
821     def is_import(self) -> bool:
822         """Is this an import line?"""
823         return bool(self) and is_import(self.leaves[0])
824
825     @property
826     def is_class(self) -> bool:
827         """Is this line a class definition?"""
828         return (
829             bool(self)
830             and self.leaves[0].type == token.NAME
831             and self.leaves[0].value == "class"
832         )
833
834     @property
835     def is_def(self) -> bool:
836         """Is this a function definition? (Also returns True for async defs.)"""
837         try:
838             first_leaf = self.leaves[0]
839         except IndexError:
840             return False
841
842         try:
843             second_leaf: Optional[Leaf] = self.leaves[1]
844         except IndexError:
845             second_leaf = None
846         return (
847             (first_leaf.type == token.NAME and first_leaf.value == "def")
848             or (
849                 first_leaf.type == token.ASYNC
850                 and second_leaf is not None
851                 and second_leaf.type == token.NAME
852                 and second_leaf.value == "def"
853             )
854         )
855
856     @property
857     def is_flow_control(self) -> bool:
858         """Is this line a flow control statement?
859
860         Those are `return`, `raise`, `break`, and `continue`.
861         """
862         return (
863             bool(self)
864             and self.leaves[0].type == token.NAME
865             and self.leaves[0].value in FLOW_CONTROL
866         )
867
868     @property
869     def is_yield(self) -> bool:
870         """Is this line a yield statement?"""
871         return (
872             bool(self)
873             and self.leaves[0].type == token.NAME
874             and self.leaves[0].value == "yield"
875         )
876
877     @property
878     def is_class_paren_empty(self) -> bool:
879         """Is this a class with no base classes but using parentheses?
880
881         Those are unnecessary and should be removed.
882         """
883         return (
884             bool(self)
885             and len(self.leaves) == 4
886             and self.is_class
887             and self.leaves[2].type == token.LPAR
888             and self.leaves[2].value == "("
889             and self.leaves[3].type == token.RPAR
890             and self.leaves[3].value == ")"
891         )
892
893     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
894         """If so, needs to be split before emitting."""
895         for leaf in self.leaves:
896             if leaf.type == STANDALONE_COMMENT:
897                 if leaf.bracket_depth <= depth_limit:
898                     return True
899
900         return False
901
902     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
903         """Remove trailing comma if there is one and it's safe."""
904         if not (
905             self.leaves
906             and self.leaves[-1].type == token.COMMA
907             and closing.type in CLOSING_BRACKETS
908         ):
909             return False
910
911         if closing.type == token.RBRACE:
912             self.remove_trailing_comma()
913             return True
914
915         if closing.type == token.RSQB:
916             comma = self.leaves[-1]
917             if comma.parent and comma.parent.type == syms.listmaker:
918                 self.remove_trailing_comma()
919                 return True
920
921         # For parens let's check if it's safe to remove the comma.
922         # Imports are always safe.
923         if self.is_import:
924             self.remove_trailing_comma()
925             return True
926
927         # Otheriwsse, if the trailing one is the only one, we might mistakenly
928         # change a tuple into a different type by removing the comma.
929         depth = closing.bracket_depth + 1
930         commas = 0
931         opening = closing.opening_bracket
932         for _opening_index, leaf in enumerate(self.leaves):
933             if leaf is opening:
934                 break
935
936         else:
937             return False
938
939         for leaf in self.leaves[_opening_index + 1 :]:
940             if leaf is closing:
941                 break
942
943             bracket_depth = leaf.bracket_depth
944             if bracket_depth == depth and leaf.type == token.COMMA:
945                 commas += 1
946                 if leaf.parent and leaf.parent.type == syms.arglist:
947                     commas += 1
948                     break
949
950         if commas > 1:
951             self.remove_trailing_comma()
952             return True
953
954         return False
955
956     def append_comment(self, comment: Leaf) -> bool:
957         """Add an inline or standalone comment to the line."""
958         if (
959             comment.type == STANDALONE_COMMENT
960             and self.bracket_tracker.any_open_brackets()
961         ):
962             comment.prefix = ""
963             return False
964
965         if comment.type != token.COMMENT:
966             return False
967
968         after = len(self.leaves) - 1
969         if after == -1:
970             comment.type = STANDALONE_COMMENT
971             comment.prefix = ""
972             return False
973
974         else:
975             self.comments.append((after, comment))
976             return True
977
978     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
979         """Generate comments that should appear directly after `leaf`.
980
981         Provide a non-negative leaf `_index` to speed up the function.
982         """
983         if _index == -1:
984             for _index, _leaf in enumerate(self.leaves):
985                 if leaf is _leaf:
986                     break
987
988             else:
989                 return
990
991         for index, comment_after in self.comments:
992             if _index == index:
993                 yield comment_after
994
995     def remove_trailing_comma(self) -> None:
996         """Remove the trailing comma and moves the comments attached to it."""
997         comma_index = len(self.leaves) - 1
998         for i in range(len(self.comments)):
999             comment_index, comment = self.comments[i]
1000             if comment_index == comma_index:
1001                 self.comments[i] = (comma_index - 1, comment)
1002         self.leaves.pop()
1003
1004     def is_complex_subscript(self, leaf: Leaf) -> bool:
1005         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1006         open_lsqb = (
1007             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1008         )
1009         if open_lsqb is None:
1010             return False
1011
1012         subscript_start = open_lsqb.next_sibling
1013         if (
1014             isinstance(subscript_start, Node)
1015             and subscript_start.type == syms.subscriptlist
1016         ):
1017             subscript_start = child_towards(subscript_start, leaf)
1018         return (
1019             subscript_start is not None
1020             and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
1021         )
1022
1023     def __str__(self) -> str:
1024         """Render the line."""
1025         if not self:
1026             return "\n"
1027
1028         indent = "    " * self.depth
1029         leaves = iter(self.leaves)
1030         first = next(leaves)
1031         res = f"{first.prefix}{indent}{first.value}"
1032         for leaf in leaves:
1033             res += str(leaf)
1034         for _, comment in self.comments:
1035             res += str(comment)
1036         return res + "\n"
1037
1038     def __bool__(self) -> bool:
1039         """Return True if the line has leaves or comments."""
1040         return bool(self.leaves or self.comments)
1041
1042
1043 class UnformattedLines(Line):
1044     """Just like :class:`Line` but stores lines which aren't reformatted."""
1045
1046     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1047         """Just add a new `leaf` to the end of the lines.
1048
1049         The `preformatted` argument is ignored.
1050
1051         Keeps track of indentation `depth`, which is useful when the user
1052         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1053         """
1054         try:
1055             list(generate_comments(leaf))
1056         except FormatOn as f_on:
1057             self.leaves.append(f_on.leaf_from_consumed(leaf))
1058             raise
1059
1060         self.leaves.append(leaf)
1061         if leaf.type == token.INDENT:
1062             self.depth += 1
1063         elif leaf.type == token.DEDENT:
1064             self.depth -= 1
1065
1066     def __str__(self) -> str:
1067         """Render unformatted lines from leaves which were added with `append()`.
1068
1069         `depth` is not used for indentation in this case.
1070         """
1071         if not self:
1072             return "\n"
1073
1074         res = ""
1075         for leaf in self.leaves:
1076             res += str(leaf)
1077         return res
1078
1079     def append_comment(self, comment: Leaf) -> bool:
1080         """Not implemented in this class. Raises `NotImplementedError`."""
1081         raise NotImplementedError("Unformatted lines don't store comments separately.")
1082
1083     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1084         """Does nothing and returns False."""
1085         return False
1086
1087     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1088         """Does nothing and returns False."""
1089         return False
1090
1091
1092 @dataclass
1093 class EmptyLineTracker:
1094     """Provides a stateful method that returns the number of potential extra
1095     empty lines needed before and after the currently processed line.
1096
1097     Note: this tracker works on lines that haven't been split yet.  It assumes
1098     the prefix of the first leaf consists of optional newlines.  Those newlines
1099     are consumed by `maybe_empty_lines()` and included in the computation.
1100     """
1101     previous_line: Optional[Line] = None
1102     previous_after: int = 0
1103     previous_defs: List[int] = Factory(list)
1104
1105     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1106         """Return the number of extra empty lines before and after the `current_line`.
1107
1108         This is for separating `def`, `async def` and `class` with extra empty
1109         lines (two on module-level), as well as providing an extra empty line
1110         after flow control keywords to make them more prominent.
1111         """
1112         if isinstance(current_line, UnformattedLines):
1113             return 0, 0
1114
1115         before, after = self._maybe_empty_lines(current_line)
1116         before -= self.previous_after
1117         self.previous_after = after
1118         self.previous_line = current_line
1119         return before, after
1120
1121     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1122         max_allowed = 1
1123         if current_line.depth == 0:
1124             max_allowed = 2
1125         if current_line.leaves:
1126             # Consume the first leaf's extra newlines.
1127             first_leaf = current_line.leaves[0]
1128             before = first_leaf.prefix.count("\n")
1129             before = min(before, max_allowed)
1130             first_leaf.prefix = ""
1131         else:
1132             before = 0
1133         depth = current_line.depth
1134         while self.previous_defs and self.previous_defs[-1] >= depth:
1135             self.previous_defs.pop()
1136             before = 1 if depth else 2
1137         is_decorator = current_line.is_decorator
1138         if is_decorator or current_line.is_def or current_line.is_class:
1139             if not is_decorator:
1140                 self.previous_defs.append(depth)
1141             if self.previous_line is None:
1142                 # Don't insert empty lines before the first line in the file.
1143                 return 0, 0
1144
1145             if self.previous_line.is_decorator:
1146                 return 0, 0
1147
1148             if (
1149                 self.previous_line.is_comment
1150                 and self.previous_line.depth == current_line.depth
1151                 and before == 0
1152             ):
1153                 return 0, 0
1154
1155             newlines = 2
1156             if current_line.depth:
1157                 newlines -= 1
1158             return newlines, 0
1159
1160         if (
1161             self.previous_line
1162             and self.previous_line.is_import
1163             and not current_line.is_import
1164             and depth == self.previous_line.depth
1165         ):
1166             return (before or 1), 0
1167
1168         return before, 0
1169
1170
1171 @dataclass
1172 class LineGenerator(Visitor[Line]):
1173     """Generates reformatted Line objects.  Empty lines are not emitted.
1174
1175     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1176     in ways that will no longer stringify to valid Python code on the tree.
1177     """
1178     current_line: Line = Factory(Line)
1179     remove_u_prefix: bool = False
1180
1181     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1182         """Generate a line.
1183
1184         If the line is empty, only emit if it makes sense.
1185         If the line is too long, split it first and then generate.
1186
1187         If any lines were generated, set up a new current_line.
1188         """
1189         if not self.current_line:
1190             if self.current_line.__class__ == type:
1191                 self.current_line.depth += indent
1192             else:
1193                 self.current_line = type(depth=self.current_line.depth + indent)
1194             return  # Line is empty, don't emit. Creating a new one unnecessary.
1195
1196         complete_line = self.current_line
1197         self.current_line = type(depth=complete_line.depth + indent)
1198         yield complete_line
1199
1200     def visit(self, node: LN) -> Iterator[Line]:
1201         """Main method to visit `node` and its children.
1202
1203         Yields :class:`Line` objects.
1204         """
1205         if isinstance(self.current_line, UnformattedLines):
1206             # File contained `# fmt: off`
1207             yield from self.visit_unformatted(node)
1208
1209         else:
1210             yield from super().visit(node)
1211
1212     def visit_default(self, node: LN) -> Iterator[Line]:
1213         """Default `visit_*()` implementation. Recurses to children of `node`."""
1214         if isinstance(node, Leaf):
1215             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1216             try:
1217                 for comment in generate_comments(node):
1218                     if any_open_brackets:
1219                         # any comment within brackets is subject to splitting
1220                         self.current_line.append(comment)
1221                     elif comment.type == token.COMMENT:
1222                         # regular trailing comment
1223                         self.current_line.append(comment)
1224                         yield from self.line()
1225
1226                     else:
1227                         # regular standalone comment
1228                         yield from self.line()
1229
1230                         self.current_line.append(comment)
1231                         yield from self.line()
1232
1233             except FormatOff as f_off:
1234                 f_off.trim_prefix(node)
1235                 yield from self.line(type=UnformattedLines)
1236                 yield from self.visit(node)
1237
1238             except FormatOn as f_on:
1239                 # This only happens here if somebody says "fmt: on" multiple
1240                 # times in a row.
1241                 f_on.trim_prefix(node)
1242                 yield from self.visit_default(node)
1243
1244             else:
1245                 normalize_prefix(node, inside_brackets=any_open_brackets)
1246                 if node.type == token.STRING:
1247                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1248                     normalize_string_quotes(node)
1249                 if node.type not in WHITESPACE:
1250                     self.current_line.append(node)
1251         yield from super().visit_default(node)
1252
1253     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1254         """Increase indentation level, maybe yield a line."""
1255         # In blib2to3 INDENT never holds comments.
1256         yield from self.line(+1)
1257         yield from self.visit_default(node)
1258
1259     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1260         """Decrease indentation level, maybe yield a line."""
1261         # The current line might still wait for trailing comments.  At DEDENT time
1262         # there won't be any (they would be prefixes on the preceding NEWLINE).
1263         # Emit the line then.
1264         yield from self.line()
1265
1266         # While DEDENT has no value, its prefix may contain standalone comments
1267         # that belong to the current indentation level.  Get 'em.
1268         yield from self.visit_default(node)
1269
1270         # Finally, emit the dedent.
1271         yield from self.line(-1)
1272
1273     def visit_stmt(
1274         self, node: Node, keywords: Set[str], parens: Set[str]
1275     ) -> Iterator[Line]:
1276         """Visit a statement.
1277
1278         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1279         `def`, `with`, `class`, `assert` and assignments.
1280
1281         The relevant Python language `keywords` for a given statement will be
1282         NAME leaves within it. This methods puts those on a separate line.
1283
1284         `parens` holds a set of string leaf values immeditely after which
1285         invisible parens should be put.
1286         """
1287         normalize_invisible_parens(node, parens_after=parens)
1288         for child in node.children:
1289             if child.type == token.NAME and child.value in keywords:  # type: ignore
1290                 yield from self.line()
1291
1292             yield from self.visit(child)
1293
1294     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1295         """Visit a statement without nested statements."""
1296         is_suite_like = node.parent and node.parent.type in STATEMENT
1297         if is_suite_like:
1298             yield from self.line(+1)
1299             yield from self.visit_default(node)
1300             yield from self.line(-1)
1301
1302         else:
1303             yield from self.line()
1304             yield from self.visit_default(node)
1305
1306     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1307         """Visit `async def`, `async for`, `async with`."""
1308         yield from self.line()
1309
1310         children = iter(node.children)
1311         for child in children:
1312             yield from self.visit(child)
1313
1314             if child.type == token.ASYNC:
1315                 break
1316
1317         internal_stmt = next(children)
1318         for child in internal_stmt.children:
1319             yield from self.visit(child)
1320
1321     def visit_decorators(self, node: Node) -> Iterator[Line]:
1322         """Visit decorators."""
1323         for child in node.children:
1324             yield from self.line()
1325             yield from self.visit(child)
1326
1327     def visit_import_from(self, node: Node) -> Iterator[Line]:
1328         """Visit import_from and maybe put invisible parentheses.
1329
1330         This is separate from `visit_stmt` because import statements don't
1331         support arbitrary atoms and thus handling of parentheses is custom.
1332         """
1333         check_lpar = False
1334         for index, child in enumerate(node.children):
1335             if check_lpar:
1336                 if child.type == token.LPAR:
1337                     # make parentheses invisible
1338                     child.value = ""  # type: ignore
1339                     node.children[-1].value = ""  # type: ignore
1340                 else:
1341                     # insert invisible parentheses
1342                     node.insert_child(index, Leaf(token.LPAR, ""))
1343                     node.append_child(Leaf(token.RPAR, ""))
1344                 break
1345
1346             check_lpar = (
1347                 child.type == token.NAME and child.value == "import"  # type: ignore
1348             )
1349
1350         for child in node.children:
1351             yield from self.visit(child)
1352
1353     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1354         """Remove a semicolon and put the other statement on a separate line."""
1355         yield from self.line()
1356
1357     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1358         """End of file. Process outstanding comments and end with a newline."""
1359         yield from self.visit_default(leaf)
1360         yield from self.line()
1361
1362     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1363         """Used when file contained a `# fmt: off`."""
1364         if isinstance(node, Node):
1365             for child in node.children:
1366                 yield from self.visit(child)
1367
1368         else:
1369             try:
1370                 self.current_line.append(node)
1371             except FormatOn as f_on:
1372                 f_on.trim_prefix(node)
1373                 yield from self.line()
1374                 yield from self.visit(node)
1375
1376             if node.type == token.ENDMARKER:
1377                 # somebody decided not to put a final `# fmt: on`
1378                 yield from self.line()
1379
1380     def __attrs_post_init__(self) -> None:
1381         """You are in a twisty little maze of passages."""
1382         v = self.visit_stmt
1383         Ø: Set[str] = set()
1384         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1385         self.visit_if_stmt = partial(
1386             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1387         )
1388         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1389         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1390         self.visit_try_stmt = partial(
1391             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1392         )
1393         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1394         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1395         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1396         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1397         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1398         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1399         self.visit_async_funcdef = self.visit_async_stmt
1400         self.visit_decorated = self.visit_decorators
1401
1402
1403 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1404 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1405 OPENING_BRACKETS = set(BRACKET.keys())
1406 CLOSING_BRACKETS = set(BRACKET.values())
1407 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1408 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1409
1410
1411 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1412     """Return whitespace prefix if needed for the given `leaf`.
1413
1414     `complex_subscript` signals whether the given leaf is part of a subscription
1415     which has non-trivial arguments, like arithmetic expressions or function calls.
1416     """
1417     NO = ""
1418     SPACE = " "
1419     DOUBLESPACE = "  "
1420     t = leaf.type
1421     p = leaf.parent
1422     v = leaf.value
1423     if t in ALWAYS_NO_SPACE:
1424         return NO
1425
1426     if t == token.COMMENT:
1427         return DOUBLESPACE
1428
1429     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1430     if (
1431         t == token.COLON
1432         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1433     ):
1434         return NO
1435
1436     prev = leaf.prev_sibling
1437     if not prev:
1438         prevp = preceding_leaf(p)
1439         if not prevp or prevp.type in OPENING_BRACKETS:
1440             return NO
1441
1442         if t == token.COLON:
1443             if prevp.type == token.COLON:
1444                 return NO
1445
1446             elif prevp.type != token.COMMA and not complex_subscript:
1447                 return NO
1448
1449             return SPACE
1450
1451         if prevp.type == token.EQUAL:
1452             if prevp.parent:
1453                 if prevp.parent.type in {
1454                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1455                 }:
1456                     return NO
1457
1458                 elif prevp.parent.type == syms.typedargslist:
1459                     # A bit hacky: if the equal sign has whitespace, it means we
1460                     # previously found it's a typed argument.  So, we're using
1461                     # that, too.
1462                     return prevp.prefix
1463
1464         elif prevp.type in STARS:
1465             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1466                 return NO
1467
1468         elif prevp.type == token.COLON:
1469             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1470                 return SPACE if complex_subscript else NO
1471
1472         elif (
1473             prevp.parent
1474             and prevp.parent.type == syms.factor
1475             and prevp.type in MATH_OPERATORS
1476         ):
1477             return NO
1478
1479         elif (
1480             prevp.type == token.RIGHTSHIFT
1481             and prevp.parent
1482             and prevp.parent.type == syms.shift_expr
1483             and prevp.prev_sibling
1484             and prevp.prev_sibling.type == token.NAME
1485             and prevp.prev_sibling.value == "print"  # type: ignore
1486         ):
1487             # Python 2 print chevron
1488             return NO
1489
1490     elif prev.type in OPENING_BRACKETS:
1491         return NO
1492
1493     if p.type in {syms.parameters, syms.arglist}:
1494         # untyped function signatures or calls
1495         if not prev or prev.type != token.COMMA:
1496             return NO
1497
1498     elif p.type == syms.varargslist:
1499         # lambdas
1500         if prev and prev.type != token.COMMA:
1501             return NO
1502
1503     elif p.type == syms.typedargslist:
1504         # typed function signatures
1505         if not prev:
1506             return NO
1507
1508         if t == token.EQUAL:
1509             if prev.type != syms.tname:
1510                 return NO
1511
1512         elif prev.type == token.EQUAL:
1513             # A bit hacky: if the equal sign has whitespace, it means we
1514             # previously found it's a typed argument.  So, we're using that, too.
1515             return prev.prefix
1516
1517         elif prev.type != token.COMMA:
1518             return NO
1519
1520     elif p.type == syms.tname:
1521         # type names
1522         if not prev:
1523             prevp = preceding_leaf(p)
1524             if not prevp or prevp.type != token.COMMA:
1525                 return NO
1526
1527     elif p.type == syms.trailer:
1528         # attributes and calls
1529         if t == token.LPAR or t == token.RPAR:
1530             return NO
1531
1532         if not prev:
1533             if t == token.DOT:
1534                 prevp = preceding_leaf(p)
1535                 if not prevp or prevp.type != token.NUMBER:
1536                     return NO
1537
1538             elif t == token.LSQB:
1539                 return NO
1540
1541         elif prev.type != token.COMMA:
1542             return NO
1543
1544     elif p.type == syms.argument:
1545         # single argument
1546         if t == token.EQUAL:
1547             return NO
1548
1549         if not prev:
1550             prevp = preceding_leaf(p)
1551             if not prevp or prevp.type == token.LPAR:
1552                 return NO
1553
1554         elif prev.type in {token.EQUAL} | STARS:
1555             return NO
1556
1557     elif p.type == syms.decorator:
1558         # decorators
1559         return NO
1560
1561     elif p.type == syms.dotted_name:
1562         if prev:
1563             return NO
1564
1565         prevp = preceding_leaf(p)
1566         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1567             return NO
1568
1569     elif p.type == syms.classdef:
1570         if t == token.LPAR:
1571             return NO
1572
1573         if prev and prev.type == token.LPAR:
1574             return NO
1575
1576     elif p.type in {syms.subscript, syms.sliceop}:
1577         # indexing
1578         if not prev:
1579             assert p.parent is not None, "subscripts are always parented"
1580             if p.parent.type == syms.subscriptlist:
1581                 return SPACE
1582
1583             return NO
1584
1585         elif not complex_subscript:
1586             return NO
1587
1588     elif p.type == syms.atom:
1589         if prev and t == token.DOT:
1590             # dots, but not the first one.
1591             return NO
1592
1593     elif p.type == syms.dictsetmaker:
1594         # dict unpacking
1595         if prev and prev.type == token.DOUBLESTAR:
1596             return NO
1597
1598     elif p.type in {syms.factor, syms.star_expr}:
1599         # unary ops
1600         if not prev:
1601             prevp = preceding_leaf(p)
1602             if not prevp or prevp.type in OPENING_BRACKETS:
1603                 return NO
1604
1605             prevp_parent = prevp.parent
1606             assert prevp_parent is not None
1607             if (
1608                 prevp.type == token.COLON
1609                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1610             ):
1611                 return NO
1612
1613             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1614                 return NO
1615
1616         elif t == token.NAME or t == token.NUMBER:
1617             return NO
1618
1619     elif p.type == syms.import_from:
1620         if t == token.DOT:
1621             if prev and prev.type == token.DOT:
1622                 return NO
1623
1624         elif t == token.NAME:
1625             if v == "import":
1626                 return SPACE
1627
1628             if prev and prev.type == token.DOT:
1629                 return NO
1630
1631     elif p.type == syms.sliceop:
1632         return NO
1633
1634     return SPACE
1635
1636
1637 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1638     """Return the first leaf that precedes `node`, if any."""
1639     while node:
1640         res = node.prev_sibling
1641         if res:
1642             if isinstance(res, Leaf):
1643                 return res
1644
1645             try:
1646                 return list(res.leaves())[-1]
1647
1648             except IndexError:
1649                 return None
1650
1651         node = node.parent
1652     return None
1653
1654
1655 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1656     """Return the child of `ancestor` that contains `descendant`."""
1657     node: Optional[LN] = descendant
1658     while node and node.parent != ancestor:
1659         node = node.parent
1660     return node
1661
1662
1663 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1664     """Return the priority of the `leaf` delimiter, given a line break after it.
1665
1666     The delimiter priorities returned here are from those delimiters that would
1667     cause a line break after themselves.
1668
1669     Higher numbers are higher priority.
1670     """
1671     if leaf.type == token.COMMA:
1672         return COMMA_PRIORITY
1673
1674     return 0
1675
1676
1677 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1678     """Return the priority of the `leaf` delimiter, given a line before after it.
1679
1680     The delimiter priorities returned here are from those delimiters that would
1681     cause a line break before themselves.
1682
1683     Higher numbers are higher priority.
1684     """
1685     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1686         # * and ** might also be MATH_OPERATORS but in this case they are not.
1687         # Don't treat them as a delimiter.
1688         return 0
1689
1690     if (
1691         leaf.type in MATH_OPERATORS
1692         and leaf.parent
1693         and leaf.parent.type not in {syms.factor, syms.star_expr}
1694     ):
1695         return MATH_PRIORITIES[leaf.type]
1696
1697     if leaf.type in COMPARATORS:
1698         return COMPARATOR_PRIORITY
1699
1700     if (
1701         leaf.type == token.STRING
1702         and previous is not None
1703         and previous.type == token.STRING
1704     ):
1705         return STRING_PRIORITY
1706
1707     if (
1708         leaf.type == token.NAME
1709         and leaf.value == "for"
1710         and leaf.parent
1711         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1712     ):
1713         return COMPREHENSION_PRIORITY
1714
1715     if (
1716         leaf.type == token.NAME
1717         and leaf.value == "if"
1718         and leaf.parent
1719         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1720     ):
1721         return COMPREHENSION_PRIORITY
1722
1723     if (
1724         leaf.type == token.NAME
1725         and leaf.value in {"if", "else"}
1726         and leaf.parent
1727         and leaf.parent.type == syms.test
1728     ):
1729         return TERNARY_PRIORITY
1730
1731     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1732         return LOGIC_PRIORITY
1733
1734     return 0
1735
1736
1737 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1738     """Clean the prefix of the `leaf` and generate comments from it, if any.
1739
1740     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1741     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1742     move because it does away with modifying the grammar to include all the
1743     possible places in which comments can be placed.
1744
1745     The sad consequence for us though is that comments don't "belong" anywhere.
1746     This is why this function generates simple parentless Leaf objects for
1747     comments.  We simply don't know what the correct parent should be.
1748
1749     No matter though, we can live without this.  We really only need to
1750     differentiate between inline and standalone comments.  The latter don't
1751     share the line with any code.
1752
1753     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1754     are emitted with a fake STANDALONE_COMMENT token identifier.
1755     """
1756     p = leaf.prefix
1757     if not p:
1758         return
1759
1760     if "#" not in p:
1761         return
1762
1763     consumed = 0
1764     nlines = 0
1765     for index, line in enumerate(p.split("\n")):
1766         consumed += len(line) + 1  # adding the length of the split '\n'
1767         line = line.lstrip()
1768         if not line:
1769             nlines += 1
1770         if not line.startswith("#"):
1771             continue
1772
1773         if index == 0 and leaf.type != token.ENDMARKER:
1774             comment_type = token.COMMENT  # simple trailing comment
1775         else:
1776             comment_type = STANDALONE_COMMENT
1777         comment = make_comment(line)
1778         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1779
1780         if comment in {"# fmt: on", "# yapf: enable"}:
1781             raise FormatOn(consumed)
1782
1783         if comment in {"# fmt: off", "# yapf: disable"}:
1784             if comment_type == STANDALONE_COMMENT:
1785                 raise FormatOff(consumed)
1786
1787             prev = preceding_leaf(leaf)
1788             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1789                 raise FormatOff(consumed)
1790
1791         nlines = 0
1792
1793
1794 def make_comment(content: str) -> str:
1795     """Return a consistently formatted comment from the given `content` string.
1796
1797     All comments (except for "##", "#!", "#:") should have a single space between
1798     the hash sign and the content.
1799
1800     If `content` didn't start with a hash sign, one is provided.
1801     """
1802     content = content.rstrip()
1803     if not content:
1804         return "#"
1805
1806     if content[0] == "#":
1807         content = content[1:]
1808     if content and content[0] not in " !:#":
1809         content = " " + content
1810     return "#" + content
1811
1812
1813 def split_line(
1814     line: Line, line_length: int, inner: bool = False, py36: bool = False
1815 ) -> Iterator[Line]:
1816     """Split a `line` into potentially many lines.
1817
1818     They should fit in the allotted `line_length` but might not be able to.
1819     `inner` signifies that there were a pair of brackets somewhere around the
1820     current `line`, possibly transitively. This means we can fallback to splitting
1821     by delimiters if the LHS/RHS don't yield any results.
1822
1823     If `py36` is True, splitting may generate syntax that is only compatible
1824     with Python 3.6 and later.
1825     """
1826     if isinstance(line, UnformattedLines) or line.is_comment:
1827         yield line
1828         return
1829
1830     line_str = str(line).strip("\n")
1831     if (
1832         len(line_str) <= line_length
1833         and "\n" not in line_str  # multiline strings
1834         and not line.contains_standalone_comments()
1835     ):
1836         yield line
1837         return
1838
1839     split_funcs: List[SplitFunc]
1840     if line.is_def:
1841         split_funcs = [left_hand_split]
1842     elif line.is_import:
1843         split_funcs = [explode_split]
1844     elif line.inside_brackets:
1845         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1846     else:
1847         split_funcs = [right_hand_split]
1848     for split_func in split_funcs:
1849         # We are accumulating lines in `result` because we might want to abort
1850         # mission and return the original line in the end, or attempt a different
1851         # split altogether.
1852         result: List[Line] = []
1853         try:
1854             for l in split_func(line, py36):
1855                 if str(l).strip("\n") == line_str:
1856                     raise CannotSplit("Split function returned an unchanged result")
1857
1858                 result.extend(
1859                     split_line(l, line_length=line_length, inner=True, py36=py36)
1860                 )
1861         except CannotSplit as cs:
1862             continue
1863
1864         else:
1865             yield from result
1866             break
1867
1868     else:
1869         yield line
1870
1871
1872 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1873     """Split line into many lines, starting with the first matching bracket pair.
1874
1875     Note: this usually looks weird, only use this for function definitions.
1876     Prefer RHS otherwise.  This is why this function is not symmetrical with
1877     :func:`right_hand_split` which also handles optional parentheses.
1878     """
1879     head = Line(depth=line.depth)
1880     body = Line(depth=line.depth + 1, inside_brackets=True)
1881     tail = Line(depth=line.depth)
1882     tail_leaves: List[Leaf] = []
1883     body_leaves: List[Leaf] = []
1884     head_leaves: List[Leaf] = []
1885     current_leaves = head_leaves
1886     matching_bracket = None
1887     for leaf in line.leaves:
1888         if (
1889             current_leaves is body_leaves
1890             and leaf.type in CLOSING_BRACKETS
1891             and leaf.opening_bracket is matching_bracket
1892         ):
1893             current_leaves = tail_leaves if body_leaves else head_leaves
1894         current_leaves.append(leaf)
1895         if current_leaves is head_leaves:
1896             if leaf.type in OPENING_BRACKETS:
1897                 matching_bracket = leaf
1898                 current_leaves = body_leaves
1899     # Since body is a new indent level, remove spurious leading whitespace.
1900     if body_leaves:
1901         normalize_prefix(body_leaves[0], inside_brackets=True)
1902     # Build the new lines.
1903     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1904         for leaf in leaves:
1905             result.append(leaf, preformatted=True)
1906             for comment_after in line.comments_after(leaf):
1907                 result.append(comment_after, preformatted=True)
1908     bracket_split_succeeded_or_raise(head, body, tail)
1909     for result in (head, body, tail):
1910         if result:
1911             yield result
1912
1913
1914 def right_hand_split(
1915     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1916 ) -> Iterator[Line]:
1917     """Split line into many lines, starting with the last matching bracket pair.
1918
1919     If the split was by optional parentheses, attempt splitting without them, too.
1920     """
1921     head = Line(depth=line.depth)
1922     body = Line(depth=line.depth + 1, inside_brackets=True)
1923     tail = Line(depth=line.depth)
1924     tail_leaves: List[Leaf] = []
1925     body_leaves: List[Leaf] = []
1926     head_leaves: List[Leaf] = []
1927     current_leaves = tail_leaves
1928     opening_bracket = None
1929     closing_bracket = None
1930     for leaf in reversed(line.leaves):
1931         if current_leaves is body_leaves:
1932             if leaf is opening_bracket:
1933                 current_leaves = head_leaves if body_leaves else tail_leaves
1934         current_leaves.append(leaf)
1935         if current_leaves is tail_leaves:
1936             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1937                 opening_bracket = leaf.opening_bracket
1938                 closing_bracket = leaf
1939                 current_leaves = body_leaves
1940     tail_leaves.reverse()
1941     body_leaves.reverse()
1942     head_leaves.reverse()
1943     # Since body is a new indent level, remove spurious leading whitespace.
1944     if body_leaves:
1945         normalize_prefix(body_leaves[0], inside_brackets=True)
1946     elif not head_leaves:
1947         # No `head` and no `body` means the split failed. `tail` has all content.
1948         raise CannotSplit("No brackets found")
1949
1950     # Build the new lines.
1951     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1952         for leaf in leaves:
1953             result.append(leaf, preformatted=True)
1954             for comment_after in line.comments_after(leaf):
1955                 result.append(comment_after, preformatted=True)
1956     bracket_split_succeeded_or_raise(head, body, tail)
1957     assert opening_bracket and closing_bracket
1958     if (
1959         # the opening bracket is an optional paren
1960         opening_bracket.type == token.LPAR
1961         and not opening_bracket.value
1962         # the closing bracket is an optional paren
1963         and closing_bracket.type == token.RPAR
1964         and not closing_bracket.value
1965         # there are no delimiters or standalone comments in the body
1966         and not body.bracket_tracker.delimiters
1967         and not line.contains_standalone_comments(0)
1968         # and it's not an import (optional parens are the only thing we can split
1969         # on in this case; attempting a split without them is a waste of time)
1970         and not line.is_import
1971     ):
1972         omit = {id(closing_bracket), *omit}
1973         try:
1974             yield from right_hand_split(line, py36=py36, omit=omit)
1975             return
1976         except CannotSplit:
1977             pass
1978
1979     ensure_visible(opening_bracket)
1980     ensure_visible(closing_bracket)
1981     for result in (head, body, tail):
1982         if result:
1983             yield result
1984
1985
1986 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1987     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1988
1989     Do nothing otherwise.
1990
1991     A left- or right-hand split is based on a pair of brackets. Content before
1992     (and including) the opening bracket is left on one line, content inside the
1993     brackets is put on a separate line, and finally content starting with and
1994     following the closing bracket is put on a separate line.
1995
1996     Those are called `head`, `body`, and `tail`, respectively. If the split
1997     produced the same line (all content in `head`) or ended up with an empty `body`
1998     and the `tail` is just the closing bracket, then it's considered failed.
1999     """
2000     tail_len = len(str(tail).strip())
2001     if not body:
2002         if tail_len == 0:
2003             raise CannotSplit("Splitting brackets produced the same line")
2004
2005         elif tail_len < 3:
2006             raise CannotSplit(
2007                 f"Splitting brackets on an empty body to save "
2008                 f"{tail_len} characters is not worth it"
2009             )
2010
2011
2012 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2013     """Normalize prefix of the first leaf in every line returned by `split_func`.
2014
2015     This is a decorator over relevant split functions.
2016     """
2017
2018     @wraps(split_func)
2019     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2020         for l in split_func(line, py36):
2021             normalize_prefix(l.leaves[0], inside_brackets=True)
2022             yield l
2023
2024     return split_wrapper
2025
2026
2027 @dont_increase_indentation
2028 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2029     """Split according to delimiters of the highest priority.
2030
2031     If `py36` is True, the split will add trailing commas also in function
2032     signatures that contain `*` and `**`.
2033     """
2034     try:
2035         last_leaf = line.leaves[-1]
2036     except IndexError:
2037         raise CannotSplit("Line empty")
2038
2039     delimiters = line.bracket_tracker.delimiters
2040     try:
2041         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2042             exclude={id(last_leaf)}
2043         )
2044     except ValueError:
2045         raise CannotSplit("No delimiters found")
2046
2047     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2048     lowest_depth = sys.maxsize
2049     trailing_comma_safe = True
2050
2051     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2052         """Append `leaf` to current line or to new line if appending impossible."""
2053         nonlocal current_line
2054         try:
2055             current_line.append_safe(leaf, preformatted=True)
2056         except ValueError as ve:
2057             yield current_line
2058
2059             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2060             current_line.append(leaf)
2061
2062     for index, leaf in enumerate(line.leaves):
2063         yield from append_to_line(leaf)
2064
2065         for comment_after in line.comments_after(leaf, index):
2066             yield from append_to_line(comment_after)
2067
2068         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2069         if (
2070             leaf.bracket_depth == lowest_depth
2071             and is_vararg(leaf, within=VARARGS_PARENTS)
2072         ):
2073             trailing_comma_safe = trailing_comma_safe and py36
2074         leaf_priority = delimiters.get(id(leaf))
2075         if leaf_priority == delimiter_priority:
2076             yield current_line
2077
2078             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2079     if current_line:
2080         if (
2081             trailing_comma_safe
2082             and delimiter_priority == COMMA_PRIORITY
2083             and current_line.leaves[-1].type != token.COMMA
2084             and current_line.leaves[-1].type != STANDALONE_COMMENT
2085         ):
2086             current_line.append(Leaf(token.COMMA, ","))
2087         yield current_line
2088
2089
2090 @dont_increase_indentation
2091 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2092     """Split standalone comments from the rest of the line."""
2093     if not line.contains_standalone_comments(0):
2094         raise CannotSplit("Line does not have any standalone comments")
2095
2096     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2097
2098     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2099         """Append `leaf` to current line or to new line if appending impossible."""
2100         nonlocal current_line
2101         try:
2102             current_line.append_safe(leaf, preformatted=True)
2103         except ValueError as ve:
2104             yield current_line
2105
2106             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2107             current_line.append(leaf)
2108
2109     for index, leaf in enumerate(line.leaves):
2110         yield from append_to_line(leaf)
2111
2112         for comment_after in line.comments_after(leaf, index):
2113             yield from append_to_line(comment_after)
2114
2115     if current_line:
2116         yield current_line
2117
2118
2119 def explode_split(
2120     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2121 ) -> Iterator[Line]:
2122     """Split by rightmost bracket and immediately split contents by a delimiter."""
2123     new_lines = list(right_hand_split(line, py36, omit))
2124     if len(new_lines) != 3:
2125         yield from new_lines
2126         return
2127
2128     yield new_lines[0]
2129
2130     try:
2131         yield from delimiter_split(new_lines[1], py36)
2132
2133     except CannotSplit:
2134         yield new_lines[1]
2135
2136     yield new_lines[2]
2137
2138
2139 def is_import(leaf: Leaf) -> bool:
2140     """Return True if the given leaf starts an import statement."""
2141     p = leaf.parent
2142     t = leaf.type
2143     v = leaf.value
2144     return bool(
2145         t == token.NAME
2146         and (
2147             (v == "import" and p and p.type == syms.import_name)
2148             or (v == "from" and p and p.type == syms.import_from)
2149         )
2150     )
2151
2152
2153 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2154     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2155     else.
2156
2157     Note: don't use backslashes for formatting or you'll lose your voting rights.
2158     """
2159     if not inside_brackets:
2160         spl = leaf.prefix.split("#")
2161         if "\\" not in spl[0]:
2162             nl_count = spl[-1].count("\n")
2163             if len(spl) > 1:
2164                 nl_count -= 1
2165             leaf.prefix = "\n" * nl_count
2166             return
2167
2168     leaf.prefix = ""
2169
2170
2171 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2172     """Make all string prefixes lowercase.
2173
2174     If remove_u_prefix is given, also removes any u prefix from the string.
2175
2176     Note: Mutates its argument.
2177     """
2178     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2179     assert match is not None, f"failed to match string {leaf.value!r}"
2180     orig_prefix = match.group(1)
2181     new_prefix = orig_prefix.lower()
2182     if remove_u_prefix:
2183         new_prefix = new_prefix.replace("u", "")
2184     leaf.value = f"{new_prefix}{match.group(2)}"
2185
2186
2187 def normalize_string_quotes(leaf: Leaf) -> None:
2188     """Prefer double quotes but only if it doesn't cause more escaping.
2189
2190     Adds or removes backslashes as appropriate. Doesn't parse and fix
2191     strings nested in f-strings (yet).
2192
2193     Note: Mutates its argument.
2194     """
2195     value = leaf.value.lstrip("furbFURB")
2196     if value[:3] == '"""':
2197         return
2198
2199     elif value[:3] == "'''":
2200         orig_quote = "'''"
2201         new_quote = '"""'
2202     elif value[0] == '"':
2203         orig_quote = '"'
2204         new_quote = "'"
2205     else:
2206         orig_quote = "'"
2207         new_quote = '"'
2208     first_quote_pos = leaf.value.find(orig_quote)
2209     if first_quote_pos == -1:
2210         return  # There's an internal error
2211
2212     prefix = leaf.value[:first_quote_pos]
2213     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2214     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2215     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2216     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2217     if "r" in prefix.casefold():
2218         if unescaped_new_quote.search(body):
2219             # There's at least one unescaped new_quote in this raw string
2220             # so converting is impossible
2221             return
2222
2223         # Do not introduce or remove backslashes in raw strings
2224         new_body = body
2225     else:
2226         # remove unnecessary quotes
2227         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2228         if body != new_body:
2229             # Consider the string without unnecessary quotes as the original
2230             body = new_body
2231             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2232         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2233         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2234     if new_quote == '"""' and new_body[-1] == '"':
2235         # edge case:
2236         new_body = new_body[:-1] + '\\"'
2237     orig_escape_count = body.count("\\")
2238     new_escape_count = new_body.count("\\")
2239     if new_escape_count > orig_escape_count:
2240         return  # Do not introduce more escaping
2241
2242     if new_escape_count == orig_escape_count and orig_quote == '"':
2243         return  # Prefer double quotes
2244
2245     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2246
2247
2248 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2249     """Make existing optional parentheses invisible or create new ones.
2250
2251     `parens_after` is a set of string leaf values immeditely after which parens
2252     should be put.
2253
2254     Standardizes on visible parentheses for single-element tuples, and keeps
2255     existing visible parentheses for other tuples and generator expressions.
2256     """
2257     check_lpar = False
2258     for child in list(node.children):
2259         if check_lpar:
2260             if child.type == syms.atom:
2261                 maybe_make_parens_invisible_in_atom(child)
2262             elif is_one_tuple(child):
2263                 # wrap child in visible parentheses
2264                 lpar = Leaf(token.LPAR, "(")
2265                 rpar = Leaf(token.RPAR, ")")
2266                 index = child.remove() or 0
2267                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2268             else:
2269                 # wrap child in invisible parentheses
2270                 lpar = Leaf(token.LPAR, "")
2271                 rpar = Leaf(token.RPAR, "")
2272                 index = child.remove() or 0
2273                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2274
2275         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2276
2277
2278 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2279     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2280     if (
2281         node.type != syms.atom
2282         or is_empty_tuple(node)
2283         or is_one_tuple(node)
2284         or is_yield(node)
2285         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2286     ):
2287         return False
2288
2289     first = node.children[0]
2290     last = node.children[-1]
2291     if first.type == token.LPAR and last.type == token.RPAR:
2292         # make parentheses invisible
2293         first.value = ""  # type: ignore
2294         last.value = ""  # type: ignore
2295         if len(node.children) > 1:
2296             maybe_make_parens_invisible_in_atom(node.children[1])
2297         return True
2298
2299     return False
2300
2301
2302 def is_empty_tuple(node: LN) -> bool:
2303     """Return True if `node` holds an empty tuple."""
2304     return (
2305         node.type == syms.atom
2306         and len(node.children) == 2
2307         and node.children[0].type == token.LPAR
2308         and node.children[1].type == token.RPAR
2309     )
2310
2311
2312 def is_one_tuple(node: LN) -> bool:
2313     """Return True if `node` holds a tuple with one element, with or without parens."""
2314     if node.type == syms.atom:
2315         if len(node.children) != 3:
2316             return False
2317
2318         lpar, gexp, rpar = node.children
2319         if not (
2320             lpar.type == token.LPAR
2321             and gexp.type == syms.testlist_gexp
2322             and rpar.type == token.RPAR
2323         ):
2324             return False
2325
2326         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2327
2328     return (
2329         node.type in IMPLICIT_TUPLE
2330         and len(node.children) == 2
2331         and node.children[1].type == token.COMMA
2332     )
2333
2334
2335 def is_yield(node: LN) -> bool:
2336     """Return True if `node` holds a `yield` or `yield from` expression."""
2337     if node.type == syms.yield_expr:
2338         return True
2339
2340     if node.type == token.NAME and node.value == "yield":  # type: ignore
2341         return True
2342
2343     if node.type != syms.atom:
2344         return False
2345
2346     if len(node.children) != 3:
2347         return False
2348
2349     lpar, expr, rpar = node.children
2350     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2351         return is_yield(expr)
2352
2353     return False
2354
2355
2356 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2357     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2358
2359     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2360     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2361     extended iterable unpacking (PEP 3132) and additional unpacking
2362     generalizations (PEP 448).
2363     """
2364     if leaf.type not in STARS or not leaf.parent:
2365         return False
2366
2367     p = leaf.parent
2368     if p.type == syms.star_expr:
2369         # Star expressions are also used as assignment targets in extended
2370         # iterable unpacking (PEP 3132).  See what its parent is instead.
2371         if not p.parent:
2372             return False
2373
2374         p = p.parent
2375
2376     return p.type in within
2377
2378
2379 def max_delimiter_priority_in_atom(node: LN) -> int:
2380     """Return maximum delimiter priority inside `node`.
2381
2382     This is specific to atoms with contents contained in a pair of parentheses.
2383     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2384     """
2385     if node.type != syms.atom:
2386         return 0
2387
2388     first = node.children[0]
2389     last = node.children[-1]
2390     if not (first.type == token.LPAR and last.type == token.RPAR):
2391         return 0
2392
2393     bt = BracketTracker()
2394     for c in node.children[1:-1]:
2395         if isinstance(c, Leaf):
2396             bt.mark(c)
2397         else:
2398             for leaf in c.leaves():
2399                 bt.mark(leaf)
2400     try:
2401         return bt.max_delimiter_priority()
2402
2403     except ValueError:
2404         return 0
2405
2406
2407 def ensure_visible(leaf: Leaf) -> None:
2408     """Make sure parentheses are visible.
2409
2410     They could be invisible as part of some statements (see
2411     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2412     """
2413     if leaf.type == token.LPAR:
2414         leaf.value = "("
2415     elif leaf.type == token.RPAR:
2416         leaf.value = ")"
2417
2418
2419 def is_python36(node: Node) -> bool:
2420     """Return True if the current file is using Python 3.6+ features.
2421
2422     Currently looking for:
2423     - f-strings; and
2424     - trailing commas after * or ** in function signatures and calls.
2425     """
2426     for n in node.pre_order():
2427         if n.type == token.STRING:
2428             value_head = n.value[:2]  # type: ignore
2429             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2430                 return True
2431
2432         elif (
2433             n.type in {syms.typedargslist, syms.arglist}
2434             and n.children
2435             and n.children[-1].type == token.COMMA
2436         ):
2437             for ch in n.children:
2438                 if ch.type in STARS:
2439                     return True
2440
2441                 if ch.type == syms.argument:
2442                     for argch in ch.children:
2443                         if argch.type in STARS:
2444                             return True
2445
2446     return False
2447
2448
2449 def get_future_imports(node: Node) -> Set[str]:
2450     """Return a set of __future__ imports in the file."""
2451     imports = set()
2452     for child in node.children:
2453         if child.type != syms.simple_stmt:
2454             break
2455         first_child = child.children[0]
2456         if isinstance(first_child, Leaf):
2457             # Continue looking if we see a docstring; otherwise stop.
2458             if (
2459                 len(child.children) == 2
2460                 and first_child.type == token.STRING
2461                 and child.children[1].type == token.NEWLINE
2462             ):
2463                 continue
2464             else:
2465                 break
2466         elif first_child.type == syms.import_from:
2467             module_name = first_child.children[1]
2468             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2469                 break
2470             for import_from_child in first_child.children[3:]:
2471                 if isinstance(import_from_child, Leaf):
2472                     if import_from_child.type == token.NAME:
2473                         imports.add(import_from_child.value)
2474                 else:
2475                     assert import_from_child.type == syms.import_as_names
2476                     for leaf in import_from_child.children:
2477                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2478                             imports.add(leaf.value)
2479         else:
2480             break
2481     return imports
2482
2483
2484 PYTHON_EXTENSIONS = {".py"}
2485 BLACKLISTED_DIRECTORIES = {
2486     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2487 }
2488
2489
2490 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2491     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2492     and have one of the PYTHON_EXTENSIONS.
2493     """
2494     for child in path.iterdir():
2495         if child.is_dir():
2496             if child.name in BLACKLISTED_DIRECTORIES:
2497                 continue
2498
2499             yield from gen_python_files_in_dir(child)
2500
2501         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2502             yield child
2503
2504
2505 @dataclass
2506 class Report:
2507     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2508     check: bool = False
2509     quiet: bool = False
2510     change_count: int = 0
2511     same_count: int = 0
2512     failure_count: int = 0
2513
2514     def done(self, src: Path, changed: Changed) -> None:
2515         """Increment the counter for successful reformatting. Write out a message."""
2516         if changed is Changed.YES:
2517             reformatted = "would reformat" if self.check else "reformatted"
2518             if not self.quiet:
2519                 out(f"{reformatted} {src}")
2520             self.change_count += 1
2521         else:
2522             if not self.quiet:
2523                 if changed is Changed.NO:
2524                     msg = f"{src} already well formatted, good job."
2525                 else:
2526                     msg = f"{src} wasn't modified on disk since last run."
2527                 out(msg, bold=False)
2528             self.same_count += 1
2529
2530     def failed(self, src: Path, message: str) -> None:
2531         """Increment the counter for failed reformatting. Write out a message."""
2532         err(f"error: cannot format {src}: {message}")
2533         self.failure_count += 1
2534
2535     @property
2536     def return_code(self) -> int:
2537         """Return the exit code that the app should use.
2538
2539         This considers the current state of changed files and failures:
2540         - if there were any failures, return 123;
2541         - if any files were changed and --check is being used, return 1;
2542         - otherwise return 0.
2543         """
2544         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2545         # 126 we have special returncodes reserved by the shell.
2546         if self.failure_count:
2547             return 123
2548
2549         elif self.change_count and self.check:
2550             return 1
2551
2552         return 0
2553
2554     def __str__(self) -> str:
2555         """Render a color report of the current state.
2556
2557         Use `click.unstyle` to remove colors.
2558         """
2559         if self.check:
2560             reformatted = "would be reformatted"
2561             unchanged = "would be left unchanged"
2562             failed = "would fail to reformat"
2563         else:
2564             reformatted = "reformatted"
2565             unchanged = "left unchanged"
2566             failed = "failed to reformat"
2567         report = []
2568         if self.change_count:
2569             s = "s" if self.change_count > 1 else ""
2570             report.append(
2571                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2572             )
2573         if self.same_count:
2574             s = "s" if self.same_count > 1 else ""
2575             report.append(f"{self.same_count} file{s} {unchanged}")
2576         if self.failure_count:
2577             s = "s" if self.failure_count > 1 else ""
2578             report.append(
2579                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2580             )
2581         return ", ".join(report) + "."
2582
2583
2584 def assert_equivalent(src: str, dst: str) -> None:
2585     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2586
2587     import ast
2588     import traceback
2589
2590     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2591         """Simple visitor generating strings to compare ASTs by content."""
2592         yield f"{'  ' * depth}{node.__class__.__name__}("
2593
2594         for field in sorted(node._fields):
2595             try:
2596                 value = getattr(node, field)
2597             except AttributeError:
2598                 continue
2599
2600             yield f"{'  ' * (depth+1)}{field}="
2601
2602             if isinstance(value, list):
2603                 for item in value:
2604                     if isinstance(item, ast.AST):
2605                         yield from _v(item, depth + 2)
2606
2607             elif isinstance(value, ast.AST):
2608                 yield from _v(value, depth + 2)
2609
2610             else:
2611                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2612
2613         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2614
2615     try:
2616         src_ast = ast.parse(src)
2617     except Exception as exc:
2618         major, minor = sys.version_info[:2]
2619         raise AssertionError(
2620             f"cannot use --safe with this file; failed to parse source file "
2621             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2622             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2623         )
2624
2625     try:
2626         dst_ast = ast.parse(dst)
2627     except Exception as exc:
2628         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2629         raise AssertionError(
2630             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2631             f"Please report a bug on https://github.com/ambv/black/issues.  "
2632             f"This invalid output might be helpful: {log}"
2633         ) from None
2634
2635     src_ast_str = "\n".join(_v(src_ast))
2636     dst_ast_str = "\n".join(_v(dst_ast))
2637     if src_ast_str != dst_ast_str:
2638         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2639         raise AssertionError(
2640             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2641             f"the source.  "
2642             f"Please report a bug on https://github.com/ambv/black/issues.  "
2643             f"This diff might be helpful: {log}"
2644         ) from None
2645
2646
2647 def assert_stable(src: str, dst: str, line_length: int) -> None:
2648     """Raise AssertionError if `dst` reformats differently the second time."""
2649     newdst = format_str(dst, line_length=line_length)
2650     if dst != newdst:
2651         log = dump_to_file(
2652             diff(src, dst, "source", "first pass"),
2653             diff(dst, newdst, "first pass", "second pass"),
2654         )
2655         raise AssertionError(
2656             f"INTERNAL ERROR: Black produced different code on the second pass "
2657             f"of the formatter.  "
2658             f"Please report a bug on https://github.com/ambv/black/issues.  "
2659             f"This diff might be helpful: {log}"
2660         ) from None
2661
2662
2663 def dump_to_file(*output: str) -> str:
2664     """Dump `output` to a temporary file. Return path to the file."""
2665     import tempfile
2666
2667     with tempfile.NamedTemporaryFile(
2668         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2669     ) as f:
2670         for lines in output:
2671             f.write(lines)
2672             if lines and lines[-1] != "\n":
2673                 f.write("\n")
2674     return f.name
2675
2676
2677 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2678     """Return a unified diff string between strings `a` and `b`."""
2679     import difflib
2680
2681     a_lines = [line + "\n" for line in a.split("\n")]
2682     b_lines = [line + "\n" for line in b.split("\n")]
2683     return "".join(
2684         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2685     )
2686
2687
2688 def cancel(tasks: List[asyncio.Task]) -> None:
2689     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2690     err("Aborted!")
2691     for task in tasks:
2692         task.cancel()
2693
2694
2695 def shutdown(loop: BaseEventLoop) -> None:
2696     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2697     try:
2698         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2699         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2700         if not to_cancel:
2701             return
2702
2703         for task in to_cancel:
2704             task.cancel()
2705         loop.run_until_complete(
2706             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2707         )
2708     finally:
2709         # `concurrent.futures.Future` objects cannot be cancelled once they
2710         # are already running. There might be some when the `shutdown()` happened.
2711         # Silence their logger's spew about the event loop being closed.
2712         cf_logger = logging.getLogger("concurrent.futures")
2713         cf_logger.setLevel(logging.CRITICAL)
2714         loop.close()
2715
2716
2717 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2718     """Replace `regex` with `replacement` twice on `original`.
2719
2720     This is used by string normalization to perform replaces on
2721     overlapping matches.
2722     """
2723     return regex.sub(replacement, regex.sub(replacement, original))
2724
2725
2726 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2727
2728
2729 def get_cache_file(line_length: int) -> Path:
2730     return CACHE_DIR / f"cache.{line_length}.pickle"
2731
2732
2733 def read_cache(line_length: int) -> Cache:
2734     """Read the cache if it exists and is well formed.
2735
2736     If it is not well formed, the call to write_cache later should resolve the issue.
2737     """
2738     cache_file = get_cache_file(line_length)
2739     if not cache_file.exists():
2740         return {}
2741
2742     with cache_file.open("rb") as fobj:
2743         try:
2744             cache: Cache = pickle.load(fobj)
2745         except pickle.UnpicklingError:
2746             return {}
2747
2748     return cache
2749
2750
2751 def get_cache_info(path: Path) -> CacheInfo:
2752     """Return the information used to check if a file is already formatted or not."""
2753     stat = path.stat()
2754     return stat.st_mtime, stat.st_size
2755
2756
2757 def filter_cached(
2758     cache: Cache, sources: Iterable[Path]
2759 ) -> Tuple[List[Path], List[Path]]:
2760     """Split a list of paths into two.
2761
2762     The first list contains paths of files that modified on disk or are not in the
2763     cache. The other list contains paths to non-modified files.
2764     """
2765     todo, done = [], []
2766     for src in sources:
2767         src = src.resolve()
2768         if cache.get(src) != get_cache_info(src):
2769             todo.append(src)
2770         else:
2771             done.append(src)
2772     return todo, done
2773
2774
2775 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2776     """Update the cache file."""
2777     cache_file = get_cache_file(line_length)
2778     try:
2779         if not CACHE_DIR.exists():
2780             CACHE_DIR.mkdir(parents=True)
2781         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2782         with cache_file.open("wb") as fobj:
2783             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2784     except OSError:
2785         pass
2786
2787
2788 if __name__ == "__main__":
2789     main()