]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Skip handling signals on event loops that don't support it (#156)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187     if check and diff:
188         exc = click.ClickException("Options --check and --diff are mutually exclusive")
189         exc.exit_code = 2
190         raise exc
191
192     if check:
193         write_back = WriteBack.NO
194     elif diff:
195         write_back = WriteBack.DIFF
196     else:
197         write_back = WriteBack.YES
198     if len(sources) == 0:
199         ctx.exit(0)
200         return
201
202     elif len(sources) == 1:
203         return_code = reformat_one(sources[0], line_length, fast, quiet, write_back)
204     else:
205         loop = asyncio.get_event_loop()
206         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207         return_code = 1
208         try:
209             return_code = loop.run_until_complete(
210                 schedule_formatting(
211                     sources, line_length, write_back, fast, quiet, loop, executor
212                 )
213             )
214         finally:
215             shutdown(loop)
216     ctx.exit(return_code)
217
218
219 def reformat_one(
220     src: Path, line_length: int, fast: bool, quiet: bool, write_back: WriteBack
221 ) -> int:
222     """Reformat a single file under `src` without spawning child processes.
223
224     If `quiet` is True, non-error messages are not output. `line_length`,
225     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
226     """
227     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
228     try:
229         changed = Changed.NO
230         if not src.is_file() and str(src) == "-":
231             if format_stdin_to_stdout(
232                 line_length=line_length, fast=fast, write_back=write_back
233             ):
234                 changed = Changed.YES
235         else:
236             cache: Cache = {}
237             if write_back != WriteBack.DIFF:
238                 cache = read_cache()
239                 src = src.resolve()
240                 if src in cache and cache[src] == get_cache_info(src):
241                     changed = Changed.CACHED
242             if (
243                 changed is not Changed.CACHED
244                 and format_file_in_place(
245                     src, line_length=line_length, fast=fast, write_back=write_back
246                 )
247             ):
248                 changed = Changed.YES
249             if write_back != WriteBack.DIFF and changed is not Changed.NO:
250                 write_cache(cache, [src])
251         report.done(src, changed)
252     except Exception as exc:
253         report.failed(src, str(exc))
254     return report.return_code
255
256
257 async def schedule_formatting(
258     sources: List[Path],
259     line_length: int,
260     write_back: WriteBack,
261     fast: bool,
262     quiet: bool,
263     loop: BaseEventLoop,
264     executor: Executor,
265 ) -> int:
266     """Run formatting of `sources` in parallel using the provided `executor`.
267
268     (Use ProcessPoolExecutors for actual parallelism.)
269
270     `line_length`, `write_back`, and `fast` options are passed to
271     :func:`format_file_in_place`.
272     """
273     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
274     cache: Cache = {}
275     if write_back != WriteBack.DIFF:
276         cache = read_cache()
277         sources, cached = filter_cached(cache, sources)
278         for src in cached:
279             report.done(src, Changed.CACHED)
280     cancelled = []
281     formatted = []
282     if sources:
283         lock = None
284         if write_back == WriteBack.DIFF:
285             # For diff output, we need locks to ensure we don't interleave output
286             # from different processes.
287             manager = Manager()
288             lock = manager.Lock()
289         tasks = {
290             src: loop.run_in_executor(
291                 executor, format_file_in_place, src, line_length, fast, write_back, lock
292             )
293             for src in sources
294         }
295         _task_values = list(tasks.values())
296         try:
297             loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
298             loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
299         except NotImplementedError:
300             # There are no good alternatives for these on Windows
301             pass
302         await asyncio.wait(_task_values)
303         for src, task in tasks.items():
304             if not task.done():
305                 report.failed(src, "timed out, cancelling")
306                 task.cancel()
307                 cancelled.append(task)
308             elif task.cancelled():
309                 cancelled.append(task)
310             elif task.exception():
311                 report.failed(src, str(task.exception()))
312             else:
313                 formatted.append(src)
314                 report.done(src, Changed.YES if task.result() else Changed.NO)
315
316     if cancelled:
317         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
318     elif not quiet:
319         out("All done! ✨ 🍰 ✨")
320     if not quiet:
321         click.echo(str(report))
322
323     if write_back != WriteBack.DIFF and formatted:
324         write_cache(cache, formatted)
325
326     return report.return_code
327
328
329 def format_file_in_place(
330     src: Path,
331     line_length: int,
332     fast: bool,
333     write_back: WriteBack = WriteBack.NO,
334     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
335 ) -> bool:
336     """Format file under `src` path. Return True if changed.
337
338     If `write_back` is True, write reformatted code back to stdout.
339     `line_length` and `fast` options are passed to :func:`format_file_contents`.
340     """
341
342     with tokenize.open(src) as src_buffer:
343         src_contents = src_buffer.read()
344     try:
345         dst_contents = format_file_contents(
346             src_contents, line_length=line_length, fast=fast
347         )
348     except NothingChanged:
349         return False
350
351     if write_back == write_back.YES:
352         with open(src, "w", encoding=src_buffer.encoding) as f:
353             f.write(dst_contents)
354     elif write_back == write_back.DIFF:
355         src_name = f"{src.name}  (original)"
356         dst_name = f"{src.name}  (formatted)"
357         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
358         if lock:
359             lock.acquire()
360         try:
361             sys.stdout.write(diff_contents)
362         finally:
363             if lock:
364                 lock.release()
365     return True
366
367
368 def format_stdin_to_stdout(
369     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
370 ) -> bool:
371     """Format file on stdin. Return True if changed.
372
373     If `write_back` is True, write reformatted code back to stdout.
374     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
375     """
376     src = sys.stdin.read()
377     dst = src
378     try:
379         dst = format_file_contents(src, line_length=line_length, fast=fast)
380         return True
381
382     except NothingChanged:
383         return False
384
385     finally:
386         if write_back == WriteBack.YES:
387             sys.stdout.write(dst)
388         elif write_back == WriteBack.DIFF:
389             src_name = "<stdin>  (original)"
390             dst_name = "<stdin>  (formatted)"
391             sys.stdout.write(diff(src, dst, src_name, dst_name))
392
393
394 def format_file_contents(
395     src_contents: str, line_length: int, fast: bool
396 ) -> FileContent:
397     """Reformat contents a file and return new contents.
398
399     If `fast` is False, additionally confirm that the reformatted code is
400     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
401     `line_length` is passed to :func:`format_str`.
402     """
403     if src_contents.strip() == "":
404         raise NothingChanged
405
406     dst_contents = format_str(src_contents, line_length=line_length)
407     if src_contents == dst_contents:
408         raise NothingChanged
409
410     if not fast:
411         assert_equivalent(src_contents, dst_contents)
412         assert_stable(src_contents, dst_contents, line_length=line_length)
413     return dst_contents
414
415
416 def format_str(src_contents: str, line_length: int) -> FileContent:
417     """Reformat a string and return new contents.
418
419     `line_length` determines how many characters per line are allowed.
420     """
421     src_node = lib2to3_parse(src_contents)
422     dst_contents = ""
423     lines = LineGenerator()
424     elt = EmptyLineTracker()
425     py36 = is_python36(src_node)
426     empty_line = Line()
427     after = 0
428     for current_line in lines.visit(src_node):
429         for _ in range(after):
430             dst_contents += str(empty_line)
431         before, after = elt.maybe_empty_lines(current_line)
432         for _ in range(before):
433             dst_contents += str(empty_line)
434         for line in split_line(current_line, line_length=line_length, py36=py36):
435             dst_contents += str(line)
436     return dst_contents
437
438
439 GRAMMARS = [
440     pygram.python_grammar_no_print_statement_no_exec_statement,
441     pygram.python_grammar_no_print_statement,
442     pygram.python_grammar_no_exec_statement,
443     pygram.python_grammar,
444 ]
445
446
447 def lib2to3_parse(src_txt: str) -> Node:
448     """Given a string with source, return the lib2to3 Node."""
449     grammar = pygram.python_grammar_no_print_statement
450     if src_txt[-1] != "\n":
451         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
452         src_txt += nl
453     for grammar in GRAMMARS:
454         drv = driver.Driver(grammar, pytree.convert)
455         try:
456             result = drv.parse_string(src_txt, True)
457             break
458
459         except ParseError as pe:
460             lineno, column = pe.context[1]
461             lines = src_txt.splitlines()
462             try:
463                 faulty_line = lines[lineno - 1]
464             except IndexError:
465                 faulty_line = "<line number missing in source>"
466             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
467     else:
468         raise exc from None
469
470     if isinstance(result, Leaf):
471         result = Node(syms.file_input, [result])
472     return result
473
474
475 def lib2to3_unparse(node: Node) -> str:
476     """Given a lib2to3 node, return its string representation."""
477     code = str(node)
478     return code
479
480
481 T = TypeVar("T")
482
483
484 class Visitor(Generic[T]):
485     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
486
487     def visit(self, node: LN) -> Iterator[T]:
488         """Main method to visit `node` and its children.
489
490         It tries to find a `visit_*()` method for the given `node.type`, like
491         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
492         If no dedicated `visit_*()` method is found, chooses `visit_default()`
493         instead.
494
495         Then yields objects of type `T` from the selected visitor.
496         """
497         if node.type < 256:
498             name = token.tok_name[node.type]
499         else:
500             name = type_repr(node.type)
501         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
502
503     def visit_default(self, node: LN) -> Iterator[T]:
504         """Default `visit_*()` implementation. Recurses to children of `node`."""
505         if isinstance(node, Node):
506             for child in node.children:
507                 yield from self.visit(child)
508
509
510 @dataclass
511 class DebugVisitor(Visitor[T]):
512     tree_depth: int = 0
513
514     def visit_default(self, node: LN) -> Iterator[T]:
515         indent = " " * (2 * self.tree_depth)
516         if isinstance(node, Node):
517             _type = type_repr(node.type)
518             out(f"{indent}{_type}", fg="yellow")
519             self.tree_depth += 1
520             for child in node.children:
521                 yield from self.visit(child)
522
523             self.tree_depth -= 1
524             out(f"{indent}/{_type}", fg="yellow", bold=False)
525         else:
526             _type = token.tok_name.get(node.type, str(node.type))
527             out(f"{indent}{_type}", fg="blue", nl=False)
528             if node.prefix:
529                 # We don't have to handle prefixes for `Node` objects since
530                 # that delegates to the first child anyway.
531                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
532             out(f" {node.value!r}", fg="blue", bold=False)
533
534     @classmethod
535     def show(cls, code: str) -> None:
536         """Pretty-print the lib2to3 AST of a given string of `code`.
537
538         Convenience method for debugging.
539         """
540         v: DebugVisitor[None] = DebugVisitor()
541         list(v.visit(lib2to3_parse(code)))
542
543
544 KEYWORDS = set(keyword.kwlist)
545 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
546 FLOW_CONTROL = {"return", "raise", "break", "continue"}
547 STATEMENT = {
548     syms.if_stmt,
549     syms.while_stmt,
550     syms.for_stmt,
551     syms.try_stmt,
552     syms.except_clause,
553     syms.with_stmt,
554     syms.funcdef,
555     syms.classdef,
556 }
557 STANDALONE_COMMENT = 153
558 LOGIC_OPERATORS = {"and", "or"}
559 COMPARATORS = {
560     token.LESS,
561     token.GREATER,
562     token.EQEQUAL,
563     token.NOTEQUAL,
564     token.LESSEQUAL,
565     token.GREATEREQUAL,
566 }
567 MATH_OPERATORS = {
568     token.PLUS,
569     token.MINUS,
570     token.STAR,
571     token.SLASH,
572     token.VBAR,
573     token.AMPER,
574     token.PERCENT,
575     token.CIRCUMFLEX,
576     token.TILDE,
577     token.LEFTSHIFT,
578     token.RIGHTSHIFT,
579     token.DOUBLESTAR,
580     token.DOUBLESLASH,
581 }
582 STARS = {token.STAR, token.DOUBLESTAR}
583 VARARGS_PARENTS = {
584     syms.arglist,
585     syms.argument,  # double star in arglist
586     syms.trailer,  # single argument to call
587     syms.typedargslist,
588     syms.varargslist,  # lambdas
589 }
590 UNPACKING_PARENTS = {
591     syms.atom,  # single element of a list or set literal
592     syms.dictsetmaker,
593     syms.listmaker,
594     syms.testlist_gexp,
595 }
596 COMPREHENSION_PRIORITY = 20
597 COMMA_PRIORITY = 10
598 LOGIC_PRIORITY = 5
599 STRING_PRIORITY = 4
600 COMPARATOR_PRIORITY = 3
601 MATH_PRIORITY = 1
602
603
604 @dataclass
605 class BracketTracker:
606     """Keeps track of brackets on a line."""
607
608     depth: int = 0
609     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
610     delimiters: Dict[LeafID, Priority] = Factory(dict)
611     previous: Optional[Leaf] = None
612
613     def mark(self, leaf: Leaf) -> None:
614         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
615
616         All leaves receive an int `bracket_depth` field that stores how deep
617         within brackets a given leaf is. 0 means there are no enclosing brackets
618         that started on this line.
619
620         If a leaf is itself a closing bracket, it receives an `opening_bracket`
621         field that it forms a pair with. This is a one-directional link to
622         avoid reference cycles.
623
624         If a leaf is a delimiter (a token on which Black can split the line if
625         needed) and it's on depth 0, its `id()` is stored in the tracker's
626         `delimiters` field.
627         """
628         if leaf.type == token.COMMENT:
629             return
630
631         if leaf.type in CLOSING_BRACKETS:
632             self.depth -= 1
633             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
634             leaf.opening_bracket = opening_bracket
635         leaf.bracket_depth = self.depth
636         if self.depth == 0:
637             delim = is_split_before_delimiter(leaf, self.previous)
638             if delim and self.previous is not None:
639                 self.delimiters[id(self.previous)] = delim
640             else:
641                 delim = is_split_after_delimiter(leaf, self.previous)
642                 if delim:
643                     self.delimiters[id(leaf)] = delim
644         if leaf.type in OPENING_BRACKETS:
645             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
646             self.depth += 1
647         self.previous = leaf
648
649     def any_open_brackets(self) -> bool:
650         """Return True if there is an yet unmatched open bracket on the line."""
651         return bool(self.bracket_match)
652
653     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
654         """Return the highest priority of a delimiter found on the line.
655
656         Values are consistent with what `is_split_*_delimiter()` return.
657         Raises ValueError on no delimiters.
658         """
659         return max(v for k, v in self.delimiters.items() if k not in exclude)
660
661
662 @dataclass
663 class Line:
664     """Holds leaves and comments. Can be printed with `str(line)`."""
665
666     depth: int = 0
667     leaves: List[Leaf] = Factory(list)
668     comments: List[Tuple[Index, Leaf]] = Factory(list)
669     bracket_tracker: BracketTracker = Factory(BracketTracker)
670     inside_brackets: bool = False
671     has_for: bool = False
672     _for_loop_variable: bool = False
673
674     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
675         """Add a new `leaf` to the end of the line.
676
677         Unless `preformatted` is True, the `leaf` will receive a new consistent
678         whitespace prefix and metadata applied by :class:`BracketTracker`.
679         Trailing commas are maybe removed, unpacked for loop variables are
680         demoted from being delimiters.
681
682         Inline comments are put aside.
683         """
684         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
685         if not has_value:
686             return
687
688         if self.leaves and not preformatted:
689             # Note: at this point leaf.prefix should be empty except for
690             # imports, for which we only preserve newlines.
691             leaf.prefix += whitespace(leaf)
692         if self.inside_brackets or not preformatted:
693             self.maybe_decrement_after_for_loop_variable(leaf)
694             self.bracket_tracker.mark(leaf)
695             self.maybe_remove_trailing_comma(leaf)
696             self.maybe_increment_for_loop_variable(leaf)
697
698         if not self.append_comment(leaf):
699             self.leaves.append(leaf)
700
701     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
702         """Like :func:`append()` but disallow invalid standalone comment structure.
703
704         Raises ValueError when any `leaf` is appended after a standalone comment
705         or when a standalone comment is not the first leaf on the line.
706         """
707         if self.bracket_tracker.depth == 0:
708             if self.is_comment:
709                 raise ValueError("cannot append to standalone comments")
710
711             if self.leaves and leaf.type == STANDALONE_COMMENT:
712                 raise ValueError(
713                     "cannot append standalone comments to a populated line"
714                 )
715
716         self.append(leaf, preformatted=preformatted)
717
718     @property
719     def is_comment(self) -> bool:
720         """Is this line a standalone comment?"""
721         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
722
723     @property
724     def is_decorator(self) -> bool:
725         """Is this line a decorator?"""
726         return bool(self) and self.leaves[0].type == token.AT
727
728     @property
729     def is_import(self) -> bool:
730         """Is this an import line?"""
731         return bool(self) and is_import(self.leaves[0])
732
733     @property
734     def is_class(self) -> bool:
735         """Is this line a class definition?"""
736         return (
737             bool(self)
738             and self.leaves[0].type == token.NAME
739             and self.leaves[0].value == "class"
740         )
741
742     @property
743     def is_def(self) -> bool:
744         """Is this a function definition? (Also returns True for async defs.)"""
745         try:
746             first_leaf = self.leaves[0]
747         except IndexError:
748             return False
749
750         try:
751             second_leaf: Optional[Leaf] = self.leaves[1]
752         except IndexError:
753             second_leaf = None
754         return (
755             (first_leaf.type == token.NAME and first_leaf.value == "def")
756             or (
757                 first_leaf.type == token.ASYNC
758                 and second_leaf is not None
759                 and second_leaf.type == token.NAME
760                 and second_leaf.value == "def"
761             )
762         )
763
764     @property
765     def is_flow_control(self) -> bool:
766         """Is this line a flow control statement?
767
768         Those are `return`, `raise`, `break`, and `continue`.
769         """
770         return (
771             bool(self)
772             and self.leaves[0].type == token.NAME
773             and self.leaves[0].value in FLOW_CONTROL
774         )
775
776     @property
777     def is_yield(self) -> bool:
778         """Is this line a yield statement?"""
779         return (
780             bool(self)
781             and self.leaves[0].type == token.NAME
782             and self.leaves[0].value == "yield"
783         )
784
785     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
786         """If so, needs to be split before emitting."""
787         for leaf in self.leaves:
788             if leaf.type == STANDALONE_COMMENT:
789                 if leaf.bracket_depth <= depth_limit:
790                     return True
791
792         return False
793
794     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
795         """Remove trailing comma if there is one and it's safe."""
796         if not (
797             self.leaves
798             and self.leaves[-1].type == token.COMMA
799             and closing.type in CLOSING_BRACKETS
800         ):
801             return False
802
803         if closing.type == token.RBRACE:
804             self.remove_trailing_comma()
805             return True
806
807         if closing.type == token.RSQB:
808             comma = self.leaves[-1]
809             if comma.parent and comma.parent.type == syms.listmaker:
810                 self.remove_trailing_comma()
811                 return True
812
813         # For parens let's check if it's safe to remove the comma.  If the
814         # trailing one is the only one, we might mistakenly change a tuple
815         # into a different type by removing the comma.
816         depth = closing.bracket_depth + 1
817         commas = 0
818         opening = closing.opening_bracket
819         for _opening_index, leaf in enumerate(self.leaves):
820             if leaf is opening:
821                 break
822
823         else:
824             return False
825
826         for leaf in self.leaves[_opening_index + 1:]:
827             if leaf is closing:
828                 break
829
830             bracket_depth = leaf.bracket_depth
831             if bracket_depth == depth and leaf.type == token.COMMA:
832                 commas += 1
833                 if leaf.parent and leaf.parent.type == syms.arglist:
834                     commas += 1
835                     break
836
837         if commas > 1:
838             self.remove_trailing_comma()
839             return True
840
841         return False
842
843     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
844         """In a for loop, or comprehension, the variables are often unpacks.
845
846         To avoid splitting on the comma in this situation, increase the depth of
847         tokens between `for` and `in`.
848         """
849         if leaf.type == token.NAME and leaf.value == "for":
850             self.has_for = True
851             self.bracket_tracker.depth += 1
852             self._for_loop_variable = True
853             return True
854
855         return False
856
857     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
858         """See `maybe_increment_for_loop_variable` above for explanation."""
859         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
860             self.bracket_tracker.depth -= 1
861             self._for_loop_variable = False
862             return True
863
864         return False
865
866     def append_comment(self, comment: Leaf) -> bool:
867         """Add an inline or standalone comment to the line."""
868         if (
869             comment.type == STANDALONE_COMMENT
870             and self.bracket_tracker.any_open_brackets()
871         ):
872             comment.prefix = ""
873             return False
874
875         if comment.type != token.COMMENT:
876             return False
877
878         after = len(self.leaves) - 1
879         if after == -1:
880             comment.type = STANDALONE_COMMENT
881             comment.prefix = ""
882             return False
883
884         else:
885             self.comments.append((after, comment))
886             return True
887
888     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
889         """Generate comments that should appear directly after `leaf`."""
890         for _leaf_index, _leaf in enumerate(self.leaves):
891             if leaf is _leaf:
892                 break
893
894         else:
895             return
896
897         for index, comment_after in self.comments:
898             if _leaf_index == index:
899                 yield comment_after
900
901     def remove_trailing_comma(self) -> None:
902         """Remove the trailing comma and moves the comments attached to it."""
903         comma_index = len(self.leaves) - 1
904         for i in range(len(self.comments)):
905             comment_index, comment = self.comments[i]
906             if comment_index == comma_index:
907                 self.comments[i] = (comma_index - 1, comment)
908         self.leaves.pop()
909
910     def __str__(self) -> str:
911         """Render the line."""
912         if not self:
913             return "\n"
914
915         indent = "    " * self.depth
916         leaves = iter(self.leaves)
917         first = next(leaves)
918         res = f"{first.prefix}{indent}{first.value}"
919         for leaf in leaves:
920             res += str(leaf)
921         for _, comment in self.comments:
922             res += str(comment)
923         return res + "\n"
924
925     def __bool__(self) -> bool:
926         """Return True if the line has leaves or comments."""
927         return bool(self.leaves or self.comments)
928
929
930 class UnformattedLines(Line):
931     """Just like :class:`Line` but stores lines which aren't reformatted."""
932
933     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
934         """Just add a new `leaf` to the end of the lines.
935
936         The `preformatted` argument is ignored.
937
938         Keeps track of indentation `depth`, which is useful when the user
939         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
940         """
941         try:
942             list(generate_comments(leaf))
943         except FormatOn as f_on:
944             self.leaves.append(f_on.leaf_from_consumed(leaf))
945             raise
946
947         self.leaves.append(leaf)
948         if leaf.type == token.INDENT:
949             self.depth += 1
950         elif leaf.type == token.DEDENT:
951             self.depth -= 1
952
953     def __str__(self) -> str:
954         """Render unformatted lines from leaves which were added with `append()`.
955
956         `depth` is not used for indentation in this case.
957         """
958         if not self:
959             return "\n"
960
961         res = ""
962         for leaf in self.leaves:
963             res += str(leaf)
964         return res
965
966     def append_comment(self, comment: Leaf) -> bool:
967         """Not implemented in this class. Raises `NotImplementedError`."""
968         raise NotImplementedError("Unformatted lines don't store comments separately.")
969
970     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
971         """Does nothing and returns False."""
972         return False
973
974     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
975         """Does nothing and returns False."""
976         return False
977
978
979 @dataclass
980 class EmptyLineTracker:
981     """Provides a stateful method that returns the number of potential extra
982     empty lines needed before and after the currently processed line.
983
984     Note: this tracker works on lines that haven't been split yet.  It assumes
985     the prefix of the first leaf consists of optional newlines.  Those newlines
986     are consumed by `maybe_empty_lines()` and included in the computation.
987     """
988     previous_line: Optional[Line] = None
989     previous_after: int = 0
990     previous_defs: List[int] = Factory(list)
991
992     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
993         """Return the number of extra empty lines before and after the `current_line`.
994
995         This is for separating `def`, `async def` and `class` with extra empty
996         lines (two on module-level), as well as providing an extra empty line
997         after flow control keywords to make them more prominent.
998         """
999         if isinstance(current_line, UnformattedLines):
1000             return 0, 0
1001
1002         before, after = self._maybe_empty_lines(current_line)
1003         before -= self.previous_after
1004         self.previous_after = after
1005         self.previous_line = current_line
1006         return before, after
1007
1008     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1009         max_allowed = 1
1010         if current_line.depth == 0:
1011             max_allowed = 2
1012         if current_line.leaves:
1013             # Consume the first leaf's extra newlines.
1014             first_leaf = current_line.leaves[0]
1015             before = first_leaf.prefix.count("\n")
1016             before = min(before, max_allowed)
1017             first_leaf.prefix = ""
1018         else:
1019             before = 0
1020         depth = current_line.depth
1021         while self.previous_defs and self.previous_defs[-1] >= depth:
1022             self.previous_defs.pop()
1023             before = 1 if depth else 2
1024         is_decorator = current_line.is_decorator
1025         if is_decorator or current_line.is_def or current_line.is_class:
1026             if not is_decorator:
1027                 self.previous_defs.append(depth)
1028             if self.previous_line is None:
1029                 # Don't insert empty lines before the first line in the file.
1030                 return 0, 0
1031
1032             if self.previous_line and self.previous_line.is_decorator:
1033                 # Don't insert empty lines between decorators.
1034                 return 0, 0
1035
1036             newlines = 2
1037             if current_line.depth:
1038                 newlines -= 1
1039             return newlines, 0
1040
1041         if current_line.is_flow_control:
1042             return before, 1
1043
1044         if (
1045             self.previous_line
1046             and self.previous_line.is_import
1047             and not current_line.is_import
1048             and depth == self.previous_line.depth
1049         ):
1050             return (before or 1), 0
1051
1052         if (
1053             self.previous_line
1054             and self.previous_line.is_yield
1055             and (not current_line.is_yield or depth != self.previous_line.depth)
1056         ):
1057             return (before or 1), 0
1058
1059         return before, 0
1060
1061
1062 @dataclass
1063 class LineGenerator(Visitor[Line]):
1064     """Generates reformatted Line objects.  Empty lines are not emitted.
1065
1066     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1067     in ways that will no longer stringify to valid Python code on the tree.
1068     """
1069     current_line: Line = Factory(Line)
1070
1071     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1072         """Generate a line.
1073
1074         If the line is empty, only emit if it makes sense.
1075         If the line is too long, split it first and then generate.
1076
1077         If any lines were generated, set up a new current_line.
1078         """
1079         if not self.current_line:
1080             if self.current_line.__class__ == type:
1081                 self.current_line.depth += indent
1082             else:
1083                 self.current_line = type(depth=self.current_line.depth + indent)
1084             return  # Line is empty, don't emit. Creating a new one unnecessary.
1085
1086         complete_line = self.current_line
1087         self.current_line = type(depth=complete_line.depth + indent)
1088         yield complete_line
1089
1090     def visit(self, node: LN) -> Iterator[Line]:
1091         """Main method to visit `node` and its children.
1092
1093         Yields :class:`Line` objects.
1094         """
1095         if isinstance(self.current_line, UnformattedLines):
1096             # File contained `# fmt: off`
1097             yield from self.visit_unformatted(node)
1098
1099         else:
1100             yield from super().visit(node)
1101
1102     def visit_default(self, node: LN) -> Iterator[Line]:
1103         """Default `visit_*()` implementation. Recurses to children of `node`."""
1104         if isinstance(node, Leaf):
1105             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1106             try:
1107                 for comment in generate_comments(node):
1108                     if any_open_brackets:
1109                         # any comment within brackets is subject to splitting
1110                         self.current_line.append(comment)
1111                     elif comment.type == token.COMMENT:
1112                         # regular trailing comment
1113                         self.current_line.append(comment)
1114                         yield from self.line()
1115
1116                     else:
1117                         # regular standalone comment
1118                         yield from self.line()
1119
1120                         self.current_line.append(comment)
1121                         yield from self.line()
1122
1123             except FormatOff as f_off:
1124                 f_off.trim_prefix(node)
1125                 yield from self.line(type=UnformattedLines)
1126                 yield from self.visit(node)
1127
1128             except FormatOn as f_on:
1129                 # This only happens here if somebody says "fmt: on" multiple
1130                 # times in a row.
1131                 f_on.trim_prefix(node)
1132                 yield from self.visit_default(node)
1133
1134             else:
1135                 normalize_prefix(node, inside_brackets=any_open_brackets)
1136                 if node.type == token.STRING:
1137                     normalize_string_quotes(node)
1138                 if node.type not in WHITESPACE:
1139                     self.current_line.append(node)
1140         yield from super().visit_default(node)
1141
1142     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1143         """Increase indentation level, maybe yield a line."""
1144         # In blib2to3 INDENT never holds comments.
1145         yield from self.line(+1)
1146         yield from self.visit_default(node)
1147
1148     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1149         """Decrease indentation level, maybe yield a line."""
1150         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1151         yield from self.line(-1)
1152
1153     def visit_stmt(
1154         self, node: Node, keywords: Set[str], parens: Set[str]
1155     ) -> Iterator[Line]:
1156         """Visit a statement.
1157
1158         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1159         `def`, `with`, `class`, and `assert`.
1160
1161         The relevant Python language `keywords` for a given statement will be
1162         NAME leaves within it. This methods puts those on a separate line.
1163
1164         `parens` holds pairs of nodes where invisible parentheses should be put.
1165         Keys hold nodes after which opening parentheses should be put, values
1166         hold nodes before which closing parentheses should be put.
1167         """
1168         normalize_invisible_parens(node, parens_after=parens)
1169         for child in node.children:
1170             if child.type == token.NAME and child.value in keywords:  # type: ignore
1171                 yield from self.line()
1172
1173             yield from self.visit(child)
1174
1175     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1176         """Visit a statement without nested statements."""
1177         is_suite_like = node.parent and node.parent.type in STATEMENT
1178         if is_suite_like:
1179             yield from self.line(+1)
1180             yield from self.visit_default(node)
1181             yield from self.line(-1)
1182
1183         else:
1184             yield from self.line()
1185             yield from self.visit_default(node)
1186
1187     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1188         """Visit `async def`, `async for`, `async with`."""
1189         yield from self.line()
1190
1191         children = iter(node.children)
1192         for child in children:
1193             yield from self.visit(child)
1194
1195             if child.type == token.ASYNC:
1196                 break
1197
1198         internal_stmt = next(children)
1199         for child in internal_stmt.children:
1200             yield from self.visit(child)
1201
1202     def visit_decorators(self, node: Node) -> Iterator[Line]:
1203         """Visit decorators."""
1204         for child in node.children:
1205             yield from self.line()
1206             yield from self.visit(child)
1207
1208     def visit_import_from(self, node: Node) -> Iterator[Line]:
1209         """Visit import_from and maybe put invisible parentheses.
1210
1211         This is separate from `visit_stmt` because import statements don't
1212         support arbitrary atoms and thus handling of parentheses is custom.
1213         """
1214         check_lpar = False
1215         for index, child in enumerate(node.children):
1216             if check_lpar:
1217                 if child.type == token.LPAR:
1218                     # make parentheses invisible
1219                     child.value = ""  # type: ignore
1220                     node.children[-1].value = ""  # type: ignore
1221                 else:
1222                     # insert invisible parentheses
1223                     node.insert_child(index, Leaf(token.LPAR, ""))
1224                     node.append_child(Leaf(token.RPAR, ""))
1225                 break
1226
1227             check_lpar = (
1228                 child.type == token.NAME and child.value == "import"  # type: ignore
1229             )
1230
1231         for child in node.children:
1232             yield from self.visit(child)
1233
1234     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1235         """Remove a semicolon and put the other statement on a separate line."""
1236         yield from self.line()
1237
1238     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1239         """End of file. Process outstanding comments and end with a newline."""
1240         yield from self.visit_default(leaf)
1241         yield from self.line()
1242
1243     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1244         """Used when file contained a `# fmt: off`."""
1245         if isinstance(node, Node):
1246             for child in node.children:
1247                 yield from self.visit(child)
1248
1249         else:
1250             try:
1251                 self.current_line.append(node)
1252             except FormatOn as f_on:
1253                 f_on.trim_prefix(node)
1254                 yield from self.line()
1255                 yield from self.visit(node)
1256
1257             if node.type == token.ENDMARKER:
1258                 # somebody decided not to put a final `# fmt: on`
1259                 yield from self.line()
1260
1261     def __attrs_post_init__(self) -> None:
1262         """You are in a twisty little maze of passages."""
1263         v = self.visit_stmt
1264         Ø: Set[str] = set()
1265         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1266         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1267         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1268         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1269         self.visit_try_stmt = partial(
1270             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1271         )
1272         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1273         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1274         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1275         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1276         self.visit_async_funcdef = self.visit_async_stmt
1277         self.visit_decorated = self.visit_decorators
1278
1279
1280 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1281 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1282 OPENING_BRACKETS = set(BRACKET.keys())
1283 CLOSING_BRACKETS = set(BRACKET.values())
1284 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1285 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1286
1287
1288 def whitespace(leaf: Leaf) -> str:  # noqa C901
1289     """Return whitespace prefix if needed for the given `leaf`."""
1290     NO = ""
1291     SPACE = " "
1292     DOUBLESPACE = "  "
1293     t = leaf.type
1294     p = leaf.parent
1295     v = leaf.value
1296     if t in ALWAYS_NO_SPACE:
1297         return NO
1298
1299     if t == token.COMMENT:
1300         return DOUBLESPACE
1301
1302     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1303     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1304         return NO
1305
1306     prev = leaf.prev_sibling
1307     if not prev:
1308         prevp = preceding_leaf(p)
1309         if not prevp or prevp.type in OPENING_BRACKETS:
1310             return NO
1311
1312         if t == token.COLON:
1313             return SPACE if prevp.type == token.COMMA else NO
1314
1315         if prevp.type == token.EQUAL:
1316             if prevp.parent:
1317                 if prevp.parent.type in {
1318                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1319                 }:
1320                     return NO
1321
1322                 elif prevp.parent.type == syms.typedargslist:
1323                     # A bit hacky: if the equal sign has whitespace, it means we
1324                     # previously found it's a typed argument.  So, we're using
1325                     # that, too.
1326                     return prevp.prefix
1327
1328         elif prevp.type in STARS:
1329             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1330                 return NO
1331
1332         elif prevp.type == token.COLON:
1333             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1334                 return NO
1335
1336         elif (
1337             prevp.parent
1338             and prevp.parent.type == syms.factor
1339             and prevp.type in MATH_OPERATORS
1340         ):
1341             return NO
1342
1343         elif (
1344             prevp.type == token.RIGHTSHIFT
1345             and prevp.parent
1346             and prevp.parent.type == syms.shift_expr
1347             and prevp.prev_sibling
1348             and prevp.prev_sibling.type == token.NAME
1349             and prevp.prev_sibling.value == "print"  # type: ignore
1350         ):
1351             # Python 2 print chevron
1352             return NO
1353
1354     elif prev.type in OPENING_BRACKETS:
1355         return NO
1356
1357     if p.type in {syms.parameters, syms.arglist}:
1358         # untyped function signatures or calls
1359         if not prev or prev.type != token.COMMA:
1360             return NO
1361
1362     elif p.type == syms.varargslist:
1363         # lambdas
1364         if prev and prev.type != token.COMMA:
1365             return NO
1366
1367     elif p.type == syms.typedargslist:
1368         # typed function signatures
1369         if not prev:
1370             return NO
1371
1372         if t == token.EQUAL:
1373             if prev.type != syms.tname:
1374                 return NO
1375
1376         elif prev.type == token.EQUAL:
1377             # A bit hacky: if the equal sign has whitespace, it means we
1378             # previously found it's a typed argument.  So, we're using that, too.
1379             return prev.prefix
1380
1381         elif prev.type != token.COMMA:
1382             return NO
1383
1384     elif p.type == syms.tname:
1385         # type names
1386         if not prev:
1387             prevp = preceding_leaf(p)
1388             if not prevp or prevp.type != token.COMMA:
1389                 return NO
1390
1391     elif p.type == syms.trailer:
1392         # attributes and calls
1393         if t == token.LPAR or t == token.RPAR:
1394             return NO
1395
1396         if not prev:
1397             if t == token.DOT:
1398                 prevp = preceding_leaf(p)
1399                 if not prevp or prevp.type != token.NUMBER:
1400                     return NO
1401
1402             elif t == token.LSQB:
1403                 return NO
1404
1405         elif prev.type != token.COMMA:
1406             return NO
1407
1408     elif p.type == syms.argument:
1409         # single argument
1410         if t == token.EQUAL:
1411             return NO
1412
1413         if not prev:
1414             prevp = preceding_leaf(p)
1415             if not prevp or prevp.type == token.LPAR:
1416                 return NO
1417
1418         elif prev.type in {token.EQUAL} | STARS:
1419             return NO
1420
1421     elif p.type == syms.decorator:
1422         # decorators
1423         return NO
1424
1425     elif p.type == syms.dotted_name:
1426         if prev:
1427             return NO
1428
1429         prevp = preceding_leaf(p)
1430         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1431             return NO
1432
1433     elif p.type == syms.classdef:
1434         if t == token.LPAR:
1435             return NO
1436
1437         if prev and prev.type == token.LPAR:
1438             return NO
1439
1440     elif p.type == syms.subscript:
1441         # indexing
1442         if not prev:
1443             assert p.parent is not None, "subscripts are always parented"
1444             if p.parent.type == syms.subscriptlist:
1445                 return SPACE
1446
1447             return NO
1448
1449         else:
1450             return NO
1451
1452     elif p.type == syms.atom:
1453         if prev and t == token.DOT:
1454             # dots, but not the first one.
1455             return NO
1456
1457     elif p.type == syms.dictsetmaker:
1458         # dict unpacking
1459         if prev and prev.type == token.DOUBLESTAR:
1460             return NO
1461
1462     elif p.type in {syms.factor, syms.star_expr}:
1463         # unary ops
1464         if not prev:
1465             prevp = preceding_leaf(p)
1466             if not prevp or prevp.type in OPENING_BRACKETS:
1467                 return NO
1468
1469             prevp_parent = prevp.parent
1470             assert prevp_parent is not None
1471             if (
1472                 prevp.type == token.COLON
1473                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1474             ):
1475                 return NO
1476
1477             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1478                 return NO
1479
1480         elif t == token.NAME or t == token.NUMBER:
1481             return NO
1482
1483     elif p.type == syms.import_from:
1484         if t == token.DOT:
1485             if prev and prev.type == token.DOT:
1486                 return NO
1487
1488         elif t == token.NAME:
1489             if v == "import":
1490                 return SPACE
1491
1492             if prev and prev.type == token.DOT:
1493                 return NO
1494
1495     elif p.type == syms.sliceop:
1496         return NO
1497
1498     return SPACE
1499
1500
1501 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1502     """Return the first leaf that precedes `node`, if any."""
1503     while node:
1504         res = node.prev_sibling
1505         if res:
1506             if isinstance(res, Leaf):
1507                 return res
1508
1509             try:
1510                 return list(res.leaves())[-1]
1511
1512             except IndexError:
1513                 return None
1514
1515         node = node.parent
1516     return None
1517
1518
1519 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1520     """Return the priority of the `leaf` delimiter, given a line break after it.
1521
1522     The delimiter priorities returned here are from those delimiters that would
1523     cause a line break after themselves.
1524
1525     Higher numbers are higher priority.
1526     """
1527     if leaf.type == token.COMMA:
1528         return COMMA_PRIORITY
1529
1530     return 0
1531
1532
1533 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1534     """Return the priority of the `leaf` delimiter, given a line before after it.
1535
1536     The delimiter priorities returned here are from those delimiters that would
1537     cause a line break before themselves.
1538
1539     Higher numbers are higher priority.
1540     """
1541     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1542         # * and ** might also be MATH_OPERATORS but in this case they are not.
1543         # Don't treat them as a delimiter.
1544         return 0
1545
1546     if (
1547         leaf.type in MATH_OPERATORS
1548         and leaf.parent
1549         and leaf.parent.type not in {syms.factor, syms.star_expr}
1550     ):
1551         return MATH_PRIORITY
1552
1553     if leaf.type in COMPARATORS:
1554         return COMPARATOR_PRIORITY
1555
1556     if (
1557         leaf.type == token.STRING
1558         and previous is not None
1559         and previous.type == token.STRING
1560     ):
1561         return STRING_PRIORITY
1562
1563     if (
1564         leaf.type == token.NAME
1565         and leaf.value == "for"
1566         and leaf.parent
1567         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1568     ):
1569         return COMPREHENSION_PRIORITY
1570
1571     if (
1572         leaf.type == token.NAME
1573         and leaf.value == "if"
1574         and leaf.parent
1575         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1576     ):
1577         return COMPREHENSION_PRIORITY
1578
1579     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1580         return LOGIC_PRIORITY
1581
1582     return 0
1583
1584
1585 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1586     """Clean the prefix of the `leaf` and generate comments from it, if any.
1587
1588     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1589     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1590     move because it does away with modifying the grammar to include all the
1591     possible places in which comments can be placed.
1592
1593     The sad consequence for us though is that comments don't "belong" anywhere.
1594     This is why this function generates simple parentless Leaf objects for
1595     comments.  We simply don't know what the correct parent should be.
1596
1597     No matter though, we can live without this.  We really only need to
1598     differentiate between inline and standalone comments.  The latter don't
1599     share the line with any code.
1600
1601     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1602     are emitted with a fake STANDALONE_COMMENT token identifier.
1603     """
1604     p = leaf.prefix
1605     if not p:
1606         return
1607
1608     if "#" not in p:
1609         return
1610
1611     consumed = 0
1612     nlines = 0
1613     for index, line in enumerate(p.split("\n")):
1614         consumed += len(line) + 1  # adding the length of the split '\n'
1615         line = line.lstrip()
1616         if not line:
1617             nlines += 1
1618         if not line.startswith("#"):
1619             continue
1620
1621         if index == 0 and leaf.type != token.ENDMARKER:
1622             comment_type = token.COMMENT  # simple trailing comment
1623         else:
1624             comment_type = STANDALONE_COMMENT
1625         comment = make_comment(line)
1626         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1627
1628         if comment in {"# fmt: on", "# yapf: enable"}:
1629             raise FormatOn(consumed)
1630
1631         if comment in {"# fmt: off", "# yapf: disable"}:
1632             if comment_type == STANDALONE_COMMENT:
1633                 raise FormatOff(consumed)
1634
1635             prev = preceding_leaf(leaf)
1636             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1637                 raise FormatOff(consumed)
1638
1639         nlines = 0
1640
1641
1642 def make_comment(content: str) -> str:
1643     """Return a consistently formatted comment from the given `content` string.
1644
1645     All comments (except for "##", "#!", "#:") should have a single space between
1646     the hash sign and the content.
1647
1648     If `content` didn't start with a hash sign, one is provided.
1649     """
1650     content = content.rstrip()
1651     if not content:
1652         return "#"
1653
1654     if content[0] == "#":
1655         content = content[1:]
1656     if content and content[0] not in " !:#":
1657         content = " " + content
1658     return "#" + content
1659
1660
1661 def split_line(
1662     line: Line, line_length: int, inner: bool = False, py36: bool = False
1663 ) -> Iterator[Line]:
1664     """Split a `line` into potentially many lines.
1665
1666     They should fit in the allotted `line_length` but might not be able to.
1667     `inner` signifies that there were a pair of brackets somewhere around the
1668     current `line`, possibly transitively. This means we can fallback to splitting
1669     by delimiters if the LHS/RHS don't yield any results.
1670
1671     If `py36` is True, splitting may generate syntax that is only compatible
1672     with Python 3.6 and later.
1673     """
1674     if isinstance(line, UnformattedLines) or line.is_comment:
1675         yield line
1676         return
1677
1678     line_str = str(line).strip("\n")
1679     if (
1680         len(line_str) <= line_length
1681         and "\n" not in line_str  # multiline strings
1682         and not line.contains_standalone_comments()
1683     ):
1684         yield line
1685         return
1686
1687     split_funcs: List[SplitFunc]
1688     if line.is_def:
1689         split_funcs = [left_hand_split]
1690     elif line.inside_brackets:
1691         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1692     else:
1693         split_funcs = [right_hand_split]
1694     for split_func in split_funcs:
1695         # We are accumulating lines in `result` because we might want to abort
1696         # mission and return the original line in the end, or attempt a different
1697         # split altogether.
1698         result: List[Line] = []
1699         try:
1700             for l in split_func(line, py36):
1701                 if str(l).strip("\n") == line_str:
1702                     raise CannotSplit("Split function returned an unchanged result")
1703
1704                 result.extend(
1705                     split_line(l, line_length=line_length, inner=True, py36=py36)
1706                 )
1707         except CannotSplit as cs:
1708             continue
1709
1710         else:
1711             yield from result
1712             break
1713
1714     else:
1715         yield line
1716
1717
1718 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1719     """Split line into many lines, starting with the first matching bracket pair.
1720
1721     Note: this usually looks weird, only use this for function definitions.
1722     Prefer RHS otherwise.
1723     """
1724     head = Line(depth=line.depth)
1725     body = Line(depth=line.depth + 1, inside_brackets=True)
1726     tail = Line(depth=line.depth)
1727     tail_leaves: List[Leaf] = []
1728     body_leaves: List[Leaf] = []
1729     head_leaves: List[Leaf] = []
1730     current_leaves = head_leaves
1731     matching_bracket = None
1732     for leaf in line.leaves:
1733         if (
1734             current_leaves is body_leaves
1735             and leaf.type in CLOSING_BRACKETS
1736             and leaf.opening_bracket is matching_bracket
1737         ):
1738             current_leaves = tail_leaves if body_leaves else head_leaves
1739         current_leaves.append(leaf)
1740         if current_leaves is head_leaves:
1741             if leaf.type in OPENING_BRACKETS:
1742                 matching_bracket = leaf
1743                 current_leaves = body_leaves
1744     # Since body is a new indent level, remove spurious leading whitespace.
1745     if body_leaves:
1746         normalize_prefix(body_leaves[0], inside_brackets=True)
1747     # Build the new lines.
1748     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1749         for leaf in leaves:
1750             result.append(leaf, preformatted=True)
1751             for comment_after in line.comments_after(leaf):
1752                 result.append(comment_after, preformatted=True)
1753     bracket_split_succeeded_or_raise(head, body, tail)
1754     for result in (head, body, tail):
1755         if result:
1756             yield result
1757
1758
1759 def right_hand_split(
1760     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1761 ) -> Iterator[Line]:
1762     """Split line into many lines, starting with the last matching bracket pair."""
1763     head = Line(depth=line.depth)
1764     body = Line(depth=line.depth + 1, inside_brackets=True)
1765     tail = Line(depth=line.depth)
1766     tail_leaves: List[Leaf] = []
1767     body_leaves: List[Leaf] = []
1768     head_leaves: List[Leaf] = []
1769     current_leaves = tail_leaves
1770     opening_bracket = None
1771     closing_bracket = None
1772     for leaf in reversed(line.leaves):
1773         if current_leaves is body_leaves:
1774             if leaf is opening_bracket:
1775                 current_leaves = head_leaves if body_leaves else tail_leaves
1776         current_leaves.append(leaf)
1777         if current_leaves is tail_leaves:
1778             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1779                 opening_bracket = leaf.opening_bracket
1780                 closing_bracket = leaf
1781                 current_leaves = body_leaves
1782     tail_leaves.reverse()
1783     body_leaves.reverse()
1784     head_leaves.reverse()
1785     # Since body is a new indent level, remove spurious leading whitespace.
1786     if body_leaves:
1787         normalize_prefix(body_leaves[0], inside_brackets=True)
1788     elif not head_leaves:
1789         # No `head` and no `body` means the split failed. `tail` has all content.
1790         raise CannotSplit("No brackets found")
1791
1792     # Build the new lines.
1793     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1794         for leaf in leaves:
1795             result.append(leaf, preformatted=True)
1796             for comment_after in line.comments_after(leaf):
1797                 result.append(comment_after, preformatted=True)
1798     bracket_split_succeeded_or_raise(head, body, tail)
1799     assert opening_bracket and closing_bracket
1800     if (
1801         opening_bracket.type == token.LPAR
1802         and not opening_bracket.value
1803         and closing_bracket.type == token.RPAR
1804         and not closing_bracket.value
1805     ):
1806         # These parens were optional. If there aren't any delimiters or standalone
1807         # comments in the body, they were unnecessary and another split without
1808         # them should be attempted.
1809         if not (
1810             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1811         ):
1812             omit = {id(closing_bracket), *omit}
1813             yield from right_hand_split(line, py36=py36, omit=omit)
1814             return
1815
1816     ensure_visible(opening_bracket)
1817     ensure_visible(closing_bracket)
1818     for result in (head, body, tail):
1819         if result:
1820             yield result
1821
1822
1823 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1824     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1825
1826     Do nothing otherwise.
1827
1828     A left- or right-hand split is based on a pair of brackets. Content before
1829     (and including) the opening bracket is left on one line, content inside the
1830     brackets is put on a separate line, and finally content starting with and
1831     following the closing bracket is put on a separate line.
1832
1833     Those are called `head`, `body`, and `tail`, respectively. If the split
1834     produced the same line (all content in `head`) or ended up with an empty `body`
1835     and the `tail` is just the closing bracket, then it's considered failed.
1836     """
1837     tail_len = len(str(tail).strip())
1838     if not body:
1839         if tail_len == 0:
1840             raise CannotSplit("Splitting brackets produced the same line")
1841
1842         elif tail_len < 3:
1843             raise CannotSplit(
1844                 f"Splitting brackets on an empty body to save "
1845                 f"{tail_len} characters is not worth it"
1846             )
1847
1848
1849 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1850     """Normalize prefix of the first leaf in every line returned by `split_func`.
1851
1852     This is a decorator over relevant split functions.
1853     """
1854
1855     @wraps(split_func)
1856     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1857         for l in split_func(line, py36):
1858             normalize_prefix(l.leaves[0], inside_brackets=True)
1859             yield l
1860
1861     return split_wrapper
1862
1863
1864 @dont_increase_indentation
1865 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1866     """Split according to delimiters of the highest priority.
1867
1868     If `py36` is True, the split will add trailing commas also in function
1869     signatures that contain `*` and `**`.
1870     """
1871     try:
1872         last_leaf = line.leaves[-1]
1873     except IndexError:
1874         raise CannotSplit("Line empty")
1875
1876     delimiters = line.bracket_tracker.delimiters
1877     try:
1878         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1879             exclude={id(last_leaf)}
1880         )
1881     except ValueError:
1882         raise CannotSplit("No delimiters found")
1883
1884     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1885     lowest_depth = sys.maxsize
1886     trailing_comma_safe = True
1887
1888     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1889         """Append `leaf` to current line or to new line if appending impossible."""
1890         nonlocal current_line
1891         try:
1892             current_line.append_safe(leaf, preformatted=True)
1893         except ValueError as ve:
1894             yield current_line
1895
1896             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1897             current_line.append(leaf)
1898
1899     for leaf in line.leaves:
1900         yield from append_to_line(leaf)
1901
1902         for comment_after in line.comments_after(leaf):
1903             yield from append_to_line(comment_after)
1904
1905         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1906         if (
1907             leaf.bracket_depth == lowest_depth
1908             and is_vararg(leaf, within=VARARGS_PARENTS)
1909         ):
1910             trailing_comma_safe = trailing_comma_safe and py36
1911         leaf_priority = delimiters.get(id(leaf))
1912         if leaf_priority == delimiter_priority:
1913             yield current_line
1914
1915             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1916     if current_line:
1917         if (
1918             trailing_comma_safe
1919             and delimiter_priority == COMMA_PRIORITY
1920             and current_line.leaves[-1].type != token.COMMA
1921             and current_line.leaves[-1].type != STANDALONE_COMMENT
1922         ):
1923             current_line.append(Leaf(token.COMMA, ","))
1924         yield current_line
1925
1926
1927 @dont_increase_indentation
1928 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1929     """Split standalone comments from the rest of the line."""
1930     if not line.contains_standalone_comments(0):
1931         raise CannotSplit("Line does not have any standalone comments")
1932
1933     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1934
1935     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1936         """Append `leaf` to current line or to new line if appending impossible."""
1937         nonlocal current_line
1938         try:
1939             current_line.append_safe(leaf, preformatted=True)
1940         except ValueError as ve:
1941             yield current_line
1942
1943             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1944             current_line.append(leaf)
1945
1946     for leaf in line.leaves:
1947         yield from append_to_line(leaf)
1948
1949         for comment_after in line.comments_after(leaf):
1950             yield from append_to_line(comment_after)
1951
1952     if current_line:
1953         yield current_line
1954
1955
1956 def is_import(leaf: Leaf) -> bool:
1957     """Return True if the given leaf starts an import statement."""
1958     p = leaf.parent
1959     t = leaf.type
1960     v = leaf.value
1961     return bool(
1962         t == token.NAME
1963         and (
1964             (v == "import" and p and p.type == syms.import_name)
1965             or (v == "from" and p and p.type == syms.import_from)
1966         )
1967     )
1968
1969
1970 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1971     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1972     else.
1973
1974     Note: don't use backslashes for formatting or you'll lose your voting rights.
1975     """
1976     if not inside_brackets:
1977         spl = leaf.prefix.split("#")
1978         if "\\" not in spl[0]:
1979             nl_count = spl[-1].count("\n")
1980             if len(spl) > 1:
1981                 nl_count -= 1
1982             leaf.prefix = "\n" * nl_count
1983             return
1984
1985     leaf.prefix = ""
1986
1987
1988 def normalize_string_quotes(leaf: Leaf) -> None:
1989     """Prefer double quotes but only if it doesn't cause more escaping.
1990
1991     Adds or removes backslashes as appropriate. Doesn't parse and fix
1992     strings nested in f-strings (yet).
1993
1994     Note: Mutates its argument.
1995     """
1996     value = leaf.value.lstrip("furbFURB")
1997     if value[:3] == '"""':
1998         return
1999
2000     elif value[:3] == "'''":
2001         orig_quote = "'''"
2002         new_quote = '"""'
2003     elif value[0] == '"':
2004         orig_quote = '"'
2005         new_quote = "'"
2006     else:
2007         orig_quote = "'"
2008         new_quote = '"'
2009     first_quote_pos = leaf.value.find(orig_quote)
2010     if first_quote_pos == -1:
2011         return  # There's an internal error
2012
2013     prefix = leaf.value[:first_quote_pos]
2014     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2015     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2016     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2017     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2018     if "r" in prefix.casefold():
2019         if unescaped_new_quote.search(body):
2020             # There's at least one unescaped new_quote in this raw string
2021             # so converting is impossible
2022             return
2023
2024         # Do not introduce or remove backslashes in raw strings
2025         new_body = body
2026     else:
2027         # remove unnecessary quotes
2028         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2029         if body != new_body:
2030             # Consider the string without unnecessary quotes as the original
2031             body = new_body
2032             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2033         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2034         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2035     if new_quote == '"""' and new_body[-1] == '"':
2036         # edge case:
2037         new_body = new_body[:-1] + '\\"'
2038     orig_escape_count = body.count("\\")
2039     new_escape_count = new_body.count("\\")
2040     if new_escape_count > orig_escape_count:
2041         return  # Do not introduce more escaping
2042
2043     if new_escape_count == orig_escape_count and orig_quote == '"':
2044         return  # Prefer double quotes
2045
2046     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2047
2048
2049 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2050     """Make existing optional parentheses invisible or create new ones.
2051
2052     Standardizes on visible parentheses for single-element tuples, and keeps
2053     existing visible parentheses for other tuples and generator expressions.
2054     """
2055     check_lpar = False
2056     for child in list(node.children):
2057         if check_lpar:
2058             if child.type == syms.atom:
2059                 if not (
2060                     is_empty_tuple(child)
2061                     or is_one_tuple(child)
2062                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2063                 ):
2064                     first = child.children[0]
2065                     last = child.children[-1]
2066                     if first.type == token.LPAR and last.type == token.RPAR:
2067                         # make parentheses invisible
2068                         first.value = ""  # type: ignore
2069                         last.value = ""  # type: ignore
2070             elif is_one_tuple(child):
2071                 # wrap child in visible parentheses
2072                 lpar = Leaf(token.LPAR, "(")
2073                 rpar = Leaf(token.RPAR, ")")
2074                 index = child.remove() or 0
2075                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2076             else:
2077                 # wrap child in invisible parentheses
2078                 lpar = Leaf(token.LPAR, "")
2079                 rpar = Leaf(token.RPAR, "")
2080                 index = child.remove() or 0
2081                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2082
2083         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2084
2085
2086 def is_empty_tuple(node: LN) -> bool:
2087     """Return True if `node` holds an empty tuple."""
2088     return (
2089         node.type == syms.atom
2090         and len(node.children) == 2
2091         and node.children[0].type == token.LPAR
2092         and node.children[1].type == token.RPAR
2093     )
2094
2095
2096 def is_one_tuple(node: LN) -> bool:
2097     """Return True if `node` holds a tuple with one element, with or without parens."""
2098     if node.type == syms.atom:
2099         if len(node.children) != 3:
2100             return False
2101
2102         lpar, gexp, rpar = node.children
2103         if not (
2104             lpar.type == token.LPAR
2105             and gexp.type == syms.testlist_gexp
2106             and rpar.type == token.RPAR
2107         ):
2108             return False
2109
2110         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2111
2112     return (
2113         node.type in IMPLICIT_TUPLE
2114         and len(node.children) == 2
2115         and node.children[1].type == token.COMMA
2116     )
2117
2118
2119 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2120     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2121
2122     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2123     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2124     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2125     generalizations (PEP 448).
2126     """
2127     if leaf.type not in STARS or not leaf.parent:
2128         return False
2129
2130     p = leaf.parent
2131     if p.type == syms.star_expr:
2132         # Star expressions are also used as assignment targets in extended
2133         # iterable unpacking (PEP 3132).  See what its parent is instead.
2134         if not p.parent:
2135             return False
2136
2137         p = p.parent
2138
2139     return p.type in within
2140
2141
2142 def max_delimiter_priority_in_atom(node: LN) -> int:
2143     """Return maximum delimiter priority inside `node`.
2144
2145     This is specific to atoms with contents contained in a pair of parentheses.
2146     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2147     """
2148     if node.type != syms.atom:
2149         return 0
2150
2151     first = node.children[0]
2152     last = node.children[-1]
2153     if not (first.type == token.LPAR and last.type == token.RPAR):
2154         return 0
2155
2156     bt = BracketTracker()
2157     for c in node.children[1:-1]:
2158         if isinstance(c, Leaf):
2159             bt.mark(c)
2160         else:
2161             for leaf in c.leaves():
2162                 bt.mark(leaf)
2163     try:
2164         return bt.max_delimiter_priority()
2165
2166     except ValueError:
2167         return 0
2168
2169
2170 def ensure_visible(leaf: Leaf) -> None:
2171     """Make sure parentheses are visible.
2172
2173     They could be invisible as part of some statements (see
2174     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2175     """
2176     if leaf.type == token.LPAR:
2177         leaf.value = "("
2178     elif leaf.type == token.RPAR:
2179         leaf.value = ")"
2180
2181
2182 def is_python36(node: Node) -> bool:
2183     """Return True if the current file is using Python 3.6+ features.
2184
2185     Currently looking for:
2186     - f-strings; and
2187     - trailing commas after * or ** in function signatures.
2188     """
2189     for n in node.pre_order():
2190         if n.type == token.STRING:
2191             value_head = n.value[:2]  # type: ignore
2192             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2193                 return True
2194
2195         elif (
2196             n.type == syms.typedargslist
2197             and n.children
2198             and n.children[-1].type == token.COMMA
2199         ):
2200             for ch in n.children:
2201                 if ch.type in STARS:
2202                     return True
2203
2204     return False
2205
2206
2207 PYTHON_EXTENSIONS = {".py"}
2208 BLACKLISTED_DIRECTORIES = {
2209     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2210 }
2211
2212
2213 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2214     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2215     and have one of the PYTHON_EXTENSIONS.
2216     """
2217     for child in path.iterdir():
2218         if child.is_dir():
2219             if child.name in BLACKLISTED_DIRECTORIES:
2220                 continue
2221
2222             yield from gen_python_files_in_dir(child)
2223
2224         elif child.suffix in PYTHON_EXTENSIONS:
2225             yield child
2226
2227
2228 @dataclass
2229 class Report:
2230     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2231     check: bool = False
2232     quiet: bool = False
2233     change_count: int = 0
2234     same_count: int = 0
2235     failure_count: int = 0
2236
2237     def done(self, src: Path, changed: Changed) -> None:
2238         """Increment the counter for successful reformatting. Write out a message."""
2239         if changed is Changed.YES:
2240             reformatted = "would reformat" if self.check else "reformatted"
2241             if not self.quiet:
2242                 out(f"{reformatted} {src}")
2243             self.change_count += 1
2244         else:
2245             if not self.quiet:
2246                 if changed is Changed.NO:
2247                     msg = f"{src} already well formatted, good job."
2248                 else:
2249                     msg = f"{src} wasn't modified on disk since last run."
2250                 out(msg, bold=False)
2251             self.same_count += 1
2252
2253     def failed(self, src: Path, message: str) -> None:
2254         """Increment the counter for failed reformatting. Write out a message."""
2255         err(f"error: cannot format {src}: {message}")
2256         self.failure_count += 1
2257
2258     @property
2259     def return_code(self) -> int:
2260         """Return the exit code that the app should use.
2261
2262         This considers the current state of changed files and failures:
2263         - if there were any failures, return 123;
2264         - if any files were changed and --check is being used, return 1;
2265         - otherwise return 0.
2266         """
2267         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2268         # 126 we have special returncodes reserved by the shell.
2269         if self.failure_count:
2270             return 123
2271
2272         elif self.change_count and self.check:
2273             return 1
2274
2275         return 0
2276
2277     def __str__(self) -> str:
2278         """Render a color report of the current state.
2279
2280         Use `click.unstyle` to remove colors.
2281         """
2282         if self.check:
2283             reformatted = "would be reformatted"
2284             unchanged = "would be left unchanged"
2285             failed = "would fail to reformat"
2286         else:
2287             reformatted = "reformatted"
2288             unchanged = "left unchanged"
2289             failed = "failed to reformat"
2290         report = []
2291         if self.change_count:
2292             s = "s" if self.change_count > 1 else ""
2293             report.append(
2294                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2295             )
2296         if self.same_count:
2297             s = "s" if self.same_count > 1 else ""
2298             report.append(f"{self.same_count} file{s} {unchanged}")
2299         if self.failure_count:
2300             s = "s" if self.failure_count > 1 else ""
2301             report.append(
2302                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2303             )
2304         return ", ".join(report) + "."
2305
2306
2307 def assert_equivalent(src: str, dst: str) -> None:
2308     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2309
2310     import ast
2311     import traceback
2312
2313     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2314         """Simple visitor generating strings to compare ASTs by content."""
2315         yield f"{'  ' * depth}{node.__class__.__name__}("
2316
2317         for field in sorted(node._fields):
2318             try:
2319                 value = getattr(node, field)
2320             except AttributeError:
2321                 continue
2322
2323             yield f"{'  ' * (depth+1)}{field}="
2324
2325             if isinstance(value, list):
2326                 for item in value:
2327                     if isinstance(item, ast.AST):
2328                         yield from _v(item, depth + 2)
2329
2330             elif isinstance(value, ast.AST):
2331                 yield from _v(value, depth + 2)
2332
2333             else:
2334                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2335
2336         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2337
2338     try:
2339         src_ast = ast.parse(src)
2340     except Exception as exc:
2341         major, minor = sys.version_info[:2]
2342         raise AssertionError(
2343             f"cannot use --safe with this file; failed to parse source file "
2344             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2345             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2346         )
2347
2348     try:
2349         dst_ast = ast.parse(dst)
2350     except Exception as exc:
2351         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2352         raise AssertionError(
2353             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2354             f"Please report a bug on https://github.com/ambv/black/issues.  "
2355             f"This invalid output might be helpful: {log}"
2356         ) from None
2357
2358     src_ast_str = "\n".join(_v(src_ast))
2359     dst_ast_str = "\n".join(_v(dst_ast))
2360     if src_ast_str != dst_ast_str:
2361         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2362         raise AssertionError(
2363             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2364             f"the source.  "
2365             f"Please report a bug on https://github.com/ambv/black/issues.  "
2366             f"This diff might be helpful: {log}"
2367         ) from None
2368
2369
2370 def assert_stable(src: str, dst: str, line_length: int) -> None:
2371     """Raise AssertionError if `dst` reformats differently the second time."""
2372     newdst = format_str(dst, line_length=line_length)
2373     if dst != newdst:
2374         log = dump_to_file(
2375             diff(src, dst, "source", "first pass"),
2376             diff(dst, newdst, "first pass", "second pass"),
2377         )
2378         raise AssertionError(
2379             f"INTERNAL ERROR: Black produced different code on the second pass "
2380             f"of the formatter.  "
2381             f"Please report a bug on https://github.com/ambv/black/issues.  "
2382             f"This diff might be helpful: {log}"
2383         ) from None
2384
2385
2386 def dump_to_file(*output: str) -> str:
2387     """Dump `output` to a temporary file. Return path to the file."""
2388     import tempfile
2389
2390     with tempfile.NamedTemporaryFile(
2391         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2392     ) as f:
2393         for lines in output:
2394             f.write(lines)
2395             if lines and lines[-1] != "\n":
2396                 f.write("\n")
2397     return f.name
2398
2399
2400 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2401     """Return a unified diff string between strings `a` and `b`."""
2402     import difflib
2403
2404     a_lines = [line + "\n" for line in a.split("\n")]
2405     b_lines = [line + "\n" for line in b.split("\n")]
2406     return "".join(
2407         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2408     )
2409
2410
2411 def cancel(tasks: List[asyncio.Task]) -> None:
2412     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2413     err("Aborted!")
2414     for task in tasks:
2415         task.cancel()
2416
2417
2418 def shutdown(loop: BaseEventLoop) -> None:
2419     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2420     try:
2421         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2422         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2423         if not to_cancel:
2424             return
2425
2426         for task in to_cancel:
2427             task.cancel()
2428         loop.run_until_complete(
2429             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2430         )
2431     finally:
2432         # `concurrent.futures.Future` objects cannot be cancelled once they
2433         # are already running. There might be some when the `shutdown()` happened.
2434         # Silence their logger's spew about the event loop being closed.
2435         cf_logger = logging.getLogger("concurrent.futures")
2436         cf_logger.setLevel(logging.CRITICAL)
2437         loop.close()
2438
2439
2440 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2441     """Replace `regex` with `replacement` twice on `original`.
2442
2443     This is used by string normalization to perform replaces on
2444     overlapping matches.
2445     """
2446     return regex.sub(replacement, regex.sub(replacement, original))
2447
2448
2449 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2450 CACHE_FILE = CACHE_DIR / "cache.pickle"
2451
2452
2453 def read_cache() -> Cache:
2454     """Read the cache if it exists and is well formed.
2455
2456     If it is not well formed, the call to write_cache later should resolve the issue.
2457     """
2458     if not CACHE_FILE.exists():
2459         return {}
2460
2461     with CACHE_FILE.open("rb") as fobj:
2462         try:
2463             cache: Cache = pickle.load(fobj)
2464         except pickle.UnpicklingError:
2465             return {}
2466
2467     return cache
2468
2469
2470 def get_cache_info(path: Path) -> CacheInfo:
2471     """Return the information used to check if a file is already formatted or not."""
2472     stat = path.stat()
2473     return stat.st_mtime, stat.st_size
2474
2475
2476 def filter_cached(
2477     cache: Cache, sources: Iterable[Path]
2478 ) -> Tuple[List[Path], List[Path]]:
2479     """Split a list of paths into two.
2480
2481     The first list contains paths of files that modified on disk or are not in the
2482     cache. The other list contains paths to non-modified files.
2483     """
2484     todo, done = [], []
2485     for src in sources:
2486         src = src.resolve()
2487         if cache.get(src) != get_cache_info(src):
2488             todo.append(src)
2489         else:
2490             done.append(src)
2491     return todo, done
2492
2493
2494 def write_cache(cache: Cache, sources: List[Path]) -> None:
2495     """Update the cache file."""
2496     try:
2497         if not CACHE_DIR.exists():
2498             CACHE_DIR.mkdir(parents=True)
2499         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2500         with CACHE_FILE.open("wb") as fobj:
2501             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2502     except OSError:
2503         pass
2504
2505
2506 if __name__ == "__main__":
2507     main()