]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Simplify single-file vs. multi-file modes
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187     if check and diff:
188         exc = click.ClickException("Options --check and --diff are mutually exclusive")
189         exc.exit_code = 2
190         raise exc
191
192     if check:
193         write_back = WriteBack.NO
194     elif diff:
195         write_back = WriteBack.DIFF
196     else:
197         write_back = WriteBack.YES
198     if len(sources) == 0:
199         ctx.exit(0)
200         return
201
202     elif len(sources) == 1:
203         return_code = reformat_one(sources[0], line_length, fast, quiet, write_back)
204     else:
205         loop = asyncio.get_event_loop()
206         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207         return_code = 1
208         try:
209             return_code = loop.run_until_complete(
210                 schedule_formatting(
211                     sources, line_length, write_back, fast, quiet, loop, executor
212                 )
213             )
214         finally:
215             shutdown(loop)
216     ctx.exit(return_code)
217
218
219 def reformat_one(
220     src: Path, line_length: int, fast: bool, quiet: bool, write_back: WriteBack
221 ) -> int:
222     """Reformat a single file under `src` without spawning child processes.
223
224     If `quiet` is True, non-error messages are not output. `line_length`,
225     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
226     """
227     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
228     try:
229         changed = Changed.NO
230         if not src.is_file() and str(src) == "-":
231             if format_stdin_to_stdout(
232                 line_length=line_length, fast=fast, write_back=write_back
233             ):
234                 changed = Changed.YES
235         else:
236             cache: Cache = {}
237             if write_back != WriteBack.DIFF:
238                 cache = read_cache()
239                 src = src.resolve()
240                 if src in cache and cache[src] == get_cache_info(src):
241                     changed = Changed.CACHED
242             if (
243                 changed is not Changed.CACHED
244                 and format_file_in_place(
245                     src, line_length=line_length, fast=fast, write_back=write_back
246                 )
247             ):
248                 changed = Changed.YES
249             if write_back != WriteBack.DIFF and changed is not Changed.NO:
250                 write_cache(cache, [src])
251         report.done(src, changed)
252     except Exception as exc:
253         report.failed(src, str(exc))
254     return report.return_code
255
256
257 async def schedule_formatting(
258     sources: List[Path],
259     line_length: int,
260     write_back: WriteBack,
261     fast: bool,
262     quiet: bool,
263     loop: BaseEventLoop,
264     executor: Executor,
265 ) -> int:
266     """Run formatting of `sources` in parallel using the provided `executor`.
267
268     (Use ProcessPoolExecutors for actual parallelism.)
269
270     `line_length`, `write_back`, and `fast` options are passed to
271     :func:`format_file_in_place`.
272     """
273     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
274     cache: Cache = {}
275     if write_back != WriteBack.DIFF:
276         cache = read_cache()
277         sources, cached = filter_cached(cache, sources)
278         for src in cached:
279             report.done(src, Changed.CACHED)
280     cancelled = []
281     formatted = []
282     if sources:
283         lock = None
284         if write_back == WriteBack.DIFF:
285             # For diff output, we need locks to ensure we don't interleave output
286             # from different processes.
287             manager = Manager()
288             lock = manager.Lock()
289         tasks = {
290             src: loop.run_in_executor(
291                 executor, format_file_in_place, src, line_length, fast, write_back, lock
292             )
293             for src in sources
294         }
295         _task_values = list(tasks.values())
296         loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
297         loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     elif not quiet:
315         out("All done! ✨ 🍰 ✨")
316     if not quiet:
317         click.echo(str(report))
318
319     if write_back != WriteBack.DIFF and formatted:
320         write_cache(cache, formatted)
321
322     return report.return_code
323
324
325 def format_file_in_place(
326     src: Path,
327     line_length: int,
328     fast: bool,
329     write_back: WriteBack = WriteBack.NO,
330     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
331 ) -> bool:
332     """Format file under `src` path. Return True if changed.
333
334     If `write_back` is True, write reformatted code back to stdout.
335     `line_length` and `fast` options are passed to :func:`format_file_contents`.
336     """
337
338     with tokenize.open(src) as src_buffer:
339         src_contents = src_buffer.read()
340     try:
341         dst_contents = format_file_contents(
342             src_contents, line_length=line_length, fast=fast
343         )
344     except NothingChanged:
345         return False
346
347     if write_back == write_back.YES:
348         with open(src, "w", encoding=src_buffer.encoding) as f:
349             f.write(dst_contents)
350     elif write_back == write_back.DIFF:
351         src_name = f"{src.name}  (original)"
352         dst_name = f"{src.name}  (formatted)"
353         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
354         if lock:
355             lock.acquire()
356         try:
357             sys.stdout.write(diff_contents)
358         finally:
359             if lock:
360                 lock.release()
361     return True
362
363
364 def format_stdin_to_stdout(
365     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
366 ) -> bool:
367     """Format file on stdin. Return True if changed.
368
369     If `write_back` is True, write reformatted code back to stdout.
370     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
371     """
372     src = sys.stdin.read()
373     dst = src
374     try:
375         dst = format_file_contents(src, line_length=line_length, fast=fast)
376         return True
377
378     except NothingChanged:
379         return False
380
381     finally:
382         if write_back == WriteBack.YES:
383             sys.stdout.write(dst)
384         elif write_back == WriteBack.DIFF:
385             src_name = "<stdin>  (original)"
386             dst_name = "<stdin>  (formatted)"
387             sys.stdout.write(diff(src, dst, src_name, dst_name))
388
389
390 def format_file_contents(
391     src_contents: str, line_length: int, fast: bool
392 ) -> FileContent:
393     """Reformat contents a file and return new contents.
394
395     If `fast` is False, additionally confirm that the reformatted code is
396     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
397     `line_length` is passed to :func:`format_str`.
398     """
399     if src_contents.strip() == "":
400         raise NothingChanged
401
402     dst_contents = format_str(src_contents, line_length=line_length)
403     if src_contents == dst_contents:
404         raise NothingChanged
405
406     if not fast:
407         assert_equivalent(src_contents, dst_contents)
408         assert_stable(src_contents, dst_contents, line_length=line_length)
409     return dst_contents
410
411
412 def format_str(src_contents: str, line_length: int) -> FileContent:
413     """Reformat a string and return new contents.
414
415     `line_length` determines how many characters per line are allowed.
416     """
417     src_node = lib2to3_parse(src_contents)
418     dst_contents = ""
419     lines = LineGenerator()
420     elt = EmptyLineTracker()
421     py36 = is_python36(src_node)
422     empty_line = Line()
423     after = 0
424     for current_line in lines.visit(src_node):
425         for _ in range(after):
426             dst_contents += str(empty_line)
427         before, after = elt.maybe_empty_lines(current_line)
428         for _ in range(before):
429             dst_contents += str(empty_line)
430         for line in split_line(current_line, line_length=line_length, py36=py36):
431             dst_contents += str(line)
432     return dst_contents
433
434
435 GRAMMARS = [
436     pygram.python_grammar_no_print_statement_no_exec_statement,
437     pygram.python_grammar_no_print_statement,
438     pygram.python_grammar_no_exec_statement,
439     pygram.python_grammar,
440 ]
441
442
443 def lib2to3_parse(src_txt: str) -> Node:
444     """Given a string with source, return the lib2to3 Node."""
445     grammar = pygram.python_grammar_no_print_statement
446     if src_txt[-1] != "\n":
447         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
448         src_txt += nl
449     for grammar in GRAMMARS:
450         drv = driver.Driver(grammar, pytree.convert)
451         try:
452             result = drv.parse_string(src_txt, True)
453             break
454
455         except ParseError as pe:
456             lineno, column = pe.context[1]
457             lines = src_txt.splitlines()
458             try:
459                 faulty_line = lines[lineno - 1]
460             except IndexError:
461                 faulty_line = "<line number missing in source>"
462             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
463     else:
464         raise exc from None
465
466     if isinstance(result, Leaf):
467         result = Node(syms.file_input, [result])
468     return result
469
470
471 def lib2to3_unparse(node: Node) -> str:
472     """Given a lib2to3 node, return its string representation."""
473     code = str(node)
474     return code
475
476
477 T = TypeVar("T")
478
479
480 class Visitor(Generic[T]):
481     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
482
483     def visit(self, node: LN) -> Iterator[T]:
484         """Main method to visit `node` and its children.
485
486         It tries to find a `visit_*()` method for the given `node.type`, like
487         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
488         If no dedicated `visit_*()` method is found, chooses `visit_default()`
489         instead.
490
491         Then yields objects of type `T` from the selected visitor.
492         """
493         if node.type < 256:
494             name = token.tok_name[node.type]
495         else:
496             name = type_repr(node.type)
497         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
498
499     def visit_default(self, node: LN) -> Iterator[T]:
500         """Default `visit_*()` implementation. Recurses to children of `node`."""
501         if isinstance(node, Node):
502             for child in node.children:
503                 yield from self.visit(child)
504
505
506 @dataclass
507 class DebugVisitor(Visitor[T]):
508     tree_depth: int = 0
509
510     def visit_default(self, node: LN) -> Iterator[T]:
511         indent = " " * (2 * self.tree_depth)
512         if isinstance(node, Node):
513             _type = type_repr(node.type)
514             out(f"{indent}{_type}", fg="yellow")
515             self.tree_depth += 1
516             for child in node.children:
517                 yield from self.visit(child)
518
519             self.tree_depth -= 1
520             out(f"{indent}/{_type}", fg="yellow", bold=False)
521         else:
522             _type = token.tok_name.get(node.type, str(node.type))
523             out(f"{indent}{_type}", fg="blue", nl=False)
524             if node.prefix:
525                 # We don't have to handle prefixes for `Node` objects since
526                 # that delegates to the first child anyway.
527                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
528             out(f" {node.value!r}", fg="blue", bold=False)
529
530     @classmethod
531     def show(cls, code: str) -> None:
532         """Pretty-print the lib2to3 AST of a given string of `code`.
533
534         Convenience method for debugging.
535         """
536         v: DebugVisitor[None] = DebugVisitor()
537         list(v.visit(lib2to3_parse(code)))
538
539
540 KEYWORDS = set(keyword.kwlist)
541 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
542 FLOW_CONTROL = {"return", "raise", "break", "continue"}
543 STATEMENT = {
544     syms.if_stmt,
545     syms.while_stmt,
546     syms.for_stmt,
547     syms.try_stmt,
548     syms.except_clause,
549     syms.with_stmt,
550     syms.funcdef,
551     syms.classdef,
552 }
553 STANDALONE_COMMENT = 153
554 LOGIC_OPERATORS = {"and", "or"}
555 COMPARATORS = {
556     token.LESS,
557     token.GREATER,
558     token.EQEQUAL,
559     token.NOTEQUAL,
560     token.LESSEQUAL,
561     token.GREATEREQUAL,
562 }
563 MATH_OPERATORS = {
564     token.PLUS,
565     token.MINUS,
566     token.STAR,
567     token.SLASH,
568     token.VBAR,
569     token.AMPER,
570     token.PERCENT,
571     token.CIRCUMFLEX,
572     token.TILDE,
573     token.LEFTSHIFT,
574     token.RIGHTSHIFT,
575     token.DOUBLESTAR,
576     token.DOUBLESLASH,
577 }
578 STARS = {token.STAR, token.DOUBLESTAR}
579 VARARGS_PARENTS = {
580     syms.arglist,
581     syms.argument,  # double star in arglist
582     syms.trailer,  # single argument to call
583     syms.typedargslist,
584     syms.varargslist,  # lambdas
585 }
586 UNPACKING_PARENTS = {
587     syms.atom,  # single element of a list or set literal
588     syms.dictsetmaker,
589     syms.listmaker,
590     syms.testlist_gexp,
591 }
592 COMPREHENSION_PRIORITY = 20
593 COMMA_PRIORITY = 10
594 LOGIC_PRIORITY = 5
595 STRING_PRIORITY = 4
596 COMPARATOR_PRIORITY = 3
597 MATH_PRIORITY = 1
598
599
600 @dataclass
601 class BracketTracker:
602     """Keeps track of brackets on a line."""
603
604     depth: int = 0
605     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
606     delimiters: Dict[LeafID, Priority] = Factory(dict)
607     previous: Optional[Leaf] = None
608
609     def mark(self, leaf: Leaf) -> None:
610         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
611
612         All leaves receive an int `bracket_depth` field that stores how deep
613         within brackets a given leaf is. 0 means there are no enclosing brackets
614         that started on this line.
615
616         If a leaf is itself a closing bracket, it receives an `opening_bracket`
617         field that it forms a pair with. This is a one-directional link to
618         avoid reference cycles.
619
620         If a leaf is a delimiter (a token on which Black can split the line if
621         needed) and it's on depth 0, its `id()` is stored in the tracker's
622         `delimiters` field.
623         """
624         if leaf.type == token.COMMENT:
625             return
626
627         if leaf.type in CLOSING_BRACKETS:
628             self.depth -= 1
629             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
630             leaf.opening_bracket = opening_bracket
631         leaf.bracket_depth = self.depth
632         if self.depth == 0:
633             delim = is_split_before_delimiter(leaf, self.previous)
634             if delim and self.previous is not None:
635                 self.delimiters[id(self.previous)] = delim
636             else:
637                 delim = is_split_after_delimiter(leaf, self.previous)
638                 if delim:
639                     self.delimiters[id(leaf)] = delim
640         if leaf.type in OPENING_BRACKETS:
641             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
642             self.depth += 1
643         self.previous = leaf
644
645     def any_open_brackets(self) -> bool:
646         """Return True if there is an yet unmatched open bracket on the line."""
647         return bool(self.bracket_match)
648
649     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
650         """Return the highest priority of a delimiter found on the line.
651
652         Values are consistent with what `is_delimiter()` returns.
653         Raises ValueError on no delimiters.
654         """
655         return max(v for k, v in self.delimiters.items() if k not in exclude)
656
657
658 @dataclass
659 class Line:
660     """Holds leaves and comments. Can be printed with `str(line)`."""
661
662     depth: int = 0
663     leaves: List[Leaf] = Factory(list)
664     comments: List[Tuple[Index, Leaf]] = Factory(list)
665     bracket_tracker: BracketTracker = Factory(BracketTracker)
666     inside_brackets: bool = False
667     has_for: bool = False
668     _for_loop_variable: bool = False
669
670     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
671         """Add a new `leaf` to the end of the line.
672
673         Unless `preformatted` is True, the `leaf` will receive a new consistent
674         whitespace prefix and metadata applied by :class:`BracketTracker`.
675         Trailing commas are maybe removed, unpacked for loop variables are
676         demoted from being delimiters.
677
678         Inline comments are put aside.
679         """
680         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
681         if not has_value:
682             return
683
684         if self.leaves and not preformatted:
685             # Note: at this point leaf.prefix should be empty except for
686             # imports, for which we only preserve newlines.
687             leaf.prefix += whitespace(leaf)
688         if self.inside_brackets or not preformatted:
689             self.maybe_decrement_after_for_loop_variable(leaf)
690             self.bracket_tracker.mark(leaf)
691             self.maybe_remove_trailing_comma(leaf)
692             self.maybe_increment_for_loop_variable(leaf)
693
694         if not self.append_comment(leaf):
695             self.leaves.append(leaf)
696
697     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
698         """Like :func:`append()` but disallow invalid standalone comment structure.
699
700         Raises ValueError when any `leaf` is appended after a standalone comment
701         or when a standalone comment is not the first leaf on the line.
702         """
703         if self.bracket_tracker.depth == 0:
704             if self.is_comment:
705                 raise ValueError("cannot append to standalone comments")
706
707             if self.leaves and leaf.type == STANDALONE_COMMENT:
708                 raise ValueError(
709                     "cannot append standalone comments to a populated line"
710                 )
711
712         self.append(leaf, preformatted=preformatted)
713
714     @property
715     def is_comment(self) -> bool:
716         """Is this line a standalone comment?"""
717         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
718
719     @property
720     def is_decorator(self) -> bool:
721         """Is this line a decorator?"""
722         return bool(self) and self.leaves[0].type == token.AT
723
724     @property
725     def is_import(self) -> bool:
726         """Is this an import line?"""
727         return bool(self) and is_import(self.leaves[0])
728
729     @property
730     def is_class(self) -> bool:
731         """Is this line a class definition?"""
732         return (
733             bool(self)
734             and self.leaves[0].type == token.NAME
735             and self.leaves[0].value == "class"
736         )
737
738     @property
739     def is_def(self) -> bool:
740         """Is this a function definition? (Also returns True for async defs.)"""
741         try:
742             first_leaf = self.leaves[0]
743         except IndexError:
744             return False
745
746         try:
747             second_leaf: Optional[Leaf] = self.leaves[1]
748         except IndexError:
749             second_leaf = None
750         return (
751             (first_leaf.type == token.NAME and first_leaf.value == "def")
752             or (
753                 first_leaf.type == token.ASYNC
754                 and second_leaf is not None
755                 and second_leaf.type == token.NAME
756                 and second_leaf.value == "def"
757             )
758         )
759
760     @property
761     def is_flow_control(self) -> bool:
762         """Is this line a flow control statement?
763
764         Those are `return`, `raise`, `break`, and `continue`.
765         """
766         return (
767             bool(self)
768             and self.leaves[0].type == token.NAME
769             and self.leaves[0].value in FLOW_CONTROL
770         )
771
772     @property
773     def is_yield(self) -> bool:
774         """Is this line a yield statement?"""
775         return (
776             bool(self)
777             and self.leaves[0].type == token.NAME
778             and self.leaves[0].value == "yield"
779         )
780
781     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
782         """If so, needs to be split before emitting."""
783         for leaf in self.leaves:
784             if leaf.type == STANDALONE_COMMENT:
785                 if leaf.bracket_depth <= depth_limit:
786                     return True
787
788         return False
789
790     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
791         """Remove trailing comma if there is one and it's safe."""
792         if not (
793             self.leaves
794             and self.leaves[-1].type == token.COMMA
795             and closing.type in CLOSING_BRACKETS
796         ):
797             return False
798
799         if closing.type == token.RBRACE:
800             self.remove_trailing_comma()
801             return True
802
803         if closing.type == token.RSQB:
804             comma = self.leaves[-1]
805             if comma.parent and comma.parent.type == syms.listmaker:
806                 self.remove_trailing_comma()
807                 return True
808
809         # For parens let's check if it's safe to remove the comma.  If the
810         # trailing one is the only one, we might mistakenly change a tuple
811         # into a different type by removing the comma.
812         depth = closing.bracket_depth + 1
813         commas = 0
814         opening = closing.opening_bracket
815         for _opening_index, leaf in enumerate(self.leaves):
816             if leaf is opening:
817                 break
818
819         else:
820             return False
821
822         for leaf in self.leaves[_opening_index + 1:]:
823             if leaf is closing:
824                 break
825
826             bracket_depth = leaf.bracket_depth
827             if bracket_depth == depth and leaf.type == token.COMMA:
828                 commas += 1
829                 if leaf.parent and leaf.parent.type == syms.arglist:
830                     commas += 1
831                     break
832
833         if commas > 1:
834             self.remove_trailing_comma()
835             return True
836
837         return False
838
839     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
840         """In a for loop, or comprehension, the variables are often unpacks.
841
842         To avoid splitting on the comma in this situation, increase the depth of
843         tokens between `for` and `in`.
844         """
845         if leaf.type == token.NAME and leaf.value == "for":
846             self.has_for = True
847             self.bracket_tracker.depth += 1
848             self._for_loop_variable = True
849             return True
850
851         return False
852
853     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
854         """See `maybe_increment_for_loop_variable` above for explanation."""
855         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
856             self.bracket_tracker.depth -= 1
857             self._for_loop_variable = False
858             return True
859
860         return False
861
862     def append_comment(self, comment: Leaf) -> bool:
863         """Add an inline or standalone comment to the line."""
864         if (
865             comment.type == STANDALONE_COMMENT
866             and self.bracket_tracker.any_open_brackets()
867         ):
868             comment.prefix = ""
869             return False
870
871         if comment.type != token.COMMENT:
872             return False
873
874         after = len(self.leaves) - 1
875         if after == -1:
876             comment.type = STANDALONE_COMMENT
877             comment.prefix = ""
878             return False
879
880         else:
881             self.comments.append((after, comment))
882             return True
883
884     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
885         """Generate comments that should appear directly after `leaf`."""
886         for _leaf_index, _leaf in enumerate(self.leaves):
887             if leaf is _leaf:
888                 break
889
890         else:
891             return
892
893         for index, comment_after in self.comments:
894             if _leaf_index == index:
895                 yield comment_after
896
897     def remove_trailing_comma(self) -> None:
898         """Remove the trailing comma and moves the comments attached to it."""
899         comma_index = len(self.leaves) - 1
900         for i in range(len(self.comments)):
901             comment_index, comment = self.comments[i]
902             if comment_index == comma_index:
903                 self.comments[i] = (comma_index - 1, comment)
904         self.leaves.pop()
905
906     def __str__(self) -> str:
907         """Render the line."""
908         if not self:
909             return "\n"
910
911         indent = "    " * self.depth
912         leaves = iter(self.leaves)
913         first = next(leaves)
914         res = f"{first.prefix}{indent}{first.value}"
915         for leaf in leaves:
916             res += str(leaf)
917         for _, comment in self.comments:
918             res += str(comment)
919         return res + "\n"
920
921     def __bool__(self) -> bool:
922         """Return True if the line has leaves or comments."""
923         return bool(self.leaves or self.comments)
924
925
926 class UnformattedLines(Line):
927     """Just like :class:`Line` but stores lines which aren't reformatted."""
928
929     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
930         """Just add a new `leaf` to the end of the lines.
931
932         The `preformatted` argument is ignored.
933
934         Keeps track of indentation `depth`, which is useful when the user
935         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
936         """
937         try:
938             list(generate_comments(leaf))
939         except FormatOn as f_on:
940             self.leaves.append(f_on.leaf_from_consumed(leaf))
941             raise
942
943         self.leaves.append(leaf)
944         if leaf.type == token.INDENT:
945             self.depth += 1
946         elif leaf.type == token.DEDENT:
947             self.depth -= 1
948
949     def __str__(self) -> str:
950         """Render unformatted lines from leaves which were added with `append()`.
951
952         `depth` is not used for indentation in this case.
953         """
954         if not self:
955             return "\n"
956
957         res = ""
958         for leaf in self.leaves:
959             res += str(leaf)
960         return res
961
962     def append_comment(self, comment: Leaf) -> bool:
963         """Not implemented in this class. Raises `NotImplementedError`."""
964         raise NotImplementedError("Unformatted lines don't store comments separately.")
965
966     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
967         """Does nothing and returns False."""
968         return False
969
970     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
971         """Does nothing and returns False."""
972         return False
973
974
975 @dataclass
976 class EmptyLineTracker:
977     """Provides a stateful method that returns the number of potential extra
978     empty lines needed before and after the currently processed line.
979
980     Note: this tracker works on lines that haven't been split yet.  It assumes
981     the prefix of the first leaf consists of optional newlines.  Those newlines
982     are consumed by `maybe_empty_lines()` and included in the computation.
983     """
984     previous_line: Optional[Line] = None
985     previous_after: int = 0
986     previous_defs: List[int] = Factory(list)
987
988     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
989         """Return the number of extra empty lines before and after the `current_line`.
990
991         This is for separating `def`, `async def` and `class` with extra empty
992         lines (two on module-level), as well as providing an extra empty line
993         after flow control keywords to make them more prominent.
994         """
995         if isinstance(current_line, UnformattedLines):
996             return 0, 0
997
998         before, after = self._maybe_empty_lines(current_line)
999         before -= self.previous_after
1000         self.previous_after = after
1001         self.previous_line = current_line
1002         return before, after
1003
1004     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1005         max_allowed = 1
1006         if current_line.depth == 0:
1007             max_allowed = 2
1008         if current_line.leaves:
1009             # Consume the first leaf's extra newlines.
1010             first_leaf = current_line.leaves[0]
1011             before = first_leaf.prefix.count("\n")
1012             before = min(before, max_allowed)
1013             first_leaf.prefix = ""
1014         else:
1015             before = 0
1016         depth = current_line.depth
1017         while self.previous_defs and self.previous_defs[-1] >= depth:
1018             self.previous_defs.pop()
1019             before = 1 if depth else 2
1020         is_decorator = current_line.is_decorator
1021         if is_decorator or current_line.is_def or current_line.is_class:
1022             if not is_decorator:
1023                 self.previous_defs.append(depth)
1024             if self.previous_line is None:
1025                 # Don't insert empty lines before the first line in the file.
1026                 return 0, 0
1027
1028             if self.previous_line and self.previous_line.is_decorator:
1029                 # Don't insert empty lines between decorators.
1030                 return 0, 0
1031
1032             newlines = 2
1033             if current_line.depth:
1034                 newlines -= 1
1035             return newlines, 0
1036
1037         if current_line.is_flow_control:
1038             return before, 1
1039
1040         if (
1041             self.previous_line
1042             and self.previous_line.is_import
1043             and not current_line.is_import
1044             and depth == self.previous_line.depth
1045         ):
1046             return (before or 1), 0
1047
1048         if (
1049             self.previous_line
1050             and self.previous_line.is_yield
1051             and (not current_line.is_yield or depth != self.previous_line.depth)
1052         ):
1053             return (before or 1), 0
1054
1055         return before, 0
1056
1057
1058 @dataclass
1059 class LineGenerator(Visitor[Line]):
1060     """Generates reformatted Line objects.  Empty lines are not emitted.
1061
1062     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1063     in ways that will no longer stringify to valid Python code on the tree.
1064     """
1065     current_line: Line = Factory(Line)
1066
1067     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1068         """Generate a line.
1069
1070         If the line is empty, only emit if it makes sense.
1071         If the line is too long, split it first and then generate.
1072
1073         If any lines were generated, set up a new current_line.
1074         """
1075         if not self.current_line:
1076             if self.current_line.__class__ == type:
1077                 self.current_line.depth += indent
1078             else:
1079                 self.current_line = type(depth=self.current_line.depth + indent)
1080             return  # Line is empty, don't emit. Creating a new one unnecessary.
1081
1082         complete_line = self.current_line
1083         self.current_line = type(depth=complete_line.depth + indent)
1084         yield complete_line
1085
1086     def visit(self, node: LN) -> Iterator[Line]:
1087         """Main method to visit `node` and its children.
1088
1089         Yields :class:`Line` objects.
1090         """
1091         if isinstance(self.current_line, UnformattedLines):
1092             # File contained `# fmt: off`
1093             yield from self.visit_unformatted(node)
1094
1095         else:
1096             yield from super().visit(node)
1097
1098     def visit_default(self, node: LN) -> Iterator[Line]:
1099         """Default `visit_*()` implementation. Recurses to children of `node`."""
1100         if isinstance(node, Leaf):
1101             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1102             try:
1103                 for comment in generate_comments(node):
1104                     if any_open_brackets:
1105                         # any comment within brackets is subject to splitting
1106                         self.current_line.append(comment)
1107                     elif comment.type == token.COMMENT:
1108                         # regular trailing comment
1109                         self.current_line.append(comment)
1110                         yield from self.line()
1111
1112                     else:
1113                         # regular standalone comment
1114                         yield from self.line()
1115
1116                         self.current_line.append(comment)
1117                         yield from self.line()
1118
1119             except FormatOff as f_off:
1120                 f_off.trim_prefix(node)
1121                 yield from self.line(type=UnformattedLines)
1122                 yield from self.visit(node)
1123
1124             except FormatOn as f_on:
1125                 # This only happens here if somebody says "fmt: on" multiple
1126                 # times in a row.
1127                 f_on.trim_prefix(node)
1128                 yield from self.visit_default(node)
1129
1130             else:
1131                 normalize_prefix(node, inside_brackets=any_open_brackets)
1132                 if node.type == token.STRING:
1133                     normalize_string_quotes(node)
1134                 if node.type not in WHITESPACE:
1135                     self.current_line.append(node)
1136         yield from super().visit_default(node)
1137
1138     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1139         """Increase indentation level, maybe yield a line."""
1140         # In blib2to3 INDENT never holds comments.
1141         yield from self.line(+1)
1142         yield from self.visit_default(node)
1143
1144     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1145         """Decrease indentation level, maybe yield a line."""
1146         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1147         yield from self.line(-1)
1148
1149     def visit_stmt(
1150         self, node: Node, keywords: Set[str], parens: Set[str]
1151     ) -> Iterator[Line]:
1152         """Visit a statement.
1153
1154         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1155         `def`, `with`, `class`, and `assert`.
1156
1157         The relevant Python language `keywords` for a given statement will be
1158         NAME leaves within it. This methods puts those on a separate line.
1159
1160         `parens` holds pairs of nodes where invisible parentheses should be put.
1161         Keys hold nodes after which opening parentheses should be put, values
1162         hold nodes before which closing parentheses should be put.
1163         """
1164         normalize_invisible_parens(node, parens_after=parens)
1165         for child in node.children:
1166             if child.type == token.NAME and child.value in keywords:  # type: ignore
1167                 yield from self.line()
1168
1169             yield from self.visit(child)
1170
1171     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1172         """Visit a statement without nested statements."""
1173         is_suite_like = node.parent and node.parent.type in STATEMENT
1174         if is_suite_like:
1175             yield from self.line(+1)
1176             yield from self.visit_default(node)
1177             yield from self.line(-1)
1178
1179         else:
1180             yield from self.line()
1181             yield from self.visit_default(node)
1182
1183     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1184         """Visit `async def`, `async for`, `async with`."""
1185         yield from self.line()
1186
1187         children = iter(node.children)
1188         for child in children:
1189             yield from self.visit(child)
1190
1191             if child.type == token.ASYNC:
1192                 break
1193
1194         internal_stmt = next(children)
1195         for child in internal_stmt.children:
1196             yield from self.visit(child)
1197
1198     def visit_decorators(self, node: Node) -> Iterator[Line]:
1199         """Visit decorators."""
1200         for child in node.children:
1201             yield from self.line()
1202             yield from self.visit(child)
1203
1204     def visit_import_from(self, node: Node) -> Iterator[Line]:
1205         """Visit import_from and maybe put invisible parentheses.
1206
1207         This is separate from `visit_stmt` because import statements don't
1208         support arbitrary atoms and thus handling of parentheses is custom.
1209         """
1210         check_lpar = False
1211         for index, child in enumerate(node.children):
1212             if check_lpar:
1213                 if child.type == token.LPAR:
1214                     # make parentheses invisible
1215                     child.value = ""  # type: ignore
1216                     node.children[-1].value = ""  # type: ignore
1217                 else:
1218                     # insert invisible parentheses
1219                     node.insert_child(index, Leaf(token.LPAR, ""))
1220                     node.append_child(Leaf(token.RPAR, ""))
1221                 break
1222
1223             check_lpar = (
1224                 child.type == token.NAME and child.value == "import"  # type: ignore
1225             )
1226
1227         for child in node.children:
1228             yield from self.visit(child)
1229
1230     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1231         """Remove a semicolon and put the other statement on a separate line."""
1232         yield from self.line()
1233
1234     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1235         """End of file. Process outstanding comments and end with a newline."""
1236         yield from self.visit_default(leaf)
1237         yield from self.line()
1238
1239     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1240         """Used when file contained a `# fmt: off`."""
1241         if isinstance(node, Node):
1242             for child in node.children:
1243                 yield from self.visit(child)
1244
1245         else:
1246             try:
1247                 self.current_line.append(node)
1248             except FormatOn as f_on:
1249                 f_on.trim_prefix(node)
1250                 yield from self.line()
1251                 yield from self.visit(node)
1252
1253             if node.type == token.ENDMARKER:
1254                 # somebody decided not to put a final `# fmt: on`
1255                 yield from self.line()
1256
1257     def __attrs_post_init__(self) -> None:
1258         """You are in a twisty little maze of passages."""
1259         v = self.visit_stmt
1260         Ø: Set[str] = set()
1261         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1262         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1263         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1264         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1265         self.visit_try_stmt = partial(
1266             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1267         )
1268         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1269         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1270         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1271         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1272         self.visit_async_funcdef = self.visit_async_stmt
1273         self.visit_decorated = self.visit_decorators
1274
1275
1276 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1277 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1278 OPENING_BRACKETS = set(BRACKET.keys())
1279 CLOSING_BRACKETS = set(BRACKET.values())
1280 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1281 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1282
1283
1284 def whitespace(leaf: Leaf) -> str:  # noqa C901
1285     """Return whitespace prefix if needed for the given `leaf`."""
1286     NO = ""
1287     SPACE = " "
1288     DOUBLESPACE = "  "
1289     t = leaf.type
1290     p = leaf.parent
1291     v = leaf.value
1292     if t in ALWAYS_NO_SPACE:
1293         return NO
1294
1295     if t == token.COMMENT:
1296         return DOUBLESPACE
1297
1298     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1299     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1300         return NO
1301
1302     prev = leaf.prev_sibling
1303     if not prev:
1304         prevp = preceding_leaf(p)
1305         if not prevp or prevp.type in OPENING_BRACKETS:
1306             return NO
1307
1308         if t == token.COLON:
1309             return SPACE if prevp.type == token.COMMA else NO
1310
1311         if prevp.type == token.EQUAL:
1312             if prevp.parent:
1313                 if prevp.parent.type in {
1314                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1315                 }:
1316                     return NO
1317
1318                 elif prevp.parent.type == syms.typedargslist:
1319                     # A bit hacky: if the equal sign has whitespace, it means we
1320                     # previously found it's a typed argument.  So, we're using
1321                     # that, too.
1322                     return prevp.prefix
1323
1324         elif prevp.type in STARS:
1325             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1326                 return NO
1327
1328         elif prevp.type == token.COLON:
1329             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1330                 return NO
1331
1332         elif (
1333             prevp.parent
1334             and prevp.parent.type == syms.factor
1335             and prevp.type in MATH_OPERATORS
1336         ):
1337             return NO
1338
1339         elif (
1340             prevp.type == token.RIGHTSHIFT
1341             and prevp.parent
1342             and prevp.parent.type == syms.shift_expr
1343             and prevp.prev_sibling
1344             and prevp.prev_sibling.type == token.NAME
1345             and prevp.prev_sibling.value == "print"  # type: ignore
1346         ):
1347             # Python 2 print chevron
1348             return NO
1349
1350     elif prev.type in OPENING_BRACKETS:
1351         return NO
1352
1353     if p.type in {syms.parameters, syms.arglist}:
1354         # untyped function signatures or calls
1355         if t == token.RPAR:
1356             return NO
1357
1358         if not prev or prev.type != token.COMMA:
1359             return NO
1360
1361     elif p.type == syms.varargslist:
1362         # lambdas
1363         if t == token.RPAR:
1364             return NO
1365
1366         if prev and prev.type != token.COMMA:
1367             return NO
1368
1369     elif p.type == syms.typedargslist:
1370         # typed function signatures
1371         if not prev:
1372             return NO
1373
1374         if t == token.EQUAL:
1375             if prev.type != syms.tname:
1376                 return NO
1377
1378         elif prev.type == token.EQUAL:
1379             # A bit hacky: if the equal sign has whitespace, it means we
1380             # previously found it's a typed argument.  So, we're using that, too.
1381             return prev.prefix
1382
1383         elif prev.type != token.COMMA:
1384             return NO
1385
1386     elif p.type == syms.tname:
1387         # type names
1388         if not prev:
1389             prevp = preceding_leaf(p)
1390             if not prevp or prevp.type != token.COMMA:
1391                 return NO
1392
1393     elif p.type == syms.trailer:
1394         # attributes and calls
1395         if t == token.LPAR or t == token.RPAR:
1396             return NO
1397
1398         if not prev:
1399             if t == token.DOT:
1400                 prevp = preceding_leaf(p)
1401                 if not prevp or prevp.type != token.NUMBER:
1402                     return NO
1403
1404             elif t == token.LSQB:
1405                 return NO
1406
1407         elif prev.type != token.COMMA:
1408             return NO
1409
1410     elif p.type == syms.argument:
1411         # single argument
1412         if t == token.EQUAL:
1413             return NO
1414
1415         if not prev:
1416             prevp = preceding_leaf(p)
1417             if not prevp or prevp.type == token.LPAR:
1418                 return NO
1419
1420         elif prev.type in {token.EQUAL} | STARS:
1421             return NO
1422
1423     elif p.type == syms.decorator:
1424         # decorators
1425         return NO
1426
1427     elif p.type == syms.dotted_name:
1428         if prev:
1429             return NO
1430
1431         prevp = preceding_leaf(p)
1432         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1433             return NO
1434
1435     elif p.type == syms.classdef:
1436         if t == token.LPAR:
1437             return NO
1438
1439         if prev and prev.type == token.LPAR:
1440             return NO
1441
1442     elif p.type == syms.subscript:
1443         # indexing
1444         if not prev:
1445             assert p.parent is not None, "subscripts are always parented"
1446             if p.parent.type == syms.subscriptlist:
1447                 return SPACE
1448
1449             return NO
1450
1451         else:
1452             return NO
1453
1454     elif p.type == syms.atom:
1455         if prev and t == token.DOT:
1456             # dots, but not the first one.
1457             return NO
1458
1459     elif (
1460         p.type == syms.listmaker
1461         or p.type == syms.testlist_gexp
1462         or p.type == syms.subscriptlist
1463     ):
1464         # list interior, including unpacking
1465         if not prev:
1466             return NO
1467
1468     elif p.type == syms.dictsetmaker:
1469         # dict and set interior, including unpacking
1470         if not prev:
1471             return NO
1472
1473         if prev.type == token.DOUBLESTAR:
1474             return NO
1475
1476     elif p.type in {syms.factor, syms.star_expr}:
1477         # unary ops
1478         if not prev:
1479             prevp = preceding_leaf(p)
1480             if not prevp or prevp.type in OPENING_BRACKETS:
1481                 return NO
1482
1483             prevp_parent = prevp.parent
1484             assert prevp_parent is not None
1485             if (
1486                 prevp.type == token.COLON
1487                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1488             ):
1489                 return NO
1490
1491             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1492                 return NO
1493
1494         elif t == token.NAME or t == token.NUMBER:
1495             return NO
1496
1497     elif p.type == syms.import_from:
1498         if t == token.DOT:
1499             if prev and prev.type == token.DOT:
1500                 return NO
1501
1502         elif t == token.NAME:
1503             if v == "import":
1504                 return SPACE
1505
1506             if prev and prev.type == token.DOT:
1507                 return NO
1508
1509     elif p.type == syms.sliceop:
1510         return NO
1511
1512     return SPACE
1513
1514
1515 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1516     """Return the first leaf that precedes `node`, if any."""
1517     while node:
1518         res = node.prev_sibling
1519         if res:
1520             if isinstance(res, Leaf):
1521                 return res
1522
1523             try:
1524                 return list(res.leaves())[-1]
1525
1526             except IndexError:
1527                 return None
1528
1529         node = node.parent
1530     return None
1531
1532
1533 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1534     """Return the priority of the `leaf` delimiter, given a line break after it.
1535
1536     The delimiter priorities returned here are from those delimiters that would
1537     cause a line break after themselves.
1538
1539     Higher numbers are higher priority.
1540     """
1541     if leaf.type == token.COMMA:
1542         return COMMA_PRIORITY
1543
1544     return 0
1545
1546
1547 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1548     """Return the priority of the `leaf` delimiter, given a line before after it.
1549
1550     The delimiter priorities returned here are from those delimiters that would
1551     cause a line break before themselves.
1552
1553     Higher numbers are higher priority.
1554     """
1555     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1556         # * and ** might also be MATH_OPERATORS but in this case they are not.
1557         # Don't treat them as a delimiter.
1558         return 0
1559
1560     if (
1561         leaf.type in MATH_OPERATORS
1562         and leaf.parent
1563         and leaf.parent.type not in {syms.factor, syms.star_expr}
1564     ):
1565         return MATH_PRIORITY
1566
1567     if leaf.type in COMPARATORS:
1568         return COMPARATOR_PRIORITY
1569
1570     if (
1571         leaf.type == token.STRING
1572         and previous is not None
1573         and previous.type == token.STRING
1574     ):
1575         return STRING_PRIORITY
1576
1577     if (
1578         leaf.type == token.NAME
1579         and leaf.value == "for"
1580         and leaf.parent
1581         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1582     ):
1583         return COMPREHENSION_PRIORITY
1584
1585     if (
1586         leaf.type == token.NAME
1587         and leaf.value == "if"
1588         and leaf.parent
1589         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1590     ):
1591         return COMPREHENSION_PRIORITY
1592
1593     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1594         return LOGIC_PRIORITY
1595
1596     return 0
1597
1598
1599 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1600     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1601
1602     Higher numbers are higher priority.
1603     """
1604     return max(
1605         is_split_before_delimiter(leaf, previous),
1606         is_split_after_delimiter(leaf, previous),
1607     )
1608
1609
1610 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1611     """Clean the prefix of the `leaf` and generate comments from it, if any.
1612
1613     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1614     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1615     move because it does away with modifying the grammar to include all the
1616     possible places in which comments can be placed.
1617
1618     The sad consequence for us though is that comments don't "belong" anywhere.
1619     This is why this function generates simple parentless Leaf objects for
1620     comments.  We simply don't know what the correct parent should be.
1621
1622     No matter though, we can live without this.  We really only need to
1623     differentiate between inline and standalone comments.  The latter don't
1624     share the line with any code.
1625
1626     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1627     are emitted with a fake STANDALONE_COMMENT token identifier.
1628     """
1629     p = leaf.prefix
1630     if not p:
1631         return
1632
1633     if "#" not in p:
1634         return
1635
1636     consumed = 0
1637     nlines = 0
1638     for index, line in enumerate(p.split("\n")):
1639         consumed += len(line) + 1  # adding the length of the split '\n'
1640         line = line.lstrip()
1641         if not line:
1642             nlines += 1
1643         if not line.startswith("#"):
1644             continue
1645
1646         if index == 0 and leaf.type != token.ENDMARKER:
1647             comment_type = token.COMMENT  # simple trailing comment
1648         else:
1649             comment_type = STANDALONE_COMMENT
1650         comment = make_comment(line)
1651         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1652
1653         if comment in {"# fmt: on", "# yapf: enable"}:
1654             raise FormatOn(consumed)
1655
1656         if comment in {"# fmt: off", "# yapf: disable"}:
1657             if comment_type == STANDALONE_COMMENT:
1658                 raise FormatOff(consumed)
1659
1660             prev = preceding_leaf(leaf)
1661             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1662                 raise FormatOff(consumed)
1663
1664         nlines = 0
1665
1666
1667 def make_comment(content: str) -> str:
1668     """Return a consistently formatted comment from the given `content` string.
1669
1670     All comments (except for "##", "#!", "#:") should have a single space between
1671     the hash sign and the content.
1672
1673     If `content` didn't start with a hash sign, one is provided.
1674     """
1675     content = content.rstrip()
1676     if not content:
1677         return "#"
1678
1679     if content[0] == "#":
1680         content = content[1:]
1681     if content and content[0] not in " !:#":
1682         content = " " + content
1683     return "#" + content
1684
1685
1686 def split_line(
1687     line: Line, line_length: int, inner: bool = False, py36: bool = False
1688 ) -> Iterator[Line]:
1689     """Split a `line` into potentially many lines.
1690
1691     They should fit in the allotted `line_length` but might not be able to.
1692     `inner` signifies that there were a pair of brackets somewhere around the
1693     current `line`, possibly transitively. This means we can fallback to splitting
1694     by delimiters if the LHS/RHS don't yield any results.
1695
1696     If `py36` is True, splitting may generate syntax that is only compatible
1697     with Python 3.6 and later.
1698     """
1699     if isinstance(line, UnformattedLines) or line.is_comment:
1700         yield line
1701         return
1702
1703     line_str = str(line).strip("\n")
1704     if (
1705         len(line_str) <= line_length
1706         and "\n" not in line_str  # multiline strings
1707         and not line.contains_standalone_comments()
1708     ):
1709         yield line
1710         return
1711
1712     split_funcs: List[SplitFunc]
1713     if line.is_def:
1714         split_funcs = [left_hand_split]
1715     elif line.inside_brackets:
1716         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1717     else:
1718         split_funcs = [right_hand_split]
1719     for split_func in split_funcs:
1720         # We are accumulating lines in `result` because we might want to abort
1721         # mission and return the original line in the end, or attempt a different
1722         # split altogether.
1723         result: List[Line] = []
1724         try:
1725             for l in split_func(line, py36):
1726                 if str(l).strip("\n") == line_str:
1727                     raise CannotSplit("Split function returned an unchanged result")
1728
1729                 result.extend(
1730                     split_line(l, line_length=line_length, inner=True, py36=py36)
1731                 )
1732         except CannotSplit as cs:
1733             continue
1734
1735         else:
1736             yield from result
1737             break
1738
1739     else:
1740         yield line
1741
1742
1743 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1744     """Split line into many lines, starting with the first matching bracket pair.
1745
1746     Note: this usually looks weird, only use this for function definitions.
1747     Prefer RHS otherwise.
1748     """
1749     head = Line(depth=line.depth)
1750     body = Line(depth=line.depth + 1, inside_brackets=True)
1751     tail = Line(depth=line.depth)
1752     tail_leaves: List[Leaf] = []
1753     body_leaves: List[Leaf] = []
1754     head_leaves: List[Leaf] = []
1755     current_leaves = head_leaves
1756     matching_bracket = None
1757     for leaf in line.leaves:
1758         if (
1759             current_leaves is body_leaves
1760             and leaf.type in CLOSING_BRACKETS
1761             and leaf.opening_bracket is matching_bracket
1762         ):
1763             current_leaves = tail_leaves if body_leaves else head_leaves
1764         current_leaves.append(leaf)
1765         if current_leaves is head_leaves:
1766             if leaf.type in OPENING_BRACKETS:
1767                 matching_bracket = leaf
1768                 current_leaves = body_leaves
1769     # Since body is a new indent level, remove spurious leading whitespace.
1770     if body_leaves:
1771         normalize_prefix(body_leaves[0], inside_brackets=True)
1772     # Build the new lines.
1773     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1774         for leaf in leaves:
1775             result.append(leaf, preformatted=True)
1776             for comment_after in line.comments_after(leaf):
1777                 result.append(comment_after, preformatted=True)
1778     bracket_split_succeeded_or_raise(head, body, tail)
1779     for result in (head, body, tail):
1780         if result:
1781             yield result
1782
1783
1784 def right_hand_split(
1785     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1786 ) -> Iterator[Line]:
1787     """Split line into many lines, starting with the last matching bracket pair."""
1788     head = Line(depth=line.depth)
1789     body = Line(depth=line.depth + 1, inside_brackets=True)
1790     tail = Line(depth=line.depth)
1791     tail_leaves: List[Leaf] = []
1792     body_leaves: List[Leaf] = []
1793     head_leaves: List[Leaf] = []
1794     current_leaves = tail_leaves
1795     opening_bracket = None
1796     closing_bracket = None
1797     for leaf in reversed(line.leaves):
1798         if current_leaves is body_leaves:
1799             if leaf is opening_bracket:
1800                 current_leaves = head_leaves if body_leaves else tail_leaves
1801         current_leaves.append(leaf)
1802         if current_leaves is tail_leaves:
1803             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1804                 opening_bracket = leaf.opening_bracket
1805                 closing_bracket = leaf
1806                 current_leaves = body_leaves
1807     tail_leaves.reverse()
1808     body_leaves.reverse()
1809     head_leaves.reverse()
1810     # Since body is a new indent level, remove spurious leading whitespace.
1811     if body_leaves:
1812         normalize_prefix(body_leaves[0], inside_brackets=True)
1813     elif not head_leaves:
1814         # No `head` and no `body` means the split failed. `tail` has all content.
1815         raise CannotSplit("No brackets found")
1816
1817     # Build the new lines.
1818     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1819         for leaf in leaves:
1820             result.append(leaf, preformatted=True)
1821             for comment_after in line.comments_after(leaf):
1822                 result.append(comment_after, preformatted=True)
1823     bracket_split_succeeded_or_raise(head, body, tail)
1824     assert opening_bracket and closing_bracket
1825     if (
1826         opening_bracket.type == token.LPAR
1827         and not opening_bracket.value
1828         and closing_bracket.type == token.RPAR
1829         and not closing_bracket.value
1830     ):
1831         # These parens were optional. If there aren't any delimiters or standalone
1832         # comments in the body, they were unnecessary and another split without
1833         # them should be attempted.
1834         if not (
1835             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1836         ):
1837             omit = {id(closing_bracket), *omit}
1838             yield from right_hand_split(line, py36=py36, omit=omit)
1839             return
1840
1841     ensure_visible(opening_bracket)
1842     ensure_visible(closing_bracket)
1843     for result in (head, body, tail):
1844         if result:
1845             yield result
1846
1847
1848 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1849     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1850
1851     Do nothing otherwise.
1852
1853     A left- or right-hand split is based on a pair of brackets. Content before
1854     (and including) the opening bracket is left on one line, content inside the
1855     brackets is put on a separate line, and finally content starting with and
1856     following the closing bracket is put on a separate line.
1857
1858     Those are called `head`, `body`, and `tail`, respectively. If the split
1859     produced the same line (all content in `head`) or ended up with an empty `body`
1860     and the `tail` is just the closing bracket, then it's considered failed.
1861     """
1862     tail_len = len(str(tail).strip())
1863     if not body:
1864         if tail_len == 0:
1865             raise CannotSplit("Splitting brackets produced the same line")
1866
1867         elif tail_len < 3:
1868             raise CannotSplit(
1869                 f"Splitting brackets on an empty body to save "
1870                 f"{tail_len} characters is not worth it"
1871             )
1872
1873
1874 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1875     """Normalize prefix of the first leaf in every line returned by `split_func`.
1876
1877     This is a decorator over relevant split functions.
1878     """
1879
1880     @wraps(split_func)
1881     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1882         for l in split_func(line, py36):
1883             normalize_prefix(l.leaves[0], inside_brackets=True)
1884             yield l
1885
1886     return split_wrapper
1887
1888
1889 @dont_increase_indentation
1890 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1891     """Split according to delimiters of the highest priority.
1892
1893     If `py36` is True, the split will add trailing commas also in function
1894     signatures that contain `*` and `**`.
1895     """
1896     try:
1897         last_leaf = line.leaves[-1]
1898     except IndexError:
1899         raise CannotSplit("Line empty")
1900
1901     delimiters = line.bracket_tracker.delimiters
1902     try:
1903         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1904             exclude={id(last_leaf)}
1905         )
1906     except ValueError:
1907         raise CannotSplit("No delimiters found")
1908
1909     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1910     lowest_depth = sys.maxsize
1911     trailing_comma_safe = True
1912
1913     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1914         """Append `leaf` to current line or to new line if appending impossible."""
1915         nonlocal current_line
1916         try:
1917             current_line.append_safe(leaf, preformatted=True)
1918         except ValueError as ve:
1919             yield current_line
1920
1921             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1922             current_line.append(leaf)
1923
1924     for leaf in line.leaves:
1925         yield from append_to_line(leaf)
1926
1927         for comment_after in line.comments_after(leaf):
1928             yield from append_to_line(comment_after)
1929
1930         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1931         if (
1932             leaf.bracket_depth == lowest_depth
1933             and is_vararg(leaf, within=VARARGS_PARENTS)
1934         ):
1935             trailing_comma_safe = trailing_comma_safe and py36
1936         leaf_priority = delimiters.get(id(leaf))
1937         if leaf_priority == delimiter_priority:
1938             yield current_line
1939
1940             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1941     if current_line:
1942         if (
1943             trailing_comma_safe
1944             and delimiter_priority == COMMA_PRIORITY
1945             and current_line.leaves[-1].type != token.COMMA
1946             and current_line.leaves[-1].type != STANDALONE_COMMENT
1947         ):
1948             current_line.append(Leaf(token.COMMA, ","))
1949         yield current_line
1950
1951
1952 @dont_increase_indentation
1953 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1954     """Split standalone comments from the rest of the line."""
1955     if not line.contains_standalone_comments(0):
1956         raise CannotSplit("Line does not have any standalone comments")
1957
1958     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1959
1960     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1961         """Append `leaf` to current line or to new line if appending impossible."""
1962         nonlocal current_line
1963         try:
1964             current_line.append_safe(leaf, preformatted=True)
1965         except ValueError as ve:
1966             yield current_line
1967
1968             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1969             current_line.append(leaf)
1970
1971     for leaf in line.leaves:
1972         yield from append_to_line(leaf)
1973
1974         for comment_after in line.comments_after(leaf):
1975             yield from append_to_line(comment_after)
1976
1977     if current_line:
1978         yield current_line
1979
1980
1981 def is_import(leaf: Leaf) -> bool:
1982     """Return True if the given leaf starts an import statement."""
1983     p = leaf.parent
1984     t = leaf.type
1985     v = leaf.value
1986     return bool(
1987         t == token.NAME
1988         and (
1989             (v == "import" and p and p.type == syms.import_name)
1990             or (v == "from" and p and p.type == syms.import_from)
1991         )
1992     )
1993
1994
1995 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1996     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1997     else.
1998
1999     Note: don't use backslashes for formatting or you'll lose your voting rights.
2000     """
2001     if not inside_brackets:
2002         spl = leaf.prefix.split("#")
2003         if "\\" not in spl[0]:
2004             nl_count = spl[-1].count("\n")
2005             if len(spl) > 1:
2006                 nl_count -= 1
2007             leaf.prefix = "\n" * nl_count
2008             return
2009
2010     leaf.prefix = ""
2011
2012
2013 def normalize_string_quotes(leaf: Leaf) -> None:
2014     """Prefer double quotes but only if it doesn't cause more escaping.
2015
2016     Adds or removes backslashes as appropriate. Doesn't parse and fix
2017     strings nested in f-strings (yet).
2018
2019     Note: Mutates its argument.
2020     """
2021     value = leaf.value.lstrip("furbFURB")
2022     if value[:3] == '"""':
2023         return
2024
2025     elif value[:3] == "'''":
2026         orig_quote = "'''"
2027         new_quote = '"""'
2028     elif value[0] == '"':
2029         orig_quote = '"'
2030         new_quote = "'"
2031     else:
2032         orig_quote = "'"
2033         new_quote = '"'
2034     first_quote_pos = leaf.value.find(orig_quote)
2035     if first_quote_pos == -1:
2036         return  # There's an internal error
2037
2038     prefix = leaf.value[:first_quote_pos]
2039     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2040     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2041     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2042     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2043     if "r" in prefix.casefold():
2044         if unescaped_new_quote.search(body):
2045             # There's at least one unescaped new_quote in this raw string
2046             # so converting is impossible
2047             return
2048
2049         # Do not introduce or remove backslashes in raw strings
2050         new_body = body
2051     else:
2052         # remove unnecessary quotes
2053         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2054         if body != new_body:
2055             # Consider the string without unnecessary quotes as the original
2056             body = new_body
2057             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2058         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2059         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2060     if new_quote == '"""' and new_body[-1] == '"':
2061         # edge case:
2062         new_body = new_body[:-1] + '\\"'
2063     orig_escape_count = body.count("\\")
2064     new_escape_count = new_body.count("\\")
2065     if new_escape_count > orig_escape_count:
2066         return  # Do not introduce more escaping
2067
2068     if new_escape_count == orig_escape_count and orig_quote == '"':
2069         return  # Prefer double quotes
2070
2071     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2072
2073
2074 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2075     """Make existing optional parentheses invisible or create new ones.
2076
2077     Standardizes on visible parentheses for single-element tuples, and keeps
2078     existing visible parentheses for other tuples and generator expressions.
2079     """
2080     check_lpar = False
2081     for child in list(node.children):
2082         if check_lpar:
2083             if child.type == syms.atom:
2084                 if not (
2085                     is_empty_tuple(child)
2086                     or is_one_tuple(child)
2087                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2088                 ):
2089                     first = child.children[0]
2090                     last = child.children[-1]
2091                     if first.type == token.LPAR and last.type == token.RPAR:
2092                         # make parentheses invisible
2093                         first.value = ""  # type: ignore
2094                         last.value = ""  # type: ignore
2095             elif is_one_tuple(child):
2096                 # wrap child in visible parentheses
2097                 lpar = Leaf(token.LPAR, "(")
2098                 rpar = Leaf(token.RPAR, ")")
2099                 index = child.remove() or 0
2100                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2101             else:
2102                 # wrap child in invisible parentheses
2103                 lpar = Leaf(token.LPAR, "")
2104                 rpar = Leaf(token.RPAR, "")
2105                 index = child.remove() or 0
2106                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2107
2108         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2109
2110
2111 def is_empty_tuple(node: LN) -> bool:
2112     """Return True if `node` holds an empty tuple."""
2113     return (
2114         node.type == syms.atom
2115         and len(node.children) == 2
2116         and node.children[0].type == token.LPAR
2117         and node.children[1].type == token.RPAR
2118     )
2119
2120
2121 def is_one_tuple(node: LN) -> bool:
2122     """Return True if `node` holds a tuple with one element, with or without parens."""
2123     if node.type == syms.atom:
2124         if len(node.children) != 3:
2125             return False
2126
2127         lpar, gexp, rpar = node.children
2128         if not (
2129             lpar.type == token.LPAR
2130             and gexp.type == syms.testlist_gexp
2131             and rpar.type == token.RPAR
2132         ):
2133             return False
2134
2135         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2136
2137     return (
2138         node.type in IMPLICIT_TUPLE
2139         and len(node.children) == 2
2140         and node.children[1].type == token.COMMA
2141     )
2142
2143
2144 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2145     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2146
2147     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2148     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2149     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2150     generalizations (PEP 448).
2151     """
2152     if leaf.type not in STARS or not leaf.parent:
2153         return False
2154
2155     p = leaf.parent
2156     if p.type == syms.star_expr:
2157         # Star expressions are also used as assignment targets in extended
2158         # iterable unpacking (PEP 3132).  See what its parent is instead.
2159         if not p.parent:
2160             return False
2161
2162         p = p.parent
2163
2164     return p.type in within
2165
2166
2167 def max_delimiter_priority_in_atom(node: LN) -> int:
2168     if node.type != syms.atom:
2169         return 0
2170
2171     first = node.children[0]
2172     last = node.children[-1]
2173     if not (first.type == token.LPAR and last.type == token.RPAR):
2174         return 0
2175
2176     bt = BracketTracker()
2177     for c in node.children[1:-1]:
2178         if isinstance(c, Leaf):
2179             bt.mark(c)
2180         else:
2181             for leaf in c.leaves():
2182                 bt.mark(leaf)
2183     try:
2184         return bt.max_delimiter_priority()
2185
2186     except ValueError:
2187         return 0
2188
2189
2190 def ensure_visible(leaf: Leaf) -> None:
2191     """Make sure parentheses are visible.
2192
2193     They could be invisible as part of some statements (see
2194     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2195     """
2196     if leaf.type == token.LPAR:
2197         leaf.value = "("
2198     elif leaf.type == token.RPAR:
2199         leaf.value = ")"
2200
2201
2202 def is_python36(node: Node) -> bool:
2203     """Return True if the current file is using Python 3.6+ features.
2204
2205     Currently looking for:
2206     - f-strings; and
2207     - trailing commas after * or ** in function signatures.
2208     """
2209     for n in node.pre_order():
2210         if n.type == token.STRING:
2211             value_head = n.value[:2]  # type: ignore
2212             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2213                 return True
2214
2215         elif (
2216             n.type == syms.typedargslist
2217             and n.children
2218             and n.children[-1].type == token.COMMA
2219         ):
2220             for ch in n.children:
2221                 if ch.type in STARS:
2222                     return True
2223
2224     return False
2225
2226
2227 PYTHON_EXTENSIONS = {".py"}
2228 BLACKLISTED_DIRECTORIES = {
2229     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2230 }
2231
2232
2233 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2234     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2235     and have one of the PYTHON_EXTENSIONS.
2236     """
2237     for child in path.iterdir():
2238         if child.is_dir():
2239             if child.name in BLACKLISTED_DIRECTORIES:
2240                 continue
2241
2242             yield from gen_python_files_in_dir(child)
2243
2244         elif child.suffix in PYTHON_EXTENSIONS:
2245             yield child
2246
2247
2248 @dataclass
2249 class Report:
2250     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2251     check: bool = False
2252     quiet: bool = False
2253     change_count: int = 0
2254     same_count: int = 0
2255     failure_count: int = 0
2256
2257     def done(self, src: Path, changed: Changed) -> None:
2258         """Increment the counter for successful reformatting. Write out a message."""
2259         if changed is Changed.YES:
2260             reformatted = "would reformat" if self.check else "reformatted"
2261             if not self.quiet:
2262                 out(f"{reformatted} {src}")
2263             self.change_count += 1
2264         else:
2265             if not self.quiet:
2266                 if changed is Changed.NO:
2267                     msg = f"{src} already well formatted, good job."
2268                 else:
2269                     msg = f"{src} wasn't modified on disk since last run."
2270                 out(msg, bold=False)
2271             self.same_count += 1
2272
2273     def failed(self, src: Path, message: str) -> None:
2274         """Increment the counter for failed reformatting. Write out a message."""
2275         err(f"error: cannot format {src}: {message}")
2276         self.failure_count += 1
2277
2278     @property
2279     def return_code(self) -> int:
2280         """Return the exit code that the app should use.
2281
2282         This considers the current state of changed files and failures:
2283         - if there were any failures, return 123;
2284         - if any files were changed and --check is being used, return 1;
2285         - otherwise return 0.
2286         """
2287         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2288         # 126 we have special returncodes reserved by the shell.
2289         if self.failure_count:
2290             return 123
2291
2292         elif self.change_count and self.check:
2293             return 1
2294
2295         return 0
2296
2297     def __str__(self) -> str:
2298         """Render a color report of the current state.
2299
2300         Use `click.unstyle` to remove colors.
2301         """
2302         if self.check:
2303             reformatted = "would be reformatted"
2304             unchanged = "would be left unchanged"
2305             failed = "would fail to reformat"
2306         else:
2307             reformatted = "reformatted"
2308             unchanged = "left unchanged"
2309             failed = "failed to reformat"
2310         report = []
2311         if self.change_count:
2312             s = "s" if self.change_count > 1 else ""
2313             report.append(
2314                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2315             )
2316         if self.same_count:
2317             s = "s" if self.same_count > 1 else ""
2318             report.append(f"{self.same_count} file{s} {unchanged}")
2319         if self.failure_count:
2320             s = "s" if self.failure_count > 1 else ""
2321             report.append(
2322                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2323             )
2324         return ", ".join(report) + "."
2325
2326
2327 def assert_equivalent(src: str, dst: str) -> None:
2328     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2329
2330     import ast
2331     import traceback
2332
2333     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2334         """Simple visitor generating strings to compare ASTs by content."""
2335         yield f"{'  ' * depth}{node.__class__.__name__}("
2336
2337         for field in sorted(node._fields):
2338             try:
2339                 value = getattr(node, field)
2340             except AttributeError:
2341                 continue
2342
2343             yield f"{'  ' * (depth+1)}{field}="
2344
2345             if isinstance(value, list):
2346                 for item in value:
2347                     if isinstance(item, ast.AST):
2348                         yield from _v(item, depth + 2)
2349
2350             elif isinstance(value, ast.AST):
2351                 yield from _v(value, depth + 2)
2352
2353             else:
2354                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2355
2356         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2357
2358     try:
2359         src_ast = ast.parse(src)
2360     except Exception as exc:
2361         major, minor = sys.version_info[:2]
2362         raise AssertionError(
2363             f"cannot use --safe with this file; failed to parse source file "
2364             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2365             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2366         )
2367
2368     try:
2369         dst_ast = ast.parse(dst)
2370     except Exception as exc:
2371         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2372         raise AssertionError(
2373             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2374             f"Please report a bug on https://github.com/ambv/black/issues.  "
2375             f"This invalid output might be helpful: {log}"
2376         ) from None
2377
2378     src_ast_str = "\n".join(_v(src_ast))
2379     dst_ast_str = "\n".join(_v(dst_ast))
2380     if src_ast_str != dst_ast_str:
2381         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2382         raise AssertionError(
2383             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2384             f"the source.  "
2385             f"Please report a bug on https://github.com/ambv/black/issues.  "
2386             f"This diff might be helpful: {log}"
2387         ) from None
2388
2389
2390 def assert_stable(src: str, dst: str, line_length: int) -> None:
2391     """Raise AssertionError if `dst` reformats differently the second time."""
2392     newdst = format_str(dst, line_length=line_length)
2393     if dst != newdst:
2394         log = dump_to_file(
2395             diff(src, dst, "source", "first pass"),
2396             diff(dst, newdst, "first pass", "second pass"),
2397         )
2398         raise AssertionError(
2399             f"INTERNAL ERROR: Black produced different code on the second pass "
2400             f"of the formatter.  "
2401             f"Please report a bug on https://github.com/ambv/black/issues.  "
2402             f"This diff might be helpful: {log}"
2403         ) from None
2404
2405
2406 def dump_to_file(*output: str) -> str:
2407     """Dump `output` to a temporary file. Return path to the file."""
2408     import tempfile
2409
2410     with tempfile.NamedTemporaryFile(
2411         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2412     ) as f:
2413         for lines in output:
2414             f.write(lines)
2415             if lines and lines[-1] != "\n":
2416                 f.write("\n")
2417     return f.name
2418
2419
2420 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2421     """Return a unified diff string between strings `a` and `b`."""
2422     import difflib
2423
2424     a_lines = [line + "\n" for line in a.split("\n")]
2425     b_lines = [line + "\n" for line in b.split("\n")]
2426     return "".join(
2427         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2428     )
2429
2430
2431 def cancel(tasks: List[asyncio.Task]) -> None:
2432     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2433     err("Aborted!")
2434     for task in tasks:
2435         task.cancel()
2436
2437
2438 def shutdown(loop: BaseEventLoop) -> None:
2439     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2440     try:
2441         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2442         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2443         if not to_cancel:
2444             return
2445
2446         for task in to_cancel:
2447             task.cancel()
2448         loop.run_until_complete(
2449             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2450         )
2451     finally:
2452         # `concurrent.futures.Future` objects cannot be cancelled once they
2453         # are already running. There might be some when the `shutdown()` happened.
2454         # Silence their logger's spew about the event loop being closed.
2455         cf_logger = logging.getLogger("concurrent.futures")
2456         cf_logger.setLevel(logging.CRITICAL)
2457         loop.close()
2458
2459
2460 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2461     """Replace `regex` with `replacement` twice on `original`.
2462
2463     This is used by string normalization to perform replaces on
2464     overlapping matches.
2465     """
2466     return regex.sub(replacement, regex.sub(replacement, original))
2467
2468
2469 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2470 CACHE_FILE = CACHE_DIR / "cache.pickle"
2471
2472
2473 def read_cache() -> Cache:
2474     """Read the cache if it exists and is well formed.
2475
2476     If it is not well formed, the call to write_cache later should resolve the issue.
2477     """
2478     if not CACHE_FILE.exists():
2479         return {}
2480
2481     with CACHE_FILE.open("rb") as fobj:
2482         try:
2483             cache: Cache = pickle.load(fobj)
2484         except pickle.UnpicklingError:
2485             return {}
2486
2487     return cache
2488
2489
2490 def get_cache_info(path: Path) -> CacheInfo:
2491     """Return the information used to check if a file is already formatted or not."""
2492     stat = path.stat()
2493     return stat.st_mtime, stat.st_size
2494
2495
2496 def filter_cached(
2497     cache: Cache, sources: Iterable[Path]
2498 ) -> Tuple[List[Path], List[Path]]:
2499     """Split a list of paths into two.
2500
2501     The first list contains paths of files that modified on disk or are not in the
2502     cache. The other list contains paths to non-modified files.
2503     """
2504     todo, done = [], []
2505     for src in sources:
2506         src = src.resolve()
2507         if cache.get(src) != get_cache_info(src):
2508             todo.append(src)
2509         else:
2510             done.append(src)
2511     return todo, done
2512
2513
2514 def write_cache(cache: Cache, sources: List[Path]) -> None:
2515     """Update the cache file."""
2516     try:
2517         if not CACHE_DIR.exists():
2518             CACHE_DIR.mkdir(parents=True)
2519         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2520         with CACHE_FILE.open("wb") as fobj:
2521             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2522     except OSError:
2523         pass
2524
2525
2526 if __name__ == "__main__":
2527     main()