]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

acks += JelleZijlstra
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Set,
28     Tuple,
29     Type,
30     TypeVar,
31     Union,
32 )
33
34 from appdirs import user_cache_dir
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a6"
45 DEFAULT_LINE_LENGTH = 88
46
47 # types
48 syms = pygram.python_symbols
49 FileContent = str
50 Encoding = str
51 Depth = int
52 NodeType = int
53 LeafID = int
54 Priority = int
55 Index = int
56 LN = Union[Leaf, Node]
57 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
58 Timestamp = float
59 FileSize = int
60 CacheInfo = Tuple[Timestamp, FileSize]
61 Cache = Dict[Path, CacheInfo]
62 out = partial(click.secho, bold=True, err=True)
63 err = partial(click.secho, fg="red", err=True)
64
65
66 class NothingChanged(UserWarning):
67     """Raised by :func:`format_file` when reformatted code is the same as source."""
68
69
70 class CannotSplit(Exception):
71     """A readable split that fits the allotted line length is impossible.
72
73     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
74     :func:`delimiter_split`.
75     """
76
77
78 class FormatError(Exception):
79     """Base exception for `# fmt: on` and `# fmt: off` handling.
80
81     It holds the number of bytes of the prefix consumed before the format
82     control comment appeared.
83     """
84
85     def __init__(self, consumed: int) -> None:
86         super().__init__(consumed)
87         self.consumed = consumed
88
89     def trim_prefix(self, leaf: Leaf) -> None:
90         leaf.prefix = leaf.prefix[self.consumed :]
91
92     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
93         """Returns a new Leaf from the consumed part of the prefix."""
94         unformatted_prefix = leaf.prefix[: self.consumed]
95         return Leaf(token.NEWLINE, unformatted_prefix)
96
97
98 class FormatOn(FormatError):
99     """Found a comment like `# fmt: on` in the file."""
100
101
102 class FormatOff(FormatError):
103     """Found a comment like `# fmt: off` in the file."""
104
105
106 class WriteBack(Enum):
107     NO = 0
108     YES = 1
109     DIFF = 2
110
111
112 class Changed(Enum):
113     NO = 0
114     CACHED = 1
115     YES = 2
116
117
118 @click.command()
119 @click.option(
120     "-l",
121     "--line-length",
122     type=int,
123     default=DEFAULT_LINE_LENGTH,
124     help="How many character per line to allow.",
125     show_default=True,
126 )
127 @click.option(
128     "--check",
129     is_flag=True,
130     help=(
131         "Don't write the files back, just return the status.  Return code 0 "
132         "means nothing would change.  Return code 1 means some files would be "
133         "reformatted.  Return code 123 means there was an internal error."
134     ),
135 )
136 @click.option(
137     "--diff",
138     is_flag=True,
139     help="Don't write the files back, just output a diff for each file on stdout.",
140 )
141 @click.option(
142     "--fast/--safe",
143     is_flag=True,
144     help="If --fast given, skip temporary sanity checks. [default: --safe]",
145 )
146 @click.option(
147     "-q",
148     "--quiet",
149     is_flag=True,
150     help=(
151         "Don't emit non-error messages to stderr. Errors are still emitted, "
152         "silence those with 2>/dev/null."
153     ),
154 )
155 @click.version_option(version=__version__)
156 @click.argument(
157     "src",
158     nargs=-1,
159     type=click.Path(
160         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
161     ),
162 )
163 @click.pass_context
164 def main(
165     ctx: click.Context,
166     line_length: int,
167     check: bool,
168     diff: bool,
169     fast: bool,
170     quiet: bool,
171     src: List[str],
172 ) -> None:
173     """The uncompromising code formatter."""
174     sources: List[Path] = []
175     for s in src:
176         p = Path(s)
177         if p.is_dir():
178             sources.extend(gen_python_files_in_dir(p))
179         elif p.is_file():
180             # if a file was explicitly given, we don't care about its extension
181             sources.append(p)
182         elif s == "-":
183             sources.append(Path("-"))
184         else:
185             err(f"invalid path: {s}")
186
187     if check and not diff:
188         write_back = WriteBack.NO
189     elif diff:
190         write_back = WriteBack.DIFF
191     else:
192         write_back = WriteBack.YES
193     report = Report(check=check, quiet=quiet)
194     if len(sources) == 0:
195         out("No paths given. Nothing to do 😴")
196         ctx.exit(0)
197         return
198
199     elif len(sources) == 1:
200         reformat_one(sources[0], line_length, fast, write_back, report)
201     else:
202         loop = asyncio.get_event_loop()
203         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
204         try:
205             loop.run_until_complete(
206                 schedule_formatting(
207                     sources, line_length, fast, write_back, report, loop, executor
208                 )
209             )
210         finally:
211             shutdown(loop)
212         if not quiet:
213             out("All done! ✨ 🍰 ✨")
214             click.echo(str(report))
215     ctx.exit(report.return_code)
216
217
218 def reformat_one(
219     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
220 ) -> None:
221     """Reformat a single file under `src` without spawning child processes.
222
223     If `quiet` is True, non-error messages are not output. `line_length`,
224     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
225     """
226     try:
227         changed = Changed.NO
228         if not src.is_file() and str(src) == "-":
229             if format_stdin_to_stdout(
230                 line_length=line_length, fast=fast, write_back=write_back
231             ):
232                 changed = Changed.YES
233         else:
234             cache: Cache = {}
235             if write_back != WriteBack.DIFF:
236                 cache = read_cache(line_length)
237                 src = src.resolve()
238                 if src in cache and cache[src] == get_cache_info(src):
239                     changed = Changed.CACHED
240             if (
241                 changed is not Changed.CACHED
242                 and format_file_in_place(
243                     src, line_length=line_length, fast=fast, write_back=write_back
244                 )
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             src: loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             )
289             for src in sources
290         }
291         _task_values = list(tasks.values())
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
294             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     if write_back == WriteBack.YES and formatted:
315         write_cache(cache, formatted, line_length)
316
317
318 def format_file_in_place(
319     src: Path,
320     line_length: int,
321     fast: bool,
322     write_back: WriteBack = WriteBack.NO,
323     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
324 ) -> bool:
325     """Format file under `src` path. Return True if changed.
326
327     If `write_back` is True, write reformatted code back to stdout.
328     `line_length` and `fast` options are passed to :func:`format_file_contents`.
329     """
330
331     with tokenize.open(src) as src_buffer:
332         src_contents = src_buffer.read()
333     try:
334         dst_contents = format_file_contents(
335             src_contents, line_length=line_length, fast=fast
336         )
337     except NothingChanged:
338         return False
339
340     if write_back == write_back.YES:
341         with open(src, "w", encoding=src_buffer.encoding) as f:
342             f.write(dst_contents)
343     elif write_back == write_back.DIFF:
344         src_name = f"{src}  (original)"
345         dst_name = f"{src}  (formatted)"
346         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347         if lock:
348             lock.acquire()
349         try:
350             sys.stdout.write(diff_contents)
351         finally:
352             if lock:
353                 lock.release()
354     return True
355
356
357 def format_stdin_to_stdout(
358     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
359 ) -> bool:
360     """Format file on stdin. Return True if changed.
361
362     If `write_back` is True, write reformatted code back to stdout.
363     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
364     """
365     src = sys.stdin.read()
366     dst = src
367     try:
368         dst = format_file_contents(src, line_length=line_length, fast=fast)
369         return True
370
371     except NothingChanged:
372         return False
373
374     finally:
375         if write_back == WriteBack.YES:
376             sys.stdout.write(dst)
377         elif write_back == WriteBack.DIFF:
378             src_name = "<stdin>  (original)"
379             dst_name = "<stdin>  (formatted)"
380             sys.stdout.write(diff(src, dst, src_name, dst_name))
381
382
383 def format_file_contents(
384     src_contents: str, line_length: int, fast: bool
385 ) -> FileContent:
386     """Reformat contents a file and return new contents.
387
388     If `fast` is False, additionally confirm that the reformatted code is
389     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
390     `line_length` is passed to :func:`format_str`.
391     """
392     if src_contents.strip() == "":
393         raise NothingChanged
394
395     dst_contents = format_str(src_contents, line_length=line_length)
396     if src_contents == dst_contents:
397         raise NothingChanged
398
399     if not fast:
400         assert_equivalent(src_contents, dst_contents)
401         assert_stable(src_contents, dst_contents, line_length=line_length)
402     return dst_contents
403
404
405 def format_str(src_contents: str, line_length: int) -> FileContent:
406     """Reformat a string and return new contents.
407
408     `line_length` determines how many characters per line are allowed.
409     """
410     src_node = lib2to3_parse(src_contents)
411     dst_contents = ""
412     future_imports = get_future_imports(src_node)
413     py36 = is_python36(src_node)
414     lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
415     elt = EmptyLineTracker()
416     empty_line = Line()
417     after = 0
418     for current_line in lines.visit(src_node):
419         for _ in range(after):
420             dst_contents += str(empty_line)
421         before, after = elt.maybe_empty_lines(current_line)
422         for _ in range(before):
423             dst_contents += str(empty_line)
424         for line in split_line(current_line, line_length=line_length, py36=py36):
425             dst_contents += str(line)
426     return dst_contents
427
428
429 GRAMMARS = [
430     pygram.python_grammar_no_print_statement_no_exec_statement,
431     pygram.python_grammar_no_print_statement,
432     pygram.python_grammar,
433 ]
434
435
436 def lib2to3_parse(src_txt: str) -> Node:
437     """Given a string with source, return the lib2to3 Node."""
438     grammar = pygram.python_grammar_no_print_statement
439     if src_txt[-1] != "\n":
440         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
441         src_txt += nl
442     for grammar in GRAMMARS:
443         drv = driver.Driver(grammar, pytree.convert)
444         try:
445             result = drv.parse_string(src_txt, True)
446             break
447
448         except ParseError as pe:
449             lineno, column = pe.context[1]
450             lines = src_txt.splitlines()
451             try:
452                 faulty_line = lines[lineno - 1]
453             except IndexError:
454                 faulty_line = "<line number missing in source>"
455             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
456     else:
457         raise exc from None
458
459     if isinstance(result, Leaf):
460         result = Node(syms.file_input, [result])
461     return result
462
463
464 def lib2to3_unparse(node: Node) -> str:
465     """Given a lib2to3 node, return its string representation."""
466     code = str(node)
467     return code
468
469
470 T = TypeVar("T")
471
472
473 class Visitor(Generic[T]):
474     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
475
476     def visit(self, node: LN) -> Iterator[T]:
477         """Main method to visit `node` and its children.
478
479         It tries to find a `visit_*()` method for the given `node.type`, like
480         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
481         If no dedicated `visit_*()` method is found, chooses `visit_default()`
482         instead.
483
484         Then yields objects of type `T` from the selected visitor.
485         """
486         if node.type < 256:
487             name = token.tok_name[node.type]
488         else:
489             name = type_repr(node.type)
490         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
491
492     def visit_default(self, node: LN) -> Iterator[T]:
493         """Default `visit_*()` implementation. Recurses to children of `node`."""
494         if isinstance(node, Node):
495             for child in node.children:
496                 yield from self.visit(child)
497
498
499 @dataclass
500 class DebugVisitor(Visitor[T]):
501     tree_depth: int = 0
502
503     def visit_default(self, node: LN) -> Iterator[T]:
504         indent = " " * (2 * self.tree_depth)
505         if isinstance(node, Node):
506             _type = type_repr(node.type)
507             out(f"{indent}{_type}", fg="yellow")
508             self.tree_depth += 1
509             for child in node.children:
510                 yield from self.visit(child)
511
512             self.tree_depth -= 1
513             out(f"{indent}/{_type}", fg="yellow", bold=False)
514         else:
515             _type = token.tok_name.get(node.type, str(node.type))
516             out(f"{indent}{_type}", fg="blue", nl=False)
517             if node.prefix:
518                 # We don't have to handle prefixes for `Node` objects since
519                 # that delegates to the first child anyway.
520                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
521             out(f" {node.value!r}", fg="blue", bold=False)
522
523     @classmethod
524     def show(cls, code: str) -> None:
525         """Pretty-print the lib2to3 AST of a given string of `code`.
526
527         Convenience method for debugging.
528         """
529         v: DebugVisitor[None] = DebugVisitor()
530         list(v.visit(lib2to3_parse(code)))
531
532
533 KEYWORDS = set(keyword.kwlist)
534 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
535 FLOW_CONTROL = {"return", "raise", "break", "continue"}
536 STATEMENT = {
537     syms.if_stmt,
538     syms.while_stmt,
539     syms.for_stmt,
540     syms.try_stmt,
541     syms.except_clause,
542     syms.with_stmt,
543     syms.funcdef,
544     syms.classdef,
545 }
546 STANDALONE_COMMENT = 153
547 LOGIC_OPERATORS = {"and", "or"}
548 COMPARATORS = {
549     token.LESS,
550     token.GREATER,
551     token.EQEQUAL,
552     token.NOTEQUAL,
553     token.LESSEQUAL,
554     token.GREATEREQUAL,
555 }
556 MATH_OPERATORS = {
557     token.VBAR,
558     token.CIRCUMFLEX,
559     token.AMPER,
560     token.LEFTSHIFT,
561     token.RIGHTSHIFT,
562     token.PLUS,
563     token.MINUS,
564     token.STAR,
565     token.SLASH,
566     token.DOUBLESLASH,
567     token.PERCENT,
568     token.AT,
569     token.TILDE,
570     token.DOUBLESTAR,
571 }
572 STARS = {token.STAR, token.DOUBLESTAR}
573 VARARGS_PARENTS = {
574     syms.arglist,
575     syms.argument,  # double star in arglist
576     syms.trailer,  # single argument to call
577     syms.typedargslist,
578     syms.varargslist,  # lambdas
579 }
580 UNPACKING_PARENTS = {
581     syms.atom,  # single element of a list or set literal
582     syms.dictsetmaker,
583     syms.listmaker,
584     syms.testlist_gexp,
585 }
586 TEST_DESCENDANTS = {
587     syms.test,
588     syms.lambdef,
589     syms.or_test,
590     syms.and_test,
591     syms.not_test,
592     syms.comparison,
593     syms.star_expr,
594     syms.expr,
595     syms.xor_expr,
596     syms.and_expr,
597     syms.shift_expr,
598     syms.arith_expr,
599     syms.trailer,
600     syms.term,
601     syms.power,
602 }
603 ASSIGNMENTS = {
604     "=",
605     "+=",
606     "-=",
607     "*=",
608     "@=",
609     "/=",
610     "%=",
611     "&=",
612     "|=",
613     "^=",
614     "<<=",
615     ">>=",
616     "**=",
617     "//=",
618 }
619 COMPREHENSION_PRIORITY = 20
620 COMMA_PRIORITY = 18
621 TERNARY_PRIORITY = 16
622 LOGIC_PRIORITY = 14
623 STRING_PRIORITY = 12
624 COMPARATOR_PRIORITY = 10
625 MATH_PRIORITIES = {
626     token.VBAR: 8,
627     token.CIRCUMFLEX: 7,
628     token.AMPER: 6,
629     token.LEFTSHIFT: 5,
630     token.RIGHTSHIFT: 5,
631     token.PLUS: 4,
632     token.MINUS: 4,
633     token.STAR: 3,
634     token.SLASH: 3,
635     token.DOUBLESLASH: 3,
636     token.PERCENT: 3,
637     token.AT: 3,
638     token.TILDE: 2,
639     token.DOUBLESTAR: 1,
640 }
641
642
643 @dataclass
644 class BracketTracker:
645     """Keeps track of brackets on a line."""
646
647     depth: int = 0
648     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
649     delimiters: Dict[LeafID, Priority] = Factory(dict)
650     previous: Optional[Leaf] = None
651     _for_loop_variable: int = 0
652     _lambda_arguments: int = 0
653
654     def mark(self, leaf: Leaf) -> None:
655         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
656
657         All leaves receive an int `bracket_depth` field that stores how deep
658         within brackets a given leaf is. 0 means there are no enclosing brackets
659         that started on this line.
660
661         If a leaf is itself a closing bracket, it receives an `opening_bracket`
662         field that it forms a pair with. This is a one-directional link to
663         avoid reference cycles.
664
665         If a leaf is a delimiter (a token on which Black can split the line if
666         needed) and it's on depth 0, its `id()` is stored in the tracker's
667         `delimiters` field.
668         """
669         if leaf.type == token.COMMENT:
670             return
671
672         self.maybe_decrement_after_for_loop_variable(leaf)
673         self.maybe_decrement_after_lambda_arguments(leaf)
674         if leaf.type in CLOSING_BRACKETS:
675             self.depth -= 1
676             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
677             leaf.opening_bracket = opening_bracket
678         leaf.bracket_depth = self.depth
679         if self.depth == 0:
680             delim = is_split_before_delimiter(leaf, self.previous)
681             if delim and self.previous is not None:
682                 self.delimiters[id(self.previous)] = delim
683             else:
684                 delim = is_split_after_delimiter(leaf, self.previous)
685                 if delim:
686                     self.delimiters[id(leaf)] = delim
687         if leaf.type in OPENING_BRACKETS:
688             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
689             self.depth += 1
690         self.previous = leaf
691         self.maybe_increment_lambda_arguments(leaf)
692         self.maybe_increment_for_loop_variable(leaf)
693
694     def any_open_brackets(self) -> bool:
695         """Return True if there is an yet unmatched open bracket on the line."""
696         return bool(self.bracket_match)
697
698     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
699         """Return the highest priority of a delimiter found on the line.
700
701         Values are consistent with what `is_split_*_delimiter()` return.
702         Raises ValueError on no delimiters.
703         """
704         return max(v for k, v in self.delimiters.items() if k not in exclude)
705
706     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
707         """In a for loop, or comprehension, the variables are often unpacks.
708
709         To avoid splitting on the comma in this situation, increase the depth of
710         tokens between `for` and `in`.
711         """
712         if leaf.type == token.NAME and leaf.value == "for":
713             self.depth += 1
714             self._for_loop_variable += 1
715             return True
716
717         return False
718
719     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
720         """See `maybe_increment_for_loop_variable` above for explanation."""
721         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
722             self.depth -= 1
723             self._for_loop_variable -= 1
724             return True
725
726         return False
727
728     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
729         """In a lambda expression, there might be more than one argument.
730
731         To avoid splitting on the comma in this situation, increase the depth of
732         tokens between `lambda` and `:`.
733         """
734         if leaf.type == token.NAME and leaf.value == "lambda":
735             self.depth += 1
736             self._lambda_arguments += 1
737             return True
738
739         return False
740
741     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
742         """See `maybe_increment_lambda_arguments` above for explanation."""
743         if self._lambda_arguments and leaf.type == token.COLON:
744             self.depth -= 1
745             self._lambda_arguments -= 1
746             return True
747
748         return False
749
750     def get_open_lsqb(self) -> Optional[Leaf]:
751         """Return the most recent opening square bracket (if any)."""
752         return self.bracket_match.get((self.depth - 1, token.RSQB))
753
754
755 @dataclass
756 class Line:
757     """Holds leaves and comments. Can be printed with `str(line)`."""
758
759     depth: int = 0
760     leaves: List[Leaf] = Factory(list)
761     comments: List[Tuple[Index, Leaf]] = Factory(list)
762     bracket_tracker: BracketTracker = Factory(BracketTracker)
763     inside_brackets: bool = False
764
765     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
766         """Add a new `leaf` to the end of the line.
767
768         Unless `preformatted` is True, the `leaf` will receive a new consistent
769         whitespace prefix and metadata applied by :class:`BracketTracker`.
770         Trailing commas are maybe removed, unpacked for loop variables are
771         demoted from being delimiters.
772
773         Inline comments are put aside.
774         """
775         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
776         if not has_value:
777             return
778
779         if token.COLON == leaf.type and self.is_class_paren_empty:
780             del self.leaves[-2:]
781         if self.leaves and not preformatted:
782             # Note: at this point leaf.prefix should be empty except for
783             # imports, for which we only preserve newlines.
784             leaf.prefix += whitespace(
785                 leaf, complex_subscript=self.is_complex_subscript(leaf)
786             )
787         if self.inside_brackets or not preformatted:
788             self.bracket_tracker.mark(leaf)
789             self.maybe_remove_trailing_comma(leaf)
790         if not self.append_comment(leaf):
791             self.leaves.append(leaf)
792
793     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
794         """Like :func:`append()` but disallow invalid standalone comment structure.
795
796         Raises ValueError when any `leaf` is appended after a standalone comment
797         or when a standalone comment is not the first leaf on the line.
798         """
799         if self.bracket_tracker.depth == 0:
800             if self.is_comment:
801                 raise ValueError("cannot append to standalone comments")
802
803             if self.leaves and leaf.type == STANDALONE_COMMENT:
804                 raise ValueError(
805                     "cannot append standalone comments to a populated line"
806                 )
807
808         self.append(leaf, preformatted=preformatted)
809
810     @property
811     def is_comment(self) -> bool:
812         """Is this line a standalone comment?"""
813         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
814
815     @property
816     def is_decorator(self) -> bool:
817         """Is this line a decorator?"""
818         return bool(self) and self.leaves[0].type == token.AT
819
820     @property
821     def is_import(self) -> bool:
822         """Is this an import line?"""
823         return bool(self) and is_import(self.leaves[0])
824
825     @property
826     def is_class(self) -> bool:
827         """Is this line a class definition?"""
828         return (
829             bool(self)
830             and self.leaves[0].type == token.NAME
831             and self.leaves[0].value == "class"
832         )
833
834     @property
835     def is_def(self) -> bool:
836         """Is this a function definition? (Also returns True for async defs.)"""
837         try:
838             first_leaf = self.leaves[0]
839         except IndexError:
840             return False
841
842         try:
843             second_leaf: Optional[Leaf] = self.leaves[1]
844         except IndexError:
845             second_leaf = None
846         return (
847             (first_leaf.type == token.NAME and first_leaf.value == "def")
848             or (
849                 first_leaf.type == token.ASYNC
850                 and second_leaf is not None
851                 and second_leaf.type == token.NAME
852                 and second_leaf.value == "def"
853             )
854         )
855
856     @property
857     def is_flow_control(self) -> bool:
858         """Is this line a flow control statement?
859
860         Those are `return`, `raise`, `break`, and `continue`.
861         """
862         return (
863             bool(self)
864             and self.leaves[0].type == token.NAME
865             and self.leaves[0].value in FLOW_CONTROL
866         )
867
868     @property
869     def is_yield(self) -> bool:
870         """Is this line a yield statement?"""
871         return (
872             bool(self)
873             and self.leaves[0].type == token.NAME
874             and self.leaves[0].value == "yield"
875         )
876
877     @property
878     def is_class_paren_empty(self) -> bool:
879         """Is this a class with no base classes but using parentheses?
880
881         Those are unnecessary and should be removed.
882         """
883         return (
884             bool(self)
885             and len(self.leaves) == 4
886             and self.is_class
887             and self.leaves[2].type == token.LPAR
888             and self.leaves[2].value == "("
889             and self.leaves[3].type == token.RPAR
890             and self.leaves[3].value == ")"
891         )
892
893     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
894         """If so, needs to be split before emitting."""
895         for leaf in self.leaves:
896             if leaf.type == STANDALONE_COMMENT:
897                 if leaf.bracket_depth <= depth_limit:
898                     return True
899
900         return False
901
902     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
903         """Remove trailing comma if there is one and it's safe."""
904         if not (
905             self.leaves
906             and self.leaves[-1].type == token.COMMA
907             and closing.type in CLOSING_BRACKETS
908         ):
909             return False
910
911         if closing.type == token.RBRACE:
912             self.remove_trailing_comma()
913             return True
914
915         if closing.type == token.RSQB:
916             comma = self.leaves[-1]
917             if comma.parent and comma.parent.type == syms.listmaker:
918                 self.remove_trailing_comma()
919                 return True
920
921         # For parens let's check if it's safe to remove the comma.
922         # Imports are always safe.
923         if self.is_import:
924             self.remove_trailing_comma()
925             return True
926
927         # Otheriwsse, if the trailing one is the only one, we might mistakenly
928         # change a tuple into a different type by removing the comma.
929         depth = closing.bracket_depth + 1
930         commas = 0
931         opening = closing.opening_bracket
932         for _opening_index, leaf in enumerate(self.leaves):
933             if leaf is opening:
934                 break
935
936         else:
937             return False
938
939         for leaf in self.leaves[_opening_index + 1 :]:
940             if leaf is closing:
941                 break
942
943             bracket_depth = leaf.bracket_depth
944             if bracket_depth == depth and leaf.type == token.COMMA:
945                 commas += 1
946                 if leaf.parent and leaf.parent.type == syms.arglist:
947                     commas += 1
948                     break
949
950         if commas > 1:
951             self.remove_trailing_comma()
952             return True
953
954         return False
955
956     def append_comment(self, comment: Leaf) -> bool:
957         """Add an inline or standalone comment to the line."""
958         if (
959             comment.type == STANDALONE_COMMENT
960             and self.bracket_tracker.any_open_brackets()
961         ):
962             comment.prefix = ""
963             return False
964
965         if comment.type != token.COMMENT:
966             return False
967
968         after = len(self.leaves) - 1
969         if after == -1:
970             comment.type = STANDALONE_COMMENT
971             comment.prefix = ""
972             return False
973
974         else:
975             self.comments.append((after, comment))
976             return True
977
978     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
979         """Generate comments that should appear directly after `leaf`."""
980         for _leaf_index, _leaf in enumerate(self.leaves):
981             if leaf is _leaf:
982                 break
983
984         else:
985             return
986
987         for index, comment_after in self.comments:
988             if _leaf_index == index:
989                 yield comment_after
990
991     def remove_trailing_comma(self) -> None:
992         """Remove the trailing comma and moves the comments attached to it."""
993         comma_index = len(self.leaves) - 1
994         for i in range(len(self.comments)):
995             comment_index, comment = self.comments[i]
996             if comment_index == comma_index:
997                 self.comments[i] = (comma_index - 1, comment)
998         self.leaves.pop()
999
1000     def is_complex_subscript(self, leaf: Leaf) -> bool:
1001         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1002         open_lsqb = (
1003             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1004         )
1005         if open_lsqb is None:
1006             return False
1007
1008         subscript_start = open_lsqb.next_sibling
1009         if (
1010             isinstance(subscript_start, Node)
1011             and subscript_start.type == syms.subscriptlist
1012         ):
1013             subscript_start = child_towards(subscript_start, leaf)
1014         return (
1015             subscript_start is not None
1016             and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
1017         )
1018
1019     def __str__(self) -> str:
1020         """Render the line."""
1021         if not self:
1022             return "\n"
1023
1024         indent = "    " * self.depth
1025         leaves = iter(self.leaves)
1026         first = next(leaves)
1027         res = f"{first.prefix}{indent}{first.value}"
1028         for leaf in leaves:
1029             res += str(leaf)
1030         for _, comment in self.comments:
1031             res += str(comment)
1032         return res + "\n"
1033
1034     def __bool__(self) -> bool:
1035         """Return True if the line has leaves or comments."""
1036         return bool(self.leaves or self.comments)
1037
1038
1039 class UnformattedLines(Line):
1040     """Just like :class:`Line` but stores lines which aren't reformatted."""
1041
1042     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1043         """Just add a new `leaf` to the end of the lines.
1044
1045         The `preformatted` argument is ignored.
1046
1047         Keeps track of indentation `depth`, which is useful when the user
1048         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1049         """
1050         try:
1051             list(generate_comments(leaf))
1052         except FormatOn as f_on:
1053             self.leaves.append(f_on.leaf_from_consumed(leaf))
1054             raise
1055
1056         self.leaves.append(leaf)
1057         if leaf.type == token.INDENT:
1058             self.depth += 1
1059         elif leaf.type == token.DEDENT:
1060             self.depth -= 1
1061
1062     def __str__(self) -> str:
1063         """Render unformatted lines from leaves which were added with `append()`.
1064
1065         `depth` is not used for indentation in this case.
1066         """
1067         if not self:
1068             return "\n"
1069
1070         res = ""
1071         for leaf in self.leaves:
1072             res += str(leaf)
1073         return res
1074
1075     def append_comment(self, comment: Leaf) -> bool:
1076         """Not implemented in this class. Raises `NotImplementedError`."""
1077         raise NotImplementedError("Unformatted lines don't store comments separately.")
1078
1079     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1080         """Does nothing and returns False."""
1081         return False
1082
1083     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1084         """Does nothing and returns False."""
1085         return False
1086
1087
1088 @dataclass
1089 class EmptyLineTracker:
1090     """Provides a stateful method that returns the number of potential extra
1091     empty lines needed before and after the currently processed line.
1092
1093     Note: this tracker works on lines that haven't been split yet.  It assumes
1094     the prefix of the first leaf consists of optional newlines.  Those newlines
1095     are consumed by `maybe_empty_lines()` and included in the computation.
1096     """
1097     previous_line: Optional[Line] = None
1098     previous_after: int = 0
1099     previous_defs: List[int] = Factory(list)
1100
1101     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1102         """Return the number of extra empty lines before and after the `current_line`.
1103
1104         This is for separating `def`, `async def` and `class` with extra empty
1105         lines (two on module-level), as well as providing an extra empty line
1106         after flow control keywords to make them more prominent.
1107         """
1108         if isinstance(current_line, UnformattedLines):
1109             return 0, 0
1110
1111         before, after = self._maybe_empty_lines(current_line)
1112         before -= self.previous_after
1113         self.previous_after = after
1114         self.previous_line = current_line
1115         return before, after
1116
1117     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1118         max_allowed = 1
1119         if current_line.depth == 0:
1120             max_allowed = 2
1121         if current_line.leaves:
1122             # Consume the first leaf's extra newlines.
1123             first_leaf = current_line.leaves[0]
1124             before = first_leaf.prefix.count("\n")
1125             before = min(before, max_allowed)
1126             first_leaf.prefix = ""
1127         else:
1128             before = 0
1129         depth = current_line.depth
1130         while self.previous_defs and self.previous_defs[-1] >= depth:
1131             self.previous_defs.pop()
1132             before = 1 if depth else 2
1133         is_decorator = current_line.is_decorator
1134         if is_decorator or current_line.is_def or current_line.is_class:
1135             if not is_decorator:
1136                 self.previous_defs.append(depth)
1137             if self.previous_line is None:
1138                 # Don't insert empty lines before the first line in the file.
1139                 return 0, 0
1140
1141             if self.previous_line.is_decorator:
1142                 return 0, 0
1143
1144             if (
1145                 self.previous_line.is_comment
1146                 and self.previous_line.depth == current_line.depth
1147                 and before == 0
1148             ):
1149                 return 0, 0
1150
1151             newlines = 2
1152             if current_line.depth:
1153                 newlines -= 1
1154             return newlines, 0
1155
1156         if (
1157             self.previous_line
1158             and self.previous_line.is_import
1159             and not current_line.is_import
1160             and depth == self.previous_line.depth
1161         ):
1162             return (before or 1), 0
1163
1164         return before, 0
1165
1166
1167 @dataclass
1168 class LineGenerator(Visitor[Line]):
1169     """Generates reformatted Line objects.  Empty lines are not emitted.
1170
1171     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1172     in ways that will no longer stringify to valid Python code on the tree.
1173     """
1174     current_line: Line = Factory(Line)
1175     remove_u_prefix: bool = False
1176
1177     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1178         """Generate a line.
1179
1180         If the line is empty, only emit if it makes sense.
1181         If the line is too long, split it first and then generate.
1182
1183         If any lines were generated, set up a new current_line.
1184         """
1185         if not self.current_line:
1186             if self.current_line.__class__ == type:
1187                 self.current_line.depth += indent
1188             else:
1189                 self.current_line = type(depth=self.current_line.depth + indent)
1190             return  # Line is empty, don't emit. Creating a new one unnecessary.
1191
1192         complete_line = self.current_line
1193         self.current_line = type(depth=complete_line.depth + indent)
1194         yield complete_line
1195
1196     def visit(self, node: LN) -> Iterator[Line]:
1197         """Main method to visit `node` and its children.
1198
1199         Yields :class:`Line` objects.
1200         """
1201         if isinstance(self.current_line, UnformattedLines):
1202             # File contained `# fmt: off`
1203             yield from self.visit_unformatted(node)
1204
1205         else:
1206             yield from super().visit(node)
1207
1208     def visit_default(self, node: LN) -> Iterator[Line]:
1209         """Default `visit_*()` implementation. Recurses to children of `node`."""
1210         if isinstance(node, Leaf):
1211             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1212             try:
1213                 for comment in generate_comments(node):
1214                     if any_open_brackets:
1215                         # any comment within brackets is subject to splitting
1216                         self.current_line.append(comment)
1217                     elif comment.type == token.COMMENT:
1218                         # regular trailing comment
1219                         self.current_line.append(comment)
1220                         yield from self.line()
1221
1222                     else:
1223                         # regular standalone comment
1224                         yield from self.line()
1225
1226                         self.current_line.append(comment)
1227                         yield from self.line()
1228
1229             except FormatOff as f_off:
1230                 f_off.trim_prefix(node)
1231                 yield from self.line(type=UnformattedLines)
1232                 yield from self.visit(node)
1233
1234             except FormatOn as f_on:
1235                 # This only happens here if somebody says "fmt: on" multiple
1236                 # times in a row.
1237                 f_on.trim_prefix(node)
1238                 yield from self.visit_default(node)
1239
1240             else:
1241                 normalize_prefix(node, inside_brackets=any_open_brackets)
1242                 if node.type == token.STRING:
1243                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1244                     normalize_string_quotes(node)
1245                 if node.type not in WHITESPACE:
1246                     self.current_line.append(node)
1247         yield from super().visit_default(node)
1248
1249     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1250         """Increase indentation level, maybe yield a line."""
1251         # In blib2to3 INDENT never holds comments.
1252         yield from self.line(+1)
1253         yield from self.visit_default(node)
1254
1255     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1256         """Decrease indentation level, maybe yield a line."""
1257         # The current line might still wait for trailing comments.  At DEDENT time
1258         # there won't be any (they would be prefixes on the preceding NEWLINE).
1259         # Emit the line then.
1260         yield from self.line()
1261
1262         # While DEDENT has no value, its prefix may contain standalone comments
1263         # that belong to the current indentation level.  Get 'em.
1264         yield from self.visit_default(node)
1265
1266         # Finally, emit the dedent.
1267         yield from self.line(-1)
1268
1269     def visit_stmt(
1270         self, node: Node, keywords: Set[str], parens: Set[str]
1271     ) -> Iterator[Line]:
1272         """Visit a statement.
1273
1274         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1275         `def`, `with`, `class`, `assert` and assignments.
1276
1277         The relevant Python language `keywords` for a given statement will be
1278         NAME leaves within it. This methods puts those on a separate line.
1279
1280         `parens` holds a set of string leaf values immeditely after which
1281         invisible parens should be put.
1282         """
1283         normalize_invisible_parens(node, parens_after=parens)
1284         for child in node.children:
1285             if child.type == token.NAME and child.value in keywords:  # type: ignore
1286                 yield from self.line()
1287
1288             yield from self.visit(child)
1289
1290     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1291         """Visit a statement without nested statements."""
1292         is_suite_like = node.parent and node.parent.type in STATEMENT
1293         if is_suite_like:
1294             yield from self.line(+1)
1295             yield from self.visit_default(node)
1296             yield from self.line(-1)
1297
1298         else:
1299             yield from self.line()
1300             yield from self.visit_default(node)
1301
1302     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1303         """Visit `async def`, `async for`, `async with`."""
1304         yield from self.line()
1305
1306         children = iter(node.children)
1307         for child in children:
1308             yield from self.visit(child)
1309
1310             if child.type == token.ASYNC:
1311                 break
1312
1313         internal_stmt = next(children)
1314         for child in internal_stmt.children:
1315             yield from self.visit(child)
1316
1317     def visit_decorators(self, node: Node) -> Iterator[Line]:
1318         """Visit decorators."""
1319         for child in node.children:
1320             yield from self.line()
1321             yield from self.visit(child)
1322
1323     def visit_import_from(self, node: Node) -> Iterator[Line]:
1324         """Visit import_from and maybe put invisible parentheses.
1325
1326         This is separate from `visit_stmt` because import statements don't
1327         support arbitrary atoms and thus handling of parentheses is custom.
1328         """
1329         check_lpar = False
1330         for index, child in enumerate(node.children):
1331             if check_lpar:
1332                 if child.type == token.LPAR:
1333                     # make parentheses invisible
1334                     child.value = ""  # type: ignore
1335                     node.children[-1].value = ""  # type: ignore
1336                 else:
1337                     # insert invisible parentheses
1338                     node.insert_child(index, Leaf(token.LPAR, ""))
1339                     node.append_child(Leaf(token.RPAR, ""))
1340                 break
1341
1342             check_lpar = (
1343                 child.type == token.NAME and child.value == "import"  # type: ignore
1344             )
1345
1346         for child in node.children:
1347             yield from self.visit(child)
1348
1349     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1350         """Remove a semicolon and put the other statement on a separate line."""
1351         yield from self.line()
1352
1353     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1354         """End of file. Process outstanding comments and end with a newline."""
1355         yield from self.visit_default(leaf)
1356         yield from self.line()
1357
1358     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1359         """Used when file contained a `# fmt: off`."""
1360         if isinstance(node, Node):
1361             for child in node.children:
1362                 yield from self.visit(child)
1363
1364         else:
1365             try:
1366                 self.current_line.append(node)
1367             except FormatOn as f_on:
1368                 f_on.trim_prefix(node)
1369                 yield from self.line()
1370                 yield from self.visit(node)
1371
1372             if node.type == token.ENDMARKER:
1373                 # somebody decided not to put a final `# fmt: on`
1374                 yield from self.line()
1375
1376     def __attrs_post_init__(self) -> None:
1377         """You are in a twisty little maze of passages."""
1378         v = self.visit_stmt
1379         Ø: Set[str] = set()
1380         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1381         self.visit_if_stmt = partial(
1382             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1383         )
1384         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1385         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1386         self.visit_try_stmt = partial(
1387             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1388         )
1389         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1390         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1391         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1392         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1393         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1394         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1395         self.visit_async_funcdef = self.visit_async_stmt
1396         self.visit_decorated = self.visit_decorators
1397
1398
1399 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1400 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1401 OPENING_BRACKETS = set(BRACKET.keys())
1402 CLOSING_BRACKETS = set(BRACKET.values())
1403 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1404 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1405
1406
1407 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1408     """Return whitespace prefix if needed for the given `leaf`.
1409
1410     `complex_subscript` signals whether the given leaf is part of a subscription
1411     which has non-trivial arguments, like arithmetic expressions or function calls.
1412     """
1413     NO = ""
1414     SPACE = " "
1415     DOUBLESPACE = "  "
1416     t = leaf.type
1417     p = leaf.parent
1418     v = leaf.value
1419     if t in ALWAYS_NO_SPACE:
1420         return NO
1421
1422     if t == token.COMMENT:
1423         return DOUBLESPACE
1424
1425     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1426     if (
1427         t == token.COLON
1428         and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
1429     ):
1430         return NO
1431
1432     prev = leaf.prev_sibling
1433     if not prev:
1434         prevp = preceding_leaf(p)
1435         if not prevp or prevp.type in OPENING_BRACKETS:
1436             return NO
1437
1438         if t == token.COLON:
1439             if prevp.type == token.COLON:
1440                 return NO
1441
1442             elif prevp.type != token.COMMA and not complex_subscript:
1443                 return NO
1444
1445             return SPACE
1446
1447         if prevp.type == token.EQUAL:
1448             if prevp.parent:
1449                 if prevp.parent.type in {
1450                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1451                 }:
1452                     return NO
1453
1454                 elif prevp.parent.type == syms.typedargslist:
1455                     # A bit hacky: if the equal sign has whitespace, it means we
1456                     # previously found it's a typed argument.  So, we're using
1457                     # that, too.
1458                     return prevp.prefix
1459
1460         elif prevp.type in STARS:
1461             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1462                 return NO
1463
1464         elif prevp.type == token.COLON:
1465             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1466                 return SPACE if complex_subscript else NO
1467
1468         elif (
1469             prevp.parent
1470             and prevp.parent.type == syms.factor
1471             and prevp.type in MATH_OPERATORS
1472         ):
1473             return NO
1474
1475         elif (
1476             prevp.type == token.RIGHTSHIFT
1477             and prevp.parent
1478             and prevp.parent.type == syms.shift_expr
1479             and prevp.prev_sibling
1480             and prevp.prev_sibling.type == token.NAME
1481             and prevp.prev_sibling.value == "print"  # type: ignore
1482         ):
1483             # Python 2 print chevron
1484             return NO
1485
1486     elif prev.type in OPENING_BRACKETS:
1487         return NO
1488
1489     if p.type in {syms.parameters, syms.arglist}:
1490         # untyped function signatures or calls
1491         if not prev or prev.type != token.COMMA:
1492             return NO
1493
1494     elif p.type == syms.varargslist:
1495         # lambdas
1496         if prev and prev.type != token.COMMA:
1497             return NO
1498
1499     elif p.type == syms.typedargslist:
1500         # typed function signatures
1501         if not prev:
1502             return NO
1503
1504         if t == token.EQUAL:
1505             if prev.type != syms.tname:
1506                 return NO
1507
1508         elif prev.type == token.EQUAL:
1509             # A bit hacky: if the equal sign has whitespace, it means we
1510             # previously found it's a typed argument.  So, we're using that, too.
1511             return prev.prefix
1512
1513         elif prev.type != token.COMMA:
1514             return NO
1515
1516     elif p.type == syms.tname:
1517         # type names
1518         if not prev:
1519             prevp = preceding_leaf(p)
1520             if not prevp or prevp.type != token.COMMA:
1521                 return NO
1522
1523     elif p.type == syms.trailer:
1524         # attributes and calls
1525         if t == token.LPAR or t == token.RPAR:
1526             return NO
1527
1528         if not prev:
1529             if t == token.DOT:
1530                 prevp = preceding_leaf(p)
1531                 if not prevp or prevp.type != token.NUMBER:
1532                     return NO
1533
1534             elif t == token.LSQB:
1535                 return NO
1536
1537         elif prev.type != token.COMMA:
1538             return NO
1539
1540     elif p.type == syms.argument:
1541         # single argument
1542         if t == token.EQUAL:
1543             return NO
1544
1545         if not prev:
1546             prevp = preceding_leaf(p)
1547             if not prevp or prevp.type == token.LPAR:
1548                 return NO
1549
1550         elif prev.type in {token.EQUAL} | STARS:
1551             return NO
1552
1553     elif p.type == syms.decorator:
1554         # decorators
1555         return NO
1556
1557     elif p.type == syms.dotted_name:
1558         if prev:
1559             return NO
1560
1561         prevp = preceding_leaf(p)
1562         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1563             return NO
1564
1565     elif p.type == syms.classdef:
1566         if t == token.LPAR:
1567             return NO
1568
1569         if prev and prev.type == token.LPAR:
1570             return NO
1571
1572     elif p.type in {syms.subscript, syms.sliceop}:
1573         # indexing
1574         if not prev:
1575             assert p.parent is not None, "subscripts are always parented"
1576             if p.parent.type == syms.subscriptlist:
1577                 return SPACE
1578
1579             return NO
1580
1581         elif not complex_subscript:
1582             return NO
1583
1584     elif p.type == syms.atom:
1585         if prev and t == token.DOT:
1586             # dots, but not the first one.
1587             return NO
1588
1589     elif p.type == syms.dictsetmaker:
1590         # dict unpacking
1591         if prev and prev.type == token.DOUBLESTAR:
1592             return NO
1593
1594     elif p.type in {syms.factor, syms.star_expr}:
1595         # unary ops
1596         if not prev:
1597             prevp = preceding_leaf(p)
1598             if not prevp or prevp.type in OPENING_BRACKETS:
1599                 return NO
1600
1601             prevp_parent = prevp.parent
1602             assert prevp_parent is not None
1603             if (
1604                 prevp.type == token.COLON
1605                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1606             ):
1607                 return NO
1608
1609             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1610                 return NO
1611
1612         elif t == token.NAME or t == token.NUMBER:
1613             return NO
1614
1615     elif p.type == syms.import_from:
1616         if t == token.DOT:
1617             if prev and prev.type == token.DOT:
1618                 return NO
1619
1620         elif t == token.NAME:
1621             if v == "import":
1622                 return SPACE
1623
1624             if prev and prev.type == token.DOT:
1625                 return NO
1626
1627     elif p.type == syms.sliceop:
1628         return NO
1629
1630     return SPACE
1631
1632
1633 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1634     """Return the first leaf that precedes `node`, if any."""
1635     while node:
1636         res = node.prev_sibling
1637         if res:
1638             if isinstance(res, Leaf):
1639                 return res
1640
1641             try:
1642                 return list(res.leaves())[-1]
1643
1644             except IndexError:
1645                 return None
1646
1647         node = node.parent
1648     return None
1649
1650
1651 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1652     """Return the child of `ancestor` that contains `descendant`."""
1653     node: Optional[LN] = descendant
1654     while node and node.parent != ancestor:
1655         node = node.parent
1656     return node
1657
1658
1659 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1660     """Return the priority of the `leaf` delimiter, given a line break after it.
1661
1662     The delimiter priorities returned here are from those delimiters that would
1663     cause a line break after themselves.
1664
1665     Higher numbers are higher priority.
1666     """
1667     if leaf.type == token.COMMA:
1668         return COMMA_PRIORITY
1669
1670     return 0
1671
1672
1673 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1674     """Return the priority of the `leaf` delimiter, given a line before after it.
1675
1676     The delimiter priorities returned here are from those delimiters that would
1677     cause a line break before themselves.
1678
1679     Higher numbers are higher priority.
1680     """
1681     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1682         # * and ** might also be MATH_OPERATORS but in this case they are not.
1683         # Don't treat them as a delimiter.
1684         return 0
1685
1686     if (
1687         leaf.type in MATH_OPERATORS
1688         and leaf.parent
1689         and leaf.parent.type not in {syms.factor, syms.star_expr}
1690     ):
1691         return MATH_PRIORITIES[leaf.type]
1692
1693     if leaf.type in COMPARATORS:
1694         return COMPARATOR_PRIORITY
1695
1696     if (
1697         leaf.type == token.STRING
1698         and previous is not None
1699         and previous.type == token.STRING
1700     ):
1701         return STRING_PRIORITY
1702
1703     if (
1704         leaf.type == token.NAME
1705         and leaf.value == "for"
1706         and leaf.parent
1707         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1708     ):
1709         return COMPREHENSION_PRIORITY
1710
1711     if (
1712         leaf.type == token.NAME
1713         and leaf.value == "if"
1714         and leaf.parent
1715         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1716     ):
1717         return COMPREHENSION_PRIORITY
1718
1719     if (
1720         leaf.type == token.NAME
1721         and leaf.value in {"if", "else"}
1722         and leaf.parent
1723         and leaf.parent.type == syms.test
1724     ):
1725         return TERNARY_PRIORITY
1726
1727     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1728         return LOGIC_PRIORITY
1729
1730     return 0
1731
1732
1733 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1734     """Clean the prefix of the `leaf` and generate comments from it, if any.
1735
1736     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1737     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1738     move because it does away with modifying the grammar to include all the
1739     possible places in which comments can be placed.
1740
1741     The sad consequence for us though is that comments don't "belong" anywhere.
1742     This is why this function generates simple parentless Leaf objects for
1743     comments.  We simply don't know what the correct parent should be.
1744
1745     No matter though, we can live without this.  We really only need to
1746     differentiate between inline and standalone comments.  The latter don't
1747     share the line with any code.
1748
1749     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1750     are emitted with a fake STANDALONE_COMMENT token identifier.
1751     """
1752     p = leaf.prefix
1753     if not p:
1754         return
1755
1756     if "#" not in p:
1757         return
1758
1759     consumed = 0
1760     nlines = 0
1761     for index, line in enumerate(p.split("\n")):
1762         consumed += len(line) + 1  # adding the length of the split '\n'
1763         line = line.lstrip()
1764         if not line:
1765             nlines += 1
1766         if not line.startswith("#"):
1767             continue
1768
1769         if index == 0 and leaf.type != token.ENDMARKER:
1770             comment_type = token.COMMENT  # simple trailing comment
1771         else:
1772             comment_type = STANDALONE_COMMENT
1773         comment = make_comment(line)
1774         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1775
1776         if comment in {"# fmt: on", "# yapf: enable"}:
1777             raise FormatOn(consumed)
1778
1779         if comment in {"# fmt: off", "# yapf: disable"}:
1780             if comment_type == STANDALONE_COMMENT:
1781                 raise FormatOff(consumed)
1782
1783             prev = preceding_leaf(leaf)
1784             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1785                 raise FormatOff(consumed)
1786
1787         nlines = 0
1788
1789
1790 def make_comment(content: str) -> str:
1791     """Return a consistently formatted comment from the given `content` string.
1792
1793     All comments (except for "##", "#!", "#:") should have a single space between
1794     the hash sign and the content.
1795
1796     If `content` didn't start with a hash sign, one is provided.
1797     """
1798     content = content.rstrip()
1799     if not content:
1800         return "#"
1801
1802     if content[0] == "#":
1803         content = content[1:]
1804     if content and content[0] not in " !:#":
1805         content = " " + content
1806     return "#" + content
1807
1808
1809 def split_line(
1810     line: Line, line_length: int, inner: bool = False, py36: bool = False
1811 ) -> Iterator[Line]:
1812     """Split a `line` into potentially many lines.
1813
1814     They should fit in the allotted `line_length` but might not be able to.
1815     `inner` signifies that there were a pair of brackets somewhere around the
1816     current `line`, possibly transitively. This means we can fallback to splitting
1817     by delimiters if the LHS/RHS don't yield any results.
1818
1819     If `py36` is True, splitting may generate syntax that is only compatible
1820     with Python 3.6 and later.
1821     """
1822     if isinstance(line, UnformattedLines) or line.is_comment:
1823         yield line
1824         return
1825
1826     line_str = str(line).strip("\n")
1827     if (
1828         len(line_str) <= line_length
1829         and "\n" not in line_str  # multiline strings
1830         and not line.contains_standalone_comments()
1831     ):
1832         yield line
1833         return
1834
1835     split_funcs: List[SplitFunc]
1836     if line.is_def:
1837         split_funcs = [left_hand_split]
1838     elif line.is_import:
1839         split_funcs = [explode_split]
1840     elif line.inside_brackets:
1841         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1842     else:
1843         split_funcs = [right_hand_split]
1844     for split_func in split_funcs:
1845         # We are accumulating lines in `result` because we might want to abort
1846         # mission and return the original line in the end, or attempt a different
1847         # split altogether.
1848         result: List[Line] = []
1849         try:
1850             for l in split_func(line, py36):
1851                 if str(l).strip("\n") == line_str:
1852                     raise CannotSplit("Split function returned an unchanged result")
1853
1854                 result.extend(
1855                     split_line(l, line_length=line_length, inner=True, py36=py36)
1856                 )
1857         except CannotSplit as cs:
1858             continue
1859
1860         else:
1861             yield from result
1862             break
1863
1864     else:
1865         yield line
1866
1867
1868 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1869     """Split line into many lines, starting with the first matching bracket pair.
1870
1871     Note: this usually looks weird, only use this for function definitions.
1872     Prefer RHS otherwise.  This is why this function is not symmetrical with
1873     :func:`right_hand_split` which also handles optional parentheses.
1874     """
1875     head = Line(depth=line.depth)
1876     body = Line(depth=line.depth + 1, inside_brackets=True)
1877     tail = Line(depth=line.depth)
1878     tail_leaves: List[Leaf] = []
1879     body_leaves: List[Leaf] = []
1880     head_leaves: List[Leaf] = []
1881     current_leaves = head_leaves
1882     matching_bracket = None
1883     for leaf in line.leaves:
1884         if (
1885             current_leaves is body_leaves
1886             and leaf.type in CLOSING_BRACKETS
1887             and leaf.opening_bracket is matching_bracket
1888         ):
1889             current_leaves = tail_leaves if body_leaves else head_leaves
1890         current_leaves.append(leaf)
1891         if current_leaves is head_leaves:
1892             if leaf.type in OPENING_BRACKETS:
1893                 matching_bracket = leaf
1894                 current_leaves = body_leaves
1895     # Since body is a new indent level, remove spurious leading whitespace.
1896     if body_leaves:
1897         normalize_prefix(body_leaves[0], inside_brackets=True)
1898     # Build the new lines.
1899     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1900         for leaf in leaves:
1901             result.append(leaf, preformatted=True)
1902             for comment_after in line.comments_after(leaf):
1903                 result.append(comment_after, preformatted=True)
1904     bracket_split_succeeded_or_raise(head, body, tail)
1905     for result in (head, body, tail):
1906         if result:
1907             yield result
1908
1909
1910 def right_hand_split(
1911     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1912 ) -> Iterator[Line]:
1913     """Split line into many lines, starting with the last matching bracket pair.
1914
1915     If the split was by optional parentheses, attempt splitting without them, too.
1916     """
1917     head = Line(depth=line.depth)
1918     body = Line(depth=line.depth + 1, inside_brackets=True)
1919     tail = Line(depth=line.depth)
1920     tail_leaves: List[Leaf] = []
1921     body_leaves: List[Leaf] = []
1922     head_leaves: List[Leaf] = []
1923     current_leaves = tail_leaves
1924     opening_bracket = None
1925     closing_bracket = None
1926     for leaf in reversed(line.leaves):
1927         if current_leaves is body_leaves:
1928             if leaf is opening_bracket:
1929                 current_leaves = head_leaves if body_leaves else tail_leaves
1930         current_leaves.append(leaf)
1931         if current_leaves is tail_leaves:
1932             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1933                 opening_bracket = leaf.opening_bracket
1934                 closing_bracket = leaf
1935                 current_leaves = body_leaves
1936     tail_leaves.reverse()
1937     body_leaves.reverse()
1938     head_leaves.reverse()
1939     # Since body is a new indent level, remove spurious leading whitespace.
1940     if body_leaves:
1941         normalize_prefix(body_leaves[0], inside_brackets=True)
1942     elif not head_leaves:
1943         # No `head` and no `body` means the split failed. `tail` has all content.
1944         raise CannotSplit("No brackets found")
1945
1946     # Build the new lines.
1947     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1948         for leaf in leaves:
1949             result.append(leaf, preformatted=True)
1950             for comment_after in line.comments_after(leaf):
1951                 result.append(comment_after, preformatted=True)
1952     bracket_split_succeeded_or_raise(head, body, tail)
1953     assert opening_bracket and closing_bracket
1954     if (
1955         # the opening bracket is an optional paren
1956         opening_bracket.type == token.LPAR
1957         and not opening_bracket.value
1958         # the closing bracket is an optional paren
1959         and closing_bracket.type == token.RPAR
1960         and not closing_bracket.value
1961         # there are no delimiters or standalone comments in the body
1962         and not body.bracket_tracker.delimiters
1963         and not line.contains_standalone_comments(0)
1964         # and it's not an import (optional parens are the only thing we can split
1965         # on in this case; attempting a split without them is a waste of time)
1966         and not line.is_import
1967     ):
1968         omit = {id(closing_bracket), *omit}
1969         try:
1970             yield from right_hand_split(line, py36=py36, omit=omit)
1971             return
1972         except CannotSplit:
1973             pass
1974
1975     ensure_visible(opening_bracket)
1976     ensure_visible(closing_bracket)
1977     for result in (head, body, tail):
1978         if result:
1979             yield result
1980
1981
1982 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1983     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1984
1985     Do nothing otherwise.
1986
1987     A left- or right-hand split is based on a pair of brackets. Content before
1988     (and including) the opening bracket is left on one line, content inside the
1989     brackets is put on a separate line, and finally content starting with and
1990     following the closing bracket is put on a separate line.
1991
1992     Those are called `head`, `body`, and `tail`, respectively. If the split
1993     produced the same line (all content in `head`) or ended up with an empty `body`
1994     and the `tail` is just the closing bracket, then it's considered failed.
1995     """
1996     tail_len = len(str(tail).strip())
1997     if not body:
1998         if tail_len == 0:
1999             raise CannotSplit("Splitting brackets produced the same line")
2000
2001         elif tail_len < 3:
2002             raise CannotSplit(
2003                 f"Splitting brackets on an empty body to save "
2004                 f"{tail_len} characters is not worth it"
2005             )
2006
2007
2008 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2009     """Normalize prefix of the first leaf in every line returned by `split_func`.
2010
2011     This is a decorator over relevant split functions.
2012     """
2013
2014     @wraps(split_func)
2015     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2016         for l in split_func(line, py36):
2017             normalize_prefix(l.leaves[0], inside_brackets=True)
2018             yield l
2019
2020     return split_wrapper
2021
2022
2023 @dont_increase_indentation
2024 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2025     """Split according to delimiters of the highest priority.
2026
2027     If `py36` is True, the split will add trailing commas also in function
2028     signatures that contain `*` and `**`.
2029     """
2030     try:
2031         last_leaf = line.leaves[-1]
2032     except IndexError:
2033         raise CannotSplit("Line empty")
2034
2035     delimiters = line.bracket_tracker.delimiters
2036     try:
2037         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
2038             exclude={id(last_leaf)}
2039         )
2040     except ValueError:
2041         raise CannotSplit("No delimiters found")
2042
2043     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2044     lowest_depth = sys.maxsize
2045     trailing_comma_safe = True
2046
2047     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2048         """Append `leaf` to current line or to new line if appending impossible."""
2049         nonlocal current_line
2050         try:
2051             current_line.append_safe(leaf, preformatted=True)
2052         except ValueError as ve:
2053             yield current_line
2054
2055             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2056             current_line.append(leaf)
2057
2058     for leaf in line.leaves:
2059         yield from append_to_line(leaf)
2060
2061         for comment_after in line.comments_after(leaf):
2062             yield from append_to_line(comment_after)
2063
2064         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2065         if (
2066             leaf.bracket_depth == lowest_depth
2067             and is_vararg(leaf, within=VARARGS_PARENTS)
2068         ):
2069             trailing_comma_safe = trailing_comma_safe and py36
2070         leaf_priority = delimiters.get(id(leaf))
2071         if leaf_priority == delimiter_priority:
2072             yield current_line
2073
2074             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2075     if current_line:
2076         if (
2077             trailing_comma_safe
2078             and delimiter_priority == COMMA_PRIORITY
2079             and current_line.leaves[-1].type != token.COMMA
2080             and current_line.leaves[-1].type != STANDALONE_COMMENT
2081         ):
2082             current_line.append(Leaf(token.COMMA, ","))
2083         yield current_line
2084
2085
2086 @dont_increase_indentation
2087 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2088     """Split standalone comments from the rest of the line."""
2089     if not line.contains_standalone_comments(0):
2090         raise CannotSplit("Line does not have any standalone comments")
2091
2092     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2093
2094     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2095         """Append `leaf` to current line or to new line if appending impossible."""
2096         nonlocal current_line
2097         try:
2098             current_line.append_safe(leaf, preformatted=True)
2099         except ValueError as ve:
2100             yield current_line
2101
2102             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2103             current_line.append(leaf)
2104
2105     for leaf in line.leaves:
2106         yield from append_to_line(leaf)
2107
2108         for comment_after in line.comments_after(leaf):
2109             yield from append_to_line(comment_after)
2110
2111     if current_line:
2112         yield current_line
2113
2114
2115 def explode_split(
2116     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2117 ) -> Iterator[Line]:
2118     """Split by rightmost bracket and immediately split contents by a delimiter."""
2119     new_lines = list(right_hand_split(line, py36, omit))
2120     if len(new_lines) != 3:
2121         yield from new_lines
2122         return
2123
2124     yield new_lines[0]
2125
2126     try:
2127         yield from delimiter_split(new_lines[1], py36)
2128
2129     except CannotSplit:
2130         yield new_lines[1]
2131
2132     yield new_lines[2]
2133
2134
2135 def is_import(leaf: Leaf) -> bool:
2136     """Return True if the given leaf starts an import statement."""
2137     p = leaf.parent
2138     t = leaf.type
2139     v = leaf.value
2140     return bool(
2141         t == token.NAME
2142         and (
2143             (v == "import" and p and p.type == syms.import_name)
2144             or (v == "from" and p and p.type == syms.import_from)
2145         )
2146     )
2147
2148
2149 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2150     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2151     else.
2152
2153     Note: don't use backslashes for formatting or you'll lose your voting rights.
2154     """
2155     if not inside_brackets:
2156         spl = leaf.prefix.split("#")
2157         if "\\" not in spl[0]:
2158             nl_count = spl[-1].count("\n")
2159             if len(spl) > 1:
2160                 nl_count -= 1
2161             leaf.prefix = "\n" * nl_count
2162             return
2163
2164     leaf.prefix = ""
2165
2166
2167 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2168     """Make all string prefixes lowercase.
2169
2170     If remove_u_prefix is given, also removes any u prefix from the string.
2171
2172     Note: Mutates its argument.
2173     """
2174     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2175     assert match is not None, f"failed to match string {leaf.value!r}"
2176     orig_prefix = match.group(1)
2177     new_prefix = orig_prefix.lower()
2178     if remove_u_prefix:
2179         new_prefix = new_prefix.replace("u", "")
2180     leaf.value = f"{new_prefix}{match.group(2)}"
2181
2182
2183 def normalize_string_quotes(leaf: Leaf) -> None:
2184     """Prefer double quotes but only if it doesn't cause more escaping.
2185
2186     Adds or removes backslashes as appropriate. Doesn't parse and fix
2187     strings nested in f-strings (yet).
2188
2189     Note: Mutates its argument.
2190     """
2191     value = leaf.value.lstrip("furbFURB")
2192     if value[:3] == '"""':
2193         return
2194
2195     elif value[:3] == "'''":
2196         orig_quote = "'''"
2197         new_quote = '"""'
2198     elif value[0] == '"':
2199         orig_quote = '"'
2200         new_quote = "'"
2201     else:
2202         orig_quote = "'"
2203         new_quote = '"'
2204     first_quote_pos = leaf.value.find(orig_quote)
2205     if first_quote_pos == -1:
2206         return  # There's an internal error
2207
2208     prefix = leaf.value[:first_quote_pos]
2209     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2210     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2211     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2212     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2213     if "r" in prefix.casefold():
2214         if unescaped_new_quote.search(body):
2215             # There's at least one unescaped new_quote in this raw string
2216             # so converting is impossible
2217             return
2218
2219         # Do not introduce or remove backslashes in raw strings
2220         new_body = body
2221     else:
2222         # remove unnecessary quotes
2223         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2224         if body != new_body:
2225             # Consider the string without unnecessary quotes as the original
2226             body = new_body
2227             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2228         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2229         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2230     if new_quote == '"""' and new_body[-1] == '"':
2231         # edge case:
2232         new_body = new_body[:-1] + '\\"'
2233     orig_escape_count = body.count("\\")
2234     new_escape_count = new_body.count("\\")
2235     if new_escape_count > orig_escape_count:
2236         return  # Do not introduce more escaping
2237
2238     if new_escape_count == orig_escape_count and orig_quote == '"':
2239         return  # Prefer double quotes
2240
2241     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2242
2243
2244 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2245     """Make existing optional parentheses invisible or create new ones.
2246
2247     `parens_after` is a set of string leaf values immeditely after which parens
2248     should be put.
2249
2250     Standardizes on visible parentheses for single-element tuples, and keeps
2251     existing visible parentheses for other tuples and generator expressions.
2252     """
2253     check_lpar = False
2254     for child in list(node.children):
2255         if check_lpar:
2256             if child.type == syms.atom:
2257                 maybe_make_parens_invisible_in_atom(child)
2258             elif is_one_tuple(child):
2259                 # wrap child in visible parentheses
2260                 lpar = Leaf(token.LPAR, "(")
2261                 rpar = Leaf(token.RPAR, ")")
2262                 index = child.remove() or 0
2263                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2264             else:
2265                 # wrap child in invisible parentheses
2266                 lpar = Leaf(token.LPAR, "")
2267                 rpar = Leaf(token.RPAR, "")
2268                 index = child.remove() or 0
2269                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2270
2271         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2272
2273
2274 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2275     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2276     if (
2277         node.type != syms.atom
2278         or is_empty_tuple(node)
2279         or is_one_tuple(node)
2280         or is_yield(node)
2281         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2282     ):
2283         return False
2284
2285     first = node.children[0]
2286     last = node.children[-1]
2287     if first.type == token.LPAR and last.type == token.RPAR:
2288         # make parentheses invisible
2289         first.value = ""  # type: ignore
2290         last.value = ""  # type: ignore
2291         if len(node.children) > 1:
2292             maybe_make_parens_invisible_in_atom(node.children[1])
2293         return True
2294
2295     return False
2296
2297
2298 def is_empty_tuple(node: LN) -> bool:
2299     """Return True if `node` holds an empty tuple."""
2300     return (
2301         node.type == syms.atom
2302         and len(node.children) == 2
2303         and node.children[0].type == token.LPAR
2304         and node.children[1].type == token.RPAR
2305     )
2306
2307
2308 def is_one_tuple(node: LN) -> bool:
2309     """Return True if `node` holds a tuple with one element, with or without parens."""
2310     if node.type == syms.atom:
2311         if len(node.children) != 3:
2312             return False
2313
2314         lpar, gexp, rpar = node.children
2315         if not (
2316             lpar.type == token.LPAR
2317             and gexp.type == syms.testlist_gexp
2318             and rpar.type == token.RPAR
2319         ):
2320             return False
2321
2322         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2323
2324     return (
2325         node.type in IMPLICIT_TUPLE
2326         and len(node.children) == 2
2327         and node.children[1].type == token.COMMA
2328     )
2329
2330
2331 def is_yield(node: LN) -> bool:
2332     """Return True if `node` holds a `yield` or `yield from` expression."""
2333     if node.type == syms.yield_expr:
2334         return True
2335
2336     if node.type == token.NAME and node.value == "yield":  # type: ignore
2337         return True
2338
2339     if node.type != syms.atom:
2340         return False
2341
2342     if len(node.children) != 3:
2343         return False
2344
2345     lpar, expr, rpar = node.children
2346     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2347         return is_yield(expr)
2348
2349     return False
2350
2351
2352 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2353     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2354
2355     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2356     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2357     extended iterable unpacking (PEP 3132) and additional unpacking
2358     generalizations (PEP 448).
2359     """
2360     if leaf.type not in STARS or not leaf.parent:
2361         return False
2362
2363     p = leaf.parent
2364     if p.type == syms.star_expr:
2365         # Star expressions are also used as assignment targets in extended
2366         # iterable unpacking (PEP 3132).  See what its parent is instead.
2367         if not p.parent:
2368             return False
2369
2370         p = p.parent
2371
2372     return p.type in within
2373
2374
2375 def max_delimiter_priority_in_atom(node: LN) -> int:
2376     """Return maximum delimiter priority inside `node`.
2377
2378     This is specific to atoms with contents contained in a pair of parentheses.
2379     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2380     """
2381     if node.type != syms.atom:
2382         return 0
2383
2384     first = node.children[0]
2385     last = node.children[-1]
2386     if not (first.type == token.LPAR and last.type == token.RPAR):
2387         return 0
2388
2389     bt = BracketTracker()
2390     for c in node.children[1:-1]:
2391         if isinstance(c, Leaf):
2392             bt.mark(c)
2393         else:
2394             for leaf in c.leaves():
2395                 bt.mark(leaf)
2396     try:
2397         return bt.max_delimiter_priority()
2398
2399     except ValueError:
2400         return 0
2401
2402
2403 def ensure_visible(leaf: Leaf) -> None:
2404     """Make sure parentheses are visible.
2405
2406     They could be invisible as part of some statements (see
2407     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2408     """
2409     if leaf.type == token.LPAR:
2410         leaf.value = "("
2411     elif leaf.type == token.RPAR:
2412         leaf.value = ")"
2413
2414
2415 def is_python36(node: Node) -> bool:
2416     """Return True if the current file is using Python 3.6+ features.
2417
2418     Currently looking for:
2419     - f-strings; and
2420     - trailing commas after * or ** in function signatures and calls.
2421     """
2422     for n in node.pre_order():
2423         if n.type == token.STRING:
2424             value_head = n.value[:2]  # type: ignore
2425             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2426                 return True
2427
2428         elif (
2429             n.type in {syms.typedargslist, syms.arglist}
2430             and n.children
2431             and n.children[-1].type == token.COMMA
2432         ):
2433             for ch in n.children:
2434                 if ch.type in STARS:
2435                     return True
2436
2437                 if ch.type == syms.argument:
2438                     for argch in ch.children:
2439                         if argch.type in STARS:
2440                             return True
2441
2442     return False
2443
2444
2445 def get_future_imports(node: Node) -> Set[str]:
2446     """Return a set of __future__ imports in the file."""
2447     imports = set()
2448     for child in node.children:
2449         if child.type != syms.simple_stmt:
2450             break
2451         first_child = child.children[0]
2452         if isinstance(first_child, Leaf):
2453             # Continue looking if we see a docstring; otherwise stop.
2454             if (
2455                 len(child.children) == 2
2456                 and first_child.type == token.STRING
2457                 and child.children[1].type == token.NEWLINE
2458             ):
2459                 continue
2460             else:
2461                 break
2462         elif first_child.type == syms.import_from:
2463             module_name = first_child.children[1]
2464             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2465                 break
2466             for import_from_child in first_child.children[3:]:
2467                 if isinstance(import_from_child, Leaf):
2468                     if import_from_child.type == token.NAME:
2469                         imports.add(import_from_child.value)
2470                 else:
2471                     assert import_from_child.type == syms.import_as_names
2472                     for leaf in import_from_child.children:
2473                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2474                             imports.add(leaf.value)
2475         else:
2476             break
2477     return imports
2478
2479
2480 PYTHON_EXTENSIONS = {".py"}
2481 BLACKLISTED_DIRECTORIES = {
2482     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2483 }
2484
2485
2486 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2487     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2488     and have one of the PYTHON_EXTENSIONS.
2489     """
2490     for child in path.iterdir():
2491         if child.is_dir():
2492             if child.name in BLACKLISTED_DIRECTORIES:
2493                 continue
2494
2495             yield from gen_python_files_in_dir(child)
2496
2497         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2498             yield child
2499
2500
2501 @dataclass
2502 class Report:
2503     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2504     check: bool = False
2505     quiet: bool = False
2506     change_count: int = 0
2507     same_count: int = 0
2508     failure_count: int = 0
2509
2510     def done(self, src: Path, changed: Changed) -> None:
2511         """Increment the counter for successful reformatting. Write out a message."""
2512         if changed is Changed.YES:
2513             reformatted = "would reformat" if self.check else "reformatted"
2514             if not self.quiet:
2515                 out(f"{reformatted} {src}")
2516             self.change_count += 1
2517         else:
2518             if not self.quiet:
2519                 if changed is Changed.NO:
2520                     msg = f"{src} already well formatted, good job."
2521                 else:
2522                     msg = f"{src} wasn't modified on disk since last run."
2523                 out(msg, bold=False)
2524             self.same_count += 1
2525
2526     def failed(self, src: Path, message: str) -> None:
2527         """Increment the counter for failed reformatting. Write out a message."""
2528         err(f"error: cannot format {src}: {message}")
2529         self.failure_count += 1
2530
2531     @property
2532     def return_code(self) -> int:
2533         """Return the exit code that the app should use.
2534
2535         This considers the current state of changed files and failures:
2536         - if there were any failures, return 123;
2537         - if any files were changed and --check is being used, return 1;
2538         - otherwise return 0.
2539         """
2540         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2541         # 126 we have special returncodes reserved by the shell.
2542         if self.failure_count:
2543             return 123
2544
2545         elif self.change_count and self.check:
2546             return 1
2547
2548         return 0
2549
2550     def __str__(self) -> str:
2551         """Render a color report of the current state.
2552
2553         Use `click.unstyle` to remove colors.
2554         """
2555         if self.check:
2556             reformatted = "would be reformatted"
2557             unchanged = "would be left unchanged"
2558             failed = "would fail to reformat"
2559         else:
2560             reformatted = "reformatted"
2561             unchanged = "left unchanged"
2562             failed = "failed to reformat"
2563         report = []
2564         if self.change_count:
2565             s = "s" if self.change_count > 1 else ""
2566             report.append(
2567                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2568             )
2569         if self.same_count:
2570             s = "s" if self.same_count > 1 else ""
2571             report.append(f"{self.same_count} file{s} {unchanged}")
2572         if self.failure_count:
2573             s = "s" if self.failure_count > 1 else ""
2574             report.append(
2575                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2576             )
2577         return ", ".join(report) + "."
2578
2579
2580 def assert_equivalent(src: str, dst: str) -> None:
2581     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2582
2583     import ast
2584     import traceback
2585
2586     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2587         """Simple visitor generating strings to compare ASTs by content."""
2588         yield f"{'  ' * depth}{node.__class__.__name__}("
2589
2590         for field in sorted(node._fields):
2591             try:
2592                 value = getattr(node, field)
2593             except AttributeError:
2594                 continue
2595
2596             yield f"{'  ' * (depth+1)}{field}="
2597
2598             if isinstance(value, list):
2599                 for item in value:
2600                     if isinstance(item, ast.AST):
2601                         yield from _v(item, depth + 2)
2602
2603             elif isinstance(value, ast.AST):
2604                 yield from _v(value, depth + 2)
2605
2606             else:
2607                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2608
2609         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2610
2611     try:
2612         src_ast = ast.parse(src)
2613     except Exception as exc:
2614         major, minor = sys.version_info[:2]
2615         raise AssertionError(
2616             f"cannot use --safe with this file; failed to parse source file "
2617             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2618             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2619         )
2620
2621     try:
2622         dst_ast = ast.parse(dst)
2623     except Exception as exc:
2624         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2625         raise AssertionError(
2626             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2627             f"Please report a bug on https://github.com/ambv/black/issues.  "
2628             f"This invalid output might be helpful: {log}"
2629         ) from None
2630
2631     src_ast_str = "\n".join(_v(src_ast))
2632     dst_ast_str = "\n".join(_v(dst_ast))
2633     if src_ast_str != dst_ast_str:
2634         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2635         raise AssertionError(
2636             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2637             f"the source.  "
2638             f"Please report a bug on https://github.com/ambv/black/issues.  "
2639             f"This diff might be helpful: {log}"
2640         ) from None
2641
2642
2643 def assert_stable(src: str, dst: str, line_length: int) -> None:
2644     """Raise AssertionError if `dst` reformats differently the second time."""
2645     newdst = format_str(dst, line_length=line_length)
2646     if dst != newdst:
2647         log = dump_to_file(
2648             diff(src, dst, "source", "first pass"),
2649             diff(dst, newdst, "first pass", "second pass"),
2650         )
2651         raise AssertionError(
2652             f"INTERNAL ERROR: Black produced different code on the second pass "
2653             f"of the formatter.  "
2654             f"Please report a bug on https://github.com/ambv/black/issues.  "
2655             f"This diff might be helpful: {log}"
2656         ) from None
2657
2658
2659 def dump_to_file(*output: str) -> str:
2660     """Dump `output` to a temporary file. Return path to the file."""
2661     import tempfile
2662
2663     with tempfile.NamedTemporaryFile(
2664         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2665     ) as f:
2666         for lines in output:
2667             f.write(lines)
2668             if lines and lines[-1] != "\n":
2669                 f.write("\n")
2670     return f.name
2671
2672
2673 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2674     """Return a unified diff string between strings `a` and `b`."""
2675     import difflib
2676
2677     a_lines = [line + "\n" for line in a.split("\n")]
2678     b_lines = [line + "\n" for line in b.split("\n")]
2679     return "".join(
2680         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2681     )
2682
2683
2684 def cancel(tasks: List[asyncio.Task]) -> None:
2685     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2686     err("Aborted!")
2687     for task in tasks:
2688         task.cancel()
2689
2690
2691 def shutdown(loop: BaseEventLoop) -> None:
2692     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2693     try:
2694         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2695         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2696         if not to_cancel:
2697             return
2698
2699         for task in to_cancel:
2700             task.cancel()
2701         loop.run_until_complete(
2702             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2703         )
2704     finally:
2705         # `concurrent.futures.Future` objects cannot be cancelled once they
2706         # are already running. There might be some when the `shutdown()` happened.
2707         # Silence their logger's spew about the event loop being closed.
2708         cf_logger = logging.getLogger("concurrent.futures")
2709         cf_logger.setLevel(logging.CRITICAL)
2710         loop.close()
2711
2712
2713 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2714     """Replace `regex` with `replacement` twice on `original`.
2715
2716     This is used by string normalization to perform replaces on
2717     overlapping matches.
2718     """
2719     return regex.sub(replacement, regex.sub(replacement, original))
2720
2721
2722 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2723
2724
2725 def get_cache_file(line_length: int) -> Path:
2726     return CACHE_DIR / f"cache.{line_length}.pickle"
2727
2728
2729 def read_cache(line_length: int) -> Cache:
2730     """Read the cache if it exists and is well formed.
2731
2732     If it is not well formed, the call to write_cache later should resolve the issue.
2733     """
2734     cache_file = get_cache_file(line_length)
2735     if not cache_file.exists():
2736         return {}
2737
2738     with cache_file.open("rb") as fobj:
2739         try:
2740             cache: Cache = pickle.load(fobj)
2741         except pickle.UnpicklingError:
2742             return {}
2743
2744     return cache
2745
2746
2747 def get_cache_info(path: Path) -> CacheInfo:
2748     """Return the information used to check if a file is already formatted or not."""
2749     stat = path.stat()
2750     return stat.st_mtime, stat.st_size
2751
2752
2753 def filter_cached(
2754     cache: Cache, sources: Iterable[Path]
2755 ) -> Tuple[List[Path], List[Path]]:
2756     """Split a list of paths into two.
2757
2758     The first list contains paths of files that modified on disk or are not in the
2759     cache. The other list contains paths to non-modified files.
2760     """
2761     todo, done = [], []
2762     for src in sources:
2763         src = src.resolve()
2764         if cache.get(src) != get_cache_info(src):
2765             todo.append(src)
2766         else:
2767             done.append(src)
2768     return todo, done
2769
2770
2771 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2772     """Update the cache file."""
2773     cache_file = get_cache_file(line_length)
2774     try:
2775         if not CACHE_DIR.exists():
2776             CACHE_DIR.mkdir(parents=True)
2777         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2778         with cache_file.open("wb") as fobj:
2779             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2780     except OSError:
2781         pass
2782
2783
2784 if __name__ == "__main__":
2785     main()