]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

e795fa3267e62883bc0811412a673c191d7ebd6b
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187     if check and diff:
188         exc = click.ClickException("Options --check and --diff are mutually exclusive")
189         exc.exit_code = 2
190         raise exc
191
192     if check:
193         write_back = WriteBack.NO
194     elif diff:
195         write_back = WriteBack.DIFF
196     else:
197         write_back = WriteBack.YES
198     if len(sources) == 0:
199         ctx.exit(0)
200         return
201
202     elif len(sources) == 1:
203         return_code = run_single_file_mode(
204             line_length, check, fast, quiet, write_back, sources[0]
205         )
206     else:
207         return_code = run_multi_file_mode(line_length, fast, quiet, write_back, sources)
208     ctx.exit(return_code)
209
210
211 def run_single_file_mode(
212     line_length: int,
213     check: bool,
214     fast: bool,
215     quiet: bool,
216     write_back: WriteBack,
217     src: Path,
218 ) -> int:
219     report = Report(check=check, quiet=quiet)
220     try:
221         if not src.is_file() and str(src) == "-":
222             changed = format_stdin_to_stdout(
223                 line_length=line_length, fast=fast, write_back=write_back
224             )
225         else:
226             changed = Changed.NO
227             cache: Cache = {}
228             if write_back != WriteBack.DIFF:
229                 cache = read_cache()
230                 src = src.resolve()
231                 if src in cache and cache[src] == get_cache_info(src):
232                     changed = Changed.CACHED
233             if changed is not Changed.CACHED:
234                 changed = format_file_in_place(
235                     src, line_length=line_length, fast=fast, write_back=write_back
236                 )
237             if write_back != WriteBack.DIFF and changed is not Changed.NO:
238                 write_cache(cache, [src])
239         report.done(src, changed)
240     except Exception as exc:
241         report.failed(src, str(exc))
242     return report.return_code
243
244
245 def run_multi_file_mode(
246     line_length: int,
247     fast: bool,
248     quiet: bool,
249     write_back: WriteBack,
250     sources: List[Path],
251 ) -> int:
252     loop = asyncio.get_event_loop()
253     executor = ProcessPoolExecutor(max_workers=os.cpu_count())
254     return_code = 1
255     try:
256         return_code = loop.run_until_complete(
257             schedule_formatting(
258                 sources, line_length, write_back, fast, quiet, loop, executor
259             )
260         )
261     finally:
262         shutdown(loop)
263         return return_code
264
265
266 async def schedule_formatting(
267     sources: List[Path],
268     line_length: int,
269     write_back: WriteBack,
270     fast: bool,
271     quiet: bool,
272     loop: BaseEventLoop,
273     executor: Executor,
274 ) -> int:
275     """Run formatting of `sources` in parallel using the provided `executor`.
276
277     (Use ProcessPoolExecutors for actual parallelism.)
278
279     `line_length`, `write_back`, and `fast` options are passed to
280     :func:`format_file_in_place`.
281     """
282     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
283     cache: Cache = {}
284     if write_back != WriteBack.DIFF:
285         cache = read_cache()
286         sources, cached = filter_cached(cache, sources)
287         for src in cached:
288             report.done(src, Changed.CACHED)
289     cancelled = []
290     formatted = []
291     if sources:
292         lock = None
293         if write_back == WriteBack.DIFF:
294             # For diff output, we need locks to ensure we don't interleave output
295             # from different processes.
296             manager = Manager()
297             lock = manager.Lock()
298         tasks = {
299             src: loop.run_in_executor(
300                 executor, format_file_in_place, src, line_length, fast, write_back, lock
301             )
302             for src in sources
303         }
304         _task_values = list(tasks.values())
305         loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
306         loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
307         await asyncio.wait(_task_values)
308         for src, task in tasks.items():
309             if not task.done():
310                 report.failed(src, "timed out, cancelling")
311                 task.cancel()
312                 cancelled.append(task)
313             elif task.cancelled():
314                 cancelled.append(task)
315             elif task.exception():
316                 report.failed(src, str(task.exception()))
317             else:
318                 formatted.append(src)
319                 report.done(src, task.result())
320
321     if cancelled:
322         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
323     elif not quiet:
324         out("All done! ✨ 🍰 ✨")
325     if not quiet:
326         click.echo(str(report))
327
328     if write_back != WriteBack.DIFF and formatted:
329         write_cache(cache, formatted)
330
331     return report.return_code
332
333
334 def format_file_in_place(
335     src: Path,
336     line_length: int,
337     fast: bool,
338     write_back: WriteBack = WriteBack.NO,
339     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
340 ) -> Changed:
341     """Format file under `src` path. Return True if changed.
342
343     If `write_back` is True, write reformatted code back to stdout.
344     `line_length` and `fast` options are passed to :func:`format_file_contents`.
345     """
346
347     with tokenize.open(src) as src_buffer:
348         src_contents = src_buffer.read()
349     try:
350         dst_contents = format_file_contents(
351             src_contents, line_length=line_length, fast=fast
352         )
353     except NothingChanged:
354         return Changed.NO
355
356     if write_back == write_back.YES:
357         with open(src, "w", encoding=src_buffer.encoding) as f:
358             f.write(dst_contents)
359     elif write_back == write_back.DIFF:
360         src_name = f"{src.name}  (original)"
361         dst_name = f"{src.name}  (formatted)"
362         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
363         if lock:
364             lock.acquire()
365         try:
366             sys.stdout.write(diff_contents)
367         finally:
368             if lock:
369                 lock.release()
370     return Changed.YES
371
372
373 def format_stdin_to_stdout(
374     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
375 ) -> Changed:
376     """Format file on stdin. Return True if changed.
377
378     If `write_back` is True, write reformatted code back to stdout.
379     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
380     """
381     src = sys.stdin.read()
382     dst = src
383     try:
384         dst = format_file_contents(src, line_length=line_length, fast=fast)
385         return Changed.YES
386
387     except NothingChanged:
388         return Changed.NO
389
390     finally:
391         if write_back == WriteBack.YES:
392             sys.stdout.write(dst)
393         elif write_back == WriteBack.DIFF:
394             src_name = "<stdin>  (original)"
395             dst_name = "<stdin>  (formatted)"
396             sys.stdout.write(diff(src, dst, src_name, dst_name))
397
398
399 def format_file_contents(
400     src_contents: str, line_length: int, fast: bool
401 ) -> FileContent:
402     """Reformat contents a file and return new contents.
403
404     If `fast` is False, additionally confirm that the reformatted code is
405     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
406     `line_length` is passed to :func:`format_str`.
407     """
408     if src_contents.strip() == "":
409         raise NothingChanged
410
411     dst_contents = format_str(src_contents, line_length=line_length)
412     if src_contents == dst_contents:
413         raise NothingChanged
414
415     if not fast:
416         assert_equivalent(src_contents, dst_contents)
417         assert_stable(src_contents, dst_contents, line_length=line_length)
418     return dst_contents
419
420
421 def format_str(src_contents: str, line_length: int) -> FileContent:
422     """Reformat a string and return new contents.
423
424     `line_length` determines how many characters per line are allowed.
425     """
426     src_node = lib2to3_parse(src_contents)
427     dst_contents = ""
428     lines = LineGenerator()
429     elt = EmptyLineTracker()
430     py36 = is_python36(src_node)
431     empty_line = Line()
432     after = 0
433     for current_line in lines.visit(src_node):
434         for _ in range(after):
435             dst_contents += str(empty_line)
436         before, after = elt.maybe_empty_lines(current_line)
437         for _ in range(before):
438             dst_contents += str(empty_line)
439         for line in split_line(current_line, line_length=line_length, py36=py36):
440             dst_contents += str(line)
441     return dst_contents
442
443
444 GRAMMARS = [
445     pygram.python_grammar_no_print_statement_no_exec_statement,
446     pygram.python_grammar_no_print_statement,
447     pygram.python_grammar_no_exec_statement,
448     pygram.python_grammar,
449 ]
450
451
452 def lib2to3_parse(src_txt: str) -> Node:
453     """Given a string with source, return the lib2to3 Node."""
454     grammar = pygram.python_grammar_no_print_statement
455     if src_txt[-1] != "\n":
456         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
457         src_txt += nl
458     for grammar in GRAMMARS:
459         drv = driver.Driver(grammar, pytree.convert)
460         try:
461             result = drv.parse_string(src_txt, True)
462             break
463
464         except ParseError as pe:
465             lineno, column = pe.context[1]
466             lines = src_txt.splitlines()
467             try:
468                 faulty_line = lines[lineno - 1]
469             except IndexError:
470                 faulty_line = "<line number missing in source>"
471             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
472     else:
473         raise exc from None
474
475     if isinstance(result, Leaf):
476         result = Node(syms.file_input, [result])
477     return result
478
479
480 def lib2to3_unparse(node: Node) -> str:
481     """Given a lib2to3 node, return its string representation."""
482     code = str(node)
483     return code
484
485
486 T = TypeVar("T")
487
488
489 class Visitor(Generic[T]):
490     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
491
492     def visit(self, node: LN) -> Iterator[T]:
493         """Main method to visit `node` and its children.
494
495         It tries to find a `visit_*()` method for the given `node.type`, like
496         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
497         If no dedicated `visit_*()` method is found, chooses `visit_default()`
498         instead.
499
500         Then yields objects of type `T` from the selected visitor.
501         """
502         if node.type < 256:
503             name = token.tok_name[node.type]
504         else:
505             name = type_repr(node.type)
506         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
507
508     def visit_default(self, node: LN) -> Iterator[T]:
509         """Default `visit_*()` implementation. Recurses to children of `node`."""
510         if isinstance(node, Node):
511             for child in node.children:
512                 yield from self.visit(child)
513
514
515 @dataclass
516 class DebugVisitor(Visitor[T]):
517     tree_depth: int = 0
518
519     def visit_default(self, node: LN) -> Iterator[T]:
520         indent = " " * (2 * self.tree_depth)
521         if isinstance(node, Node):
522             _type = type_repr(node.type)
523             out(f"{indent}{_type}", fg="yellow")
524             self.tree_depth += 1
525             for child in node.children:
526                 yield from self.visit(child)
527
528             self.tree_depth -= 1
529             out(f"{indent}/{_type}", fg="yellow", bold=False)
530         else:
531             _type = token.tok_name.get(node.type, str(node.type))
532             out(f"{indent}{_type}", fg="blue", nl=False)
533             if node.prefix:
534                 # We don't have to handle prefixes for `Node` objects since
535                 # that delegates to the first child anyway.
536                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
537             out(f" {node.value!r}", fg="blue", bold=False)
538
539     @classmethod
540     def show(cls, code: str) -> None:
541         """Pretty-print the lib2to3 AST of a given string of `code`.
542
543         Convenience method for debugging.
544         """
545         v: DebugVisitor[None] = DebugVisitor()
546         list(v.visit(lib2to3_parse(code)))
547
548
549 KEYWORDS = set(keyword.kwlist)
550 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
551 FLOW_CONTROL = {"return", "raise", "break", "continue"}
552 STATEMENT = {
553     syms.if_stmt,
554     syms.while_stmt,
555     syms.for_stmt,
556     syms.try_stmt,
557     syms.except_clause,
558     syms.with_stmt,
559     syms.funcdef,
560     syms.classdef,
561 }
562 STANDALONE_COMMENT = 153
563 LOGIC_OPERATORS = {"and", "or"}
564 COMPARATORS = {
565     token.LESS,
566     token.GREATER,
567     token.EQEQUAL,
568     token.NOTEQUAL,
569     token.LESSEQUAL,
570     token.GREATEREQUAL,
571 }
572 MATH_OPERATORS = {
573     token.PLUS,
574     token.MINUS,
575     token.STAR,
576     token.SLASH,
577     token.VBAR,
578     token.AMPER,
579     token.PERCENT,
580     token.CIRCUMFLEX,
581     token.TILDE,
582     token.LEFTSHIFT,
583     token.RIGHTSHIFT,
584     token.DOUBLESTAR,
585     token.DOUBLESLASH,
586 }
587 STARS = {token.STAR, token.DOUBLESTAR}
588 VARARGS_PARENTS = {
589     syms.arglist,
590     syms.argument,  # double star in arglist
591     syms.trailer,  # single argument to call
592     syms.typedargslist,
593     syms.varargslist,  # lambdas
594 }
595 UNPACKING_PARENTS = {
596     syms.atom,  # single element of a list or set literal
597     syms.dictsetmaker,
598     syms.listmaker,
599     syms.testlist_gexp,
600 }
601 COMPREHENSION_PRIORITY = 20
602 COMMA_PRIORITY = 10
603 LOGIC_PRIORITY = 5
604 STRING_PRIORITY = 4
605 COMPARATOR_PRIORITY = 3
606 MATH_PRIORITY = 1
607
608
609 @dataclass
610 class BracketTracker:
611     """Keeps track of brackets on a line."""
612
613     depth: int = 0
614     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
615     delimiters: Dict[LeafID, Priority] = Factory(dict)
616     previous: Optional[Leaf] = None
617
618     def mark(self, leaf: Leaf) -> None:
619         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
620
621         All leaves receive an int `bracket_depth` field that stores how deep
622         within brackets a given leaf is. 0 means there are no enclosing brackets
623         that started on this line.
624
625         If a leaf is itself a closing bracket, it receives an `opening_bracket`
626         field that it forms a pair with. This is a one-directional link to
627         avoid reference cycles.
628
629         If a leaf is a delimiter (a token on which Black can split the line if
630         needed) and it's on depth 0, its `id()` is stored in the tracker's
631         `delimiters` field.
632         """
633         if leaf.type == token.COMMENT:
634             return
635
636         if leaf.type in CLOSING_BRACKETS:
637             self.depth -= 1
638             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
639             leaf.opening_bracket = opening_bracket
640         leaf.bracket_depth = self.depth
641         if self.depth == 0:
642             delim = is_split_before_delimiter(leaf, self.previous)
643             if delim and self.previous is not None:
644                 self.delimiters[id(self.previous)] = delim
645             else:
646                 delim = is_split_after_delimiter(leaf, self.previous)
647                 if delim:
648                     self.delimiters[id(leaf)] = delim
649         if leaf.type in OPENING_BRACKETS:
650             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
651             self.depth += 1
652         self.previous = leaf
653
654     def any_open_brackets(self) -> bool:
655         """Return True if there is an yet unmatched open bracket on the line."""
656         return bool(self.bracket_match)
657
658     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
659         """Return the highest priority of a delimiter found on the line.
660
661         Values are consistent with what `is_delimiter()` returns.
662         Raises ValueError on no delimiters.
663         """
664         return max(v for k, v in self.delimiters.items() if k not in exclude)
665
666
667 @dataclass
668 class Line:
669     """Holds leaves and comments. Can be printed with `str(line)`."""
670
671     depth: int = 0
672     leaves: List[Leaf] = Factory(list)
673     comments: List[Tuple[Index, Leaf]] = Factory(list)
674     bracket_tracker: BracketTracker = Factory(BracketTracker)
675     inside_brackets: bool = False
676     has_for: bool = False
677     _for_loop_variable: bool = False
678
679     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
680         """Add a new `leaf` to the end of the line.
681
682         Unless `preformatted` is True, the `leaf` will receive a new consistent
683         whitespace prefix and metadata applied by :class:`BracketTracker`.
684         Trailing commas are maybe removed, unpacked for loop variables are
685         demoted from being delimiters.
686
687         Inline comments are put aside.
688         """
689         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
690         if not has_value:
691             return
692
693         if self.leaves and not preformatted:
694             # Note: at this point leaf.prefix should be empty except for
695             # imports, for which we only preserve newlines.
696             leaf.prefix += whitespace(leaf)
697         if self.inside_brackets or not preformatted:
698             self.maybe_decrement_after_for_loop_variable(leaf)
699             self.bracket_tracker.mark(leaf)
700             self.maybe_remove_trailing_comma(leaf)
701             self.maybe_increment_for_loop_variable(leaf)
702
703         if not self.append_comment(leaf):
704             self.leaves.append(leaf)
705
706     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
707         """Like :func:`append()` but disallow invalid standalone comment structure.
708
709         Raises ValueError when any `leaf` is appended after a standalone comment
710         or when a standalone comment is not the first leaf on the line.
711         """
712         if self.bracket_tracker.depth == 0:
713             if self.is_comment:
714                 raise ValueError("cannot append to standalone comments")
715
716             if self.leaves and leaf.type == STANDALONE_COMMENT:
717                 raise ValueError(
718                     "cannot append standalone comments to a populated line"
719                 )
720
721         self.append(leaf, preformatted=preformatted)
722
723     @property
724     def is_comment(self) -> bool:
725         """Is this line a standalone comment?"""
726         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
727
728     @property
729     def is_decorator(self) -> bool:
730         """Is this line a decorator?"""
731         return bool(self) and self.leaves[0].type == token.AT
732
733     @property
734     def is_import(self) -> bool:
735         """Is this an import line?"""
736         return bool(self) and is_import(self.leaves[0])
737
738     @property
739     def is_class(self) -> bool:
740         """Is this line a class definition?"""
741         return (
742             bool(self)
743             and self.leaves[0].type == token.NAME
744             and self.leaves[0].value == "class"
745         )
746
747     @property
748     def is_def(self) -> bool:
749         """Is this a function definition? (Also returns True for async defs.)"""
750         try:
751             first_leaf = self.leaves[0]
752         except IndexError:
753             return False
754
755         try:
756             second_leaf: Optional[Leaf] = self.leaves[1]
757         except IndexError:
758             second_leaf = None
759         return (
760             (first_leaf.type == token.NAME and first_leaf.value == "def")
761             or (
762                 first_leaf.type == token.ASYNC
763                 and second_leaf is not None
764                 and second_leaf.type == token.NAME
765                 and second_leaf.value == "def"
766             )
767         )
768
769     @property
770     def is_flow_control(self) -> bool:
771         """Is this line a flow control statement?
772
773         Those are `return`, `raise`, `break`, and `continue`.
774         """
775         return (
776             bool(self)
777             and self.leaves[0].type == token.NAME
778             and self.leaves[0].value in FLOW_CONTROL
779         )
780
781     @property
782     def is_yield(self) -> bool:
783         """Is this line a yield statement?"""
784         return (
785             bool(self)
786             and self.leaves[0].type == token.NAME
787             and self.leaves[0].value == "yield"
788         )
789
790     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
791         """If so, needs to be split before emitting."""
792         for leaf in self.leaves:
793             if leaf.type == STANDALONE_COMMENT:
794                 if leaf.bracket_depth <= depth_limit:
795                     return True
796
797         return False
798
799     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
800         """Remove trailing comma if there is one and it's safe."""
801         if not (
802             self.leaves
803             and self.leaves[-1].type == token.COMMA
804             and closing.type in CLOSING_BRACKETS
805         ):
806             return False
807
808         if closing.type == token.RBRACE:
809             self.remove_trailing_comma()
810             return True
811
812         if closing.type == token.RSQB:
813             comma = self.leaves[-1]
814             if comma.parent and comma.parent.type == syms.listmaker:
815                 self.remove_trailing_comma()
816                 return True
817
818         # For parens let's check if it's safe to remove the comma.  If the
819         # trailing one is the only one, we might mistakenly change a tuple
820         # into a different type by removing the comma.
821         depth = closing.bracket_depth + 1
822         commas = 0
823         opening = closing.opening_bracket
824         for _opening_index, leaf in enumerate(self.leaves):
825             if leaf is opening:
826                 break
827
828         else:
829             return False
830
831         for leaf in self.leaves[_opening_index + 1:]:
832             if leaf is closing:
833                 break
834
835             bracket_depth = leaf.bracket_depth
836             if bracket_depth == depth and leaf.type == token.COMMA:
837                 commas += 1
838                 if leaf.parent and leaf.parent.type == syms.arglist:
839                     commas += 1
840                     break
841
842         if commas > 1:
843             self.remove_trailing_comma()
844             return True
845
846         return False
847
848     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
849         """In a for loop, or comprehension, the variables are often unpacks.
850
851         To avoid splitting on the comma in this situation, increase the depth of
852         tokens between `for` and `in`.
853         """
854         if leaf.type == token.NAME and leaf.value == "for":
855             self.has_for = True
856             self.bracket_tracker.depth += 1
857             self._for_loop_variable = True
858             return True
859
860         return False
861
862     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
863         """See `maybe_increment_for_loop_variable` above for explanation."""
864         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
865             self.bracket_tracker.depth -= 1
866             self._for_loop_variable = False
867             return True
868
869         return False
870
871     def append_comment(self, comment: Leaf) -> bool:
872         """Add an inline or standalone comment to the line."""
873         if (
874             comment.type == STANDALONE_COMMENT
875             and self.bracket_tracker.any_open_brackets()
876         ):
877             comment.prefix = ""
878             return False
879
880         if comment.type != token.COMMENT:
881             return False
882
883         after = len(self.leaves) - 1
884         if after == -1:
885             comment.type = STANDALONE_COMMENT
886             comment.prefix = ""
887             return False
888
889         else:
890             self.comments.append((after, comment))
891             return True
892
893     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
894         """Generate comments that should appear directly after `leaf`."""
895         for _leaf_index, _leaf in enumerate(self.leaves):
896             if leaf is _leaf:
897                 break
898
899         else:
900             return
901
902         for index, comment_after in self.comments:
903             if _leaf_index == index:
904                 yield comment_after
905
906     def remove_trailing_comma(self) -> None:
907         """Remove the trailing comma and moves the comments attached to it."""
908         comma_index = len(self.leaves) - 1
909         for i in range(len(self.comments)):
910             comment_index, comment = self.comments[i]
911             if comment_index == comma_index:
912                 self.comments[i] = (comma_index - 1, comment)
913         self.leaves.pop()
914
915     def __str__(self) -> str:
916         """Render the line."""
917         if not self:
918             return "\n"
919
920         indent = "    " * self.depth
921         leaves = iter(self.leaves)
922         first = next(leaves)
923         res = f"{first.prefix}{indent}{first.value}"
924         for leaf in leaves:
925             res += str(leaf)
926         for _, comment in self.comments:
927             res += str(comment)
928         return res + "\n"
929
930     def __bool__(self) -> bool:
931         """Return True if the line has leaves or comments."""
932         return bool(self.leaves or self.comments)
933
934
935 class UnformattedLines(Line):
936     """Just like :class:`Line` but stores lines which aren't reformatted."""
937
938     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
939         """Just add a new `leaf` to the end of the lines.
940
941         The `preformatted` argument is ignored.
942
943         Keeps track of indentation `depth`, which is useful when the user
944         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
945         """
946         try:
947             list(generate_comments(leaf))
948         except FormatOn as f_on:
949             self.leaves.append(f_on.leaf_from_consumed(leaf))
950             raise
951
952         self.leaves.append(leaf)
953         if leaf.type == token.INDENT:
954             self.depth += 1
955         elif leaf.type == token.DEDENT:
956             self.depth -= 1
957
958     def __str__(self) -> str:
959         """Render unformatted lines from leaves which were added with `append()`.
960
961         `depth` is not used for indentation in this case.
962         """
963         if not self:
964             return "\n"
965
966         res = ""
967         for leaf in self.leaves:
968             res += str(leaf)
969         return res
970
971     def append_comment(self, comment: Leaf) -> bool:
972         """Not implemented in this class. Raises `NotImplementedError`."""
973         raise NotImplementedError("Unformatted lines don't store comments separately.")
974
975     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
976         """Does nothing and returns False."""
977         return False
978
979     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
980         """Does nothing and returns False."""
981         return False
982
983
984 @dataclass
985 class EmptyLineTracker:
986     """Provides a stateful method that returns the number of potential extra
987     empty lines needed before and after the currently processed line.
988
989     Note: this tracker works on lines that haven't been split yet.  It assumes
990     the prefix of the first leaf consists of optional newlines.  Those newlines
991     are consumed by `maybe_empty_lines()` and included in the computation.
992     """
993     previous_line: Optional[Line] = None
994     previous_after: int = 0
995     previous_defs: List[int] = Factory(list)
996
997     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
998         """Return the number of extra empty lines before and after the `current_line`.
999
1000         This is for separating `def`, `async def` and `class` with extra empty
1001         lines (two on module-level), as well as providing an extra empty line
1002         after flow control keywords to make them more prominent.
1003         """
1004         if isinstance(current_line, UnformattedLines):
1005             return 0, 0
1006
1007         before, after = self._maybe_empty_lines(current_line)
1008         before -= self.previous_after
1009         self.previous_after = after
1010         self.previous_line = current_line
1011         return before, after
1012
1013     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1014         max_allowed = 1
1015         if current_line.depth == 0:
1016             max_allowed = 2
1017         if current_line.leaves:
1018             # Consume the first leaf's extra newlines.
1019             first_leaf = current_line.leaves[0]
1020             before = first_leaf.prefix.count("\n")
1021             before = min(before, max_allowed)
1022             first_leaf.prefix = ""
1023         else:
1024             before = 0
1025         depth = current_line.depth
1026         while self.previous_defs and self.previous_defs[-1] >= depth:
1027             self.previous_defs.pop()
1028             before = 1 if depth else 2
1029         is_decorator = current_line.is_decorator
1030         if is_decorator or current_line.is_def or current_line.is_class:
1031             if not is_decorator:
1032                 self.previous_defs.append(depth)
1033             if self.previous_line is None:
1034                 # Don't insert empty lines before the first line in the file.
1035                 return 0, 0
1036
1037             if self.previous_line and self.previous_line.is_decorator:
1038                 # Don't insert empty lines between decorators.
1039                 return 0, 0
1040
1041             newlines = 2
1042             if current_line.depth:
1043                 newlines -= 1
1044             return newlines, 0
1045
1046         if current_line.is_flow_control:
1047             return before, 1
1048
1049         if (
1050             self.previous_line
1051             and self.previous_line.is_import
1052             and not current_line.is_import
1053             and depth == self.previous_line.depth
1054         ):
1055             return (before or 1), 0
1056
1057         if (
1058             self.previous_line
1059             and self.previous_line.is_yield
1060             and (not current_line.is_yield or depth != self.previous_line.depth)
1061         ):
1062             return (before or 1), 0
1063
1064         return before, 0
1065
1066
1067 @dataclass
1068 class LineGenerator(Visitor[Line]):
1069     """Generates reformatted Line objects.  Empty lines are not emitted.
1070
1071     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1072     in ways that will no longer stringify to valid Python code on the tree.
1073     """
1074     current_line: Line = Factory(Line)
1075
1076     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1077         """Generate a line.
1078
1079         If the line is empty, only emit if it makes sense.
1080         If the line is too long, split it first and then generate.
1081
1082         If any lines were generated, set up a new current_line.
1083         """
1084         if not self.current_line:
1085             if self.current_line.__class__ == type:
1086                 self.current_line.depth += indent
1087             else:
1088                 self.current_line = type(depth=self.current_line.depth + indent)
1089             return  # Line is empty, don't emit. Creating a new one unnecessary.
1090
1091         complete_line = self.current_line
1092         self.current_line = type(depth=complete_line.depth + indent)
1093         yield complete_line
1094
1095     def visit(self, node: LN) -> Iterator[Line]:
1096         """Main method to visit `node` and its children.
1097
1098         Yields :class:`Line` objects.
1099         """
1100         if isinstance(self.current_line, UnformattedLines):
1101             # File contained `# fmt: off`
1102             yield from self.visit_unformatted(node)
1103
1104         else:
1105             yield from super().visit(node)
1106
1107     def visit_default(self, node: LN) -> Iterator[Line]:
1108         """Default `visit_*()` implementation. Recurses to children of `node`."""
1109         if isinstance(node, Leaf):
1110             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1111             try:
1112                 for comment in generate_comments(node):
1113                     if any_open_brackets:
1114                         # any comment within brackets is subject to splitting
1115                         self.current_line.append(comment)
1116                     elif comment.type == token.COMMENT:
1117                         # regular trailing comment
1118                         self.current_line.append(comment)
1119                         yield from self.line()
1120
1121                     else:
1122                         # regular standalone comment
1123                         yield from self.line()
1124
1125                         self.current_line.append(comment)
1126                         yield from self.line()
1127
1128             except FormatOff as f_off:
1129                 f_off.trim_prefix(node)
1130                 yield from self.line(type=UnformattedLines)
1131                 yield from self.visit(node)
1132
1133             except FormatOn as f_on:
1134                 # This only happens here if somebody says "fmt: on" multiple
1135                 # times in a row.
1136                 f_on.trim_prefix(node)
1137                 yield from self.visit_default(node)
1138
1139             else:
1140                 normalize_prefix(node, inside_brackets=any_open_brackets)
1141                 if node.type == token.STRING:
1142                     normalize_string_quotes(node)
1143                 if node.type not in WHITESPACE:
1144                     self.current_line.append(node)
1145         yield from super().visit_default(node)
1146
1147     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1148         """Increase indentation level, maybe yield a line."""
1149         # In blib2to3 INDENT never holds comments.
1150         yield from self.line(+1)
1151         yield from self.visit_default(node)
1152
1153     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1154         """Decrease indentation level, maybe yield a line."""
1155         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1156         yield from self.line(-1)
1157
1158     def visit_stmt(
1159         self, node: Node, keywords: Set[str], parens: Set[str]
1160     ) -> Iterator[Line]:
1161         """Visit a statement.
1162
1163         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1164         `def`, `with`, `class`, and `assert`.
1165
1166         The relevant Python language `keywords` for a given statement will be
1167         NAME leaves within it. This methods puts those on a separate line.
1168
1169         `parens` holds pairs of nodes where invisible parentheses should be put.
1170         Keys hold nodes after which opening parentheses should be put, values
1171         hold nodes before which closing parentheses should be put.
1172         """
1173         normalize_invisible_parens(node, parens_after=parens)
1174         for child in node.children:
1175             if child.type == token.NAME and child.value in keywords:  # type: ignore
1176                 yield from self.line()
1177
1178             yield from self.visit(child)
1179
1180     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1181         """Visit a statement without nested statements."""
1182         is_suite_like = node.parent and node.parent.type in STATEMENT
1183         if is_suite_like:
1184             yield from self.line(+1)
1185             yield from self.visit_default(node)
1186             yield from self.line(-1)
1187
1188         else:
1189             yield from self.line()
1190             yield from self.visit_default(node)
1191
1192     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1193         """Visit `async def`, `async for`, `async with`."""
1194         yield from self.line()
1195
1196         children = iter(node.children)
1197         for child in children:
1198             yield from self.visit(child)
1199
1200             if child.type == token.ASYNC:
1201                 break
1202
1203         internal_stmt = next(children)
1204         for child in internal_stmt.children:
1205             yield from self.visit(child)
1206
1207     def visit_decorators(self, node: Node) -> Iterator[Line]:
1208         """Visit decorators."""
1209         for child in node.children:
1210             yield from self.line()
1211             yield from self.visit(child)
1212
1213     def visit_import_from(self, node: Node) -> Iterator[Line]:
1214         """Visit import_from and maybe put invisible parentheses.
1215
1216         This is separate from `visit_stmt` because import statements don't
1217         support arbitrary atoms and thus handling of parentheses is custom.
1218         """
1219         check_lpar = False
1220         for index, child in enumerate(node.children):
1221             if check_lpar:
1222                 if child.type == token.LPAR:
1223                     # make parentheses invisible
1224                     child.value = ""  # type: ignore
1225                     node.children[-1].value = ""  # type: ignore
1226                 else:
1227                     # insert invisible parentheses
1228                     node.insert_child(index, Leaf(token.LPAR, ""))
1229                     node.append_child(Leaf(token.RPAR, ""))
1230                 break
1231
1232             check_lpar = (
1233                 child.type == token.NAME and child.value == "import"  # type: ignore
1234             )
1235
1236         for child in node.children:
1237             yield from self.visit(child)
1238
1239     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1240         """Remove a semicolon and put the other statement on a separate line."""
1241         yield from self.line()
1242
1243     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1244         """End of file. Process outstanding comments and end with a newline."""
1245         yield from self.visit_default(leaf)
1246         yield from self.line()
1247
1248     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1249         """Used when file contained a `# fmt: off`."""
1250         if isinstance(node, Node):
1251             for child in node.children:
1252                 yield from self.visit(child)
1253
1254         else:
1255             try:
1256                 self.current_line.append(node)
1257             except FormatOn as f_on:
1258                 f_on.trim_prefix(node)
1259                 yield from self.line()
1260                 yield from self.visit(node)
1261
1262             if node.type == token.ENDMARKER:
1263                 # somebody decided not to put a final `# fmt: on`
1264                 yield from self.line()
1265
1266     def __attrs_post_init__(self) -> None:
1267         """You are in a twisty little maze of passages."""
1268         v = self.visit_stmt
1269         Ø: Set[str] = set()
1270         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1271         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1272         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1273         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1274         self.visit_try_stmt = partial(
1275             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1276         )
1277         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1278         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1279         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1280         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1281         self.visit_async_funcdef = self.visit_async_stmt
1282         self.visit_decorated = self.visit_decorators
1283
1284
1285 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1286 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1287 OPENING_BRACKETS = set(BRACKET.keys())
1288 CLOSING_BRACKETS = set(BRACKET.values())
1289 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1290 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1291
1292
1293 def whitespace(leaf: Leaf) -> str:  # noqa C901
1294     """Return whitespace prefix if needed for the given `leaf`."""
1295     NO = ""
1296     SPACE = " "
1297     DOUBLESPACE = "  "
1298     t = leaf.type
1299     p = leaf.parent
1300     v = leaf.value
1301     if t in ALWAYS_NO_SPACE:
1302         return NO
1303
1304     if t == token.COMMENT:
1305         return DOUBLESPACE
1306
1307     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1308     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1309         return NO
1310
1311     prev = leaf.prev_sibling
1312     if not prev:
1313         prevp = preceding_leaf(p)
1314         if not prevp or prevp.type in OPENING_BRACKETS:
1315             return NO
1316
1317         if t == token.COLON:
1318             return SPACE if prevp.type == token.COMMA else NO
1319
1320         if prevp.type == token.EQUAL:
1321             if prevp.parent:
1322                 if prevp.parent.type in {
1323                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1324                 }:
1325                     return NO
1326
1327                 elif prevp.parent.type == syms.typedargslist:
1328                     # A bit hacky: if the equal sign has whitespace, it means we
1329                     # previously found it's a typed argument.  So, we're using
1330                     # that, too.
1331                     return prevp.prefix
1332
1333         elif prevp.type in STARS:
1334             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1335                 return NO
1336
1337         elif prevp.type == token.COLON:
1338             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1339                 return NO
1340
1341         elif (
1342             prevp.parent
1343             and prevp.parent.type == syms.factor
1344             and prevp.type in MATH_OPERATORS
1345         ):
1346             return NO
1347
1348         elif (
1349             prevp.type == token.RIGHTSHIFT
1350             and prevp.parent
1351             and prevp.parent.type == syms.shift_expr
1352             and prevp.prev_sibling
1353             and prevp.prev_sibling.type == token.NAME
1354             and prevp.prev_sibling.value == "print"  # type: ignore
1355         ):
1356             # Python 2 print chevron
1357             return NO
1358
1359     elif prev.type in OPENING_BRACKETS:
1360         return NO
1361
1362     if p.type in {syms.parameters, syms.arglist}:
1363         # untyped function signatures or calls
1364         if t == token.RPAR:
1365             return NO
1366
1367         if not prev or prev.type != token.COMMA:
1368             return NO
1369
1370     elif p.type == syms.varargslist:
1371         # lambdas
1372         if t == token.RPAR:
1373             return NO
1374
1375         if prev and prev.type != token.COMMA:
1376             return NO
1377
1378     elif p.type == syms.typedargslist:
1379         # typed function signatures
1380         if not prev:
1381             return NO
1382
1383         if t == token.EQUAL:
1384             if prev.type != syms.tname:
1385                 return NO
1386
1387         elif prev.type == token.EQUAL:
1388             # A bit hacky: if the equal sign has whitespace, it means we
1389             # previously found it's a typed argument.  So, we're using that, too.
1390             return prev.prefix
1391
1392         elif prev.type != token.COMMA:
1393             return NO
1394
1395     elif p.type == syms.tname:
1396         # type names
1397         if not prev:
1398             prevp = preceding_leaf(p)
1399             if not prevp or prevp.type != token.COMMA:
1400                 return NO
1401
1402     elif p.type == syms.trailer:
1403         # attributes and calls
1404         if t == token.LPAR or t == token.RPAR:
1405             return NO
1406
1407         if not prev:
1408             if t == token.DOT:
1409                 prevp = preceding_leaf(p)
1410                 if not prevp or prevp.type != token.NUMBER:
1411                     return NO
1412
1413             elif t == token.LSQB:
1414                 return NO
1415
1416         elif prev.type != token.COMMA:
1417             return NO
1418
1419     elif p.type == syms.argument:
1420         # single argument
1421         if t == token.EQUAL:
1422             return NO
1423
1424         if not prev:
1425             prevp = preceding_leaf(p)
1426             if not prevp or prevp.type == token.LPAR:
1427                 return NO
1428
1429         elif prev.type in {token.EQUAL} | STARS:
1430             return NO
1431
1432     elif p.type == syms.decorator:
1433         # decorators
1434         return NO
1435
1436     elif p.type == syms.dotted_name:
1437         if prev:
1438             return NO
1439
1440         prevp = preceding_leaf(p)
1441         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1442             return NO
1443
1444     elif p.type == syms.classdef:
1445         if t == token.LPAR:
1446             return NO
1447
1448         if prev and prev.type == token.LPAR:
1449             return NO
1450
1451     elif p.type == syms.subscript:
1452         # indexing
1453         if not prev:
1454             assert p.parent is not None, "subscripts are always parented"
1455             if p.parent.type == syms.subscriptlist:
1456                 return SPACE
1457
1458             return NO
1459
1460         else:
1461             return NO
1462
1463     elif p.type == syms.atom:
1464         if prev and t == token.DOT:
1465             # dots, but not the first one.
1466             return NO
1467
1468     elif (
1469         p.type == syms.listmaker
1470         or p.type == syms.testlist_gexp
1471         or p.type == syms.subscriptlist
1472     ):
1473         # list interior, including unpacking
1474         if not prev:
1475             return NO
1476
1477     elif p.type == syms.dictsetmaker:
1478         # dict and set interior, including unpacking
1479         if not prev:
1480             return NO
1481
1482         if prev.type == token.DOUBLESTAR:
1483             return NO
1484
1485     elif p.type in {syms.factor, syms.star_expr}:
1486         # unary ops
1487         if not prev:
1488             prevp = preceding_leaf(p)
1489             if not prevp or prevp.type in OPENING_BRACKETS:
1490                 return NO
1491
1492             prevp_parent = prevp.parent
1493             assert prevp_parent is not None
1494             if (
1495                 prevp.type == token.COLON
1496                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1497             ):
1498                 return NO
1499
1500             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1501                 return NO
1502
1503         elif t == token.NAME or t == token.NUMBER:
1504             return NO
1505
1506     elif p.type == syms.import_from:
1507         if t == token.DOT:
1508             if prev and prev.type == token.DOT:
1509                 return NO
1510
1511         elif t == token.NAME:
1512             if v == "import":
1513                 return SPACE
1514
1515             if prev and prev.type == token.DOT:
1516                 return NO
1517
1518     elif p.type == syms.sliceop:
1519         return NO
1520
1521     return SPACE
1522
1523
1524 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1525     """Return the first leaf that precedes `node`, if any."""
1526     while node:
1527         res = node.prev_sibling
1528         if res:
1529             if isinstance(res, Leaf):
1530                 return res
1531
1532             try:
1533                 return list(res.leaves())[-1]
1534
1535             except IndexError:
1536                 return None
1537
1538         node = node.parent
1539     return None
1540
1541
1542 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1543     """Return the priority of the `leaf` delimiter, given a line break after it.
1544
1545     The delimiter priorities returned here are from those delimiters that would
1546     cause a line break after themselves.
1547
1548     Higher numbers are higher priority.
1549     """
1550     if leaf.type == token.COMMA:
1551         return COMMA_PRIORITY
1552
1553     return 0
1554
1555
1556 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1557     """Return the priority of the `leaf` delimiter, given a line before after it.
1558
1559     The delimiter priorities returned here are from those delimiters that would
1560     cause a line break before themselves.
1561
1562     Higher numbers are higher priority.
1563     """
1564     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1565         # * and ** might also be MATH_OPERATORS but in this case they are not.
1566         # Don't treat them as a delimiter.
1567         return 0
1568
1569     if (
1570         leaf.type in MATH_OPERATORS
1571         and leaf.parent
1572         and leaf.parent.type not in {syms.factor, syms.star_expr}
1573     ):
1574         return MATH_PRIORITY
1575
1576     if leaf.type in COMPARATORS:
1577         return COMPARATOR_PRIORITY
1578
1579     if (
1580         leaf.type == token.STRING
1581         and previous is not None
1582         and previous.type == token.STRING
1583     ):
1584         return STRING_PRIORITY
1585
1586     if (
1587         leaf.type == token.NAME
1588         and leaf.value == "for"
1589         and leaf.parent
1590         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1591     ):
1592         return COMPREHENSION_PRIORITY
1593
1594     if (
1595         leaf.type == token.NAME
1596         and leaf.value == "if"
1597         and leaf.parent
1598         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1599     ):
1600         return COMPREHENSION_PRIORITY
1601
1602     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1603         return LOGIC_PRIORITY
1604
1605     return 0
1606
1607
1608 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1609     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1610
1611     Higher numbers are higher priority.
1612     """
1613     return max(
1614         is_split_before_delimiter(leaf, previous),
1615         is_split_after_delimiter(leaf, previous),
1616     )
1617
1618
1619 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1620     """Clean the prefix of the `leaf` and generate comments from it, if any.
1621
1622     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1623     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1624     move because it does away with modifying the grammar to include all the
1625     possible places in which comments can be placed.
1626
1627     The sad consequence for us though is that comments don't "belong" anywhere.
1628     This is why this function generates simple parentless Leaf objects for
1629     comments.  We simply don't know what the correct parent should be.
1630
1631     No matter though, we can live without this.  We really only need to
1632     differentiate between inline and standalone comments.  The latter don't
1633     share the line with any code.
1634
1635     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1636     are emitted with a fake STANDALONE_COMMENT token identifier.
1637     """
1638     p = leaf.prefix
1639     if not p:
1640         return
1641
1642     if "#" not in p:
1643         return
1644
1645     consumed = 0
1646     nlines = 0
1647     for index, line in enumerate(p.split("\n")):
1648         consumed += len(line) + 1  # adding the length of the split '\n'
1649         line = line.lstrip()
1650         if not line:
1651             nlines += 1
1652         if not line.startswith("#"):
1653             continue
1654
1655         if index == 0 and leaf.type != token.ENDMARKER:
1656             comment_type = token.COMMENT  # simple trailing comment
1657         else:
1658             comment_type = STANDALONE_COMMENT
1659         comment = make_comment(line)
1660         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1661
1662         if comment in {"# fmt: on", "# yapf: enable"}:
1663             raise FormatOn(consumed)
1664
1665         if comment in {"# fmt: off", "# yapf: disable"}:
1666             if comment_type == STANDALONE_COMMENT:
1667                 raise FormatOff(consumed)
1668
1669             prev = preceding_leaf(leaf)
1670             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1671                 raise FormatOff(consumed)
1672
1673         nlines = 0
1674
1675
1676 def make_comment(content: str) -> str:
1677     """Return a consistently formatted comment from the given `content` string.
1678
1679     All comments (except for "##", "#!", "#:") should have a single space between
1680     the hash sign and the content.
1681
1682     If `content` didn't start with a hash sign, one is provided.
1683     """
1684     content = content.rstrip()
1685     if not content:
1686         return "#"
1687
1688     if content[0] == "#":
1689         content = content[1:]
1690     if content and content[0] not in " !:#":
1691         content = " " + content
1692     return "#" + content
1693
1694
1695 def split_line(
1696     line: Line, line_length: int, inner: bool = False, py36: bool = False
1697 ) -> Iterator[Line]:
1698     """Split a `line` into potentially many lines.
1699
1700     They should fit in the allotted `line_length` but might not be able to.
1701     `inner` signifies that there were a pair of brackets somewhere around the
1702     current `line`, possibly transitively. This means we can fallback to splitting
1703     by delimiters if the LHS/RHS don't yield any results.
1704
1705     If `py36` is True, splitting may generate syntax that is only compatible
1706     with Python 3.6 and later.
1707     """
1708     if isinstance(line, UnformattedLines) or line.is_comment:
1709         yield line
1710         return
1711
1712     line_str = str(line).strip("\n")
1713     if (
1714         len(line_str) <= line_length
1715         and "\n" not in line_str  # multiline strings
1716         and not line.contains_standalone_comments()
1717     ):
1718         yield line
1719         return
1720
1721     split_funcs: List[SplitFunc]
1722     if line.is_def:
1723         split_funcs = [left_hand_split]
1724     elif line.inside_brackets:
1725         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1726     else:
1727         split_funcs = [right_hand_split]
1728     for split_func in split_funcs:
1729         # We are accumulating lines in `result` because we might want to abort
1730         # mission and return the original line in the end, or attempt a different
1731         # split altogether.
1732         result: List[Line] = []
1733         try:
1734             for l in split_func(line, py36):
1735                 if str(l).strip("\n") == line_str:
1736                     raise CannotSplit("Split function returned an unchanged result")
1737
1738                 result.extend(
1739                     split_line(l, line_length=line_length, inner=True, py36=py36)
1740                 )
1741         except CannotSplit as cs:
1742             continue
1743
1744         else:
1745             yield from result
1746             break
1747
1748     else:
1749         yield line
1750
1751
1752 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1753     """Split line into many lines, starting with the first matching bracket pair.
1754
1755     Note: this usually looks weird, only use this for function definitions.
1756     Prefer RHS otherwise.
1757     """
1758     head = Line(depth=line.depth)
1759     body = Line(depth=line.depth + 1, inside_brackets=True)
1760     tail = Line(depth=line.depth)
1761     tail_leaves: List[Leaf] = []
1762     body_leaves: List[Leaf] = []
1763     head_leaves: List[Leaf] = []
1764     current_leaves = head_leaves
1765     matching_bracket = None
1766     for leaf in line.leaves:
1767         if (
1768             current_leaves is body_leaves
1769             and leaf.type in CLOSING_BRACKETS
1770             and leaf.opening_bracket is matching_bracket
1771         ):
1772             current_leaves = tail_leaves if body_leaves else head_leaves
1773         current_leaves.append(leaf)
1774         if current_leaves is head_leaves:
1775             if leaf.type in OPENING_BRACKETS:
1776                 matching_bracket = leaf
1777                 current_leaves = body_leaves
1778     # Since body is a new indent level, remove spurious leading whitespace.
1779     if body_leaves:
1780         normalize_prefix(body_leaves[0], inside_brackets=True)
1781     # Build the new lines.
1782     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1783         for leaf in leaves:
1784             result.append(leaf, preformatted=True)
1785             for comment_after in line.comments_after(leaf):
1786                 result.append(comment_after, preformatted=True)
1787     bracket_split_succeeded_or_raise(head, body, tail)
1788     for result in (head, body, tail):
1789         if result:
1790             yield result
1791
1792
1793 def right_hand_split(
1794     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1795 ) -> Iterator[Line]:
1796     """Split line into many lines, starting with the last matching bracket pair."""
1797     head = Line(depth=line.depth)
1798     body = Line(depth=line.depth + 1, inside_brackets=True)
1799     tail = Line(depth=line.depth)
1800     tail_leaves: List[Leaf] = []
1801     body_leaves: List[Leaf] = []
1802     head_leaves: List[Leaf] = []
1803     current_leaves = tail_leaves
1804     opening_bracket = None
1805     closing_bracket = None
1806     for leaf in reversed(line.leaves):
1807         if current_leaves is body_leaves:
1808             if leaf is opening_bracket:
1809                 current_leaves = head_leaves if body_leaves else tail_leaves
1810         current_leaves.append(leaf)
1811         if current_leaves is tail_leaves:
1812             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1813                 opening_bracket = leaf.opening_bracket
1814                 closing_bracket = leaf
1815                 current_leaves = body_leaves
1816     tail_leaves.reverse()
1817     body_leaves.reverse()
1818     head_leaves.reverse()
1819     # Since body is a new indent level, remove spurious leading whitespace.
1820     if body_leaves:
1821         normalize_prefix(body_leaves[0], inside_brackets=True)
1822     elif not head_leaves:
1823         # No `head` and no `body` means the split failed. `tail` has all content.
1824         raise CannotSplit("No brackets found")
1825
1826     # Build the new lines.
1827     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1828         for leaf in leaves:
1829             result.append(leaf, preformatted=True)
1830             for comment_after in line.comments_after(leaf):
1831                 result.append(comment_after, preformatted=True)
1832     bracket_split_succeeded_or_raise(head, body, tail)
1833     assert opening_bracket and closing_bracket
1834     if (
1835         opening_bracket.type == token.LPAR
1836         and not opening_bracket.value
1837         and closing_bracket.type == token.RPAR
1838         and not closing_bracket.value
1839     ):
1840         # These parens were optional. If there aren't any delimiters or standalone
1841         # comments in the body, they were unnecessary and another split without
1842         # them should be attempted.
1843         if not (
1844             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1845         ):
1846             omit = {id(closing_bracket), *omit}
1847             yield from right_hand_split(line, py36=py36, omit=omit)
1848             return
1849
1850     ensure_visible(opening_bracket)
1851     ensure_visible(closing_bracket)
1852     for result in (head, body, tail):
1853         if result:
1854             yield result
1855
1856
1857 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1858     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1859
1860     Do nothing otherwise.
1861
1862     A left- or right-hand split is based on a pair of brackets. Content before
1863     (and including) the opening bracket is left on one line, content inside the
1864     brackets is put on a separate line, and finally content starting with and
1865     following the closing bracket is put on a separate line.
1866
1867     Those are called `head`, `body`, and `tail`, respectively. If the split
1868     produced the same line (all content in `head`) or ended up with an empty `body`
1869     and the `tail` is just the closing bracket, then it's considered failed.
1870     """
1871     tail_len = len(str(tail).strip())
1872     if not body:
1873         if tail_len == 0:
1874             raise CannotSplit("Splitting brackets produced the same line")
1875
1876         elif tail_len < 3:
1877             raise CannotSplit(
1878                 f"Splitting brackets on an empty body to save "
1879                 f"{tail_len} characters is not worth it"
1880             )
1881
1882
1883 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1884     """Normalize prefix of the first leaf in every line returned by `split_func`.
1885
1886     This is a decorator over relevant split functions.
1887     """
1888
1889     @wraps(split_func)
1890     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1891         for l in split_func(line, py36):
1892             normalize_prefix(l.leaves[0], inside_brackets=True)
1893             yield l
1894
1895     return split_wrapper
1896
1897
1898 @dont_increase_indentation
1899 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1900     """Split according to delimiters of the highest priority.
1901
1902     If `py36` is True, the split will add trailing commas also in function
1903     signatures that contain `*` and `**`.
1904     """
1905     try:
1906         last_leaf = line.leaves[-1]
1907     except IndexError:
1908         raise CannotSplit("Line empty")
1909
1910     delimiters = line.bracket_tracker.delimiters
1911     try:
1912         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1913             exclude={id(last_leaf)}
1914         )
1915     except ValueError:
1916         raise CannotSplit("No delimiters found")
1917
1918     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1919     lowest_depth = sys.maxsize
1920     trailing_comma_safe = True
1921
1922     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1923         """Append `leaf` to current line or to new line if appending impossible."""
1924         nonlocal current_line
1925         try:
1926             current_line.append_safe(leaf, preformatted=True)
1927         except ValueError as ve:
1928             yield current_line
1929
1930             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1931             current_line.append(leaf)
1932
1933     for leaf in line.leaves:
1934         yield from append_to_line(leaf)
1935
1936         for comment_after in line.comments_after(leaf):
1937             yield from append_to_line(comment_after)
1938
1939         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1940         if (
1941             leaf.bracket_depth == lowest_depth
1942             and is_vararg(leaf, within=VARARGS_PARENTS)
1943         ):
1944             trailing_comma_safe = trailing_comma_safe and py36
1945         leaf_priority = delimiters.get(id(leaf))
1946         if leaf_priority == delimiter_priority:
1947             yield current_line
1948
1949             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1950     if current_line:
1951         if (
1952             trailing_comma_safe
1953             and delimiter_priority == COMMA_PRIORITY
1954             and current_line.leaves[-1].type != token.COMMA
1955             and current_line.leaves[-1].type != STANDALONE_COMMENT
1956         ):
1957             current_line.append(Leaf(token.COMMA, ","))
1958         yield current_line
1959
1960
1961 @dont_increase_indentation
1962 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1963     """Split standalone comments from the rest of the line."""
1964     if not line.contains_standalone_comments(0):
1965         raise CannotSplit("Line does not have any standalone comments")
1966
1967     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1968
1969     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1970         """Append `leaf` to current line or to new line if appending impossible."""
1971         nonlocal current_line
1972         try:
1973             current_line.append_safe(leaf, preformatted=True)
1974         except ValueError as ve:
1975             yield current_line
1976
1977             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1978             current_line.append(leaf)
1979
1980     for leaf in line.leaves:
1981         yield from append_to_line(leaf)
1982
1983         for comment_after in line.comments_after(leaf):
1984             yield from append_to_line(comment_after)
1985
1986     if current_line:
1987         yield current_line
1988
1989
1990 def is_import(leaf: Leaf) -> bool:
1991     """Return True if the given leaf starts an import statement."""
1992     p = leaf.parent
1993     t = leaf.type
1994     v = leaf.value
1995     return bool(
1996         t == token.NAME
1997         and (
1998             (v == "import" and p and p.type == syms.import_name)
1999             or (v == "from" and p and p.type == syms.import_from)
2000         )
2001     )
2002
2003
2004 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2005     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2006     else.
2007
2008     Note: don't use backslashes for formatting or you'll lose your voting rights.
2009     """
2010     if not inside_brackets:
2011         spl = leaf.prefix.split("#")
2012         if "\\" not in spl[0]:
2013             nl_count = spl[-1].count("\n")
2014             if len(spl) > 1:
2015                 nl_count -= 1
2016             leaf.prefix = "\n" * nl_count
2017             return
2018
2019     leaf.prefix = ""
2020
2021
2022 def normalize_string_quotes(leaf: Leaf) -> None:
2023     """Prefer double quotes but only if it doesn't cause more escaping.
2024
2025     Adds or removes backslashes as appropriate. Doesn't parse and fix
2026     strings nested in f-strings (yet).
2027
2028     Note: Mutates its argument.
2029     """
2030     value = leaf.value.lstrip("furbFURB")
2031     if value[:3] == '"""':
2032         return
2033
2034     elif value[:3] == "'''":
2035         orig_quote = "'''"
2036         new_quote = '"""'
2037     elif value[0] == '"':
2038         orig_quote = '"'
2039         new_quote = "'"
2040     else:
2041         orig_quote = "'"
2042         new_quote = '"'
2043     first_quote_pos = leaf.value.find(orig_quote)
2044     if first_quote_pos == -1:
2045         return  # There's an internal error
2046
2047     prefix = leaf.value[:first_quote_pos]
2048     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2049     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2050     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2051     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2052     if "r" in prefix.casefold():
2053         if unescaped_new_quote.search(body):
2054             # There's at least one unescaped new_quote in this raw string
2055             # so converting is impossible
2056             return
2057
2058         # Do not introduce or remove backslashes in raw strings
2059         new_body = body
2060     else:
2061         # remove unnecessary quotes
2062         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2063         if body != new_body:
2064             # Consider the string without unnecessary quotes as the original
2065             body = new_body
2066             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2067         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2068         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2069     if new_quote == '"""' and new_body[-1] == '"':
2070         # edge case:
2071         new_body = new_body[:-1] + '\\"'
2072     orig_escape_count = body.count("\\")
2073     new_escape_count = new_body.count("\\")
2074     if new_escape_count > orig_escape_count:
2075         return  # Do not introduce more escaping
2076
2077     if new_escape_count == orig_escape_count and orig_quote == '"':
2078         return  # Prefer double quotes
2079
2080     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2081
2082
2083 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2084     """Make existing optional parentheses invisible or create new ones.
2085
2086     Standardizes on visible parentheses for single-element tuples, and keeps
2087     existing visible parentheses for other tuples and generator expressions.
2088     """
2089     check_lpar = False
2090     for child in list(node.children):
2091         if check_lpar:
2092             if child.type == syms.atom:
2093                 if not (
2094                     is_empty_tuple(child)
2095                     or is_one_tuple(child)
2096                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2097                 ):
2098                     first = child.children[0]
2099                     last = child.children[-1]
2100                     if first.type == token.LPAR and last.type == token.RPAR:
2101                         # make parentheses invisible
2102                         first.value = ""  # type: ignore
2103                         last.value = ""  # type: ignore
2104             elif is_one_tuple(child):
2105                 # wrap child in visible parentheses
2106                 lpar = Leaf(token.LPAR, "(")
2107                 rpar = Leaf(token.RPAR, ")")
2108                 index = child.remove() or 0
2109                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2110             else:
2111                 # wrap child in invisible parentheses
2112                 lpar = Leaf(token.LPAR, "")
2113                 rpar = Leaf(token.RPAR, "")
2114                 index = child.remove() or 0
2115                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2116
2117         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2118
2119
2120 def is_empty_tuple(node: LN) -> bool:
2121     """Return True if `node` holds an empty tuple."""
2122     return (
2123         node.type == syms.atom
2124         and len(node.children) == 2
2125         and node.children[0].type == token.LPAR
2126         and node.children[1].type == token.RPAR
2127     )
2128
2129
2130 def is_one_tuple(node: LN) -> bool:
2131     """Return True if `node` holds a tuple with one element, with or without parens."""
2132     if node.type == syms.atom:
2133         if len(node.children) != 3:
2134             return False
2135
2136         lpar, gexp, rpar = node.children
2137         if not (
2138             lpar.type == token.LPAR
2139             and gexp.type == syms.testlist_gexp
2140             and rpar.type == token.RPAR
2141         ):
2142             return False
2143
2144         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2145
2146     return (
2147         node.type in IMPLICIT_TUPLE
2148         and len(node.children) == 2
2149         and node.children[1].type == token.COMMA
2150     )
2151
2152
2153 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2154     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2155
2156     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2157     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2158     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2159     generalizations (PEP 448).
2160     """
2161     if leaf.type not in STARS or not leaf.parent:
2162         return False
2163
2164     p = leaf.parent
2165     if p.type == syms.star_expr:
2166         # Star expressions are also used as assignment targets in extended
2167         # iterable unpacking (PEP 3132).  See what its parent is instead.
2168         if not p.parent:
2169             return False
2170
2171         p = p.parent
2172
2173     return p.type in within
2174
2175
2176 def max_delimiter_priority_in_atom(node: LN) -> int:
2177     if node.type != syms.atom:
2178         return 0
2179
2180     first = node.children[0]
2181     last = node.children[-1]
2182     if not (first.type == token.LPAR and last.type == token.RPAR):
2183         return 0
2184
2185     bt = BracketTracker()
2186     for c in node.children[1:-1]:
2187         if isinstance(c, Leaf):
2188             bt.mark(c)
2189         else:
2190             for leaf in c.leaves():
2191                 bt.mark(leaf)
2192     try:
2193         return bt.max_delimiter_priority()
2194
2195     except ValueError:
2196         return 0
2197
2198
2199 def ensure_visible(leaf: Leaf) -> None:
2200     """Make sure parentheses are visible.
2201
2202     They could be invisible as part of some statements (see
2203     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2204     """
2205     if leaf.type == token.LPAR:
2206         leaf.value = "("
2207     elif leaf.type == token.RPAR:
2208         leaf.value = ")"
2209
2210
2211 def is_python36(node: Node) -> bool:
2212     """Return True if the current file is using Python 3.6+ features.
2213
2214     Currently looking for:
2215     - f-strings; and
2216     - trailing commas after * or ** in function signatures.
2217     """
2218     for n in node.pre_order():
2219         if n.type == token.STRING:
2220             value_head = n.value[:2]  # type: ignore
2221             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2222                 return True
2223
2224         elif (
2225             n.type == syms.typedargslist
2226             and n.children
2227             and n.children[-1].type == token.COMMA
2228         ):
2229             for ch in n.children:
2230                 if ch.type in STARS:
2231                     return True
2232
2233     return False
2234
2235
2236 PYTHON_EXTENSIONS = {".py"}
2237 BLACKLISTED_DIRECTORIES = {
2238     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2239 }
2240
2241
2242 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2243     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2244     and have one of the PYTHON_EXTENSIONS.
2245     """
2246     for child in path.iterdir():
2247         if child.is_dir():
2248             if child.name in BLACKLISTED_DIRECTORIES:
2249                 continue
2250
2251             yield from gen_python_files_in_dir(child)
2252
2253         elif child.suffix in PYTHON_EXTENSIONS:
2254             yield child
2255
2256
2257 @dataclass
2258 class Report:
2259     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2260     check: bool = False
2261     quiet: bool = False
2262     change_count: int = 0
2263     same_count: int = 0
2264     failure_count: int = 0
2265
2266     def done(self, src: Path, changed: Changed) -> None:
2267         """Increment the counter for successful reformatting. Write out a message."""
2268         if changed is Changed.YES:
2269             reformatted = "would reformat" if self.check else "reformatted"
2270             if not self.quiet:
2271                 out(f"{reformatted} {src}")
2272             self.change_count += 1
2273         else:
2274             if not self.quiet:
2275                 if changed is Changed.NO:
2276                     msg = f"{src} already well formatted, good job."
2277                 else:
2278                     msg = f"{src} wasn't modified on disk since last run."
2279                 out(msg, bold=False)
2280             self.same_count += 1
2281
2282     def failed(self, src: Path, message: str) -> None:
2283         """Increment the counter for failed reformatting. Write out a message."""
2284         err(f"error: cannot format {src}: {message}")
2285         self.failure_count += 1
2286
2287     @property
2288     def return_code(self) -> int:
2289         """Return the exit code that the app should use.
2290
2291         This considers the current state of changed files and failures:
2292         - if there were any failures, return 123;
2293         - if any files were changed and --check is being used, return 1;
2294         - otherwise return 0.
2295         """
2296         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2297         # 126 we have special returncodes reserved by the shell.
2298         if self.failure_count:
2299             return 123
2300
2301         elif self.change_count and self.check:
2302             return 1
2303
2304         return 0
2305
2306     def __str__(self) -> str:
2307         """Render a color report of the current state.
2308
2309         Use `click.unstyle` to remove colors.
2310         """
2311         if self.check:
2312             reformatted = "would be reformatted"
2313             unchanged = "would be left unchanged"
2314             failed = "would fail to reformat"
2315         else:
2316             reformatted = "reformatted"
2317             unchanged = "left unchanged"
2318             failed = "failed to reformat"
2319         report = []
2320         if self.change_count:
2321             s = "s" if self.change_count > 1 else ""
2322             report.append(
2323                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2324             )
2325         if self.same_count:
2326             s = "s" if self.same_count > 1 else ""
2327             report.append(f"{self.same_count} file{s} {unchanged}")
2328         if self.failure_count:
2329             s = "s" if self.failure_count > 1 else ""
2330             report.append(
2331                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2332             )
2333         return ", ".join(report) + "."
2334
2335
2336 def assert_equivalent(src: str, dst: str) -> None:
2337     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2338
2339     import ast
2340     import traceback
2341
2342     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2343         """Simple visitor generating strings to compare ASTs by content."""
2344         yield f"{'  ' * depth}{node.__class__.__name__}("
2345
2346         for field in sorted(node._fields):
2347             try:
2348                 value = getattr(node, field)
2349             except AttributeError:
2350                 continue
2351
2352             yield f"{'  ' * (depth+1)}{field}="
2353
2354             if isinstance(value, list):
2355                 for item in value:
2356                     if isinstance(item, ast.AST):
2357                         yield from _v(item, depth + 2)
2358
2359             elif isinstance(value, ast.AST):
2360                 yield from _v(value, depth + 2)
2361
2362             else:
2363                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2364
2365         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2366
2367     try:
2368         src_ast = ast.parse(src)
2369     except Exception as exc:
2370         major, minor = sys.version_info[:2]
2371         raise AssertionError(
2372             f"cannot use --safe with this file; failed to parse source file "
2373             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2374             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2375         )
2376
2377     try:
2378         dst_ast = ast.parse(dst)
2379     except Exception as exc:
2380         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2381         raise AssertionError(
2382             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2383             f"Please report a bug on https://github.com/ambv/black/issues.  "
2384             f"This invalid output might be helpful: {log}"
2385         ) from None
2386
2387     src_ast_str = "\n".join(_v(src_ast))
2388     dst_ast_str = "\n".join(_v(dst_ast))
2389     if src_ast_str != dst_ast_str:
2390         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2391         raise AssertionError(
2392             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2393             f"the source.  "
2394             f"Please report a bug on https://github.com/ambv/black/issues.  "
2395             f"This diff might be helpful: {log}"
2396         ) from None
2397
2398
2399 def assert_stable(src: str, dst: str, line_length: int) -> None:
2400     """Raise AssertionError if `dst` reformats differently the second time."""
2401     newdst = format_str(dst, line_length=line_length)
2402     if dst != newdst:
2403         log = dump_to_file(
2404             diff(src, dst, "source", "first pass"),
2405             diff(dst, newdst, "first pass", "second pass"),
2406         )
2407         raise AssertionError(
2408             f"INTERNAL ERROR: Black produced different code on the second pass "
2409             f"of the formatter.  "
2410             f"Please report a bug on https://github.com/ambv/black/issues.  "
2411             f"This diff might be helpful: {log}"
2412         ) from None
2413
2414
2415 def dump_to_file(*output: str) -> str:
2416     """Dump `output` to a temporary file. Return path to the file."""
2417     import tempfile
2418
2419     with tempfile.NamedTemporaryFile(
2420         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2421     ) as f:
2422         for lines in output:
2423             f.write(lines)
2424             if lines and lines[-1] != "\n":
2425                 f.write("\n")
2426     return f.name
2427
2428
2429 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2430     """Return a unified diff string between strings `a` and `b`."""
2431     import difflib
2432
2433     a_lines = [line + "\n" for line in a.split("\n")]
2434     b_lines = [line + "\n" for line in b.split("\n")]
2435     return "".join(
2436         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2437     )
2438
2439
2440 def cancel(tasks: List[asyncio.Task]) -> None:
2441     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2442     err("Aborted!")
2443     for task in tasks:
2444         task.cancel()
2445
2446
2447 def shutdown(loop: BaseEventLoop) -> None:
2448     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2449     try:
2450         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2451         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2452         if not to_cancel:
2453             return
2454
2455         for task in to_cancel:
2456             task.cancel()
2457         loop.run_until_complete(
2458             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2459         )
2460     finally:
2461         # `concurrent.futures.Future` objects cannot be cancelled once they
2462         # are already running. There might be some when the `shutdown()` happened.
2463         # Silence their logger's spew about the event loop being closed.
2464         cf_logger = logging.getLogger("concurrent.futures")
2465         cf_logger.setLevel(logging.CRITICAL)
2466         loop.close()
2467
2468
2469 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2470     """Replace `regex` with `replacement` twice on `original`.
2471
2472     This is used by string normalization to perform replaces on
2473     overlapping matches.
2474     """
2475     return regex.sub(replacement, regex.sub(replacement, original))
2476
2477
2478 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2479 CACHE_FILE = CACHE_DIR / "cache.pickle"
2480
2481
2482 def read_cache() -> Cache:
2483     """Read the cache if it exists and is well formed.
2484
2485     If it is not well formed, the call to write_cache later should resolve the issue.
2486     """
2487     if not CACHE_FILE.exists():
2488         return {}
2489
2490     with CACHE_FILE.open("rb") as fobj:
2491         try:
2492             cache: Cache = pickle.load(fobj)
2493         except pickle.UnpicklingError:
2494             return {}
2495
2496     return cache
2497
2498
2499 def get_cache_info(path: Path) -> CacheInfo:
2500     """Return the information used to check if a file is already formatted or not."""
2501     stat = path.stat()
2502     return stat.st_mtime, stat.st_size
2503
2504
2505 def filter_cached(
2506     cache: Cache, sources: Iterable[Path]
2507 ) -> Tuple[List[Path], List[Path]]:
2508     """Split a list of paths into two.
2509
2510     The first list contains paths of files that modified on disk or are not in the
2511     cache. The other list contains paths to non-modified files.
2512     """
2513     todo, done = [], []
2514     for src in sources:
2515         src = src.resolve()
2516         if cache.get(src) != get_cache_info(src):
2517             todo.append(src)
2518         else:
2519             done.append(src)
2520     return todo, done
2521
2522
2523 def write_cache(cache: Cache, sources: List[Path]) -> None:
2524     """Update the cache file."""
2525     try:
2526         if not CACHE_DIR.exists():
2527             CACHE_DIR.mkdir(parents=True)
2528         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2529         with CACHE_FILE.open("wb") as fobj:
2530             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2531     except OSError:
2532         pass
2533
2534
2535 if __name__ == "__main__":
2536     main()