]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Revert `format_file_in_place()` and `format_stdin_to_stdout()` to return bools
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187     if check and diff:
188         exc = click.ClickException("Options --check and --diff are mutually exclusive")
189         exc.exit_code = 2
190         raise exc
191
192     if check:
193         write_back = WriteBack.NO
194     elif diff:
195         write_back = WriteBack.DIFF
196     else:
197         write_back = WriteBack.YES
198     if len(sources) == 0:
199         ctx.exit(0)
200         return
201
202     elif len(sources) == 1:
203         return_code = run_single_file_mode(
204             line_length, check, fast, quiet, write_back, sources[0]
205         )
206     else:
207         return_code = run_multi_file_mode(line_length, fast, quiet, write_back, sources)
208     ctx.exit(return_code)
209
210
211 def run_single_file_mode(
212     line_length: int,
213     check: bool,
214     fast: bool,
215     quiet: bool,
216     write_back: WriteBack,
217     src: Path,
218 ) -> int:
219     report = Report(check=check, quiet=quiet)
220     try:
221         changed = Changed.NO
222         if not src.is_file() and str(src) == "-":
223             if format_stdin_to_stdout(
224                 line_length=line_length, fast=fast, write_back=write_back
225             ):
226                 changed = Changed.YES
227         else:
228             cache: Cache = {}
229             if write_back != WriteBack.DIFF:
230                 cache = read_cache()
231                 src = src.resolve()
232                 if src in cache and cache[src] == get_cache_info(src):
233                     changed = Changed.CACHED
234             if (
235                 changed is not Changed.CACHED
236                 and format_file_in_place(
237                     src, line_length=line_length, fast=fast, write_back=write_back
238                 )
239             ):
240                 changed = Changed.YES
241             if write_back != WriteBack.DIFF and changed is not Changed.NO:
242                 write_cache(cache, [src])
243         report.done(src, changed)
244     except Exception as exc:
245         report.failed(src, str(exc))
246     return report.return_code
247
248
249 def run_multi_file_mode(
250     line_length: int,
251     fast: bool,
252     quiet: bool,
253     write_back: WriteBack,
254     sources: List[Path],
255 ) -> int:
256     loop = asyncio.get_event_loop()
257     executor = ProcessPoolExecutor(max_workers=os.cpu_count())
258     return_code = 1
259     try:
260         return_code = loop.run_until_complete(
261             schedule_formatting(
262                 sources, line_length, write_back, fast, quiet, loop, executor
263             )
264         )
265     finally:
266         shutdown(loop)
267         return return_code
268
269
270 async def schedule_formatting(
271     sources: List[Path],
272     line_length: int,
273     write_back: WriteBack,
274     fast: bool,
275     quiet: bool,
276     loop: BaseEventLoop,
277     executor: Executor,
278 ) -> int:
279     """Run formatting of `sources` in parallel using the provided `executor`.
280
281     (Use ProcessPoolExecutors for actual parallelism.)
282
283     `line_length`, `write_back`, and `fast` options are passed to
284     :func:`format_file_in_place`.
285     """
286     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
287     cache: Cache = {}
288     if write_back != WriteBack.DIFF:
289         cache = read_cache()
290         sources, cached = filter_cached(cache, sources)
291         for src in cached:
292             report.done(src, Changed.CACHED)
293     cancelled = []
294     formatted = []
295     if sources:
296         lock = None
297         if write_back == WriteBack.DIFF:
298             # For diff output, we need locks to ensure we don't interleave output
299             # from different processes.
300             manager = Manager()
301             lock = manager.Lock()
302         tasks = {
303             src: loop.run_in_executor(
304                 executor, format_file_in_place, src, line_length, fast, write_back, lock
305             )
306             for src in sources
307         }
308         _task_values = list(tasks.values())
309         loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
310         loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
311         await asyncio.wait(_task_values)
312         for src, task in tasks.items():
313             if not task.done():
314                 report.failed(src, "timed out, cancelling")
315                 task.cancel()
316                 cancelled.append(task)
317             elif task.cancelled():
318                 cancelled.append(task)
319             elif task.exception():
320                 report.failed(src, str(task.exception()))
321             else:
322                 formatted.append(src)
323                 report.done(src, Changed.YES if task.result() else Changed.NO)
324
325     if cancelled:
326         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
327     elif not quiet:
328         out("All done! ✨ 🍰 ✨")
329     if not quiet:
330         click.echo(str(report))
331
332     if write_back != WriteBack.DIFF and formatted:
333         write_cache(cache, formatted)
334
335     return report.return_code
336
337
338 def format_file_in_place(
339     src: Path,
340     line_length: int,
341     fast: bool,
342     write_back: WriteBack = WriteBack.NO,
343     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
344 ) -> bool:
345     """Format file under `src` path. Return True if changed.
346
347     If `write_back` is True, write reformatted code back to stdout.
348     `line_length` and `fast` options are passed to :func:`format_file_contents`.
349     """
350
351     with tokenize.open(src) as src_buffer:
352         src_contents = src_buffer.read()
353     try:
354         dst_contents = format_file_contents(
355             src_contents, line_length=line_length, fast=fast
356         )
357     except NothingChanged:
358         return False
359
360     if write_back == write_back.YES:
361         with open(src, "w", encoding=src_buffer.encoding) as f:
362             f.write(dst_contents)
363     elif write_back == write_back.DIFF:
364         src_name = f"{src.name}  (original)"
365         dst_name = f"{src.name}  (formatted)"
366         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
367         if lock:
368             lock.acquire()
369         try:
370             sys.stdout.write(diff_contents)
371         finally:
372             if lock:
373                 lock.release()
374     return True
375
376
377 def format_stdin_to_stdout(
378     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
379 ) -> bool:
380     """Format file on stdin. Return True if changed.
381
382     If `write_back` is True, write reformatted code back to stdout.
383     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
384     """
385     src = sys.stdin.read()
386     dst = src
387     try:
388         dst = format_file_contents(src, line_length=line_length, fast=fast)
389         return True
390
391     except NothingChanged:
392         return False
393
394     finally:
395         if write_back == WriteBack.YES:
396             sys.stdout.write(dst)
397         elif write_back == WriteBack.DIFF:
398             src_name = "<stdin>  (original)"
399             dst_name = "<stdin>  (formatted)"
400             sys.stdout.write(diff(src, dst, src_name, dst_name))
401
402
403 def format_file_contents(
404     src_contents: str, line_length: int, fast: bool
405 ) -> FileContent:
406     """Reformat contents a file and return new contents.
407
408     If `fast` is False, additionally confirm that the reformatted code is
409     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
410     `line_length` is passed to :func:`format_str`.
411     """
412     if src_contents.strip() == "":
413         raise NothingChanged
414
415     dst_contents = format_str(src_contents, line_length=line_length)
416     if src_contents == dst_contents:
417         raise NothingChanged
418
419     if not fast:
420         assert_equivalent(src_contents, dst_contents)
421         assert_stable(src_contents, dst_contents, line_length=line_length)
422     return dst_contents
423
424
425 def format_str(src_contents: str, line_length: int) -> FileContent:
426     """Reformat a string and return new contents.
427
428     `line_length` determines how many characters per line are allowed.
429     """
430     src_node = lib2to3_parse(src_contents)
431     dst_contents = ""
432     lines = LineGenerator()
433     elt = EmptyLineTracker()
434     py36 = is_python36(src_node)
435     empty_line = Line()
436     after = 0
437     for current_line in lines.visit(src_node):
438         for _ in range(after):
439             dst_contents += str(empty_line)
440         before, after = elt.maybe_empty_lines(current_line)
441         for _ in range(before):
442             dst_contents += str(empty_line)
443         for line in split_line(current_line, line_length=line_length, py36=py36):
444             dst_contents += str(line)
445     return dst_contents
446
447
448 GRAMMARS = [
449     pygram.python_grammar_no_print_statement_no_exec_statement,
450     pygram.python_grammar_no_print_statement,
451     pygram.python_grammar_no_exec_statement,
452     pygram.python_grammar,
453 ]
454
455
456 def lib2to3_parse(src_txt: str) -> Node:
457     """Given a string with source, return the lib2to3 Node."""
458     grammar = pygram.python_grammar_no_print_statement
459     if src_txt[-1] != "\n":
460         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
461         src_txt += nl
462     for grammar in GRAMMARS:
463         drv = driver.Driver(grammar, pytree.convert)
464         try:
465             result = drv.parse_string(src_txt, True)
466             break
467
468         except ParseError as pe:
469             lineno, column = pe.context[1]
470             lines = src_txt.splitlines()
471             try:
472                 faulty_line = lines[lineno - 1]
473             except IndexError:
474                 faulty_line = "<line number missing in source>"
475             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
476     else:
477         raise exc from None
478
479     if isinstance(result, Leaf):
480         result = Node(syms.file_input, [result])
481     return result
482
483
484 def lib2to3_unparse(node: Node) -> str:
485     """Given a lib2to3 node, return its string representation."""
486     code = str(node)
487     return code
488
489
490 T = TypeVar("T")
491
492
493 class Visitor(Generic[T]):
494     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
495
496     def visit(self, node: LN) -> Iterator[T]:
497         """Main method to visit `node` and its children.
498
499         It tries to find a `visit_*()` method for the given `node.type`, like
500         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
501         If no dedicated `visit_*()` method is found, chooses `visit_default()`
502         instead.
503
504         Then yields objects of type `T` from the selected visitor.
505         """
506         if node.type < 256:
507             name = token.tok_name[node.type]
508         else:
509             name = type_repr(node.type)
510         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
511
512     def visit_default(self, node: LN) -> Iterator[T]:
513         """Default `visit_*()` implementation. Recurses to children of `node`."""
514         if isinstance(node, Node):
515             for child in node.children:
516                 yield from self.visit(child)
517
518
519 @dataclass
520 class DebugVisitor(Visitor[T]):
521     tree_depth: int = 0
522
523     def visit_default(self, node: LN) -> Iterator[T]:
524         indent = " " * (2 * self.tree_depth)
525         if isinstance(node, Node):
526             _type = type_repr(node.type)
527             out(f"{indent}{_type}", fg="yellow")
528             self.tree_depth += 1
529             for child in node.children:
530                 yield from self.visit(child)
531
532             self.tree_depth -= 1
533             out(f"{indent}/{_type}", fg="yellow", bold=False)
534         else:
535             _type = token.tok_name.get(node.type, str(node.type))
536             out(f"{indent}{_type}", fg="blue", nl=False)
537             if node.prefix:
538                 # We don't have to handle prefixes for `Node` objects since
539                 # that delegates to the first child anyway.
540                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
541             out(f" {node.value!r}", fg="blue", bold=False)
542
543     @classmethod
544     def show(cls, code: str) -> None:
545         """Pretty-print the lib2to3 AST of a given string of `code`.
546
547         Convenience method for debugging.
548         """
549         v: DebugVisitor[None] = DebugVisitor()
550         list(v.visit(lib2to3_parse(code)))
551
552
553 KEYWORDS = set(keyword.kwlist)
554 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
555 FLOW_CONTROL = {"return", "raise", "break", "continue"}
556 STATEMENT = {
557     syms.if_stmt,
558     syms.while_stmt,
559     syms.for_stmt,
560     syms.try_stmt,
561     syms.except_clause,
562     syms.with_stmt,
563     syms.funcdef,
564     syms.classdef,
565 }
566 STANDALONE_COMMENT = 153
567 LOGIC_OPERATORS = {"and", "or"}
568 COMPARATORS = {
569     token.LESS,
570     token.GREATER,
571     token.EQEQUAL,
572     token.NOTEQUAL,
573     token.LESSEQUAL,
574     token.GREATEREQUAL,
575 }
576 MATH_OPERATORS = {
577     token.PLUS,
578     token.MINUS,
579     token.STAR,
580     token.SLASH,
581     token.VBAR,
582     token.AMPER,
583     token.PERCENT,
584     token.CIRCUMFLEX,
585     token.TILDE,
586     token.LEFTSHIFT,
587     token.RIGHTSHIFT,
588     token.DOUBLESTAR,
589     token.DOUBLESLASH,
590 }
591 STARS = {token.STAR, token.DOUBLESTAR}
592 VARARGS_PARENTS = {
593     syms.arglist,
594     syms.argument,  # double star in arglist
595     syms.trailer,  # single argument to call
596     syms.typedargslist,
597     syms.varargslist,  # lambdas
598 }
599 UNPACKING_PARENTS = {
600     syms.atom,  # single element of a list or set literal
601     syms.dictsetmaker,
602     syms.listmaker,
603     syms.testlist_gexp,
604 }
605 COMPREHENSION_PRIORITY = 20
606 COMMA_PRIORITY = 10
607 LOGIC_PRIORITY = 5
608 STRING_PRIORITY = 4
609 COMPARATOR_PRIORITY = 3
610 MATH_PRIORITY = 1
611
612
613 @dataclass
614 class BracketTracker:
615     """Keeps track of brackets on a line."""
616
617     depth: int = 0
618     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
619     delimiters: Dict[LeafID, Priority] = Factory(dict)
620     previous: Optional[Leaf] = None
621
622     def mark(self, leaf: Leaf) -> None:
623         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
624
625         All leaves receive an int `bracket_depth` field that stores how deep
626         within brackets a given leaf is. 0 means there are no enclosing brackets
627         that started on this line.
628
629         If a leaf is itself a closing bracket, it receives an `opening_bracket`
630         field that it forms a pair with. This is a one-directional link to
631         avoid reference cycles.
632
633         If a leaf is a delimiter (a token on which Black can split the line if
634         needed) and it's on depth 0, its `id()` is stored in the tracker's
635         `delimiters` field.
636         """
637         if leaf.type == token.COMMENT:
638             return
639
640         if leaf.type in CLOSING_BRACKETS:
641             self.depth -= 1
642             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
643             leaf.opening_bracket = opening_bracket
644         leaf.bracket_depth = self.depth
645         if self.depth == 0:
646             delim = is_split_before_delimiter(leaf, self.previous)
647             if delim and self.previous is not None:
648                 self.delimiters[id(self.previous)] = delim
649             else:
650                 delim = is_split_after_delimiter(leaf, self.previous)
651                 if delim:
652                     self.delimiters[id(leaf)] = delim
653         if leaf.type in OPENING_BRACKETS:
654             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
655             self.depth += 1
656         self.previous = leaf
657
658     def any_open_brackets(self) -> bool:
659         """Return True if there is an yet unmatched open bracket on the line."""
660         return bool(self.bracket_match)
661
662     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
663         """Return the highest priority of a delimiter found on the line.
664
665         Values are consistent with what `is_delimiter()` returns.
666         Raises ValueError on no delimiters.
667         """
668         return max(v for k, v in self.delimiters.items() if k not in exclude)
669
670
671 @dataclass
672 class Line:
673     """Holds leaves and comments. Can be printed with `str(line)`."""
674
675     depth: int = 0
676     leaves: List[Leaf] = Factory(list)
677     comments: List[Tuple[Index, Leaf]] = Factory(list)
678     bracket_tracker: BracketTracker = Factory(BracketTracker)
679     inside_brackets: bool = False
680     has_for: bool = False
681     _for_loop_variable: bool = False
682
683     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
684         """Add a new `leaf` to the end of the line.
685
686         Unless `preformatted` is True, the `leaf` will receive a new consistent
687         whitespace prefix and metadata applied by :class:`BracketTracker`.
688         Trailing commas are maybe removed, unpacked for loop variables are
689         demoted from being delimiters.
690
691         Inline comments are put aside.
692         """
693         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
694         if not has_value:
695             return
696
697         if self.leaves and not preformatted:
698             # Note: at this point leaf.prefix should be empty except for
699             # imports, for which we only preserve newlines.
700             leaf.prefix += whitespace(leaf)
701         if self.inside_brackets or not preformatted:
702             self.maybe_decrement_after_for_loop_variable(leaf)
703             self.bracket_tracker.mark(leaf)
704             self.maybe_remove_trailing_comma(leaf)
705             self.maybe_increment_for_loop_variable(leaf)
706
707         if not self.append_comment(leaf):
708             self.leaves.append(leaf)
709
710     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
711         """Like :func:`append()` but disallow invalid standalone comment structure.
712
713         Raises ValueError when any `leaf` is appended after a standalone comment
714         or when a standalone comment is not the first leaf on the line.
715         """
716         if self.bracket_tracker.depth == 0:
717             if self.is_comment:
718                 raise ValueError("cannot append to standalone comments")
719
720             if self.leaves and leaf.type == STANDALONE_COMMENT:
721                 raise ValueError(
722                     "cannot append standalone comments to a populated line"
723                 )
724
725         self.append(leaf, preformatted=preformatted)
726
727     @property
728     def is_comment(self) -> bool:
729         """Is this line a standalone comment?"""
730         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
731
732     @property
733     def is_decorator(self) -> bool:
734         """Is this line a decorator?"""
735         return bool(self) and self.leaves[0].type == token.AT
736
737     @property
738     def is_import(self) -> bool:
739         """Is this an import line?"""
740         return bool(self) and is_import(self.leaves[0])
741
742     @property
743     def is_class(self) -> bool:
744         """Is this line a class definition?"""
745         return (
746             bool(self)
747             and self.leaves[0].type == token.NAME
748             and self.leaves[0].value == "class"
749         )
750
751     @property
752     def is_def(self) -> bool:
753         """Is this a function definition? (Also returns True for async defs.)"""
754         try:
755             first_leaf = self.leaves[0]
756         except IndexError:
757             return False
758
759         try:
760             second_leaf: Optional[Leaf] = self.leaves[1]
761         except IndexError:
762             second_leaf = None
763         return (
764             (first_leaf.type == token.NAME and first_leaf.value == "def")
765             or (
766                 first_leaf.type == token.ASYNC
767                 and second_leaf is not None
768                 and second_leaf.type == token.NAME
769                 and second_leaf.value == "def"
770             )
771         )
772
773     @property
774     def is_flow_control(self) -> bool:
775         """Is this line a flow control statement?
776
777         Those are `return`, `raise`, `break`, and `continue`.
778         """
779         return (
780             bool(self)
781             and self.leaves[0].type == token.NAME
782             and self.leaves[0].value in FLOW_CONTROL
783         )
784
785     @property
786     def is_yield(self) -> bool:
787         """Is this line a yield statement?"""
788         return (
789             bool(self)
790             and self.leaves[0].type == token.NAME
791             and self.leaves[0].value == "yield"
792         )
793
794     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
795         """If so, needs to be split before emitting."""
796         for leaf in self.leaves:
797             if leaf.type == STANDALONE_COMMENT:
798                 if leaf.bracket_depth <= depth_limit:
799                     return True
800
801         return False
802
803     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
804         """Remove trailing comma if there is one and it's safe."""
805         if not (
806             self.leaves
807             and self.leaves[-1].type == token.COMMA
808             and closing.type in CLOSING_BRACKETS
809         ):
810             return False
811
812         if closing.type == token.RBRACE:
813             self.remove_trailing_comma()
814             return True
815
816         if closing.type == token.RSQB:
817             comma = self.leaves[-1]
818             if comma.parent and comma.parent.type == syms.listmaker:
819                 self.remove_trailing_comma()
820                 return True
821
822         # For parens let's check if it's safe to remove the comma.  If the
823         # trailing one is the only one, we might mistakenly change a tuple
824         # into a different type by removing the comma.
825         depth = closing.bracket_depth + 1
826         commas = 0
827         opening = closing.opening_bracket
828         for _opening_index, leaf in enumerate(self.leaves):
829             if leaf is opening:
830                 break
831
832         else:
833             return False
834
835         for leaf in self.leaves[_opening_index + 1:]:
836             if leaf is closing:
837                 break
838
839             bracket_depth = leaf.bracket_depth
840             if bracket_depth == depth and leaf.type == token.COMMA:
841                 commas += 1
842                 if leaf.parent and leaf.parent.type == syms.arglist:
843                     commas += 1
844                     break
845
846         if commas > 1:
847             self.remove_trailing_comma()
848             return True
849
850         return False
851
852     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
853         """In a for loop, or comprehension, the variables are often unpacks.
854
855         To avoid splitting on the comma in this situation, increase the depth of
856         tokens between `for` and `in`.
857         """
858         if leaf.type == token.NAME and leaf.value == "for":
859             self.has_for = True
860             self.bracket_tracker.depth += 1
861             self._for_loop_variable = True
862             return True
863
864         return False
865
866     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
867         """See `maybe_increment_for_loop_variable` above for explanation."""
868         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
869             self.bracket_tracker.depth -= 1
870             self._for_loop_variable = False
871             return True
872
873         return False
874
875     def append_comment(self, comment: Leaf) -> bool:
876         """Add an inline or standalone comment to the line."""
877         if (
878             comment.type == STANDALONE_COMMENT
879             and self.bracket_tracker.any_open_brackets()
880         ):
881             comment.prefix = ""
882             return False
883
884         if comment.type != token.COMMENT:
885             return False
886
887         after = len(self.leaves) - 1
888         if after == -1:
889             comment.type = STANDALONE_COMMENT
890             comment.prefix = ""
891             return False
892
893         else:
894             self.comments.append((after, comment))
895             return True
896
897     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
898         """Generate comments that should appear directly after `leaf`."""
899         for _leaf_index, _leaf in enumerate(self.leaves):
900             if leaf is _leaf:
901                 break
902
903         else:
904             return
905
906         for index, comment_after in self.comments:
907             if _leaf_index == index:
908                 yield comment_after
909
910     def remove_trailing_comma(self) -> None:
911         """Remove the trailing comma and moves the comments attached to it."""
912         comma_index = len(self.leaves) - 1
913         for i in range(len(self.comments)):
914             comment_index, comment = self.comments[i]
915             if comment_index == comma_index:
916                 self.comments[i] = (comma_index - 1, comment)
917         self.leaves.pop()
918
919     def __str__(self) -> str:
920         """Render the line."""
921         if not self:
922             return "\n"
923
924         indent = "    " * self.depth
925         leaves = iter(self.leaves)
926         first = next(leaves)
927         res = f"{first.prefix}{indent}{first.value}"
928         for leaf in leaves:
929             res += str(leaf)
930         for _, comment in self.comments:
931             res += str(comment)
932         return res + "\n"
933
934     def __bool__(self) -> bool:
935         """Return True if the line has leaves or comments."""
936         return bool(self.leaves or self.comments)
937
938
939 class UnformattedLines(Line):
940     """Just like :class:`Line` but stores lines which aren't reformatted."""
941
942     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
943         """Just add a new `leaf` to the end of the lines.
944
945         The `preformatted` argument is ignored.
946
947         Keeps track of indentation `depth`, which is useful when the user
948         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
949         """
950         try:
951             list(generate_comments(leaf))
952         except FormatOn as f_on:
953             self.leaves.append(f_on.leaf_from_consumed(leaf))
954             raise
955
956         self.leaves.append(leaf)
957         if leaf.type == token.INDENT:
958             self.depth += 1
959         elif leaf.type == token.DEDENT:
960             self.depth -= 1
961
962     def __str__(self) -> str:
963         """Render unformatted lines from leaves which were added with `append()`.
964
965         `depth` is not used for indentation in this case.
966         """
967         if not self:
968             return "\n"
969
970         res = ""
971         for leaf in self.leaves:
972             res += str(leaf)
973         return res
974
975     def append_comment(self, comment: Leaf) -> bool:
976         """Not implemented in this class. Raises `NotImplementedError`."""
977         raise NotImplementedError("Unformatted lines don't store comments separately.")
978
979     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
980         """Does nothing and returns False."""
981         return False
982
983     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
984         """Does nothing and returns False."""
985         return False
986
987
988 @dataclass
989 class EmptyLineTracker:
990     """Provides a stateful method that returns the number of potential extra
991     empty lines needed before and after the currently processed line.
992
993     Note: this tracker works on lines that haven't been split yet.  It assumes
994     the prefix of the first leaf consists of optional newlines.  Those newlines
995     are consumed by `maybe_empty_lines()` and included in the computation.
996     """
997     previous_line: Optional[Line] = None
998     previous_after: int = 0
999     previous_defs: List[int] = Factory(list)
1000
1001     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1002         """Return the number of extra empty lines before and after the `current_line`.
1003
1004         This is for separating `def`, `async def` and `class` with extra empty
1005         lines (two on module-level), as well as providing an extra empty line
1006         after flow control keywords to make them more prominent.
1007         """
1008         if isinstance(current_line, UnformattedLines):
1009             return 0, 0
1010
1011         before, after = self._maybe_empty_lines(current_line)
1012         before -= self.previous_after
1013         self.previous_after = after
1014         self.previous_line = current_line
1015         return before, after
1016
1017     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1018         max_allowed = 1
1019         if current_line.depth == 0:
1020             max_allowed = 2
1021         if current_line.leaves:
1022             # Consume the first leaf's extra newlines.
1023             first_leaf = current_line.leaves[0]
1024             before = first_leaf.prefix.count("\n")
1025             before = min(before, max_allowed)
1026             first_leaf.prefix = ""
1027         else:
1028             before = 0
1029         depth = current_line.depth
1030         while self.previous_defs and self.previous_defs[-1] >= depth:
1031             self.previous_defs.pop()
1032             before = 1 if depth else 2
1033         is_decorator = current_line.is_decorator
1034         if is_decorator or current_line.is_def or current_line.is_class:
1035             if not is_decorator:
1036                 self.previous_defs.append(depth)
1037             if self.previous_line is None:
1038                 # Don't insert empty lines before the first line in the file.
1039                 return 0, 0
1040
1041             if self.previous_line and self.previous_line.is_decorator:
1042                 # Don't insert empty lines between decorators.
1043                 return 0, 0
1044
1045             newlines = 2
1046             if current_line.depth:
1047                 newlines -= 1
1048             return newlines, 0
1049
1050         if current_line.is_flow_control:
1051             return before, 1
1052
1053         if (
1054             self.previous_line
1055             and self.previous_line.is_import
1056             and not current_line.is_import
1057             and depth == self.previous_line.depth
1058         ):
1059             return (before or 1), 0
1060
1061         if (
1062             self.previous_line
1063             and self.previous_line.is_yield
1064             and (not current_line.is_yield or depth != self.previous_line.depth)
1065         ):
1066             return (before or 1), 0
1067
1068         return before, 0
1069
1070
1071 @dataclass
1072 class LineGenerator(Visitor[Line]):
1073     """Generates reformatted Line objects.  Empty lines are not emitted.
1074
1075     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1076     in ways that will no longer stringify to valid Python code on the tree.
1077     """
1078     current_line: Line = Factory(Line)
1079
1080     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1081         """Generate a line.
1082
1083         If the line is empty, only emit if it makes sense.
1084         If the line is too long, split it first and then generate.
1085
1086         If any lines were generated, set up a new current_line.
1087         """
1088         if not self.current_line:
1089             if self.current_line.__class__ == type:
1090                 self.current_line.depth += indent
1091             else:
1092                 self.current_line = type(depth=self.current_line.depth + indent)
1093             return  # Line is empty, don't emit. Creating a new one unnecessary.
1094
1095         complete_line = self.current_line
1096         self.current_line = type(depth=complete_line.depth + indent)
1097         yield complete_line
1098
1099     def visit(self, node: LN) -> Iterator[Line]:
1100         """Main method to visit `node` and its children.
1101
1102         Yields :class:`Line` objects.
1103         """
1104         if isinstance(self.current_line, UnformattedLines):
1105             # File contained `# fmt: off`
1106             yield from self.visit_unformatted(node)
1107
1108         else:
1109             yield from super().visit(node)
1110
1111     def visit_default(self, node: LN) -> Iterator[Line]:
1112         """Default `visit_*()` implementation. Recurses to children of `node`."""
1113         if isinstance(node, Leaf):
1114             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1115             try:
1116                 for comment in generate_comments(node):
1117                     if any_open_brackets:
1118                         # any comment within brackets is subject to splitting
1119                         self.current_line.append(comment)
1120                     elif comment.type == token.COMMENT:
1121                         # regular trailing comment
1122                         self.current_line.append(comment)
1123                         yield from self.line()
1124
1125                     else:
1126                         # regular standalone comment
1127                         yield from self.line()
1128
1129                         self.current_line.append(comment)
1130                         yield from self.line()
1131
1132             except FormatOff as f_off:
1133                 f_off.trim_prefix(node)
1134                 yield from self.line(type=UnformattedLines)
1135                 yield from self.visit(node)
1136
1137             except FormatOn as f_on:
1138                 # This only happens here if somebody says "fmt: on" multiple
1139                 # times in a row.
1140                 f_on.trim_prefix(node)
1141                 yield from self.visit_default(node)
1142
1143             else:
1144                 normalize_prefix(node, inside_brackets=any_open_brackets)
1145                 if node.type == token.STRING:
1146                     normalize_string_quotes(node)
1147                 if node.type not in WHITESPACE:
1148                     self.current_line.append(node)
1149         yield from super().visit_default(node)
1150
1151     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1152         """Increase indentation level, maybe yield a line."""
1153         # In blib2to3 INDENT never holds comments.
1154         yield from self.line(+1)
1155         yield from self.visit_default(node)
1156
1157     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1158         """Decrease indentation level, maybe yield a line."""
1159         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1160         yield from self.line(-1)
1161
1162     def visit_stmt(
1163         self, node: Node, keywords: Set[str], parens: Set[str]
1164     ) -> Iterator[Line]:
1165         """Visit a statement.
1166
1167         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1168         `def`, `with`, `class`, and `assert`.
1169
1170         The relevant Python language `keywords` for a given statement will be
1171         NAME leaves within it. This methods puts those on a separate line.
1172
1173         `parens` holds pairs of nodes where invisible parentheses should be put.
1174         Keys hold nodes after which opening parentheses should be put, values
1175         hold nodes before which closing parentheses should be put.
1176         """
1177         normalize_invisible_parens(node, parens_after=parens)
1178         for child in node.children:
1179             if child.type == token.NAME and child.value in keywords:  # type: ignore
1180                 yield from self.line()
1181
1182             yield from self.visit(child)
1183
1184     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1185         """Visit a statement without nested statements."""
1186         is_suite_like = node.parent and node.parent.type in STATEMENT
1187         if is_suite_like:
1188             yield from self.line(+1)
1189             yield from self.visit_default(node)
1190             yield from self.line(-1)
1191
1192         else:
1193             yield from self.line()
1194             yield from self.visit_default(node)
1195
1196     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1197         """Visit `async def`, `async for`, `async with`."""
1198         yield from self.line()
1199
1200         children = iter(node.children)
1201         for child in children:
1202             yield from self.visit(child)
1203
1204             if child.type == token.ASYNC:
1205                 break
1206
1207         internal_stmt = next(children)
1208         for child in internal_stmt.children:
1209             yield from self.visit(child)
1210
1211     def visit_decorators(self, node: Node) -> Iterator[Line]:
1212         """Visit decorators."""
1213         for child in node.children:
1214             yield from self.line()
1215             yield from self.visit(child)
1216
1217     def visit_import_from(self, node: Node) -> Iterator[Line]:
1218         """Visit import_from and maybe put invisible parentheses.
1219
1220         This is separate from `visit_stmt` because import statements don't
1221         support arbitrary atoms and thus handling of parentheses is custom.
1222         """
1223         check_lpar = False
1224         for index, child in enumerate(node.children):
1225             if check_lpar:
1226                 if child.type == token.LPAR:
1227                     # make parentheses invisible
1228                     child.value = ""  # type: ignore
1229                     node.children[-1].value = ""  # type: ignore
1230                 else:
1231                     # insert invisible parentheses
1232                     node.insert_child(index, Leaf(token.LPAR, ""))
1233                     node.append_child(Leaf(token.RPAR, ""))
1234                 break
1235
1236             check_lpar = (
1237                 child.type == token.NAME and child.value == "import"  # type: ignore
1238             )
1239
1240         for child in node.children:
1241             yield from self.visit(child)
1242
1243     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1244         """Remove a semicolon and put the other statement on a separate line."""
1245         yield from self.line()
1246
1247     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1248         """End of file. Process outstanding comments and end with a newline."""
1249         yield from self.visit_default(leaf)
1250         yield from self.line()
1251
1252     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1253         """Used when file contained a `# fmt: off`."""
1254         if isinstance(node, Node):
1255             for child in node.children:
1256                 yield from self.visit(child)
1257
1258         else:
1259             try:
1260                 self.current_line.append(node)
1261             except FormatOn as f_on:
1262                 f_on.trim_prefix(node)
1263                 yield from self.line()
1264                 yield from self.visit(node)
1265
1266             if node.type == token.ENDMARKER:
1267                 # somebody decided not to put a final `# fmt: on`
1268                 yield from self.line()
1269
1270     def __attrs_post_init__(self) -> None:
1271         """You are in a twisty little maze of passages."""
1272         v = self.visit_stmt
1273         Ø: Set[str] = set()
1274         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1275         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1276         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1277         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1278         self.visit_try_stmt = partial(
1279             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1280         )
1281         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1282         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1283         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1284         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1285         self.visit_async_funcdef = self.visit_async_stmt
1286         self.visit_decorated = self.visit_decorators
1287
1288
1289 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1290 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1291 OPENING_BRACKETS = set(BRACKET.keys())
1292 CLOSING_BRACKETS = set(BRACKET.values())
1293 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1294 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1295
1296
1297 def whitespace(leaf: Leaf) -> str:  # noqa C901
1298     """Return whitespace prefix if needed for the given `leaf`."""
1299     NO = ""
1300     SPACE = " "
1301     DOUBLESPACE = "  "
1302     t = leaf.type
1303     p = leaf.parent
1304     v = leaf.value
1305     if t in ALWAYS_NO_SPACE:
1306         return NO
1307
1308     if t == token.COMMENT:
1309         return DOUBLESPACE
1310
1311     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1312     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1313         return NO
1314
1315     prev = leaf.prev_sibling
1316     if not prev:
1317         prevp = preceding_leaf(p)
1318         if not prevp or prevp.type in OPENING_BRACKETS:
1319             return NO
1320
1321         if t == token.COLON:
1322             return SPACE if prevp.type == token.COMMA else NO
1323
1324         if prevp.type == token.EQUAL:
1325             if prevp.parent:
1326                 if prevp.parent.type in {
1327                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1328                 }:
1329                     return NO
1330
1331                 elif prevp.parent.type == syms.typedargslist:
1332                     # A bit hacky: if the equal sign has whitespace, it means we
1333                     # previously found it's a typed argument.  So, we're using
1334                     # that, too.
1335                     return prevp.prefix
1336
1337         elif prevp.type in STARS:
1338             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1339                 return NO
1340
1341         elif prevp.type == token.COLON:
1342             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1343                 return NO
1344
1345         elif (
1346             prevp.parent
1347             and prevp.parent.type == syms.factor
1348             and prevp.type in MATH_OPERATORS
1349         ):
1350             return NO
1351
1352         elif (
1353             prevp.type == token.RIGHTSHIFT
1354             and prevp.parent
1355             and prevp.parent.type == syms.shift_expr
1356             and prevp.prev_sibling
1357             and prevp.prev_sibling.type == token.NAME
1358             and prevp.prev_sibling.value == "print"  # type: ignore
1359         ):
1360             # Python 2 print chevron
1361             return NO
1362
1363     elif prev.type in OPENING_BRACKETS:
1364         return NO
1365
1366     if p.type in {syms.parameters, syms.arglist}:
1367         # untyped function signatures or calls
1368         if t == token.RPAR:
1369             return NO
1370
1371         if not prev or prev.type != token.COMMA:
1372             return NO
1373
1374     elif p.type == syms.varargslist:
1375         # lambdas
1376         if t == token.RPAR:
1377             return NO
1378
1379         if prev and prev.type != token.COMMA:
1380             return NO
1381
1382     elif p.type == syms.typedargslist:
1383         # typed function signatures
1384         if not prev:
1385             return NO
1386
1387         if t == token.EQUAL:
1388             if prev.type != syms.tname:
1389                 return NO
1390
1391         elif prev.type == token.EQUAL:
1392             # A bit hacky: if the equal sign has whitespace, it means we
1393             # previously found it's a typed argument.  So, we're using that, too.
1394             return prev.prefix
1395
1396         elif prev.type != token.COMMA:
1397             return NO
1398
1399     elif p.type == syms.tname:
1400         # type names
1401         if not prev:
1402             prevp = preceding_leaf(p)
1403             if not prevp or prevp.type != token.COMMA:
1404                 return NO
1405
1406     elif p.type == syms.trailer:
1407         # attributes and calls
1408         if t == token.LPAR or t == token.RPAR:
1409             return NO
1410
1411         if not prev:
1412             if t == token.DOT:
1413                 prevp = preceding_leaf(p)
1414                 if not prevp or prevp.type != token.NUMBER:
1415                     return NO
1416
1417             elif t == token.LSQB:
1418                 return NO
1419
1420         elif prev.type != token.COMMA:
1421             return NO
1422
1423     elif p.type == syms.argument:
1424         # single argument
1425         if t == token.EQUAL:
1426             return NO
1427
1428         if not prev:
1429             prevp = preceding_leaf(p)
1430             if not prevp or prevp.type == token.LPAR:
1431                 return NO
1432
1433         elif prev.type in {token.EQUAL} | STARS:
1434             return NO
1435
1436     elif p.type == syms.decorator:
1437         # decorators
1438         return NO
1439
1440     elif p.type == syms.dotted_name:
1441         if prev:
1442             return NO
1443
1444         prevp = preceding_leaf(p)
1445         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1446             return NO
1447
1448     elif p.type == syms.classdef:
1449         if t == token.LPAR:
1450             return NO
1451
1452         if prev and prev.type == token.LPAR:
1453             return NO
1454
1455     elif p.type == syms.subscript:
1456         # indexing
1457         if not prev:
1458             assert p.parent is not None, "subscripts are always parented"
1459             if p.parent.type == syms.subscriptlist:
1460                 return SPACE
1461
1462             return NO
1463
1464         else:
1465             return NO
1466
1467     elif p.type == syms.atom:
1468         if prev and t == token.DOT:
1469             # dots, but not the first one.
1470             return NO
1471
1472     elif (
1473         p.type == syms.listmaker
1474         or p.type == syms.testlist_gexp
1475         or p.type == syms.subscriptlist
1476     ):
1477         # list interior, including unpacking
1478         if not prev:
1479             return NO
1480
1481     elif p.type == syms.dictsetmaker:
1482         # dict and set interior, including unpacking
1483         if not prev:
1484             return NO
1485
1486         if prev.type == token.DOUBLESTAR:
1487             return NO
1488
1489     elif p.type in {syms.factor, syms.star_expr}:
1490         # unary ops
1491         if not prev:
1492             prevp = preceding_leaf(p)
1493             if not prevp or prevp.type in OPENING_BRACKETS:
1494                 return NO
1495
1496             prevp_parent = prevp.parent
1497             assert prevp_parent is not None
1498             if (
1499                 prevp.type == token.COLON
1500                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1501             ):
1502                 return NO
1503
1504             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1505                 return NO
1506
1507         elif t == token.NAME or t == token.NUMBER:
1508             return NO
1509
1510     elif p.type == syms.import_from:
1511         if t == token.DOT:
1512             if prev and prev.type == token.DOT:
1513                 return NO
1514
1515         elif t == token.NAME:
1516             if v == "import":
1517                 return SPACE
1518
1519             if prev and prev.type == token.DOT:
1520                 return NO
1521
1522     elif p.type == syms.sliceop:
1523         return NO
1524
1525     return SPACE
1526
1527
1528 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1529     """Return the first leaf that precedes `node`, if any."""
1530     while node:
1531         res = node.prev_sibling
1532         if res:
1533             if isinstance(res, Leaf):
1534                 return res
1535
1536             try:
1537                 return list(res.leaves())[-1]
1538
1539             except IndexError:
1540                 return None
1541
1542         node = node.parent
1543     return None
1544
1545
1546 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1547     """Return the priority of the `leaf` delimiter, given a line break after it.
1548
1549     The delimiter priorities returned here are from those delimiters that would
1550     cause a line break after themselves.
1551
1552     Higher numbers are higher priority.
1553     """
1554     if leaf.type == token.COMMA:
1555         return COMMA_PRIORITY
1556
1557     return 0
1558
1559
1560 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1561     """Return the priority of the `leaf` delimiter, given a line before after it.
1562
1563     The delimiter priorities returned here are from those delimiters that would
1564     cause a line break before themselves.
1565
1566     Higher numbers are higher priority.
1567     """
1568     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1569         # * and ** might also be MATH_OPERATORS but in this case they are not.
1570         # Don't treat them as a delimiter.
1571         return 0
1572
1573     if (
1574         leaf.type in MATH_OPERATORS
1575         and leaf.parent
1576         and leaf.parent.type not in {syms.factor, syms.star_expr}
1577     ):
1578         return MATH_PRIORITY
1579
1580     if leaf.type in COMPARATORS:
1581         return COMPARATOR_PRIORITY
1582
1583     if (
1584         leaf.type == token.STRING
1585         and previous is not None
1586         and previous.type == token.STRING
1587     ):
1588         return STRING_PRIORITY
1589
1590     if (
1591         leaf.type == token.NAME
1592         and leaf.value == "for"
1593         and leaf.parent
1594         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1595     ):
1596         return COMPREHENSION_PRIORITY
1597
1598     if (
1599         leaf.type == token.NAME
1600         and leaf.value == "if"
1601         and leaf.parent
1602         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1603     ):
1604         return COMPREHENSION_PRIORITY
1605
1606     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1607         return LOGIC_PRIORITY
1608
1609     return 0
1610
1611
1612 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1613     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1614
1615     Higher numbers are higher priority.
1616     """
1617     return max(
1618         is_split_before_delimiter(leaf, previous),
1619         is_split_after_delimiter(leaf, previous),
1620     )
1621
1622
1623 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1624     """Clean the prefix of the `leaf` and generate comments from it, if any.
1625
1626     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1627     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1628     move because it does away with modifying the grammar to include all the
1629     possible places in which comments can be placed.
1630
1631     The sad consequence for us though is that comments don't "belong" anywhere.
1632     This is why this function generates simple parentless Leaf objects for
1633     comments.  We simply don't know what the correct parent should be.
1634
1635     No matter though, we can live without this.  We really only need to
1636     differentiate between inline and standalone comments.  The latter don't
1637     share the line with any code.
1638
1639     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1640     are emitted with a fake STANDALONE_COMMENT token identifier.
1641     """
1642     p = leaf.prefix
1643     if not p:
1644         return
1645
1646     if "#" not in p:
1647         return
1648
1649     consumed = 0
1650     nlines = 0
1651     for index, line in enumerate(p.split("\n")):
1652         consumed += len(line) + 1  # adding the length of the split '\n'
1653         line = line.lstrip()
1654         if not line:
1655             nlines += 1
1656         if not line.startswith("#"):
1657             continue
1658
1659         if index == 0 and leaf.type != token.ENDMARKER:
1660             comment_type = token.COMMENT  # simple trailing comment
1661         else:
1662             comment_type = STANDALONE_COMMENT
1663         comment = make_comment(line)
1664         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1665
1666         if comment in {"# fmt: on", "# yapf: enable"}:
1667             raise FormatOn(consumed)
1668
1669         if comment in {"# fmt: off", "# yapf: disable"}:
1670             if comment_type == STANDALONE_COMMENT:
1671                 raise FormatOff(consumed)
1672
1673             prev = preceding_leaf(leaf)
1674             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1675                 raise FormatOff(consumed)
1676
1677         nlines = 0
1678
1679
1680 def make_comment(content: str) -> str:
1681     """Return a consistently formatted comment from the given `content` string.
1682
1683     All comments (except for "##", "#!", "#:") should have a single space between
1684     the hash sign and the content.
1685
1686     If `content` didn't start with a hash sign, one is provided.
1687     """
1688     content = content.rstrip()
1689     if not content:
1690         return "#"
1691
1692     if content[0] == "#":
1693         content = content[1:]
1694     if content and content[0] not in " !:#":
1695         content = " " + content
1696     return "#" + content
1697
1698
1699 def split_line(
1700     line: Line, line_length: int, inner: bool = False, py36: bool = False
1701 ) -> Iterator[Line]:
1702     """Split a `line` into potentially many lines.
1703
1704     They should fit in the allotted `line_length` but might not be able to.
1705     `inner` signifies that there were a pair of brackets somewhere around the
1706     current `line`, possibly transitively. This means we can fallback to splitting
1707     by delimiters if the LHS/RHS don't yield any results.
1708
1709     If `py36` is True, splitting may generate syntax that is only compatible
1710     with Python 3.6 and later.
1711     """
1712     if isinstance(line, UnformattedLines) or line.is_comment:
1713         yield line
1714         return
1715
1716     line_str = str(line).strip("\n")
1717     if (
1718         len(line_str) <= line_length
1719         and "\n" not in line_str  # multiline strings
1720         and not line.contains_standalone_comments()
1721     ):
1722         yield line
1723         return
1724
1725     split_funcs: List[SplitFunc]
1726     if line.is_def:
1727         split_funcs = [left_hand_split]
1728     elif line.inside_brackets:
1729         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1730     else:
1731         split_funcs = [right_hand_split]
1732     for split_func in split_funcs:
1733         # We are accumulating lines in `result` because we might want to abort
1734         # mission and return the original line in the end, or attempt a different
1735         # split altogether.
1736         result: List[Line] = []
1737         try:
1738             for l in split_func(line, py36):
1739                 if str(l).strip("\n") == line_str:
1740                     raise CannotSplit("Split function returned an unchanged result")
1741
1742                 result.extend(
1743                     split_line(l, line_length=line_length, inner=True, py36=py36)
1744                 )
1745         except CannotSplit as cs:
1746             continue
1747
1748         else:
1749             yield from result
1750             break
1751
1752     else:
1753         yield line
1754
1755
1756 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1757     """Split line into many lines, starting with the first matching bracket pair.
1758
1759     Note: this usually looks weird, only use this for function definitions.
1760     Prefer RHS otherwise.
1761     """
1762     head = Line(depth=line.depth)
1763     body = Line(depth=line.depth + 1, inside_brackets=True)
1764     tail = Line(depth=line.depth)
1765     tail_leaves: List[Leaf] = []
1766     body_leaves: List[Leaf] = []
1767     head_leaves: List[Leaf] = []
1768     current_leaves = head_leaves
1769     matching_bracket = None
1770     for leaf in line.leaves:
1771         if (
1772             current_leaves is body_leaves
1773             and leaf.type in CLOSING_BRACKETS
1774             and leaf.opening_bracket is matching_bracket
1775         ):
1776             current_leaves = tail_leaves if body_leaves else head_leaves
1777         current_leaves.append(leaf)
1778         if current_leaves is head_leaves:
1779             if leaf.type in OPENING_BRACKETS:
1780                 matching_bracket = leaf
1781                 current_leaves = body_leaves
1782     # Since body is a new indent level, remove spurious leading whitespace.
1783     if body_leaves:
1784         normalize_prefix(body_leaves[0], inside_brackets=True)
1785     # Build the new lines.
1786     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1787         for leaf in leaves:
1788             result.append(leaf, preformatted=True)
1789             for comment_after in line.comments_after(leaf):
1790                 result.append(comment_after, preformatted=True)
1791     bracket_split_succeeded_or_raise(head, body, tail)
1792     for result in (head, body, tail):
1793         if result:
1794             yield result
1795
1796
1797 def right_hand_split(
1798     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1799 ) -> Iterator[Line]:
1800     """Split line into many lines, starting with the last matching bracket pair."""
1801     head = Line(depth=line.depth)
1802     body = Line(depth=line.depth + 1, inside_brackets=True)
1803     tail = Line(depth=line.depth)
1804     tail_leaves: List[Leaf] = []
1805     body_leaves: List[Leaf] = []
1806     head_leaves: List[Leaf] = []
1807     current_leaves = tail_leaves
1808     opening_bracket = None
1809     closing_bracket = None
1810     for leaf in reversed(line.leaves):
1811         if current_leaves is body_leaves:
1812             if leaf is opening_bracket:
1813                 current_leaves = head_leaves if body_leaves else tail_leaves
1814         current_leaves.append(leaf)
1815         if current_leaves is tail_leaves:
1816             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1817                 opening_bracket = leaf.opening_bracket
1818                 closing_bracket = leaf
1819                 current_leaves = body_leaves
1820     tail_leaves.reverse()
1821     body_leaves.reverse()
1822     head_leaves.reverse()
1823     # Since body is a new indent level, remove spurious leading whitespace.
1824     if body_leaves:
1825         normalize_prefix(body_leaves[0], inside_brackets=True)
1826     elif not head_leaves:
1827         # No `head` and no `body` means the split failed. `tail` has all content.
1828         raise CannotSplit("No brackets found")
1829
1830     # Build the new lines.
1831     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1832         for leaf in leaves:
1833             result.append(leaf, preformatted=True)
1834             for comment_after in line.comments_after(leaf):
1835                 result.append(comment_after, preformatted=True)
1836     bracket_split_succeeded_or_raise(head, body, tail)
1837     assert opening_bracket and closing_bracket
1838     if (
1839         opening_bracket.type == token.LPAR
1840         and not opening_bracket.value
1841         and closing_bracket.type == token.RPAR
1842         and not closing_bracket.value
1843     ):
1844         # These parens were optional. If there aren't any delimiters or standalone
1845         # comments in the body, they were unnecessary and another split without
1846         # them should be attempted.
1847         if not (
1848             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1849         ):
1850             omit = {id(closing_bracket), *omit}
1851             yield from right_hand_split(line, py36=py36, omit=omit)
1852             return
1853
1854     ensure_visible(opening_bracket)
1855     ensure_visible(closing_bracket)
1856     for result in (head, body, tail):
1857         if result:
1858             yield result
1859
1860
1861 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1862     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1863
1864     Do nothing otherwise.
1865
1866     A left- or right-hand split is based on a pair of brackets. Content before
1867     (and including) the opening bracket is left on one line, content inside the
1868     brackets is put on a separate line, and finally content starting with and
1869     following the closing bracket is put on a separate line.
1870
1871     Those are called `head`, `body`, and `tail`, respectively. If the split
1872     produced the same line (all content in `head`) or ended up with an empty `body`
1873     and the `tail` is just the closing bracket, then it's considered failed.
1874     """
1875     tail_len = len(str(tail).strip())
1876     if not body:
1877         if tail_len == 0:
1878             raise CannotSplit("Splitting brackets produced the same line")
1879
1880         elif tail_len < 3:
1881             raise CannotSplit(
1882                 f"Splitting brackets on an empty body to save "
1883                 f"{tail_len} characters is not worth it"
1884             )
1885
1886
1887 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1888     """Normalize prefix of the first leaf in every line returned by `split_func`.
1889
1890     This is a decorator over relevant split functions.
1891     """
1892
1893     @wraps(split_func)
1894     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1895         for l in split_func(line, py36):
1896             normalize_prefix(l.leaves[0], inside_brackets=True)
1897             yield l
1898
1899     return split_wrapper
1900
1901
1902 @dont_increase_indentation
1903 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1904     """Split according to delimiters of the highest priority.
1905
1906     If `py36` is True, the split will add trailing commas also in function
1907     signatures that contain `*` and `**`.
1908     """
1909     try:
1910         last_leaf = line.leaves[-1]
1911     except IndexError:
1912         raise CannotSplit("Line empty")
1913
1914     delimiters = line.bracket_tracker.delimiters
1915     try:
1916         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1917             exclude={id(last_leaf)}
1918         )
1919     except ValueError:
1920         raise CannotSplit("No delimiters found")
1921
1922     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1923     lowest_depth = sys.maxsize
1924     trailing_comma_safe = True
1925
1926     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1927         """Append `leaf` to current line or to new line if appending impossible."""
1928         nonlocal current_line
1929         try:
1930             current_line.append_safe(leaf, preformatted=True)
1931         except ValueError as ve:
1932             yield current_line
1933
1934             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1935             current_line.append(leaf)
1936
1937     for leaf in line.leaves:
1938         yield from append_to_line(leaf)
1939
1940         for comment_after in line.comments_after(leaf):
1941             yield from append_to_line(comment_after)
1942
1943         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1944         if (
1945             leaf.bracket_depth == lowest_depth
1946             and is_vararg(leaf, within=VARARGS_PARENTS)
1947         ):
1948             trailing_comma_safe = trailing_comma_safe and py36
1949         leaf_priority = delimiters.get(id(leaf))
1950         if leaf_priority == delimiter_priority:
1951             yield current_line
1952
1953             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1954     if current_line:
1955         if (
1956             trailing_comma_safe
1957             and delimiter_priority == COMMA_PRIORITY
1958             and current_line.leaves[-1].type != token.COMMA
1959             and current_line.leaves[-1].type != STANDALONE_COMMENT
1960         ):
1961             current_line.append(Leaf(token.COMMA, ","))
1962         yield current_line
1963
1964
1965 @dont_increase_indentation
1966 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1967     """Split standalone comments from the rest of the line."""
1968     if not line.contains_standalone_comments(0):
1969         raise CannotSplit("Line does not have any standalone comments")
1970
1971     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1972
1973     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1974         """Append `leaf` to current line or to new line if appending impossible."""
1975         nonlocal current_line
1976         try:
1977             current_line.append_safe(leaf, preformatted=True)
1978         except ValueError as ve:
1979             yield current_line
1980
1981             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1982             current_line.append(leaf)
1983
1984     for leaf in line.leaves:
1985         yield from append_to_line(leaf)
1986
1987         for comment_after in line.comments_after(leaf):
1988             yield from append_to_line(comment_after)
1989
1990     if current_line:
1991         yield current_line
1992
1993
1994 def is_import(leaf: Leaf) -> bool:
1995     """Return True if the given leaf starts an import statement."""
1996     p = leaf.parent
1997     t = leaf.type
1998     v = leaf.value
1999     return bool(
2000         t == token.NAME
2001         and (
2002             (v == "import" and p and p.type == syms.import_name)
2003             or (v == "from" and p and p.type == syms.import_from)
2004         )
2005     )
2006
2007
2008 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2009     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2010     else.
2011
2012     Note: don't use backslashes for formatting or you'll lose your voting rights.
2013     """
2014     if not inside_brackets:
2015         spl = leaf.prefix.split("#")
2016         if "\\" not in spl[0]:
2017             nl_count = spl[-1].count("\n")
2018             if len(spl) > 1:
2019                 nl_count -= 1
2020             leaf.prefix = "\n" * nl_count
2021             return
2022
2023     leaf.prefix = ""
2024
2025
2026 def normalize_string_quotes(leaf: Leaf) -> None:
2027     """Prefer double quotes but only if it doesn't cause more escaping.
2028
2029     Adds or removes backslashes as appropriate. Doesn't parse and fix
2030     strings nested in f-strings (yet).
2031
2032     Note: Mutates its argument.
2033     """
2034     value = leaf.value.lstrip("furbFURB")
2035     if value[:3] == '"""':
2036         return
2037
2038     elif value[:3] == "'''":
2039         orig_quote = "'''"
2040         new_quote = '"""'
2041     elif value[0] == '"':
2042         orig_quote = '"'
2043         new_quote = "'"
2044     else:
2045         orig_quote = "'"
2046         new_quote = '"'
2047     first_quote_pos = leaf.value.find(orig_quote)
2048     if first_quote_pos == -1:
2049         return  # There's an internal error
2050
2051     prefix = leaf.value[:first_quote_pos]
2052     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2053     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2054     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2055     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2056     if "r" in prefix.casefold():
2057         if unescaped_new_quote.search(body):
2058             # There's at least one unescaped new_quote in this raw string
2059             # so converting is impossible
2060             return
2061
2062         # Do not introduce or remove backslashes in raw strings
2063         new_body = body
2064     else:
2065         # remove unnecessary quotes
2066         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2067         if body != new_body:
2068             # Consider the string without unnecessary quotes as the original
2069             body = new_body
2070             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2071         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2072         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2073     if new_quote == '"""' and new_body[-1] == '"':
2074         # edge case:
2075         new_body = new_body[:-1] + '\\"'
2076     orig_escape_count = body.count("\\")
2077     new_escape_count = new_body.count("\\")
2078     if new_escape_count > orig_escape_count:
2079         return  # Do not introduce more escaping
2080
2081     if new_escape_count == orig_escape_count and orig_quote == '"':
2082         return  # Prefer double quotes
2083
2084     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2085
2086
2087 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2088     """Make existing optional parentheses invisible or create new ones.
2089
2090     Standardizes on visible parentheses for single-element tuples, and keeps
2091     existing visible parentheses for other tuples and generator expressions.
2092     """
2093     check_lpar = False
2094     for child in list(node.children):
2095         if check_lpar:
2096             if child.type == syms.atom:
2097                 if not (
2098                     is_empty_tuple(child)
2099                     or is_one_tuple(child)
2100                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2101                 ):
2102                     first = child.children[0]
2103                     last = child.children[-1]
2104                     if first.type == token.LPAR and last.type == token.RPAR:
2105                         # make parentheses invisible
2106                         first.value = ""  # type: ignore
2107                         last.value = ""  # type: ignore
2108             elif is_one_tuple(child):
2109                 # wrap child in visible parentheses
2110                 lpar = Leaf(token.LPAR, "(")
2111                 rpar = Leaf(token.RPAR, ")")
2112                 index = child.remove() or 0
2113                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2114             else:
2115                 # wrap child in invisible parentheses
2116                 lpar = Leaf(token.LPAR, "")
2117                 rpar = Leaf(token.RPAR, "")
2118                 index = child.remove() or 0
2119                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2120
2121         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2122
2123
2124 def is_empty_tuple(node: LN) -> bool:
2125     """Return True if `node` holds an empty tuple."""
2126     return (
2127         node.type == syms.atom
2128         and len(node.children) == 2
2129         and node.children[0].type == token.LPAR
2130         and node.children[1].type == token.RPAR
2131     )
2132
2133
2134 def is_one_tuple(node: LN) -> bool:
2135     """Return True if `node` holds a tuple with one element, with or without parens."""
2136     if node.type == syms.atom:
2137         if len(node.children) != 3:
2138             return False
2139
2140         lpar, gexp, rpar = node.children
2141         if not (
2142             lpar.type == token.LPAR
2143             and gexp.type == syms.testlist_gexp
2144             and rpar.type == token.RPAR
2145         ):
2146             return False
2147
2148         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2149
2150     return (
2151         node.type in IMPLICIT_TUPLE
2152         and len(node.children) == 2
2153         and node.children[1].type == token.COMMA
2154     )
2155
2156
2157 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2158     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2159
2160     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2161     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2162     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2163     generalizations (PEP 448).
2164     """
2165     if leaf.type not in STARS or not leaf.parent:
2166         return False
2167
2168     p = leaf.parent
2169     if p.type == syms.star_expr:
2170         # Star expressions are also used as assignment targets in extended
2171         # iterable unpacking (PEP 3132).  See what its parent is instead.
2172         if not p.parent:
2173             return False
2174
2175         p = p.parent
2176
2177     return p.type in within
2178
2179
2180 def max_delimiter_priority_in_atom(node: LN) -> int:
2181     if node.type != syms.atom:
2182         return 0
2183
2184     first = node.children[0]
2185     last = node.children[-1]
2186     if not (first.type == token.LPAR and last.type == token.RPAR):
2187         return 0
2188
2189     bt = BracketTracker()
2190     for c in node.children[1:-1]:
2191         if isinstance(c, Leaf):
2192             bt.mark(c)
2193         else:
2194             for leaf in c.leaves():
2195                 bt.mark(leaf)
2196     try:
2197         return bt.max_delimiter_priority()
2198
2199     except ValueError:
2200         return 0
2201
2202
2203 def ensure_visible(leaf: Leaf) -> None:
2204     """Make sure parentheses are visible.
2205
2206     They could be invisible as part of some statements (see
2207     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2208     """
2209     if leaf.type == token.LPAR:
2210         leaf.value = "("
2211     elif leaf.type == token.RPAR:
2212         leaf.value = ")"
2213
2214
2215 def is_python36(node: Node) -> bool:
2216     """Return True if the current file is using Python 3.6+ features.
2217
2218     Currently looking for:
2219     - f-strings; and
2220     - trailing commas after * or ** in function signatures.
2221     """
2222     for n in node.pre_order():
2223         if n.type == token.STRING:
2224             value_head = n.value[:2]  # type: ignore
2225             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2226                 return True
2227
2228         elif (
2229             n.type == syms.typedargslist
2230             and n.children
2231             and n.children[-1].type == token.COMMA
2232         ):
2233             for ch in n.children:
2234                 if ch.type in STARS:
2235                     return True
2236
2237     return False
2238
2239
2240 PYTHON_EXTENSIONS = {".py"}
2241 BLACKLISTED_DIRECTORIES = {
2242     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2243 }
2244
2245
2246 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2247     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2248     and have one of the PYTHON_EXTENSIONS.
2249     """
2250     for child in path.iterdir():
2251         if child.is_dir():
2252             if child.name in BLACKLISTED_DIRECTORIES:
2253                 continue
2254
2255             yield from gen_python_files_in_dir(child)
2256
2257         elif child.suffix in PYTHON_EXTENSIONS:
2258             yield child
2259
2260
2261 @dataclass
2262 class Report:
2263     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2264     check: bool = False
2265     quiet: bool = False
2266     change_count: int = 0
2267     same_count: int = 0
2268     failure_count: int = 0
2269
2270     def done(self, src: Path, changed: Changed) -> None:
2271         """Increment the counter for successful reformatting. Write out a message."""
2272         if changed is Changed.YES:
2273             reformatted = "would reformat" if self.check else "reformatted"
2274             if not self.quiet:
2275                 out(f"{reformatted} {src}")
2276             self.change_count += 1
2277         else:
2278             if not self.quiet:
2279                 if changed is Changed.NO:
2280                     msg = f"{src} already well formatted, good job."
2281                 else:
2282                     msg = f"{src} wasn't modified on disk since last run."
2283                 out(msg, bold=False)
2284             self.same_count += 1
2285
2286     def failed(self, src: Path, message: str) -> None:
2287         """Increment the counter for failed reformatting. Write out a message."""
2288         err(f"error: cannot format {src}: {message}")
2289         self.failure_count += 1
2290
2291     @property
2292     def return_code(self) -> int:
2293         """Return the exit code that the app should use.
2294
2295         This considers the current state of changed files and failures:
2296         - if there were any failures, return 123;
2297         - if any files were changed and --check is being used, return 1;
2298         - otherwise return 0.
2299         """
2300         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2301         # 126 we have special returncodes reserved by the shell.
2302         if self.failure_count:
2303             return 123
2304
2305         elif self.change_count and self.check:
2306             return 1
2307
2308         return 0
2309
2310     def __str__(self) -> str:
2311         """Render a color report of the current state.
2312
2313         Use `click.unstyle` to remove colors.
2314         """
2315         if self.check:
2316             reformatted = "would be reformatted"
2317             unchanged = "would be left unchanged"
2318             failed = "would fail to reformat"
2319         else:
2320             reformatted = "reformatted"
2321             unchanged = "left unchanged"
2322             failed = "failed to reformat"
2323         report = []
2324         if self.change_count:
2325             s = "s" if self.change_count > 1 else ""
2326             report.append(
2327                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2328             )
2329         if self.same_count:
2330             s = "s" if self.same_count > 1 else ""
2331             report.append(f"{self.same_count} file{s} {unchanged}")
2332         if self.failure_count:
2333             s = "s" if self.failure_count > 1 else ""
2334             report.append(
2335                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2336             )
2337         return ", ".join(report) + "."
2338
2339
2340 def assert_equivalent(src: str, dst: str) -> None:
2341     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2342
2343     import ast
2344     import traceback
2345
2346     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2347         """Simple visitor generating strings to compare ASTs by content."""
2348         yield f"{'  ' * depth}{node.__class__.__name__}("
2349
2350         for field in sorted(node._fields):
2351             try:
2352                 value = getattr(node, field)
2353             except AttributeError:
2354                 continue
2355
2356             yield f"{'  ' * (depth+1)}{field}="
2357
2358             if isinstance(value, list):
2359                 for item in value:
2360                     if isinstance(item, ast.AST):
2361                         yield from _v(item, depth + 2)
2362
2363             elif isinstance(value, ast.AST):
2364                 yield from _v(value, depth + 2)
2365
2366             else:
2367                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2368
2369         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2370
2371     try:
2372         src_ast = ast.parse(src)
2373     except Exception as exc:
2374         major, minor = sys.version_info[:2]
2375         raise AssertionError(
2376             f"cannot use --safe with this file; failed to parse source file "
2377             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2378             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2379         )
2380
2381     try:
2382         dst_ast = ast.parse(dst)
2383     except Exception as exc:
2384         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2385         raise AssertionError(
2386             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2387             f"Please report a bug on https://github.com/ambv/black/issues.  "
2388             f"This invalid output might be helpful: {log}"
2389         ) from None
2390
2391     src_ast_str = "\n".join(_v(src_ast))
2392     dst_ast_str = "\n".join(_v(dst_ast))
2393     if src_ast_str != dst_ast_str:
2394         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2395         raise AssertionError(
2396             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2397             f"the source.  "
2398             f"Please report a bug on https://github.com/ambv/black/issues.  "
2399             f"This diff might be helpful: {log}"
2400         ) from None
2401
2402
2403 def assert_stable(src: str, dst: str, line_length: int) -> None:
2404     """Raise AssertionError if `dst` reformats differently the second time."""
2405     newdst = format_str(dst, line_length=line_length)
2406     if dst != newdst:
2407         log = dump_to_file(
2408             diff(src, dst, "source", "first pass"),
2409             diff(dst, newdst, "first pass", "second pass"),
2410         )
2411         raise AssertionError(
2412             f"INTERNAL ERROR: Black produced different code on the second pass "
2413             f"of the formatter.  "
2414             f"Please report a bug on https://github.com/ambv/black/issues.  "
2415             f"This diff might be helpful: {log}"
2416         ) from None
2417
2418
2419 def dump_to_file(*output: str) -> str:
2420     """Dump `output` to a temporary file. Return path to the file."""
2421     import tempfile
2422
2423     with tempfile.NamedTemporaryFile(
2424         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2425     ) as f:
2426         for lines in output:
2427             f.write(lines)
2428             if lines and lines[-1] != "\n":
2429                 f.write("\n")
2430     return f.name
2431
2432
2433 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2434     """Return a unified diff string between strings `a` and `b`."""
2435     import difflib
2436
2437     a_lines = [line + "\n" for line in a.split("\n")]
2438     b_lines = [line + "\n" for line in b.split("\n")]
2439     return "".join(
2440         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2441     )
2442
2443
2444 def cancel(tasks: List[asyncio.Task]) -> None:
2445     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2446     err("Aborted!")
2447     for task in tasks:
2448         task.cancel()
2449
2450
2451 def shutdown(loop: BaseEventLoop) -> None:
2452     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2453     try:
2454         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2455         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2456         if not to_cancel:
2457             return
2458
2459         for task in to_cancel:
2460             task.cancel()
2461         loop.run_until_complete(
2462             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2463         )
2464     finally:
2465         # `concurrent.futures.Future` objects cannot be cancelled once they
2466         # are already running. There might be some when the `shutdown()` happened.
2467         # Silence their logger's spew about the event loop being closed.
2468         cf_logger = logging.getLogger("concurrent.futures")
2469         cf_logger.setLevel(logging.CRITICAL)
2470         loop.close()
2471
2472
2473 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2474     """Replace `regex` with `replacement` twice on `original`.
2475
2476     This is used by string normalization to perform replaces on
2477     overlapping matches.
2478     """
2479     return regex.sub(replacement, regex.sub(replacement, original))
2480
2481
2482 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2483 CACHE_FILE = CACHE_DIR / "cache.pickle"
2484
2485
2486 def read_cache() -> Cache:
2487     """Read the cache if it exists and is well formed.
2488
2489     If it is not well formed, the call to write_cache later should resolve the issue.
2490     """
2491     if not CACHE_FILE.exists():
2492         return {}
2493
2494     with CACHE_FILE.open("rb") as fobj:
2495         try:
2496             cache: Cache = pickle.load(fobj)
2497         except pickle.UnpicklingError:
2498             return {}
2499
2500     return cache
2501
2502
2503 def get_cache_info(path: Path) -> CacheInfo:
2504     """Return the information used to check if a file is already formatted or not."""
2505     stat = path.stat()
2506     return stat.st_mtime, stat.st_size
2507
2508
2509 def filter_cached(
2510     cache: Cache, sources: Iterable[Path]
2511 ) -> Tuple[List[Path], List[Path]]:
2512     """Split a list of paths into two.
2513
2514     The first list contains paths of files that modified on disk or are not in the
2515     cache. The other list contains paths to non-modified files.
2516     """
2517     todo, done = [], []
2518     for src in sources:
2519         src = src.resolve()
2520         if cache.get(src) != get_cache_info(src):
2521             todo.append(src)
2522         else:
2523             done.append(src)
2524     return todo, done
2525
2526
2527 def write_cache(cache: Cache, sources: List[Path]) -> None:
2528     """Update the cache file."""
2529     try:
2530         if not CACHE_DIR.exists():
2531             CACHE_DIR.mkdir(parents=True)
2532         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2533         with CACHE_FILE.open("wb") as fobj:
2534             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2535     except OSError:
2536         pass
2537
2538
2539 if __name__ == "__main__":
2540     main()