]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

fixed cache file location in readme (#150)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 import pickle
5 from asyncio.base_events import BaseEventLoop
6 from concurrent.futures import Executor, ProcessPoolExecutor
7 from enum import Enum
8 from functools import partial, wraps
9 import keyword
10 import logging
11 from multiprocessing import Manager
12 import os
13 from pathlib import Path
14 import re
15 import tokenize
16 import signal
17 import sys
18 from typing import (
19     Any,
20     Callable,
21     Collection,
22     Dict,
23     Generic,
24     Iterable,
25     Iterator,
26     List,
27     Optional,
28     Pattern,
29     Set,
30     Tuple,
31     Type,
32     TypeVar,
33     Union,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46 __version__ = "18.4a2"
47 DEFAULT_LINE_LENGTH = 88
48 # types
49 syms = pygram.python_symbols
50 FileContent = str
51 Encoding = str
52 Depth = int
53 NodeType = int
54 LeafID = int
55 Priority = int
56 Index = int
57 LN = Union[Leaf, Node]
58 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
59 Timestamp = float
60 FileSize = int
61 CacheInfo = Tuple[Timestamp, FileSize]
62 Cache = Dict[Path, CacheInfo]
63 out = partial(click.secho, bold=True, err=True)
64 err = partial(click.secho, fg="red", err=True)
65
66
67 class NothingChanged(UserWarning):
68     """Raised by :func:`format_file` when reformatted code is the same as source."""
69
70
71 class CannotSplit(Exception):
72     """A readable split that fits the allotted line length is impossible.
73
74     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
75     :func:`delimiter_split`.
76     """
77
78
79 class FormatError(Exception):
80     """Base exception for `# fmt: on` and `# fmt: off` handling.
81
82     It holds the number of bytes of the prefix consumed before the format
83     control comment appeared.
84     """
85
86     def __init__(self, consumed: int) -> None:
87         super().__init__(consumed)
88         self.consumed = consumed
89
90     def trim_prefix(self, leaf: Leaf) -> None:
91         leaf.prefix = leaf.prefix[self.consumed:]
92
93     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
94         """Returns a new Leaf from the consumed part of the prefix."""
95         unformatted_prefix = leaf.prefix[:self.consumed]
96         return Leaf(token.NEWLINE, unformatted_prefix)
97
98
99 class FormatOn(FormatError):
100     """Found a comment like `# fmt: on` in the file."""
101
102
103 class FormatOff(FormatError):
104     """Found a comment like `# fmt: off` in the file."""
105
106
107 class WriteBack(Enum):
108     NO = 0
109     YES = 1
110     DIFF = 2
111
112
113 class Changed(Enum):
114     NO = 0
115     CACHED = 1
116     YES = 2
117
118
119 @click.command()
120 @click.option(
121     "-l",
122     "--line-length",
123     type=int,
124     default=DEFAULT_LINE_LENGTH,
125     help="How many character per line to allow.",
126     show_default=True,
127 )
128 @click.option(
129     "--check",
130     is_flag=True,
131     help=(
132         "Don't write the files back, just return the status.  Return code 0 "
133         "means nothing would change.  Return code 1 means some files would be "
134         "reformatted.  Return code 123 means there was an internal error."
135     ),
136 )
137 @click.option(
138     "--diff",
139     is_flag=True,
140     help="Don't write the files back, just output a diff for each file on stdout.",
141 )
142 @click.option(
143     "--fast/--safe",
144     is_flag=True,
145     help="If --fast given, skip temporary sanity checks. [default: --safe]",
146 )
147 @click.option(
148     "-q",
149     "--quiet",
150     is_flag=True,
151     help=(
152         "Don't emit non-error messages to stderr. Errors are still emitted, "
153         "silence those with 2>/dev/null."
154     ),
155 )
156 @click.version_option(version=__version__)
157 @click.argument(
158     "src",
159     nargs=-1,
160     type=click.Path(
161         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
162     ),
163 )
164 @click.pass_context
165 def main(
166     ctx: click.Context,
167     line_length: int,
168     check: bool,
169     diff: bool,
170     fast: bool,
171     quiet: bool,
172     src: List[str],
173 ) -> None:
174     """The uncompromising code formatter."""
175     sources: List[Path] = []
176     for s in src:
177         p = Path(s)
178         if p.is_dir():
179             sources.extend(gen_python_files_in_dir(p))
180         elif p.is_file():
181             # if a file was explicitly given, we don't care about its extension
182             sources.append(p)
183         elif s == "-":
184             sources.append(Path("-"))
185         else:
186             err(f"invalid path: {s}")
187     if check and diff:
188         exc = click.ClickException("Options --check and --diff are mutually exclusive")
189         exc.exit_code = 2
190         raise exc
191
192     if check:
193         write_back = WriteBack.NO
194     elif diff:
195         write_back = WriteBack.DIFF
196     else:
197         write_back = WriteBack.YES
198     if len(sources) == 0:
199         ctx.exit(0)
200         return
201
202     elif len(sources) == 1:
203         return_code = reformat_one(sources[0], line_length, fast, quiet, write_back)
204     else:
205         loop = asyncio.get_event_loop()
206         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207         return_code = 1
208         try:
209             return_code = loop.run_until_complete(
210                 schedule_formatting(
211                     sources, line_length, write_back, fast, quiet, loop, executor
212                 )
213             )
214         finally:
215             shutdown(loop)
216     ctx.exit(return_code)
217
218
219 def reformat_one(
220     src: Path, line_length: int, fast: bool, quiet: bool, write_back: WriteBack
221 ) -> int:
222     """Reformat a single file under `src` without spawning child processes.
223
224     If `quiet` is True, non-error messages are not output. `line_length`,
225     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
226     """
227     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
228     try:
229         changed = Changed.NO
230         if not src.is_file() and str(src) == "-":
231             if format_stdin_to_stdout(
232                 line_length=line_length, fast=fast, write_back=write_back
233             ):
234                 changed = Changed.YES
235         else:
236             cache: Cache = {}
237             if write_back != WriteBack.DIFF:
238                 cache = read_cache()
239                 src = src.resolve()
240                 if src in cache and cache[src] == get_cache_info(src):
241                     changed = Changed.CACHED
242             if (
243                 changed is not Changed.CACHED
244                 and format_file_in_place(
245                     src, line_length=line_length, fast=fast, write_back=write_back
246                 )
247             ):
248                 changed = Changed.YES
249             if write_back != WriteBack.DIFF and changed is not Changed.NO:
250                 write_cache(cache, [src])
251         report.done(src, changed)
252     except Exception as exc:
253         report.failed(src, str(exc))
254     return report.return_code
255
256
257 async def schedule_formatting(
258     sources: List[Path],
259     line_length: int,
260     write_back: WriteBack,
261     fast: bool,
262     quiet: bool,
263     loop: BaseEventLoop,
264     executor: Executor,
265 ) -> int:
266     """Run formatting of `sources` in parallel using the provided `executor`.
267
268     (Use ProcessPoolExecutors for actual parallelism.)
269
270     `line_length`, `write_back`, and `fast` options are passed to
271     :func:`format_file_in_place`.
272     """
273     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
274     cache: Cache = {}
275     if write_back != WriteBack.DIFF:
276         cache = read_cache()
277         sources, cached = filter_cached(cache, sources)
278         for src in cached:
279             report.done(src, Changed.CACHED)
280     cancelled = []
281     formatted = []
282     if sources:
283         lock = None
284         if write_back == WriteBack.DIFF:
285             # For diff output, we need locks to ensure we don't interleave output
286             # from different processes.
287             manager = Manager()
288             lock = manager.Lock()
289         tasks = {
290             src: loop.run_in_executor(
291                 executor, format_file_in_place, src, line_length, fast, write_back, lock
292             )
293             for src in sources
294         }
295         _task_values = list(tasks.values())
296         loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
297         loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
298         await asyncio.wait(_task_values)
299         for src, task in tasks.items():
300             if not task.done():
301                 report.failed(src, "timed out, cancelling")
302                 task.cancel()
303                 cancelled.append(task)
304             elif task.cancelled():
305                 cancelled.append(task)
306             elif task.exception():
307                 report.failed(src, str(task.exception()))
308             else:
309                 formatted.append(src)
310                 report.done(src, Changed.YES if task.result() else Changed.NO)
311
312     if cancelled:
313         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
314     elif not quiet:
315         out("All done! ✨ 🍰 ✨")
316     if not quiet:
317         click.echo(str(report))
318
319     if write_back != WriteBack.DIFF and formatted:
320         write_cache(cache, formatted)
321
322     return report.return_code
323
324
325 def format_file_in_place(
326     src: Path,
327     line_length: int,
328     fast: bool,
329     write_back: WriteBack = WriteBack.NO,
330     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
331 ) -> bool:
332     """Format file under `src` path. Return True if changed.
333
334     If `write_back` is True, write reformatted code back to stdout.
335     `line_length` and `fast` options are passed to :func:`format_file_contents`.
336     """
337
338     with tokenize.open(src) as src_buffer:
339         src_contents = src_buffer.read()
340     try:
341         dst_contents = format_file_contents(
342             src_contents, line_length=line_length, fast=fast
343         )
344     except NothingChanged:
345         return False
346
347     if write_back == write_back.YES:
348         with open(src, "w", encoding=src_buffer.encoding) as f:
349             f.write(dst_contents)
350     elif write_back == write_back.DIFF:
351         src_name = f"{src.name}  (original)"
352         dst_name = f"{src.name}  (formatted)"
353         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
354         if lock:
355             lock.acquire()
356         try:
357             sys.stdout.write(diff_contents)
358         finally:
359             if lock:
360                 lock.release()
361     return True
362
363
364 def format_stdin_to_stdout(
365     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
366 ) -> bool:
367     """Format file on stdin. Return True if changed.
368
369     If `write_back` is True, write reformatted code back to stdout.
370     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
371     """
372     src = sys.stdin.read()
373     dst = src
374     try:
375         dst = format_file_contents(src, line_length=line_length, fast=fast)
376         return True
377
378     except NothingChanged:
379         return False
380
381     finally:
382         if write_back == WriteBack.YES:
383             sys.stdout.write(dst)
384         elif write_back == WriteBack.DIFF:
385             src_name = "<stdin>  (original)"
386             dst_name = "<stdin>  (formatted)"
387             sys.stdout.write(diff(src, dst, src_name, dst_name))
388
389
390 def format_file_contents(
391     src_contents: str, line_length: int, fast: bool
392 ) -> FileContent:
393     """Reformat contents a file and return new contents.
394
395     If `fast` is False, additionally confirm that the reformatted code is
396     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
397     `line_length` is passed to :func:`format_str`.
398     """
399     if src_contents.strip() == "":
400         raise NothingChanged
401
402     dst_contents = format_str(src_contents, line_length=line_length)
403     if src_contents == dst_contents:
404         raise NothingChanged
405
406     if not fast:
407         assert_equivalent(src_contents, dst_contents)
408         assert_stable(src_contents, dst_contents, line_length=line_length)
409     return dst_contents
410
411
412 def format_str(src_contents: str, line_length: int) -> FileContent:
413     """Reformat a string and return new contents.
414
415     `line_length` determines how many characters per line are allowed.
416     """
417     src_node = lib2to3_parse(src_contents)
418     dst_contents = ""
419     lines = LineGenerator()
420     elt = EmptyLineTracker()
421     py36 = is_python36(src_node)
422     empty_line = Line()
423     after = 0
424     for current_line in lines.visit(src_node):
425         for _ in range(after):
426             dst_contents += str(empty_line)
427         before, after = elt.maybe_empty_lines(current_line)
428         for _ in range(before):
429             dst_contents += str(empty_line)
430         for line in split_line(current_line, line_length=line_length, py36=py36):
431             dst_contents += str(line)
432     return dst_contents
433
434
435 GRAMMARS = [
436     pygram.python_grammar_no_print_statement_no_exec_statement,
437     pygram.python_grammar_no_print_statement,
438     pygram.python_grammar_no_exec_statement,
439     pygram.python_grammar,
440 ]
441
442
443 def lib2to3_parse(src_txt: str) -> Node:
444     """Given a string with source, return the lib2to3 Node."""
445     grammar = pygram.python_grammar_no_print_statement
446     if src_txt[-1] != "\n":
447         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
448         src_txt += nl
449     for grammar in GRAMMARS:
450         drv = driver.Driver(grammar, pytree.convert)
451         try:
452             result = drv.parse_string(src_txt, True)
453             break
454
455         except ParseError as pe:
456             lineno, column = pe.context[1]
457             lines = src_txt.splitlines()
458             try:
459                 faulty_line = lines[lineno - 1]
460             except IndexError:
461                 faulty_line = "<line number missing in source>"
462             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
463     else:
464         raise exc from None
465
466     if isinstance(result, Leaf):
467         result = Node(syms.file_input, [result])
468     return result
469
470
471 def lib2to3_unparse(node: Node) -> str:
472     """Given a lib2to3 node, return its string representation."""
473     code = str(node)
474     return code
475
476
477 T = TypeVar("T")
478
479
480 class Visitor(Generic[T]):
481     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
482
483     def visit(self, node: LN) -> Iterator[T]:
484         """Main method to visit `node` and its children.
485
486         It tries to find a `visit_*()` method for the given `node.type`, like
487         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
488         If no dedicated `visit_*()` method is found, chooses `visit_default()`
489         instead.
490
491         Then yields objects of type `T` from the selected visitor.
492         """
493         if node.type < 256:
494             name = token.tok_name[node.type]
495         else:
496             name = type_repr(node.type)
497         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
498
499     def visit_default(self, node: LN) -> Iterator[T]:
500         """Default `visit_*()` implementation. Recurses to children of `node`."""
501         if isinstance(node, Node):
502             for child in node.children:
503                 yield from self.visit(child)
504
505
506 @dataclass
507 class DebugVisitor(Visitor[T]):
508     tree_depth: int = 0
509
510     def visit_default(self, node: LN) -> Iterator[T]:
511         indent = " " * (2 * self.tree_depth)
512         if isinstance(node, Node):
513             _type = type_repr(node.type)
514             out(f"{indent}{_type}", fg="yellow")
515             self.tree_depth += 1
516             for child in node.children:
517                 yield from self.visit(child)
518
519             self.tree_depth -= 1
520             out(f"{indent}/{_type}", fg="yellow", bold=False)
521         else:
522             _type = token.tok_name.get(node.type, str(node.type))
523             out(f"{indent}{_type}", fg="blue", nl=False)
524             if node.prefix:
525                 # We don't have to handle prefixes for `Node` objects since
526                 # that delegates to the first child anyway.
527                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
528             out(f" {node.value!r}", fg="blue", bold=False)
529
530     @classmethod
531     def show(cls, code: str) -> None:
532         """Pretty-print the lib2to3 AST of a given string of `code`.
533
534         Convenience method for debugging.
535         """
536         v: DebugVisitor[None] = DebugVisitor()
537         list(v.visit(lib2to3_parse(code)))
538
539
540 KEYWORDS = set(keyword.kwlist)
541 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
542 FLOW_CONTROL = {"return", "raise", "break", "continue"}
543 STATEMENT = {
544     syms.if_stmt,
545     syms.while_stmt,
546     syms.for_stmt,
547     syms.try_stmt,
548     syms.except_clause,
549     syms.with_stmt,
550     syms.funcdef,
551     syms.classdef,
552 }
553 STANDALONE_COMMENT = 153
554 LOGIC_OPERATORS = {"and", "or"}
555 COMPARATORS = {
556     token.LESS,
557     token.GREATER,
558     token.EQEQUAL,
559     token.NOTEQUAL,
560     token.LESSEQUAL,
561     token.GREATEREQUAL,
562 }
563 MATH_OPERATORS = {
564     token.PLUS,
565     token.MINUS,
566     token.STAR,
567     token.SLASH,
568     token.VBAR,
569     token.AMPER,
570     token.PERCENT,
571     token.CIRCUMFLEX,
572     token.TILDE,
573     token.LEFTSHIFT,
574     token.RIGHTSHIFT,
575     token.DOUBLESTAR,
576     token.DOUBLESLASH,
577 }
578 STARS = {token.STAR, token.DOUBLESTAR}
579 VARARGS_PARENTS = {
580     syms.arglist,
581     syms.argument,  # double star in arglist
582     syms.trailer,  # single argument to call
583     syms.typedargslist,
584     syms.varargslist,  # lambdas
585 }
586 UNPACKING_PARENTS = {
587     syms.atom,  # single element of a list or set literal
588     syms.dictsetmaker,
589     syms.listmaker,
590     syms.testlist_gexp,
591 }
592 COMPREHENSION_PRIORITY = 20
593 COMMA_PRIORITY = 10
594 LOGIC_PRIORITY = 5
595 STRING_PRIORITY = 4
596 COMPARATOR_PRIORITY = 3
597 MATH_PRIORITY = 1
598
599
600 @dataclass
601 class BracketTracker:
602     """Keeps track of brackets on a line."""
603
604     depth: int = 0
605     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
606     delimiters: Dict[LeafID, Priority] = Factory(dict)
607     previous: Optional[Leaf] = None
608
609     def mark(self, leaf: Leaf) -> None:
610         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
611
612         All leaves receive an int `bracket_depth` field that stores how deep
613         within brackets a given leaf is. 0 means there are no enclosing brackets
614         that started on this line.
615
616         If a leaf is itself a closing bracket, it receives an `opening_bracket`
617         field that it forms a pair with. This is a one-directional link to
618         avoid reference cycles.
619
620         If a leaf is a delimiter (a token on which Black can split the line if
621         needed) and it's on depth 0, its `id()` is stored in the tracker's
622         `delimiters` field.
623         """
624         if leaf.type == token.COMMENT:
625             return
626
627         if leaf.type in CLOSING_BRACKETS:
628             self.depth -= 1
629             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
630             leaf.opening_bracket = opening_bracket
631         leaf.bracket_depth = self.depth
632         if self.depth == 0:
633             delim = is_split_before_delimiter(leaf, self.previous)
634             if delim and self.previous is not None:
635                 self.delimiters[id(self.previous)] = delim
636             else:
637                 delim = is_split_after_delimiter(leaf, self.previous)
638                 if delim:
639                     self.delimiters[id(leaf)] = delim
640         if leaf.type in OPENING_BRACKETS:
641             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
642             self.depth += 1
643         self.previous = leaf
644
645     def any_open_brackets(self) -> bool:
646         """Return True if there is an yet unmatched open bracket on the line."""
647         return bool(self.bracket_match)
648
649     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
650         """Return the highest priority of a delimiter found on the line.
651
652         Values are consistent with what `is_split_*_delimiter()` return.
653         Raises ValueError on no delimiters.
654         """
655         return max(v for k, v in self.delimiters.items() if k not in exclude)
656
657
658 @dataclass
659 class Line:
660     """Holds leaves and comments. Can be printed with `str(line)`."""
661
662     depth: int = 0
663     leaves: List[Leaf] = Factory(list)
664     comments: List[Tuple[Index, Leaf]] = Factory(list)
665     bracket_tracker: BracketTracker = Factory(BracketTracker)
666     inside_brackets: bool = False
667     has_for: bool = False
668     _for_loop_variable: bool = False
669
670     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
671         """Add a new `leaf` to the end of the line.
672
673         Unless `preformatted` is True, the `leaf` will receive a new consistent
674         whitespace prefix and metadata applied by :class:`BracketTracker`.
675         Trailing commas are maybe removed, unpacked for loop variables are
676         demoted from being delimiters.
677
678         Inline comments are put aside.
679         """
680         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
681         if not has_value:
682             return
683
684         if self.leaves and not preformatted:
685             # Note: at this point leaf.prefix should be empty except for
686             # imports, for which we only preserve newlines.
687             leaf.prefix += whitespace(leaf)
688         if self.inside_brackets or not preformatted:
689             self.maybe_decrement_after_for_loop_variable(leaf)
690             self.bracket_tracker.mark(leaf)
691             self.maybe_remove_trailing_comma(leaf)
692             self.maybe_increment_for_loop_variable(leaf)
693
694         if not self.append_comment(leaf):
695             self.leaves.append(leaf)
696
697     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
698         """Like :func:`append()` but disallow invalid standalone comment structure.
699
700         Raises ValueError when any `leaf` is appended after a standalone comment
701         or when a standalone comment is not the first leaf on the line.
702         """
703         if self.bracket_tracker.depth == 0:
704             if self.is_comment:
705                 raise ValueError("cannot append to standalone comments")
706
707             if self.leaves and leaf.type == STANDALONE_COMMENT:
708                 raise ValueError(
709                     "cannot append standalone comments to a populated line"
710                 )
711
712         self.append(leaf, preformatted=preformatted)
713
714     @property
715     def is_comment(self) -> bool:
716         """Is this line a standalone comment?"""
717         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
718
719     @property
720     def is_decorator(self) -> bool:
721         """Is this line a decorator?"""
722         return bool(self) and self.leaves[0].type == token.AT
723
724     @property
725     def is_import(self) -> bool:
726         """Is this an import line?"""
727         return bool(self) and is_import(self.leaves[0])
728
729     @property
730     def is_class(self) -> bool:
731         """Is this line a class definition?"""
732         return (
733             bool(self)
734             and self.leaves[0].type == token.NAME
735             and self.leaves[0].value == "class"
736         )
737
738     @property
739     def is_def(self) -> bool:
740         """Is this a function definition? (Also returns True for async defs.)"""
741         try:
742             first_leaf = self.leaves[0]
743         except IndexError:
744             return False
745
746         try:
747             second_leaf: Optional[Leaf] = self.leaves[1]
748         except IndexError:
749             second_leaf = None
750         return (
751             (first_leaf.type == token.NAME and first_leaf.value == "def")
752             or (
753                 first_leaf.type == token.ASYNC
754                 and second_leaf is not None
755                 and second_leaf.type == token.NAME
756                 and second_leaf.value == "def"
757             )
758         )
759
760     @property
761     def is_flow_control(self) -> bool:
762         """Is this line a flow control statement?
763
764         Those are `return`, `raise`, `break`, and `continue`.
765         """
766         return (
767             bool(self)
768             and self.leaves[0].type == token.NAME
769             and self.leaves[0].value in FLOW_CONTROL
770         )
771
772     @property
773     def is_yield(self) -> bool:
774         """Is this line a yield statement?"""
775         return (
776             bool(self)
777             and self.leaves[0].type == token.NAME
778             and self.leaves[0].value == "yield"
779         )
780
781     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
782         """If so, needs to be split before emitting."""
783         for leaf in self.leaves:
784             if leaf.type == STANDALONE_COMMENT:
785                 if leaf.bracket_depth <= depth_limit:
786                     return True
787
788         return False
789
790     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
791         """Remove trailing comma if there is one and it's safe."""
792         if not (
793             self.leaves
794             and self.leaves[-1].type == token.COMMA
795             and closing.type in CLOSING_BRACKETS
796         ):
797             return False
798
799         if closing.type == token.RBRACE:
800             self.remove_trailing_comma()
801             return True
802
803         if closing.type == token.RSQB:
804             comma = self.leaves[-1]
805             if comma.parent and comma.parent.type == syms.listmaker:
806                 self.remove_trailing_comma()
807                 return True
808
809         # For parens let's check if it's safe to remove the comma.  If the
810         # trailing one is the only one, we might mistakenly change a tuple
811         # into a different type by removing the comma.
812         depth = closing.bracket_depth + 1
813         commas = 0
814         opening = closing.opening_bracket
815         for _opening_index, leaf in enumerate(self.leaves):
816             if leaf is opening:
817                 break
818
819         else:
820             return False
821
822         for leaf in self.leaves[_opening_index + 1:]:
823             if leaf is closing:
824                 break
825
826             bracket_depth = leaf.bracket_depth
827             if bracket_depth == depth and leaf.type == token.COMMA:
828                 commas += 1
829                 if leaf.parent and leaf.parent.type == syms.arglist:
830                     commas += 1
831                     break
832
833         if commas > 1:
834             self.remove_trailing_comma()
835             return True
836
837         return False
838
839     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
840         """In a for loop, or comprehension, the variables are often unpacks.
841
842         To avoid splitting on the comma in this situation, increase the depth of
843         tokens between `for` and `in`.
844         """
845         if leaf.type == token.NAME and leaf.value == "for":
846             self.has_for = True
847             self.bracket_tracker.depth += 1
848             self._for_loop_variable = True
849             return True
850
851         return False
852
853     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
854         """See `maybe_increment_for_loop_variable` above for explanation."""
855         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
856             self.bracket_tracker.depth -= 1
857             self._for_loop_variable = False
858             return True
859
860         return False
861
862     def append_comment(self, comment: Leaf) -> bool:
863         """Add an inline or standalone comment to the line."""
864         if (
865             comment.type == STANDALONE_COMMENT
866             and self.bracket_tracker.any_open_brackets()
867         ):
868             comment.prefix = ""
869             return False
870
871         if comment.type != token.COMMENT:
872             return False
873
874         after = len(self.leaves) - 1
875         if after == -1:
876             comment.type = STANDALONE_COMMENT
877             comment.prefix = ""
878             return False
879
880         else:
881             self.comments.append((after, comment))
882             return True
883
884     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
885         """Generate comments that should appear directly after `leaf`."""
886         for _leaf_index, _leaf in enumerate(self.leaves):
887             if leaf is _leaf:
888                 break
889
890         else:
891             return
892
893         for index, comment_after in self.comments:
894             if _leaf_index == index:
895                 yield comment_after
896
897     def remove_trailing_comma(self) -> None:
898         """Remove the trailing comma and moves the comments attached to it."""
899         comma_index = len(self.leaves) - 1
900         for i in range(len(self.comments)):
901             comment_index, comment = self.comments[i]
902             if comment_index == comma_index:
903                 self.comments[i] = (comma_index - 1, comment)
904         self.leaves.pop()
905
906     def __str__(self) -> str:
907         """Render the line."""
908         if not self:
909             return "\n"
910
911         indent = "    " * self.depth
912         leaves = iter(self.leaves)
913         first = next(leaves)
914         res = f"{first.prefix}{indent}{first.value}"
915         for leaf in leaves:
916             res += str(leaf)
917         for _, comment in self.comments:
918             res += str(comment)
919         return res + "\n"
920
921     def __bool__(self) -> bool:
922         """Return True if the line has leaves or comments."""
923         return bool(self.leaves or self.comments)
924
925
926 class UnformattedLines(Line):
927     """Just like :class:`Line` but stores lines which aren't reformatted."""
928
929     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
930         """Just add a new `leaf` to the end of the lines.
931
932         The `preformatted` argument is ignored.
933
934         Keeps track of indentation `depth`, which is useful when the user
935         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
936         """
937         try:
938             list(generate_comments(leaf))
939         except FormatOn as f_on:
940             self.leaves.append(f_on.leaf_from_consumed(leaf))
941             raise
942
943         self.leaves.append(leaf)
944         if leaf.type == token.INDENT:
945             self.depth += 1
946         elif leaf.type == token.DEDENT:
947             self.depth -= 1
948
949     def __str__(self) -> str:
950         """Render unformatted lines from leaves which were added with `append()`.
951
952         `depth` is not used for indentation in this case.
953         """
954         if not self:
955             return "\n"
956
957         res = ""
958         for leaf in self.leaves:
959             res += str(leaf)
960         return res
961
962     def append_comment(self, comment: Leaf) -> bool:
963         """Not implemented in this class. Raises `NotImplementedError`."""
964         raise NotImplementedError("Unformatted lines don't store comments separately.")
965
966     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
967         """Does nothing and returns False."""
968         return False
969
970     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
971         """Does nothing and returns False."""
972         return False
973
974
975 @dataclass
976 class EmptyLineTracker:
977     """Provides a stateful method that returns the number of potential extra
978     empty lines needed before and after the currently processed line.
979
980     Note: this tracker works on lines that haven't been split yet.  It assumes
981     the prefix of the first leaf consists of optional newlines.  Those newlines
982     are consumed by `maybe_empty_lines()` and included in the computation.
983     """
984     previous_line: Optional[Line] = None
985     previous_after: int = 0
986     previous_defs: List[int] = Factory(list)
987
988     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
989         """Return the number of extra empty lines before and after the `current_line`.
990
991         This is for separating `def`, `async def` and `class` with extra empty
992         lines (two on module-level), as well as providing an extra empty line
993         after flow control keywords to make them more prominent.
994         """
995         if isinstance(current_line, UnformattedLines):
996             return 0, 0
997
998         before, after = self._maybe_empty_lines(current_line)
999         before -= self.previous_after
1000         self.previous_after = after
1001         self.previous_line = current_line
1002         return before, after
1003
1004     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1005         max_allowed = 1
1006         if current_line.depth == 0:
1007             max_allowed = 2
1008         if current_line.leaves:
1009             # Consume the first leaf's extra newlines.
1010             first_leaf = current_line.leaves[0]
1011             before = first_leaf.prefix.count("\n")
1012             before = min(before, max_allowed)
1013             first_leaf.prefix = ""
1014         else:
1015             before = 0
1016         depth = current_line.depth
1017         while self.previous_defs and self.previous_defs[-1] >= depth:
1018             self.previous_defs.pop()
1019             before = 1 if depth else 2
1020         is_decorator = current_line.is_decorator
1021         if is_decorator or current_line.is_def or current_line.is_class:
1022             if not is_decorator:
1023                 self.previous_defs.append(depth)
1024             if self.previous_line is None:
1025                 # Don't insert empty lines before the first line in the file.
1026                 return 0, 0
1027
1028             if self.previous_line and self.previous_line.is_decorator:
1029                 # Don't insert empty lines between decorators.
1030                 return 0, 0
1031
1032             newlines = 2
1033             if current_line.depth:
1034                 newlines -= 1
1035             return newlines, 0
1036
1037         if current_line.is_flow_control:
1038             return before, 1
1039
1040         if (
1041             self.previous_line
1042             and self.previous_line.is_import
1043             and not current_line.is_import
1044             and depth == self.previous_line.depth
1045         ):
1046             return (before or 1), 0
1047
1048         if (
1049             self.previous_line
1050             and self.previous_line.is_yield
1051             and (not current_line.is_yield or depth != self.previous_line.depth)
1052         ):
1053             return (before or 1), 0
1054
1055         return before, 0
1056
1057
1058 @dataclass
1059 class LineGenerator(Visitor[Line]):
1060     """Generates reformatted Line objects.  Empty lines are not emitted.
1061
1062     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1063     in ways that will no longer stringify to valid Python code on the tree.
1064     """
1065     current_line: Line = Factory(Line)
1066
1067     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1068         """Generate a line.
1069
1070         If the line is empty, only emit if it makes sense.
1071         If the line is too long, split it first and then generate.
1072
1073         If any lines were generated, set up a new current_line.
1074         """
1075         if not self.current_line:
1076             if self.current_line.__class__ == type:
1077                 self.current_line.depth += indent
1078             else:
1079                 self.current_line = type(depth=self.current_line.depth + indent)
1080             return  # Line is empty, don't emit. Creating a new one unnecessary.
1081
1082         complete_line = self.current_line
1083         self.current_line = type(depth=complete_line.depth + indent)
1084         yield complete_line
1085
1086     def visit(self, node: LN) -> Iterator[Line]:
1087         """Main method to visit `node` and its children.
1088
1089         Yields :class:`Line` objects.
1090         """
1091         if isinstance(self.current_line, UnformattedLines):
1092             # File contained `# fmt: off`
1093             yield from self.visit_unformatted(node)
1094
1095         else:
1096             yield from super().visit(node)
1097
1098     def visit_default(self, node: LN) -> Iterator[Line]:
1099         """Default `visit_*()` implementation. Recurses to children of `node`."""
1100         if isinstance(node, Leaf):
1101             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1102             try:
1103                 for comment in generate_comments(node):
1104                     if any_open_brackets:
1105                         # any comment within brackets is subject to splitting
1106                         self.current_line.append(comment)
1107                     elif comment.type == token.COMMENT:
1108                         # regular trailing comment
1109                         self.current_line.append(comment)
1110                         yield from self.line()
1111
1112                     else:
1113                         # regular standalone comment
1114                         yield from self.line()
1115
1116                         self.current_line.append(comment)
1117                         yield from self.line()
1118
1119             except FormatOff as f_off:
1120                 f_off.trim_prefix(node)
1121                 yield from self.line(type=UnformattedLines)
1122                 yield from self.visit(node)
1123
1124             except FormatOn as f_on:
1125                 # This only happens here if somebody says "fmt: on" multiple
1126                 # times in a row.
1127                 f_on.trim_prefix(node)
1128                 yield from self.visit_default(node)
1129
1130             else:
1131                 normalize_prefix(node, inside_brackets=any_open_brackets)
1132                 if node.type == token.STRING:
1133                     normalize_string_quotes(node)
1134                 if node.type not in WHITESPACE:
1135                     self.current_line.append(node)
1136         yield from super().visit_default(node)
1137
1138     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1139         """Increase indentation level, maybe yield a line."""
1140         # In blib2to3 INDENT never holds comments.
1141         yield from self.line(+1)
1142         yield from self.visit_default(node)
1143
1144     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1145         """Decrease indentation level, maybe yield a line."""
1146         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1147         yield from self.line(-1)
1148
1149     def visit_stmt(
1150         self, node: Node, keywords: Set[str], parens: Set[str]
1151     ) -> Iterator[Line]:
1152         """Visit a statement.
1153
1154         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1155         `def`, `with`, `class`, and `assert`.
1156
1157         The relevant Python language `keywords` for a given statement will be
1158         NAME leaves within it. This methods puts those on a separate line.
1159
1160         `parens` holds pairs of nodes where invisible parentheses should be put.
1161         Keys hold nodes after which opening parentheses should be put, values
1162         hold nodes before which closing parentheses should be put.
1163         """
1164         normalize_invisible_parens(node, parens_after=parens)
1165         for child in node.children:
1166             if child.type == token.NAME and child.value in keywords:  # type: ignore
1167                 yield from self.line()
1168
1169             yield from self.visit(child)
1170
1171     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1172         """Visit a statement without nested statements."""
1173         is_suite_like = node.parent and node.parent.type in STATEMENT
1174         if is_suite_like:
1175             yield from self.line(+1)
1176             yield from self.visit_default(node)
1177             yield from self.line(-1)
1178
1179         else:
1180             yield from self.line()
1181             yield from self.visit_default(node)
1182
1183     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1184         """Visit `async def`, `async for`, `async with`."""
1185         yield from self.line()
1186
1187         children = iter(node.children)
1188         for child in children:
1189             yield from self.visit(child)
1190
1191             if child.type == token.ASYNC:
1192                 break
1193
1194         internal_stmt = next(children)
1195         for child in internal_stmt.children:
1196             yield from self.visit(child)
1197
1198     def visit_decorators(self, node: Node) -> Iterator[Line]:
1199         """Visit decorators."""
1200         for child in node.children:
1201             yield from self.line()
1202             yield from self.visit(child)
1203
1204     def visit_import_from(self, node: Node) -> Iterator[Line]:
1205         """Visit import_from and maybe put invisible parentheses.
1206
1207         This is separate from `visit_stmt` because import statements don't
1208         support arbitrary atoms and thus handling of parentheses is custom.
1209         """
1210         check_lpar = False
1211         for index, child in enumerate(node.children):
1212             if check_lpar:
1213                 if child.type == token.LPAR:
1214                     # make parentheses invisible
1215                     child.value = ""  # type: ignore
1216                     node.children[-1].value = ""  # type: ignore
1217                 else:
1218                     # insert invisible parentheses
1219                     node.insert_child(index, Leaf(token.LPAR, ""))
1220                     node.append_child(Leaf(token.RPAR, ""))
1221                 break
1222
1223             check_lpar = (
1224                 child.type == token.NAME and child.value == "import"  # type: ignore
1225             )
1226
1227         for child in node.children:
1228             yield from self.visit(child)
1229
1230     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1231         """Remove a semicolon and put the other statement on a separate line."""
1232         yield from self.line()
1233
1234     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1235         """End of file. Process outstanding comments and end with a newline."""
1236         yield from self.visit_default(leaf)
1237         yield from self.line()
1238
1239     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1240         """Used when file contained a `# fmt: off`."""
1241         if isinstance(node, Node):
1242             for child in node.children:
1243                 yield from self.visit(child)
1244
1245         else:
1246             try:
1247                 self.current_line.append(node)
1248             except FormatOn as f_on:
1249                 f_on.trim_prefix(node)
1250                 yield from self.line()
1251                 yield from self.visit(node)
1252
1253             if node.type == token.ENDMARKER:
1254                 # somebody decided not to put a final `# fmt: on`
1255                 yield from self.line()
1256
1257     def __attrs_post_init__(self) -> None:
1258         """You are in a twisty little maze of passages."""
1259         v = self.visit_stmt
1260         Ø: Set[str] = set()
1261         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1262         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1263         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1264         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1265         self.visit_try_stmt = partial(
1266             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1267         )
1268         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1269         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1270         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1271         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1272         self.visit_async_funcdef = self.visit_async_stmt
1273         self.visit_decorated = self.visit_decorators
1274
1275
1276 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1277 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1278 OPENING_BRACKETS = set(BRACKET.keys())
1279 CLOSING_BRACKETS = set(BRACKET.values())
1280 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1281 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1282
1283
1284 def whitespace(leaf: Leaf) -> str:  # noqa C901
1285     """Return whitespace prefix if needed for the given `leaf`."""
1286     NO = ""
1287     SPACE = " "
1288     DOUBLESPACE = "  "
1289     t = leaf.type
1290     p = leaf.parent
1291     v = leaf.value
1292     if t in ALWAYS_NO_SPACE:
1293         return NO
1294
1295     if t == token.COMMENT:
1296         return DOUBLESPACE
1297
1298     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1299     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1300         return NO
1301
1302     prev = leaf.prev_sibling
1303     if not prev:
1304         prevp = preceding_leaf(p)
1305         if not prevp or prevp.type in OPENING_BRACKETS:
1306             return NO
1307
1308         if t == token.COLON:
1309             return SPACE if prevp.type == token.COMMA else NO
1310
1311         if prevp.type == token.EQUAL:
1312             if prevp.parent:
1313                 if prevp.parent.type in {
1314                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1315                 }:
1316                     return NO
1317
1318                 elif prevp.parent.type == syms.typedargslist:
1319                     # A bit hacky: if the equal sign has whitespace, it means we
1320                     # previously found it's a typed argument.  So, we're using
1321                     # that, too.
1322                     return prevp.prefix
1323
1324         elif prevp.type in STARS:
1325             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1326                 return NO
1327
1328         elif prevp.type == token.COLON:
1329             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1330                 return NO
1331
1332         elif (
1333             prevp.parent
1334             and prevp.parent.type == syms.factor
1335             and prevp.type in MATH_OPERATORS
1336         ):
1337             return NO
1338
1339         elif (
1340             prevp.type == token.RIGHTSHIFT
1341             and prevp.parent
1342             and prevp.parent.type == syms.shift_expr
1343             and prevp.prev_sibling
1344             and prevp.prev_sibling.type == token.NAME
1345             and prevp.prev_sibling.value == "print"  # type: ignore
1346         ):
1347             # Python 2 print chevron
1348             return NO
1349
1350     elif prev.type in OPENING_BRACKETS:
1351         return NO
1352
1353     if p.type in {syms.parameters, syms.arglist}:
1354         # untyped function signatures or calls
1355         if not prev or prev.type != token.COMMA:
1356             return NO
1357
1358     elif p.type == syms.varargslist:
1359         # lambdas
1360         if prev and prev.type != token.COMMA:
1361             return NO
1362
1363     elif p.type == syms.typedargslist:
1364         # typed function signatures
1365         if not prev:
1366             return NO
1367
1368         if t == token.EQUAL:
1369             if prev.type != syms.tname:
1370                 return NO
1371
1372         elif prev.type == token.EQUAL:
1373             # A bit hacky: if the equal sign has whitespace, it means we
1374             # previously found it's a typed argument.  So, we're using that, too.
1375             return prev.prefix
1376
1377         elif prev.type != token.COMMA:
1378             return NO
1379
1380     elif p.type == syms.tname:
1381         # type names
1382         if not prev:
1383             prevp = preceding_leaf(p)
1384             if not prevp or prevp.type != token.COMMA:
1385                 return NO
1386
1387     elif p.type == syms.trailer:
1388         # attributes and calls
1389         if t == token.LPAR or t == token.RPAR:
1390             return NO
1391
1392         if not prev:
1393             if t == token.DOT:
1394                 prevp = preceding_leaf(p)
1395                 if not prevp or prevp.type != token.NUMBER:
1396                     return NO
1397
1398             elif t == token.LSQB:
1399                 return NO
1400
1401         elif prev.type != token.COMMA:
1402             return NO
1403
1404     elif p.type == syms.argument:
1405         # single argument
1406         if t == token.EQUAL:
1407             return NO
1408
1409         if not prev:
1410             prevp = preceding_leaf(p)
1411             if not prevp or prevp.type == token.LPAR:
1412                 return NO
1413
1414         elif prev.type in {token.EQUAL} | STARS:
1415             return NO
1416
1417     elif p.type == syms.decorator:
1418         # decorators
1419         return NO
1420
1421     elif p.type == syms.dotted_name:
1422         if prev:
1423             return NO
1424
1425         prevp = preceding_leaf(p)
1426         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1427             return NO
1428
1429     elif p.type == syms.classdef:
1430         if t == token.LPAR:
1431             return NO
1432
1433         if prev and prev.type == token.LPAR:
1434             return NO
1435
1436     elif p.type == syms.subscript:
1437         # indexing
1438         if not prev:
1439             assert p.parent is not None, "subscripts are always parented"
1440             if p.parent.type == syms.subscriptlist:
1441                 return SPACE
1442
1443             return NO
1444
1445         else:
1446             return NO
1447
1448     elif p.type == syms.atom:
1449         if prev and t == token.DOT:
1450             # dots, but not the first one.
1451             return NO
1452
1453     elif p.type == syms.dictsetmaker:
1454         # dict unpacking
1455         if prev and prev.type == token.DOUBLESTAR:
1456             return NO
1457
1458     elif p.type in {syms.factor, syms.star_expr}:
1459         # unary ops
1460         if not prev:
1461             prevp = preceding_leaf(p)
1462             if not prevp or prevp.type in OPENING_BRACKETS:
1463                 return NO
1464
1465             prevp_parent = prevp.parent
1466             assert prevp_parent is not None
1467             if (
1468                 prevp.type == token.COLON
1469                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1470             ):
1471                 return NO
1472
1473             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1474                 return NO
1475
1476         elif t == token.NAME or t == token.NUMBER:
1477             return NO
1478
1479     elif p.type == syms.import_from:
1480         if t == token.DOT:
1481             if prev and prev.type == token.DOT:
1482                 return NO
1483
1484         elif t == token.NAME:
1485             if v == "import":
1486                 return SPACE
1487
1488             if prev and prev.type == token.DOT:
1489                 return NO
1490
1491     elif p.type == syms.sliceop:
1492         return NO
1493
1494     return SPACE
1495
1496
1497 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1498     """Return the first leaf that precedes `node`, if any."""
1499     while node:
1500         res = node.prev_sibling
1501         if res:
1502             if isinstance(res, Leaf):
1503                 return res
1504
1505             try:
1506                 return list(res.leaves())[-1]
1507
1508             except IndexError:
1509                 return None
1510
1511         node = node.parent
1512     return None
1513
1514
1515 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1516     """Return the priority of the `leaf` delimiter, given a line break after it.
1517
1518     The delimiter priorities returned here are from those delimiters that would
1519     cause a line break after themselves.
1520
1521     Higher numbers are higher priority.
1522     """
1523     if leaf.type == token.COMMA:
1524         return COMMA_PRIORITY
1525
1526     return 0
1527
1528
1529 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1530     """Return the priority of the `leaf` delimiter, given a line before after it.
1531
1532     The delimiter priorities returned here are from those delimiters that would
1533     cause a line break before themselves.
1534
1535     Higher numbers are higher priority.
1536     """
1537     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1538         # * and ** might also be MATH_OPERATORS but in this case they are not.
1539         # Don't treat them as a delimiter.
1540         return 0
1541
1542     if (
1543         leaf.type in MATH_OPERATORS
1544         and leaf.parent
1545         and leaf.parent.type not in {syms.factor, syms.star_expr}
1546     ):
1547         return MATH_PRIORITY
1548
1549     if leaf.type in COMPARATORS:
1550         return COMPARATOR_PRIORITY
1551
1552     if (
1553         leaf.type == token.STRING
1554         and previous is not None
1555         and previous.type == token.STRING
1556     ):
1557         return STRING_PRIORITY
1558
1559     if (
1560         leaf.type == token.NAME
1561         and leaf.value == "for"
1562         and leaf.parent
1563         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1564     ):
1565         return COMPREHENSION_PRIORITY
1566
1567     if (
1568         leaf.type == token.NAME
1569         and leaf.value == "if"
1570         and leaf.parent
1571         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1572     ):
1573         return COMPREHENSION_PRIORITY
1574
1575     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1576         return LOGIC_PRIORITY
1577
1578     return 0
1579
1580
1581 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1582     """Clean the prefix of the `leaf` and generate comments from it, if any.
1583
1584     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1585     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1586     move because it does away with modifying the grammar to include all the
1587     possible places in which comments can be placed.
1588
1589     The sad consequence for us though is that comments don't "belong" anywhere.
1590     This is why this function generates simple parentless Leaf objects for
1591     comments.  We simply don't know what the correct parent should be.
1592
1593     No matter though, we can live without this.  We really only need to
1594     differentiate between inline and standalone comments.  The latter don't
1595     share the line with any code.
1596
1597     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1598     are emitted with a fake STANDALONE_COMMENT token identifier.
1599     """
1600     p = leaf.prefix
1601     if not p:
1602         return
1603
1604     if "#" not in p:
1605         return
1606
1607     consumed = 0
1608     nlines = 0
1609     for index, line in enumerate(p.split("\n")):
1610         consumed += len(line) + 1  # adding the length of the split '\n'
1611         line = line.lstrip()
1612         if not line:
1613             nlines += 1
1614         if not line.startswith("#"):
1615             continue
1616
1617         if index == 0 and leaf.type != token.ENDMARKER:
1618             comment_type = token.COMMENT  # simple trailing comment
1619         else:
1620             comment_type = STANDALONE_COMMENT
1621         comment = make_comment(line)
1622         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1623
1624         if comment in {"# fmt: on", "# yapf: enable"}:
1625             raise FormatOn(consumed)
1626
1627         if comment in {"# fmt: off", "# yapf: disable"}:
1628             if comment_type == STANDALONE_COMMENT:
1629                 raise FormatOff(consumed)
1630
1631             prev = preceding_leaf(leaf)
1632             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1633                 raise FormatOff(consumed)
1634
1635         nlines = 0
1636
1637
1638 def make_comment(content: str) -> str:
1639     """Return a consistently formatted comment from the given `content` string.
1640
1641     All comments (except for "##", "#!", "#:") should have a single space between
1642     the hash sign and the content.
1643
1644     If `content` didn't start with a hash sign, one is provided.
1645     """
1646     content = content.rstrip()
1647     if not content:
1648         return "#"
1649
1650     if content[0] == "#":
1651         content = content[1:]
1652     if content and content[0] not in " !:#":
1653         content = " " + content
1654     return "#" + content
1655
1656
1657 def split_line(
1658     line: Line, line_length: int, inner: bool = False, py36: bool = False
1659 ) -> Iterator[Line]:
1660     """Split a `line` into potentially many lines.
1661
1662     They should fit in the allotted `line_length` but might not be able to.
1663     `inner` signifies that there were a pair of brackets somewhere around the
1664     current `line`, possibly transitively. This means we can fallback to splitting
1665     by delimiters if the LHS/RHS don't yield any results.
1666
1667     If `py36` is True, splitting may generate syntax that is only compatible
1668     with Python 3.6 and later.
1669     """
1670     if isinstance(line, UnformattedLines) or line.is_comment:
1671         yield line
1672         return
1673
1674     line_str = str(line).strip("\n")
1675     if (
1676         len(line_str) <= line_length
1677         and "\n" not in line_str  # multiline strings
1678         and not line.contains_standalone_comments()
1679     ):
1680         yield line
1681         return
1682
1683     split_funcs: List[SplitFunc]
1684     if line.is_def:
1685         split_funcs = [left_hand_split]
1686     elif line.inside_brackets:
1687         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1688     else:
1689         split_funcs = [right_hand_split]
1690     for split_func in split_funcs:
1691         # We are accumulating lines in `result` because we might want to abort
1692         # mission and return the original line in the end, or attempt a different
1693         # split altogether.
1694         result: List[Line] = []
1695         try:
1696             for l in split_func(line, py36):
1697                 if str(l).strip("\n") == line_str:
1698                     raise CannotSplit("Split function returned an unchanged result")
1699
1700                 result.extend(
1701                     split_line(l, line_length=line_length, inner=True, py36=py36)
1702                 )
1703         except CannotSplit as cs:
1704             continue
1705
1706         else:
1707             yield from result
1708             break
1709
1710     else:
1711         yield line
1712
1713
1714 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1715     """Split line into many lines, starting with the first matching bracket pair.
1716
1717     Note: this usually looks weird, only use this for function definitions.
1718     Prefer RHS otherwise.
1719     """
1720     head = Line(depth=line.depth)
1721     body = Line(depth=line.depth + 1, inside_brackets=True)
1722     tail = Line(depth=line.depth)
1723     tail_leaves: List[Leaf] = []
1724     body_leaves: List[Leaf] = []
1725     head_leaves: List[Leaf] = []
1726     current_leaves = head_leaves
1727     matching_bracket = None
1728     for leaf in line.leaves:
1729         if (
1730             current_leaves is body_leaves
1731             and leaf.type in CLOSING_BRACKETS
1732             and leaf.opening_bracket is matching_bracket
1733         ):
1734             current_leaves = tail_leaves if body_leaves else head_leaves
1735         current_leaves.append(leaf)
1736         if current_leaves is head_leaves:
1737             if leaf.type in OPENING_BRACKETS:
1738                 matching_bracket = leaf
1739                 current_leaves = body_leaves
1740     # Since body is a new indent level, remove spurious leading whitespace.
1741     if body_leaves:
1742         normalize_prefix(body_leaves[0], inside_brackets=True)
1743     # Build the new lines.
1744     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1745         for leaf in leaves:
1746             result.append(leaf, preformatted=True)
1747             for comment_after in line.comments_after(leaf):
1748                 result.append(comment_after, preformatted=True)
1749     bracket_split_succeeded_or_raise(head, body, tail)
1750     for result in (head, body, tail):
1751         if result:
1752             yield result
1753
1754
1755 def right_hand_split(
1756     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1757 ) -> Iterator[Line]:
1758     """Split line into many lines, starting with the last matching bracket pair."""
1759     head = Line(depth=line.depth)
1760     body = Line(depth=line.depth + 1, inside_brackets=True)
1761     tail = Line(depth=line.depth)
1762     tail_leaves: List[Leaf] = []
1763     body_leaves: List[Leaf] = []
1764     head_leaves: List[Leaf] = []
1765     current_leaves = tail_leaves
1766     opening_bracket = None
1767     closing_bracket = None
1768     for leaf in reversed(line.leaves):
1769         if current_leaves is body_leaves:
1770             if leaf is opening_bracket:
1771                 current_leaves = head_leaves if body_leaves else tail_leaves
1772         current_leaves.append(leaf)
1773         if current_leaves is tail_leaves:
1774             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1775                 opening_bracket = leaf.opening_bracket
1776                 closing_bracket = leaf
1777                 current_leaves = body_leaves
1778     tail_leaves.reverse()
1779     body_leaves.reverse()
1780     head_leaves.reverse()
1781     # Since body is a new indent level, remove spurious leading whitespace.
1782     if body_leaves:
1783         normalize_prefix(body_leaves[0], inside_brackets=True)
1784     elif not head_leaves:
1785         # No `head` and no `body` means the split failed. `tail` has all content.
1786         raise CannotSplit("No brackets found")
1787
1788     # Build the new lines.
1789     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1790         for leaf in leaves:
1791             result.append(leaf, preformatted=True)
1792             for comment_after in line.comments_after(leaf):
1793                 result.append(comment_after, preformatted=True)
1794     bracket_split_succeeded_or_raise(head, body, tail)
1795     assert opening_bracket and closing_bracket
1796     if (
1797         opening_bracket.type == token.LPAR
1798         and not opening_bracket.value
1799         and closing_bracket.type == token.RPAR
1800         and not closing_bracket.value
1801     ):
1802         # These parens were optional. If there aren't any delimiters or standalone
1803         # comments in the body, they were unnecessary and another split without
1804         # them should be attempted.
1805         if not (
1806             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1807         ):
1808             omit = {id(closing_bracket), *omit}
1809             yield from right_hand_split(line, py36=py36, omit=omit)
1810             return
1811
1812     ensure_visible(opening_bracket)
1813     ensure_visible(closing_bracket)
1814     for result in (head, body, tail):
1815         if result:
1816             yield result
1817
1818
1819 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1820     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1821
1822     Do nothing otherwise.
1823
1824     A left- or right-hand split is based on a pair of brackets. Content before
1825     (and including) the opening bracket is left on one line, content inside the
1826     brackets is put on a separate line, and finally content starting with and
1827     following the closing bracket is put on a separate line.
1828
1829     Those are called `head`, `body`, and `tail`, respectively. If the split
1830     produced the same line (all content in `head`) or ended up with an empty `body`
1831     and the `tail` is just the closing bracket, then it's considered failed.
1832     """
1833     tail_len = len(str(tail).strip())
1834     if not body:
1835         if tail_len == 0:
1836             raise CannotSplit("Splitting brackets produced the same line")
1837
1838         elif tail_len < 3:
1839             raise CannotSplit(
1840                 f"Splitting brackets on an empty body to save "
1841                 f"{tail_len} characters is not worth it"
1842             )
1843
1844
1845 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1846     """Normalize prefix of the first leaf in every line returned by `split_func`.
1847
1848     This is a decorator over relevant split functions.
1849     """
1850
1851     @wraps(split_func)
1852     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1853         for l in split_func(line, py36):
1854             normalize_prefix(l.leaves[0], inside_brackets=True)
1855             yield l
1856
1857     return split_wrapper
1858
1859
1860 @dont_increase_indentation
1861 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1862     """Split according to delimiters of the highest priority.
1863
1864     If `py36` is True, the split will add trailing commas also in function
1865     signatures that contain `*` and `**`.
1866     """
1867     try:
1868         last_leaf = line.leaves[-1]
1869     except IndexError:
1870         raise CannotSplit("Line empty")
1871
1872     delimiters = line.bracket_tracker.delimiters
1873     try:
1874         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1875             exclude={id(last_leaf)}
1876         )
1877     except ValueError:
1878         raise CannotSplit("No delimiters found")
1879
1880     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1881     lowest_depth = sys.maxsize
1882     trailing_comma_safe = True
1883
1884     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1885         """Append `leaf` to current line or to new line if appending impossible."""
1886         nonlocal current_line
1887         try:
1888             current_line.append_safe(leaf, preformatted=True)
1889         except ValueError as ve:
1890             yield current_line
1891
1892             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1893             current_line.append(leaf)
1894
1895     for leaf in line.leaves:
1896         yield from append_to_line(leaf)
1897
1898         for comment_after in line.comments_after(leaf):
1899             yield from append_to_line(comment_after)
1900
1901         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1902         if (
1903             leaf.bracket_depth == lowest_depth
1904             and is_vararg(leaf, within=VARARGS_PARENTS)
1905         ):
1906             trailing_comma_safe = trailing_comma_safe and py36
1907         leaf_priority = delimiters.get(id(leaf))
1908         if leaf_priority == delimiter_priority:
1909             yield current_line
1910
1911             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1912     if current_line:
1913         if (
1914             trailing_comma_safe
1915             and delimiter_priority == COMMA_PRIORITY
1916             and current_line.leaves[-1].type != token.COMMA
1917             and current_line.leaves[-1].type != STANDALONE_COMMENT
1918         ):
1919             current_line.append(Leaf(token.COMMA, ","))
1920         yield current_line
1921
1922
1923 @dont_increase_indentation
1924 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1925     """Split standalone comments from the rest of the line."""
1926     if not line.contains_standalone_comments(0):
1927         raise CannotSplit("Line does not have any standalone comments")
1928
1929     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1930
1931     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1932         """Append `leaf` to current line or to new line if appending impossible."""
1933         nonlocal current_line
1934         try:
1935             current_line.append_safe(leaf, preformatted=True)
1936         except ValueError as ve:
1937             yield current_line
1938
1939             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1940             current_line.append(leaf)
1941
1942     for leaf in line.leaves:
1943         yield from append_to_line(leaf)
1944
1945         for comment_after in line.comments_after(leaf):
1946             yield from append_to_line(comment_after)
1947
1948     if current_line:
1949         yield current_line
1950
1951
1952 def is_import(leaf: Leaf) -> bool:
1953     """Return True if the given leaf starts an import statement."""
1954     p = leaf.parent
1955     t = leaf.type
1956     v = leaf.value
1957     return bool(
1958         t == token.NAME
1959         and (
1960             (v == "import" and p and p.type == syms.import_name)
1961             or (v == "from" and p and p.type == syms.import_from)
1962         )
1963     )
1964
1965
1966 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1967     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1968     else.
1969
1970     Note: don't use backslashes for formatting or you'll lose your voting rights.
1971     """
1972     if not inside_brackets:
1973         spl = leaf.prefix.split("#")
1974         if "\\" not in spl[0]:
1975             nl_count = spl[-1].count("\n")
1976             if len(spl) > 1:
1977                 nl_count -= 1
1978             leaf.prefix = "\n" * nl_count
1979             return
1980
1981     leaf.prefix = ""
1982
1983
1984 def normalize_string_quotes(leaf: Leaf) -> None:
1985     """Prefer double quotes but only if it doesn't cause more escaping.
1986
1987     Adds or removes backslashes as appropriate. Doesn't parse and fix
1988     strings nested in f-strings (yet).
1989
1990     Note: Mutates its argument.
1991     """
1992     value = leaf.value.lstrip("furbFURB")
1993     if value[:3] == '"""':
1994         return
1995
1996     elif value[:3] == "'''":
1997         orig_quote = "'''"
1998         new_quote = '"""'
1999     elif value[0] == '"':
2000         orig_quote = '"'
2001         new_quote = "'"
2002     else:
2003         orig_quote = "'"
2004         new_quote = '"'
2005     first_quote_pos = leaf.value.find(orig_quote)
2006     if first_quote_pos == -1:
2007         return  # There's an internal error
2008
2009     prefix = leaf.value[:first_quote_pos]
2010     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2011     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2012     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2013     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
2014     if "r" in prefix.casefold():
2015         if unescaped_new_quote.search(body):
2016             # There's at least one unescaped new_quote in this raw string
2017             # so converting is impossible
2018             return
2019
2020         # Do not introduce or remove backslashes in raw strings
2021         new_body = body
2022     else:
2023         # remove unnecessary quotes
2024         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2025         if body != new_body:
2026             # Consider the string without unnecessary quotes as the original
2027             body = new_body
2028             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2029         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2030         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2031     if new_quote == '"""' and new_body[-1] == '"':
2032         # edge case:
2033         new_body = new_body[:-1] + '\\"'
2034     orig_escape_count = body.count("\\")
2035     new_escape_count = new_body.count("\\")
2036     if new_escape_count > orig_escape_count:
2037         return  # Do not introduce more escaping
2038
2039     if new_escape_count == orig_escape_count and orig_quote == '"':
2040         return  # Prefer double quotes
2041
2042     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2043
2044
2045 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2046     """Make existing optional parentheses invisible or create new ones.
2047
2048     Standardizes on visible parentheses for single-element tuples, and keeps
2049     existing visible parentheses for other tuples and generator expressions.
2050     """
2051     check_lpar = False
2052     for child in list(node.children):
2053         if check_lpar:
2054             if child.type == syms.atom:
2055                 if not (
2056                     is_empty_tuple(child)
2057                     or is_one_tuple(child)
2058                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2059                 ):
2060                     first = child.children[0]
2061                     last = child.children[-1]
2062                     if first.type == token.LPAR and last.type == token.RPAR:
2063                         # make parentheses invisible
2064                         first.value = ""  # type: ignore
2065                         last.value = ""  # type: ignore
2066             elif is_one_tuple(child):
2067                 # wrap child in visible parentheses
2068                 lpar = Leaf(token.LPAR, "(")
2069                 rpar = Leaf(token.RPAR, ")")
2070                 index = child.remove() or 0
2071                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2072             else:
2073                 # wrap child in invisible parentheses
2074                 lpar = Leaf(token.LPAR, "")
2075                 rpar = Leaf(token.RPAR, "")
2076                 index = child.remove() or 0
2077                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2078
2079         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2080
2081
2082 def is_empty_tuple(node: LN) -> bool:
2083     """Return True if `node` holds an empty tuple."""
2084     return (
2085         node.type == syms.atom
2086         and len(node.children) == 2
2087         and node.children[0].type == token.LPAR
2088         and node.children[1].type == token.RPAR
2089     )
2090
2091
2092 def is_one_tuple(node: LN) -> bool:
2093     """Return True if `node` holds a tuple with one element, with or without parens."""
2094     if node.type == syms.atom:
2095         if len(node.children) != 3:
2096             return False
2097
2098         lpar, gexp, rpar = node.children
2099         if not (
2100             lpar.type == token.LPAR
2101             and gexp.type == syms.testlist_gexp
2102             and rpar.type == token.RPAR
2103         ):
2104             return False
2105
2106         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2107
2108     return (
2109         node.type in IMPLICIT_TUPLE
2110         and len(node.children) == 2
2111         and node.children[1].type == token.COMMA
2112     )
2113
2114
2115 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2116     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2117
2118     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2119     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2120     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2121     generalizations (PEP 448).
2122     """
2123     if leaf.type not in STARS or not leaf.parent:
2124         return False
2125
2126     p = leaf.parent
2127     if p.type == syms.star_expr:
2128         # Star expressions are also used as assignment targets in extended
2129         # iterable unpacking (PEP 3132).  See what its parent is instead.
2130         if not p.parent:
2131             return False
2132
2133         p = p.parent
2134
2135     return p.type in within
2136
2137
2138 def max_delimiter_priority_in_atom(node: LN) -> int:
2139     """Return maximum delimiter priority inside `node`.
2140
2141     This is specific to atoms with contents contained in a pair of parentheses.
2142     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2143     """
2144     if node.type != syms.atom:
2145         return 0
2146
2147     first = node.children[0]
2148     last = node.children[-1]
2149     if not (first.type == token.LPAR and last.type == token.RPAR):
2150         return 0
2151
2152     bt = BracketTracker()
2153     for c in node.children[1:-1]:
2154         if isinstance(c, Leaf):
2155             bt.mark(c)
2156         else:
2157             for leaf in c.leaves():
2158                 bt.mark(leaf)
2159     try:
2160         return bt.max_delimiter_priority()
2161
2162     except ValueError:
2163         return 0
2164
2165
2166 def ensure_visible(leaf: Leaf) -> None:
2167     """Make sure parentheses are visible.
2168
2169     They could be invisible as part of some statements (see
2170     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2171     """
2172     if leaf.type == token.LPAR:
2173         leaf.value = "("
2174     elif leaf.type == token.RPAR:
2175         leaf.value = ")"
2176
2177
2178 def is_python36(node: Node) -> bool:
2179     """Return True if the current file is using Python 3.6+ features.
2180
2181     Currently looking for:
2182     - f-strings; and
2183     - trailing commas after * or ** in function signatures.
2184     """
2185     for n in node.pre_order():
2186         if n.type == token.STRING:
2187             value_head = n.value[:2]  # type: ignore
2188             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2189                 return True
2190
2191         elif (
2192             n.type == syms.typedargslist
2193             and n.children
2194             and n.children[-1].type == token.COMMA
2195         ):
2196             for ch in n.children:
2197                 if ch.type in STARS:
2198                     return True
2199
2200     return False
2201
2202
2203 PYTHON_EXTENSIONS = {".py"}
2204 BLACKLISTED_DIRECTORIES = {
2205     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2206 }
2207
2208
2209 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2210     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2211     and have one of the PYTHON_EXTENSIONS.
2212     """
2213     for child in path.iterdir():
2214         if child.is_dir():
2215             if child.name in BLACKLISTED_DIRECTORIES:
2216                 continue
2217
2218             yield from gen_python_files_in_dir(child)
2219
2220         elif child.suffix in PYTHON_EXTENSIONS:
2221             yield child
2222
2223
2224 @dataclass
2225 class Report:
2226     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2227     check: bool = False
2228     quiet: bool = False
2229     change_count: int = 0
2230     same_count: int = 0
2231     failure_count: int = 0
2232
2233     def done(self, src: Path, changed: Changed) -> None:
2234         """Increment the counter for successful reformatting. Write out a message."""
2235         if changed is Changed.YES:
2236             reformatted = "would reformat" if self.check else "reformatted"
2237             if not self.quiet:
2238                 out(f"{reformatted} {src}")
2239             self.change_count += 1
2240         else:
2241             if not self.quiet:
2242                 if changed is Changed.NO:
2243                     msg = f"{src} already well formatted, good job."
2244                 else:
2245                     msg = f"{src} wasn't modified on disk since last run."
2246                 out(msg, bold=False)
2247             self.same_count += 1
2248
2249     def failed(self, src: Path, message: str) -> None:
2250         """Increment the counter for failed reformatting. Write out a message."""
2251         err(f"error: cannot format {src}: {message}")
2252         self.failure_count += 1
2253
2254     @property
2255     def return_code(self) -> int:
2256         """Return the exit code that the app should use.
2257
2258         This considers the current state of changed files and failures:
2259         - if there were any failures, return 123;
2260         - if any files were changed and --check is being used, return 1;
2261         - otherwise return 0.
2262         """
2263         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2264         # 126 we have special returncodes reserved by the shell.
2265         if self.failure_count:
2266             return 123
2267
2268         elif self.change_count and self.check:
2269             return 1
2270
2271         return 0
2272
2273     def __str__(self) -> str:
2274         """Render a color report of the current state.
2275
2276         Use `click.unstyle` to remove colors.
2277         """
2278         if self.check:
2279             reformatted = "would be reformatted"
2280             unchanged = "would be left unchanged"
2281             failed = "would fail to reformat"
2282         else:
2283             reformatted = "reformatted"
2284             unchanged = "left unchanged"
2285             failed = "failed to reformat"
2286         report = []
2287         if self.change_count:
2288             s = "s" if self.change_count > 1 else ""
2289             report.append(
2290                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2291             )
2292         if self.same_count:
2293             s = "s" if self.same_count > 1 else ""
2294             report.append(f"{self.same_count} file{s} {unchanged}")
2295         if self.failure_count:
2296             s = "s" if self.failure_count > 1 else ""
2297             report.append(
2298                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2299             )
2300         return ", ".join(report) + "."
2301
2302
2303 def assert_equivalent(src: str, dst: str) -> None:
2304     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2305
2306     import ast
2307     import traceback
2308
2309     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2310         """Simple visitor generating strings to compare ASTs by content."""
2311         yield f"{'  ' * depth}{node.__class__.__name__}("
2312
2313         for field in sorted(node._fields):
2314             try:
2315                 value = getattr(node, field)
2316             except AttributeError:
2317                 continue
2318
2319             yield f"{'  ' * (depth+1)}{field}="
2320
2321             if isinstance(value, list):
2322                 for item in value:
2323                     if isinstance(item, ast.AST):
2324                         yield from _v(item, depth + 2)
2325
2326             elif isinstance(value, ast.AST):
2327                 yield from _v(value, depth + 2)
2328
2329             else:
2330                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2331
2332         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2333
2334     try:
2335         src_ast = ast.parse(src)
2336     except Exception as exc:
2337         major, minor = sys.version_info[:2]
2338         raise AssertionError(
2339             f"cannot use --safe with this file; failed to parse source file "
2340             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2341             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2342         )
2343
2344     try:
2345         dst_ast = ast.parse(dst)
2346     except Exception as exc:
2347         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2348         raise AssertionError(
2349             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2350             f"Please report a bug on https://github.com/ambv/black/issues.  "
2351             f"This invalid output might be helpful: {log}"
2352         ) from None
2353
2354     src_ast_str = "\n".join(_v(src_ast))
2355     dst_ast_str = "\n".join(_v(dst_ast))
2356     if src_ast_str != dst_ast_str:
2357         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2358         raise AssertionError(
2359             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2360             f"the source.  "
2361             f"Please report a bug on https://github.com/ambv/black/issues.  "
2362             f"This diff might be helpful: {log}"
2363         ) from None
2364
2365
2366 def assert_stable(src: str, dst: str, line_length: int) -> None:
2367     """Raise AssertionError if `dst` reformats differently the second time."""
2368     newdst = format_str(dst, line_length=line_length)
2369     if dst != newdst:
2370         log = dump_to_file(
2371             diff(src, dst, "source", "first pass"),
2372             diff(dst, newdst, "first pass", "second pass"),
2373         )
2374         raise AssertionError(
2375             f"INTERNAL ERROR: Black produced different code on the second pass "
2376             f"of the formatter.  "
2377             f"Please report a bug on https://github.com/ambv/black/issues.  "
2378             f"This diff might be helpful: {log}"
2379         ) from None
2380
2381
2382 def dump_to_file(*output: str) -> str:
2383     """Dump `output` to a temporary file. Return path to the file."""
2384     import tempfile
2385
2386     with tempfile.NamedTemporaryFile(
2387         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2388     ) as f:
2389         for lines in output:
2390             f.write(lines)
2391             if lines and lines[-1] != "\n":
2392                 f.write("\n")
2393     return f.name
2394
2395
2396 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2397     """Return a unified diff string between strings `a` and `b`."""
2398     import difflib
2399
2400     a_lines = [line + "\n" for line in a.split("\n")]
2401     b_lines = [line + "\n" for line in b.split("\n")]
2402     return "".join(
2403         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2404     )
2405
2406
2407 def cancel(tasks: List[asyncio.Task]) -> None:
2408     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2409     err("Aborted!")
2410     for task in tasks:
2411         task.cancel()
2412
2413
2414 def shutdown(loop: BaseEventLoop) -> None:
2415     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2416     try:
2417         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2418         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2419         if not to_cancel:
2420             return
2421
2422         for task in to_cancel:
2423             task.cancel()
2424         loop.run_until_complete(
2425             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2426         )
2427     finally:
2428         # `concurrent.futures.Future` objects cannot be cancelled once they
2429         # are already running. There might be some when the `shutdown()` happened.
2430         # Silence their logger's spew about the event loop being closed.
2431         cf_logger = logging.getLogger("concurrent.futures")
2432         cf_logger.setLevel(logging.CRITICAL)
2433         loop.close()
2434
2435
2436 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2437     """Replace `regex` with `replacement` twice on `original`.
2438
2439     This is used by string normalization to perform replaces on
2440     overlapping matches.
2441     """
2442     return regex.sub(replacement, regex.sub(replacement, original))
2443
2444
2445 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2446 CACHE_FILE = CACHE_DIR / "cache.pickle"
2447
2448
2449 def read_cache() -> Cache:
2450     """Read the cache if it exists and is well formed.
2451
2452     If it is not well formed, the call to write_cache later should resolve the issue.
2453     """
2454     if not CACHE_FILE.exists():
2455         return {}
2456
2457     with CACHE_FILE.open("rb") as fobj:
2458         try:
2459             cache: Cache = pickle.load(fobj)
2460         except pickle.UnpicklingError:
2461             return {}
2462
2463     return cache
2464
2465
2466 def get_cache_info(path: Path) -> CacheInfo:
2467     """Return the information used to check if a file is already formatted or not."""
2468     stat = path.stat()
2469     return stat.st_mtime, stat.st_size
2470
2471
2472 def filter_cached(
2473     cache: Cache, sources: Iterable[Path]
2474 ) -> Tuple[List[Path], List[Path]]:
2475     """Split a list of paths into two.
2476
2477     The first list contains paths of files that modified on disk or are not in the
2478     cache. The other list contains paths to non-modified files.
2479     """
2480     todo, done = [], []
2481     for src in sources:
2482         src = src.resolve()
2483         if cache.get(src) != get_cache_info(src):
2484             todo.append(src)
2485         else:
2486             done.append(src)
2487     return todo, done
2488
2489
2490 def write_cache(cache: Cache, sources: List[Path]) -> None:
2491     """Update the cache file."""
2492     try:
2493         if not CACHE_DIR.exists():
2494             CACHE_DIR.mkdir(parents=True)
2495         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2496         with CACHE_FILE.open("wb") as fobj:
2497             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
2498     except OSError:
2499         pass
2500
2501
2502 if __name__ == "__main__":
2503     main()