]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add sublack plugin for sublimetext (#137)
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from enum import Enum
7 from functools import partial, wraps
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import re
14 import tokenize
15 import signal
16 import sys
17 from typing import (
18     Any,
19     Callable,
20     Collection,
21     Dict,
22     Generic,
23     Iterable,
24     Iterator,
25     List,
26     Optional,
27     Pattern,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33 )
34
35 from attr import dataclass, Factory
36 import click
37
38 # lib2to3 fork
39 from blib2to3.pytree import Node, Leaf, type_repr
40 from blib2to3 import pygram, pytree
41 from blib2to3.pgen2 import driver, token
42 from blib2to3.pgen2.parse import ParseError
43
44 __version__ = "18.4a2"
45 DEFAULT_LINE_LENGTH = 88
46 # types
47 syms = pygram.python_symbols
48 FileContent = str
49 Encoding = str
50 Depth = int
51 NodeType = int
52 LeafID = int
53 Priority = int
54 Index = int
55 LN = Union[Leaf, Node]
56 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
57 out = partial(click.secho, bold=True, err=True)
58 err = partial(click.secho, fg="red", err=True)
59
60
61 class NothingChanged(UserWarning):
62     """Raised by :func:`format_file` when reformatted code is the same as source."""
63
64
65 class CannotSplit(Exception):
66     """A readable split that fits the allotted line length is impossible.
67
68     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
69     :func:`delimiter_split`.
70     """
71
72
73 class FormatError(Exception):
74     """Base exception for `# fmt: on` and `# fmt: off` handling.
75
76     It holds the number of bytes of the prefix consumed before the format
77     control comment appeared.
78     """
79
80     def __init__(self, consumed: int) -> None:
81         super().__init__(consumed)
82         self.consumed = consumed
83
84     def trim_prefix(self, leaf: Leaf) -> None:
85         leaf.prefix = leaf.prefix[self.consumed:]
86
87     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
88         """Returns a new Leaf from the consumed part of the prefix."""
89         unformatted_prefix = leaf.prefix[:self.consumed]
90         return Leaf(token.NEWLINE, unformatted_prefix)
91
92
93 class FormatOn(FormatError):
94     """Found a comment like `# fmt: on` in the file."""
95
96
97 class FormatOff(FormatError):
98     """Found a comment like `# fmt: off` in the file."""
99
100
101 class WriteBack(Enum):
102     NO = 0
103     YES = 1
104     DIFF = 2
105
106
107 @click.command()
108 @click.option(
109     "-l",
110     "--line-length",
111     type=int,
112     default=DEFAULT_LINE_LENGTH,
113     help="How many character per line to allow.",
114     show_default=True,
115 )
116 @click.option(
117     "--check",
118     is_flag=True,
119     help=(
120         "Don't write the files back, just return the status.  Return code 0 "
121         "means nothing would change.  Return code 1 means some files would be "
122         "reformatted.  Return code 123 means there was an internal error."
123     ),
124 )
125 @click.option(
126     "--diff",
127     is_flag=True,
128     help="Don't write the files back, just output a diff for each file on stdout.",
129 )
130 @click.option(
131     "--fast/--safe",
132     is_flag=True,
133     help="If --fast given, skip temporary sanity checks. [default: --safe]",
134 )
135 @click.option(
136     "-q",
137     "--quiet",
138     is_flag=True,
139     help=(
140         "Don't emit non-error messages to stderr. Errors are still emitted, "
141         "silence those with 2>/dev/null."
142     ),
143 )
144 @click.version_option(version=__version__)
145 @click.argument(
146     "src",
147     nargs=-1,
148     type=click.Path(
149         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
150     ),
151 )
152 @click.pass_context
153 def main(
154     ctx: click.Context,
155     line_length: int,
156     check: bool,
157     diff: bool,
158     fast: bool,
159     quiet: bool,
160     src: List[str],
161 ) -> None:
162     """The uncompromising code formatter."""
163     sources: List[Path] = []
164     for s in src:
165         p = Path(s)
166         if p.is_dir():
167             sources.extend(gen_python_files_in_dir(p))
168         elif p.is_file():
169             # if a file was explicitly given, we don't care about its extension
170             sources.append(p)
171         elif s == "-":
172             sources.append(Path("-"))
173         else:
174             err(f"invalid path: {s}")
175     if check and diff:
176         exc = click.ClickException("Options --check and --diff are mutually exclusive")
177         exc.exit_code = 2
178         raise exc
179
180     if check:
181         write_back = WriteBack.NO
182     elif diff:
183         write_back = WriteBack.DIFF
184     else:
185         write_back = WriteBack.YES
186     if len(sources) == 0:
187         ctx.exit(0)
188     elif len(sources) == 1:
189         p = sources[0]
190         report = Report(check=check, quiet=quiet)
191         try:
192             if not p.is_file() and str(p) == "-":
193                 changed = format_stdin_to_stdout(
194                     line_length=line_length, fast=fast, write_back=write_back
195                 )
196             else:
197                 changed = format_file_in_place(
198                     p, line_length=line_length, fast=fast, write_back=write_back
199                 )
200             report.done(p, changed)
201         except Exception as exc:
202             report.failed(p, str(exc))
203         ctx.exit(report.return_code)
204     else:
205         loop = asyncio.get_event_loop()
206         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207         return_code = 1
208         try:
209             return_code = loop.run_until_complete(
210                 schedule_formatting(
211                     sources, line_length, write_back, fast, quiet, loop, executor
212                 )
213             )
214         finally:
215             shutdown(loop)
216             ctx.exit(return_code)
217
218
219 async def schedule_formatting(
220     sources: List[Path],
221     line_length: int,
222     write_back: WriteBack,
223     fast: bool,
224     quiet: bool,
225     loop: BaseEventLoop,
226     executor: Executor,
227 ) -> int:
228     """Run formatting of `sources` in parallel using the provided `executor`.
229
230     (Use ProcessPoolExecutors for actual parallelism.)
231
232     `line_length`, `write_back`, and `fast` options are passed to
233     :func:`format_file_in_place`.
234     """
235     lock = None
236     if write_back == WriteBack.DIFF:
237         # For diff output, we need locks to ensure we don't interleave output
238         # from different processes.
239         manager = Manager()
240         lock = manager.Lock()
241     tasks = {
242         src: loop.run_in_executor(
243             executor, format_file_in_place, src, line_length, fast, write_back, lock
244         )
245         for src in sources
246     }
247     _task_values = list(tasks.values())
248     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
249     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
250     await asyncio.wait(tasks.values())
251     cancelled = []
252     report = Report(check=write_back is WriteBack.NO, quiet=quiet)
253     for src, task in tasks.items():
254         if not task.done():
255             report.failed(src, "timed out, cancelling")
256             task.cancel()
257             cancelled.append(task)
258         elif task.cancelled():
259             cancelled.append(task)
260         elif task.exception():
261             report.failed(src, str(task.exception()))
262         else:
263             report.done(src, task.result())
264     if cancelled:
265         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
266     elif not quiet:
267         out("All done! ✨ 🍰 ✨")
268     if not quiet:
269         click.echo(str(report))
270     return report.return_code
271
272
273 def format_file_in_place(
274     src: Path,
275     line_length: int,
276     fast: bool,
277     write_back: WriteBack = WriteBack.NO,
278     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
279 ) -> bool:
280     """Format file under `src` path. Return True if changed.
281
282     If `write_back` is True, write reformatted code back to stdout.
283     `line_length` and `fast` options are passed to :func:`format_file_contents`.
284     """
285     with tokenize.open(src) as src_buffer:
286         src_contents = src_buffer.read()
287     try:
288         dst_contents = format_file_contents(
289             src_contents, line_length=line_length, fast=fast
290         )
291     except NothingChanged:
292         return False
293
294     if write_back == write_back.YES:
295         with open(src, "w", encoding=src_buffer.encoding) as f:
296             f.write(dst_contents)
297     elif write_back == write_back.DIFF:
298         src_name = f"{src.name}  (original)"
299         dst_name = f"{src.name}  (formatted)"
300         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
301         if lock:
302             lock.acquire()
303         try:
304             sys.stdout.write(diff_contents)
305         finally:
306             if lock:
307                 lock.release()
308     return True
309
310
311 def format_stdin_to_stdout(
312     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
313 ) -> bool:
314     """Format file on stdin. Return True if changed.
315
316     If `write_back` is True, write reformatted code back to stdout.
317     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
318     """
319     src = sys.stdin.read()
320     dst = src
321     try:
322         dst = format_file_contents(src, line_length=line_length, fast=fast)
323         return True
324
325     except NothingChanged:
326         return False
327
328     finally:
329         if write_back == WriteBack.YES:
330             sys.stdout.write(dst)
331         elif write_back == WriteBack.DIFF:
332             src_name = "<stdin>  (original)"
333             dst_name = "<stdin>  (formatted)"
334             sys.stdout.write(diff(src, dst, src_name, dst_name))
335
336
337 def format_file_contents(
338     src_contents: str, line_length: int, fast: bool
339 ) -> FileContent:
340     """Reformat contents a file and return new contents.
341
342     If `fast` is False, additionally confirm that the reformatted code is
343     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
344     `line_length` is passed to :func:`format_str`.
345     """
346     if src_contents.strip() == "":
347         raise NothingChanged
348
349     dst_contents = format_str(src_contents, line_length=line_length)
350     if src_contents == dst_contents:
351         raise NothingChanged
352
353     if not fast:
354         assert_equivalent(src_contents, dst_contents)
355         assert_stable(src_contents, dst_contents, line_length=line_length)
356     return dst_contents
357
358
359 def format_str(src_contents: str, line_length: int) -> FileContent:
360     """Reformat a string and return new contents.
361
362     `line_length` determines how many characters per line are allowed.
363     """
364     src_node = lib2to3_parse(src_contents)
365     dst_contents = ""
366     lines = LineGenerator()
367     elt = EmptyLineTracker()
368     py36 = is_python36(src_node)
369     empty_line = Line()
370     after = 0
371     for current_line in lines.visit(src_node):
372         for _ in range(after):
373             dst_contents += str(empty_line)
374         before, after = elt.maybe_empty_lines(current_line)
375         for _ in range(before):
376             dst_contents += str(empty_line)
377         for line in split_line(current_line, line_length=line_length, py36=py36):
378             dst_contents += str(line)
379     return dst_contents
380
381
382 GRAMMARS = [
383     pygram.python_grammar_no_print_statement_no_exec_statement,
384     pygram.python_grammar_no_print_statement,
385     pygram.python_grammar_no_exec_statement,
386     pygram.python_grammar,
387 ]
388
389
390 def lib2to3_parse(src_txt: str) -> Node:
391     """Given a string with source, return the lib2to3 Node."""
392     grammar = pygram.python_grammar_no_print_statement
393     if src_txt[-1] != "\n":
394         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
395         src_txt += nl
396     for grammar in GRAMMARS:
397         drv = driver.Driver(grammar, pytree.convert)
398         try:
399             result = drv.parse_string(src_txt, True)
400             break
401
402         except ParseError as pe:
403             lineno, column = pe.context[1]
404             lines = src_txt.splitlines()
405             try:
406                 faulty_line = lines[lineno - 1]
407             except IndexError:
408                 faulty_line = "<line number missing in source>"
409             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
410     else:
411         raise exc from None
412
413     if isinstance(result, Leaf):
414         result = Node(syms.file_input, [result])
415     return result
416
417
418 def lib2to3_unparse(node: Node) -> str:
419     """Given a lib2to3 node, return its string representation."""
420     code = str(node)
421     return code
422
423
424 T = TypeVar("T")
425
426
427 class Visitor(Generic[T]):
428     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
429
430     def visit(self, node: LN) -> Iterator[T]:
431         """Main method to visit `node` and its children.
432
433         It tries to find a `visit_*()` method for the given `node.type`, like
434         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
435         If no dedicated `visit_*()` method is found, chooses `visit_default()`
436         instead.
437
438         Then yields objects of type `T` from the selected visitor.
439         """
440         if node.type < 256:
441             name = token.tok_name[node.type]
442         else:
443             name = type_repr(node.type)
444         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
445
446     def visit_default(self, node: LN) -> Iterator[T]:
447         """Default `visit_*()` implementation. Recurses to children of `node`."""
448         if isinstance(node, Node):
449             for child in node.children:
450                 yield from self.visit(child)
451
452
453 @dataclass
454 class DebugVisitor(Visitor[T]):
455     tree_depth: int = 0
456
457     def visit_default(self, node: LN) -> Iterator[T]:
458         indent = " " * (2 * self.tree_depth)
459         if isinstance(node, Node):
460             _type = type_repr(node.type)
461             out(f"{indent}{_type}", fg="yellow")
462             self.tree_depth += 1
463             for child in node.children:
464                 yield from self.visit(child)
465
466             self.tree_depth -= 1
467             out(f"{indent}/{_type}", fg="yellow", bold=False)
468         else:
469             _type = token.tok_name.get(node.type, str(node.type))
470             out(f"{indent}{_type}", fg="blue", nl=False)
471             if node.prefix:
472                 # We don't have to handle prefixes for `Node` objects since
473                 # that delegates to the first child anyway.
474                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
475             out(f" {node.value!r}", fg="blue", bold=False)
476
477     @classmethod
478     def show(cls, code: str) -> None:
479         """Pretty-print the lib2to3 AST of a given string of `code`.
480
481         Convenience method for debugging.
482         """
483         v: DebugVisitor[None] = DebugVisitor()
484         list(v.visit(lib2to3_parse(code)))
485
486
487 KEYWORDS = set(keyword.kwlist)
488 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
489 FLOW_CONTROL = {"return", "raise", "break", "continue"}
490 STATEMENT = {
491     syms.if_stmt,
492     syms.while_stmt,
493     syms.for_stmt,
494     syms.try_stmt,
495     syms.except_clause,
496     syms.with_stmt,
497     syms.funcdef,
498     syms.classdef,
499 }
500 STANDALONE_COMMENT = 153
501 LOGIC_OPERATORS = {"and", "or"}
502 COMPARATORS = {
503     token.LESS,
504     token.GREATER,
505     token.EQEQUAL,
506     token.NOTEQUAL,
507     token.LESSEQUAL,
508     token.GREATEREQUAL,
509 }
510 MATH_OPERATORS = {
511     token.PLUS,
512     token.MINUS,
513     token.STAR,
514     token.SLASH,
515     token.VBAR,
516     token.AMPER,
517     token.PERCENT,
518     token.CIRCUMFLEX,
519     token.TILDE,
520     token.LEFTSHIFT,
521     token.RIGHTSHIFT,
522     token.DOUBLESTAR,
523     token.DOUBLESLASH,
524 }
525 STARS = {token.STAR, token.DOUBLESTAR}
526 VARARGS_PARENTS = {
527     syms.arglist,
528     syms.argument,  # double star in arglist
529     syms.trailer,  # single argument to call
530     syms.typedargslist,
531     syms.varargslist,  # lambdas
532 }
533 UNPACKING_PARENTS = {
534     syms.atom,  # single element of a list or set literal
535     syms.dictsetmaker,
536     syms.listmaker,
537     syms.testlist_gexp,
538 }
539 COMPREHENSION_PRIORITY = 20
540 COMMA_PRIORITY = 10
541 LOGIC_PRIORITY = 5
542 STRING_PRIORITY = 4
543 COMPARATOR_PRIORITY = 3
544 MATH_PRIORITY = 1
545
546
547 @dataclass
548 class BracketTracker:
549     """Keeps track of brackets on a line."""
550
551     depth: int = 0
552     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
553     delimiters: Dict[LeafID, Priority] = Factory(dict)
554     previous: Optional[Leaf] = None
555
556     def mark(self, leaf: Leaf) -> None:
557         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
558
559         All leaves receive an int `bracket_depth` field that stores how deep
560         within brackets a given leaf is. 0 means there are no enclosing brackets
561         that started on this line.
562
563         If a leaf is itself a closing bracket, it receives an `opening_bracket`
564         field that it forms a pair with. This is a one-directional link to
565         avoid reference cycles.
566
567         If a leaf is a delimiter (a token on which Black can split the line if
568         needed) and it's on depth 0, its `id()` is stored in the tracker's
569         `delimiters` field.
570         """
571         if leaf.type == token.COMMENT:
572             return
573
574         if leaf.type in CLOSING_BRACKETS:
575             self.depth -= 1
576             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
577             leaf.opening_bracket = opening_bracket
578         leaf.bracket_depth = self.depth
579         if self.depth == 0:
580             delim = is_split_before_delimiter(leaf, self.previous)
581             if delim and self.previous is not None:
582                 self.delimiters[id(self.previous)] = delim
583             else:
584                 delim = is_split_after_delimiter(leaf, self.previous)
585                 if delim:
586                     self.delimiters[id(leaf)] = delim
587         if leaf.type in OPENING_BRACKETS:
588             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
589             self.depth += 1
590         self.previous = leaf
591
592     def any_open_brackets(self) -> bool:
593         """Return True if there is an yet unmatched open bracket on the line."""
594         return bool(self.bracket_match)
595
596     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
597         """Return the highest priority of a delimiter found on the line.
598
599         Values are consistent with what `is_delimiter()` returns.
600         Raises ValueError on no delimiters.
601         """
602         return max(v for k, v in self.delimiters.items() if k not in exclude)
603
604
605 @dataclass
606 class Line:
607     """Holds leaves and comments. Can be printed with `str(line)`."""
608
609     depth: int = 0
610     leaves: List[Leaf] = Factory(list)
611     comments: List[Tuple[Index, Leaf]] = Factory(list)
612     bracket_tracker: BracketTracker = Factory(BracketTracker)
613     inside_brackets: bool = False
614     has_for: bool = False
615     _for_loop_variable: bool = False
616
617     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
618         """Add a new `leaf` to the end of the line.
619
620         Unless `preformatted` is True, the `leaf` will receive a new consistent
621         whitespace prefix and metadata applied by :class:`BracketTracker`.
622         Trailing commas are maybe removed, unpacked for loop variables are
623         demoted from being delimiters.
624
625         Inline comments are put aside.
626         """
627         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
628         if not has_value:
629             return
630
631         if self.leaves and not preformatted:
632             # Note: at this point leaf.prefix should be empty except for
633             # imports, for which we only preserve newlines.
634             leaf.prefix += whitespace(leaf)
635         if self.inside_brackets or not preformatted:
636             self.maybe_decrement_after_for_loop_variable(leaf)
637             self.bracket_tracker.mark(leaf)
638             self.maybe_remove_trailing_comma(leaf)
639             self.maybe_increment_for_loop_variable(leaf)
640
641         if not self.append_comment(leaf):
642             self.leaves.append(leaf)
643
644     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
645         """Like :func:`append()` but disallow invalid standalone comment structure.
646
647         Raises ValueError when any `leaf` is appended after a standalone comment
648         or when a standalone comment is not the first leaf on the line.
649         """
650         if self.bracket_tracker.depth == 0:
651             if self.is_comment:
652                 raise ValueError("cannot append to standalone comments")
653
654             if self.leaves and leaf.type == STANDALONE_COMMENT:
655                 raise ValueError(
656                     "cannot append standalone comments to a populated line"
657                 )
658
659         self.append(leaf, preformatted=preformatted)
660
661     @property
662     def is_comment(self) -> bool:
663         """Is this line a standalone comment?"""
664         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
665
666     @property
667     def is_decorator(self) -> bool:
668         """Is this line a decorator?"""
669         return bool(self) and self.leaves[0].type == token.AT
670
671     @property
672     def is_import(self) -> bool:
673         """Is this an import line?"""
674         return bool(self) and is_import(self.leaves[0])
675
676     @property
677     def is_class(self) -> bool:
678         """Is this line a class definition?"""
679         return (
680             bool(self)
681             and self.leaves[0].type == token.NAME
682             and self.leaves[0].value == "class"
683         )
684
685     @property
686     def is_def(self) -> bool:
687         """Is this a function definition? (Also returns True for async defs.)"""
688         try:
689             first_leaf = self.leaves[0]
690         except IndexError:
691             return False
692
693         try:
694             second_leaf: Optional[Leaf] = self.leaves[1]
695         except IndexError:
696             second_leaf = None
697         return (
698             (first_leaf.type == token.NAME and first_leaf.value == "def")
699             or (
700                 first_leaf.type == token.ASYNC
701                 and second_leaf is not None
702                 and second_leaf.type == token.NAME
703                 and second_leaf.value == "def"
704             )
705         )
706
707     @property
708     def is_flow_control(self) -> bool:
709         """Is this line a flow control statement?
710
711         Those are `return`, `raise`, `break`, and `continue`.
712         """
713         return (
714             bool(self)
715             and self.leaves[0].type == token.NAME
716             and self.leaves[0].value in FLOW_CONTROL
717         )
718
719     @property
720     def is_yield(self) -> bool:
721         """Is this line a yield statement?"""
722         return (
723             bool(self)
724             and self.leaves[0].type == token.NAME
725             and self.leaves[0].value == "yield"
726         )
727
728     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
729         """If so, needs to be split before emitting."""
730         for leaf in self.leaves:
731             if leaf.type == STANDALONE_COMMENT:
732                 if leaf.bracket_depth <= depth_limit:
733                     return True
734
735         return False
736
737     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
738         """Remove trailing comma if there is one and it's safe."""
739         if not (
740             self.leaves
741             and self.leaves[-1].type == token.COMMA
742             and closing.type in CLOSING_BRACKETS
743         ):
744             return False
745
746         if closing.type == token.RBRACE:
747             self.remove_trailing_comma()
748             return True
749
750         if closing.type == token.RSQB:
751             comma = self.leaves[-1]
752             if comma.parent and comma.parent.type == syms.listmaker:
753                 self.remove_trailing_comma()
754                 return True
755
756         # For parens let's check if it's safe to remove the comma.  If the
757         # trailing one is the only one, we might mistakenly change a tuple
758         # into a different type by removing the comma.
759         depth = closing.bracket_depth + 1
760         commas = 0
761         opening = closing.opening_bracket
762         for _opening_index, leaf in enumerate(self.leaves):
763             if leaf is opening:
764                 break
765
766         else:
767             return False
768
769         for leaf in self.leaves[_opening_index + 1:]:
770             if leaf is closing:
771                 break
772
773             bracket_depth = leaf.bracket_depth
774             if bracket_depth == depth and leaf.type == token.COMMA:
775                 commas += 1
776                 if leaf.parent and leaf.parent.type == syms.arglist:
777                     commas += 1
778                     break
779
780         if commas > 1:
781             self.remove_trailing_comma()
782             return True
783
784         return False
785
786     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
787         """In a for loop, or comprehension, the variables are often unpacks.
788
789         To avoid splitting on the comma in this situation, increase the depth of
790         tokens between `for` and `in`.
791         """
792         if leaf.type == token.NAME and leaf.value == "for":
793             self.has_for = True
794             self.bracket_tracker.depth += 1
795             self._for_loop_variable = True
796             return True
797
798         return False
799
800     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
801         """See `maybe_increment_for_loop_variable` above for explanation."""
802         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
803             self.bracket_tracker.depth -= 1
804             self._for_loop_variable = False
805             return True
806
807         return False
808
809     def append_comment(self, comment: Leaf) -> bool:
810         """Add an inline or standalone comment to the line."""
811         if (
812             comment.type == STANDALONE_COMMENT
813             and self.bracket_tracker.any_open_brackets()
814         ):
815             comment.prefix = ""
816             return False
817
818         if comment.type != token.COMMENT:
819             return False
820
821         after = len(self.leaves) - 1
822         if after == -1:
823             comment.type = STANDALONE_COMMENT
824             comment.prefix = ""
825             return False
826
827         else:
828             self.comments.append((after, comment))
829             return True
830
831     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
832         """Generate comments that should appear directly after `leaf`."""
833         for _leaf_index, _leaf in enumerate(self.leaves):
834             if leaf is _leaf:
835                 break
836
837         else:
838             return
839
840         for index, comment_after in self.comments:
841             if _leaf_index == index:
842                 yield comment_after
843
844     def remove_trailing_comma(self) -> None:
845         """Remove the trailing comma and moves the comments attached to it."""
846         comma_index = len(self.leaves) - 1
847         for i in range(len(self.comments)):
848             comment_index, comment = self.comments[i]
849             if comment_index == comma_index:
850                 self.comments[i] = (comma_index - 1, comment)
851         self.leaves.pop()
852
853     def __str__(self) -> str:
854         """Render the line."""
855         if not self:
856             return "\n"
857
858         indent = "    " * self.depth
859         leaves = iter(self.leaves)
860         first = next(leaves)
861         res = f"{first.prefix}{indent}{first.value}"
862         for leaf in leaves:
863             res += str(leaf)
864         for _, comment in self.comments:
865             res += str(comment)
866         return res + "\n"
867
868     def __bool__(self) -> bool:
869         """Return True if the line has leaves or comments."""
870         return bool(self.leaves or self.comments)
871
872
873 class UnformattedLines(Line):
874     """Just like :class:`Line` but stores lines which aren't reformatted."""
875
876     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
877         """Just add a new `leaf` to the end of the lines.
878
879         The `preformatted` argument is ignored.
880
881         Keeps track of indentation `depth`, which is useful when the user
882         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
883         """
884         try:
885             list(generate_comments(leaf))
886         except FormatOn as f_on:
887             self.leaves.append(f_on.leaf_from_consumed(leaf))
888             raise
889
890         self.leaves.append(leaf)
891         if leaf.type == token.INDENT:
892             self.depth += 1
893         elif leaf.type == token.DEDENT:
894             self.depth -= 1
895
896     def __str__(self) -> str:
897         """Render unformatted lines from leaves which were added with `append()`.
898
899         `depth` is not used for indentation in this case.
900         """
901         if not self:
902             return "\n"
903
904         res = ""
905         for leaf in self.leaves:
906             res += str(leaf)
907         return res
908
909     def append_comment(self, comment: Leaf) -> bool:
910         """Not implemented in this class. Raises `NotImplementedError`."""
911         raise NotImplementedError("Unformatted lines don't store comments separately.")
912
913     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
914         """Does nothing and returns False."""
915         return False
916
917     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
918         """Does nothing and returns False."""
919         return False
920
921
922 @dataclass
923 class EmptyLineTracker:
924     """Provides a stateful method that returns the number of potential extra
925     empty lines needed before and after the currently processed line.
926
927     Note: this tracker works on lines that haven't been split yet.  It assumes
928     the prefix of the first leaf consists of optional newlines.  Those newlines
929     are consumed by `maybe_empty_lines()` and included in the computation.
930     """
931     previous_line: Optional[Line] = None
932     previous_after: int = 0
933     previous_defs: List[int] = Factory(list)
934
935     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
936         """Return the number of extra empty lines before and after the `current_line`.
937
938         This is for separating `def`, `async def` and `class` with extra empty
939         lines (two on module-level), as well as providing an extra empty line
940         after flow control keywords to make them more prominent.
941         """
942         if isinstance(current_line, UnformattedLines):
943             return 0, 0
944
945         before, after = self._maybe_empty_lines(current_line)
946         before -= self.previous_after
947         self.previous_after = after
948         self.previous_line = current_line
949         return before, after
950
951     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
952         max_allowed = 1
953         if current_line.depth == 0:
954             max_allowed = 2
955         if current_line.leaves:
956             # Consume the first leaf's extra newlines.
957             first_leaf = current_line.leaves[0]
958             before = first_leaf.prefix.count("\n")
959             before = min(before, max_allowed)
960             first_leaf.prefix = ""
961         else:
962             before = 0
963         depth = current_line.depth
964         while self.previous_defs and self.previous_defs[-1] >= depth:
965             self.previous_defs.pop()
966             before = 1 if depth else 2
967         is_decorator = current_line.is_decorator
968         if is_decorator or current_line.is_def or current_line.is_class:
969             if not is_decorator:
970                 self.previous_defs.append(depth)
971             if self.previous_line is None:
972                 # Don't insert empty lines before the first line in the file.
973                 return 0, 0
974
975             if self.previous_line and self.previous_line.is_decorator:
976                 # Don't insert empty lines between decorators.
977                 return 0, 0
978
979             newlines = 2
980             if current_line.depth:
981                 newlines -= 1
982             return newlines, 0
983
984         if current_line.is_flow_control:
985             return before, 1
986
987         if (
988             self.previous_line
989             and self.previous_line.is_import
990             and not current_line.is_import
991             and depth == self.previous_line.depth
992         ):
993             return (before or 1), 0
994
995         if (
996             self.previous_line
997             and self.previous_line.is_yield
998             and (not current_line.is_yield or depth != self.previous_line.depth)
999         ):
1000             return (before or 1), 0
1001
1002         return before, 0
1003
1004
1005 @dataclass
1006 class LineGenerator(Visitor[Line]):
1007     """Generates reformatted Line objects.  Empty lines are not emitted.
1008
1009     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1010     in ways that will no longer stringify to valid Python code on the tree.
1011     """
1012     current_line: Line = Factory(Line)
1013
1014     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1015         """Generate a line.
1016
1017         If the line is empty, only emit if it makes sense.
1018         If the line is too long, split it first and then generate.
1019
1020         If any lines were generated, set up a new current_line.
1021         """
1022         if not self.current_line:
1023             if self.current_line.__class__ == type:
1024                 self.current_line.depth += indent
1025             else:
1026                 self.current_line = type(depth=self.current_line.depth + indent)
1027             return  # Line is empty, don't emit. Creating a new one unnecessary.
1028
1029         complete_line = self.current_line
1030         self.current_line = type(depth=complete_line.depth + indent)
1031         yield complete_line
1032
1033     def visit(self, node: LN) -> Iterator[Line]:
1034         """Main method to visit `node` and its children.
1035
1036         Yields :class:`Line` objects.
1037         """
1038         if isinstance(self.current_line, UnformattedLines):
1039             # File contained `# fmt: off`
1040             yield from self.visit_unformatted(node)
1041
1042         else:
1043             yield from super().visit(node)
1044
1045     def visit_default(self, node: LN) -> Iterator[Line]:
1046         """Default `visit_*()` implementation. Recurses to children of `node`."""
1047         if isinstance(node, Leaf):
1048             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1049             try:
1050                 for comment in generate_comments(node):
1051                     if any_open_brackets:
1052                         # any comment within brackets is subject to splitting
1053                         self.current_line.append(comment)
1054                     elif comment.type == token.COMMENT:
1055                         # regular trailing comment
1056                         self.current_line.append(comment)
1057                         yield from self.line()
1058
1059                     else:
1060                         # regular standalone comment
1061                         yield from self.line()
1062
1063                         self.current_line.append(comment)
1064                         yield from self.line()
1065
1066             except FormatOff as f_off:
1067                 f_off.trim_prefix(node)
1068                 yield from self.line(type=UnformattedLines)
1069                 yield from self.visit(node)
1070
1071             except FormatOn as f_on:
1072                 # This only happens here if somebody says "fmt: on" multiple
1073                 # times in a row.
1074                 f_on.trim_prefix(node)
1075                 yield from self.visit_default(node)
1076
1077             else:
1078                 normalize_prefix(node, inside_brackets=any_open_brackets)
1079                 if node.type == token.STRING:
1080                     normalize_string_quotes(node)
1081                 if node.type not in WHITESPACE:
1082                     self.current_line.append(node)
1083         yield from super().visit_default(node)
1084
1085     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1086         """Increase indentation level, maybe yield a line."""
1087         # In blib2to3 INDENT never holds comments.
1088         yield from self.line(+1)
1089         yield from self.visit_default(node)
1090
1091     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1092         """Decrease indentation level, maybe yield a line."""
1093         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1094         yield from self.line(-1)
1095
1096     def visit_stmt(
1097         self, node: Node, keywords: Set[str], parens: Set[str]
1098     ) -> Iterator[Line]:
1099         """Visit a statement.
1100
1101         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1102         `def`, `with`, `class`, and `assert`.
1103
1104         The relevant Python language `keywords` for a given statement will be
1105         NAME leaves within it. This methods puts those on a separate line.
1106
1107         `parens` holds pairs of nodes where invisible parentheses should be put.
1108         Keys hold nodes after which opening parentheses should be put, values
1109         hold nodes before which closing parentheses should be put.
1110         """
1111         normalize_invisible_parens(node, parens_after=parens)
1112         for child in node.children:
1113             if child.type == token.NAME and child.value in keywords:  # type: ignore
1114                 yield from self.line()
1115
1116             yield from self.visit(child)
1117
1118     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1119         """Visit a statement without nested statements."""
1120         is_suite_like = node.parent and node.parent.type in STATEMENT
1121         if is_suite_like:
1122             yield from self.line(+1)
1123             yield from self.visit_default(node)
1124             yield from self.line(-1)
1125
1126         else:
1127             yield from self.line()
1128             yield from self.visit_default(node)
1129
1130     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1131         """Visit `async def`, `async for`, `async with`."""
1132         yield from self.line()
1133
1134         children = iter(node.children)
1135         for child in children:
1136             yield from self.visit(child)
1137
1138             if child.type == token.ASYNC:
1139                 break
1140
1141         internal_stmt = next(children)
1142         for child in internal_stmt.children:
1143             yield from self.visit(child)
1144
1145     def visit_decorators(self, node: Node) -> Iterator[Line]:
1146         """Visit decorators."""
1147         for child in node.children:
1148             yield from self.line()
1149             yield from self.visit(child)
1150
1151     def visit_import_from(self, node: Node) -> Iterator[Line]:
1152         """Visit import_from and maybe put invisible parentheses.
1153
1154         This is separate from `visit_stmt` because import statements don't
1155         support arbitrary atoms and thus handling of parentheses is custom.
1156         """
1157         check_lpar = False
1158         for index, child in enumerate(node.children):
1159             if check_lpar:
1160                 if child.type == token.LPAR:
1161                     # make parentheses invisible
1162                     child.value = ""  # type: ignore
1163                     node.children[-1].value = ""  # type: ignore
1164                 else:
1165                     # insert invisible parentheses
1166                     node.insert_child(index, Leaf(token.LPAR, ""))
1167                     node.append_child(Leaf(token.RPAR, ""))
1168                 break
1169
1170             check_lpar = (
1171                 child.type == token.NAME and child.value == "import"  # type: ignore
1172             )
1173
1174         for child in node.children:
1175             yield from self.visit(child)
1176
1177     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1178         """Remove a semicolon and put the other statement on a separate line."""
1179         yield from self.line()
1180
1181     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1182         """End of file. Process outstanding comments and end with a newline."""
1183         yield from self.visit_default(leaf)
1184         yield from self.line()
1185
1186     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1187         """Used when file contained a `# fmt: off`."""
1188         if isinstance(node, Node):
1189             for child in node.children:
1190                 yield from self.visit(child)
1191
1192         else:
1193             try:
1194                 self.current_line.append(node)
1195             except FormatOn as f_on:
1196                 f_on.trim_prefix(node)
1197                 yield from self.line()
1198                 yield from self.visit(node)
1199
1200             if node.type == token.ENDMARKER:
1201                 # somebody decided not to put a final `# fmt: on`
1202                 yield from self.line()
1203
1204     def __attrs_post_init__(self) -> None:
1205         """You are in a twisty little maze of passages."""
1206         v = self.visit_stmt
1207         Ø: Set[str] = set()
1208         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1209         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
1210         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1211         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1212         self.visit_try_stmt = partial(
1213             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1214         )
1215         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1216         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1217         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1218         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1219         self.visit_async_funcdef = self.visit_async_stmt
1220         self.visit_decorated = self.visit_decorators
1221
1222
1223 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1224 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1225 OPENING_BRACKETS = set(BRACKET.keys())
1226 CLOSING_BRACKETS = set(BRACKET.values())
1227 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1228 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1229
1230
1231 def whitespace(leaf: Leaf) -> str:  # noqa C901
1232     """Return whitespace prefix if needed for the given `leaf`."""
1233     NO = ""
1234     SPACE = " "
1235     DOUBLESPACE = "  "
1236     t = leaf.type
1237     p = leaf.parent
1238     v = leaf.value
1239     if t in ALWAYS_NO_SPACE:
1240         return NO
1241
1242     if t == token.COMMENT:
1243         return DOUBLESPACE
1244
1245     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1246     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1247         return NO
1248
1249     prev = leaf.prev_sibling
1250     if not prev:
1251         prevp = preceding_leaf(p)
1252         if not prevp or prevp.type in OPENING_BRACKETS:
1253             return NO
1254
1255         if t == token.COLON:
1256             return SPACE if prevp.type == token.COMMA else NO
1257
1258         if prevp.type == token.EQUAL:
1259             if prevp.parent:
1260                 if prevp.parent.type in {
1261                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1262                 }:
1263                     return NO
1264
1265                 elif prevp.parent.type == syms.typedargslist:
1266                     # A bit hacky: if the equal sign has whitespace, it means we
1267                     # previously found it's a typed argument.  So, we're using
1268                     # that, too.
1269                     return prevp.prefix
1270
1271         elif prevp.type in STARS:
1272             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1273                 return NO
1274
1275         elif prevp.type == token.COLON:
1276             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1277                 return NO
1278
1279         elif (
1280             prevp.parent
1281             and prevp.parent.type == syms.factor
1282             and prevp.type in MATH_OPERATORS
1283         ):
1284             return NO
1285
1286         elif (
1287             prevp.type == token.RIGHTSHIFT
1288             and prevp.parent
1289             and prevp.parent.type == syms.shift_expr
1290             and prevp.prev_sibling
1291             and prevp.prev_sibling.type == token.NAME
1292             and prevp.prev_sibling.value == "print"  # type: ignore
1293         ):
1294             # Python 2 print chevron
1295             return NO
1296
1297     elif prev.type in OPENING_BRACKETS:
1298         return NO
1299
1300     if p.type in {syms.parameters, syms.arglist}:
1301         # untyped function signatures or calls
1302         if t == token.RPAR:
1303             return NO
1304
1305         if not prev or prev.type != token.COMMA:
1306             return NO
1307
1308     elif p.type == syms.varargslist:
1309         # lambdas
1310         if t == token.RPAR:
1311             return NO
1312
1313         if prev and prev.type != token.COMMA:
1314             return NO
1315
1316     elif p.type == syms.typedargslist:
1317         # typed function signatures
1318         if not prev:
1319             return NO
1320
1321         if t == token.EQUAL:
1322             if prev.type != syms.tname:
1323                 return NO
1324
1325         elif prev.type == token.EQUAL:
1326             # A bit hacky: if the equal sign has whitespace, it means we
1327             # previously found it's a typed argument.  So, we're using that, too.
1328             return prev.prefix
1329
1330         elif prev.type != token.COMMA:
1331             return NO
1332
1333     elif p.type == syms.tname:
1334         # type names
1335         if not prev:
1336             prevp = preceding_leaf(p)
1337             if not prevp or prevp.type != token.COMMA:
1338                 return NO
1339
1340     elif p.type == syms.trailer:
1341         # attributes and calls
1342         if t == token.LPAR or t == token.RPAR:
1343             return NO
1344
1345         if not prev:
1346             if t == token.DOT:
1347                 prevp = preceding_leaf(p)
1348                 if not prevp or prevp.type != token.NUMBER:
1349                     return NO
1350
1351             elif t == token.LSQB:
1352                 return NO
1353
1354         elif prev.type != token.COMMA:
1355             return NO
1356
1357     elif p.type == syms.argument:
1358         # single argument
1359         if t == token.EQUAL:
1360             return NO
1361
1362         if not prev:
1363             prevp = preceding_leaf(p)
1364             if not prevp or prevp.type == token.LPAR:
1365                 return NO
1366
1367         elif prev.type in {token.EQUAL} | STARS:
1368             return NO
1369
1370     elif p.type == syms.decorator:
1371         # decorators
1372         return NO
1373
1374     elif p.type == syms.dotted_name:
1375         if prev:
1376             return NO
1377
1378         prevp = preceding_leaf(p)
1379         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1380             return NO
1381
1382     elif p.type == syms.classdef:
1383         if t == token.LPAR:
1384             return NO
1385
1386         if prev and prev.type == token.LPAR:
1387             return NO
1388
1389     elif p.type == syms.subscript:
1390         # indexing
1391         if not prev:
1392             assert p.parent is not None, "subscripts are always parented"
1393             if p.parent.type == syms.subscriptlist:
1394                 return SPACE
1395
1396             return NO
1397
1398         else:
1399             return NO
1400
1401     elif p.type == syms.atom:
1402         if prev and t == token.DOT:
1403             # dots, but not the first one.
1404             return NO
1405
1406     elif (
1407         p.type == syms.listmaker
1408         or p.type == syms.testlist_gexp
1409         or p.type == syms.subscriptlist
1410     ):
1411         # list interior, including unpacking
1412         if not prev:
1413             return NO
1414
1415     elif p.type == syms.dictsetmaker:
1416         # dict and set interior, including unpacking
1417         if not prev:
1418             return NO
1419
1420         if prev.type == token.DOUBLESTAR:
1421             return NO
1422
1423     elif p.type in {syms.factor, syms.star_expr}:
1424         # unary ops
1425         if not prev:
1426             prevp = preceding_leaf(p)
1427             if not prevp or prevp.type in OPENING_BRACKETS:
1428                 return NO
1429
1430             prevp_parent = prevp.parent
1431             assert prevp_parent is not None
1432             if (
1433                 prevp.type == token.COLON
1434                 and prevp_parent.type in {syms.subscript, syms.sliceop}
1435             ):
1436                 return NO
1437
1438             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1439                 return NO
1440
1441         elif t == token.NAME or t == token.NUMBER:
1442             return NO
1443
1444     elif p.type == syms.import_from:
1445         if t == token.DOT:
1446             if prev and prev.type == token.DOT:
1447                 return NO
1448
1449         elif t == token.NAME:
1450             if v == "import":
1451                 return SPACE
1452
1453             if prev and prev.type == token.DOT:
1454                 return NO
1455
1456     elif p.type == syms.sliceop:
1457         return NO
1458
1459     return SPACE
1460
1461
1462 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1463     """Return the first leaf that precedes `node`, if any."""
1464     while node:
1465         res = node.prev_sibling
1466         if res:
1467             if isinstance(res, Leaf):
1468                 return res
1469
1470             try:
1471                 return list(res.leaves())[-1]
1472
1473             except IndexError:
1474                 return None
1475
1476         node = node.parent
1477     return None
1478
1479
1480 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1481     """Return the priority of the `leaf` delimiter, given a line break after it.
1482
1483     The delimiter priorities returned here are from those delimiters that would
1484     cause a line break after themselves.
1485
1486     Higher numbers are higher priority.
1487     """
1488     if leaf.type == token.COMMA:
1489         return COMMA_PRIORITY
1490
1491     return 0
1492
1493
1494 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1495     """Return the priority of the `leaf` delimiter, given a line before after it.
1496
1497     The delimiter priorities returned here are from those delimiters that would
1498     cause a line break before themselves.
1499
1500     Higher numbers are higher priority.
1501     """
1502     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1503         # * and ** might also be MATH_OPERATORS but in this case they are not.
1504         # Don't treat them as a delimiter.
1505         return 0
1506
1507     if (
1508         leaf.type in MATH_OPERATORS
1509         and leaf.parent
1510         and leaf.parent.type not in {syms.factor, syms.star_expr}
1511     ):
1512         return MATH_PRIORITY
1513
1514     if leaf.type in COMPARATORS:
1515         return COMPARATOR_PRIORITY
1516
1517     if (
1518         leaf.type == token.STRING
1519         and previous is not None
1520         and previous.type == token.STRING
1521     ):
1522         return STRING_PRIORITY
1523
1524     if (
1525         leaf.type == token.NAME
1526         and leaf.value == "for"
1527         and leaf.parent
1528         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1529     ):
1530         return COMPREHENSION_PRIORITY
1531
1532     if (
1533         leaf.type == token.NAME
1534         and leaf.value == "if"
1535         and leaf.parent
1536         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1537     ):
1538         return COMPREHENSION_PRIORITY
1539
1540     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1541         return LOGIC_PRIORITY
1542
1543     return 0
1544
1545
1546 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1547     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1548
1549     Higher numbers are higher priority.
1550     """
1551     return max(
1552         is_split_before_delimiter(leaf, previous),
1553         is_split_after_delimiter(leaf, previous),
1554     )
1555
1556
1557 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1558     """Clean the prefix of the `leaf` and generate comments from it, if any.
1559
1560     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1561     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1562     move because it does away with modifying the grammar to include all the
1563     possible places in which comments can be placed.
1564
1565     The sad consequence for us though is that comments don't "belong" anywhere.
1566     This is why this function generates simple parentless Leaf objects for
1567     comments.  We simply don't know what the correct parent should be.
1568
1569     No matter though, we can live without this.  We really only need to
1570     differentiate between inline and standalone comments.  The latter don't
1571     share the line with any code.
1572
1573     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1574     are emitted with a fake STANDALONE_COMMENT token identifier.
1575     """
1576     p = leaf.prefix
1577     if not p:
1578         return
1579
1580     if "#" not in p:
1581         return
1582
1583     consumed = 0
1584     nlines = 0
1585     for index, line in enumerate(p.split("\n")):
1586         consumed += len(line) + 1  # adding the length of the split '\n'
1587         line = line.lstrip()
1588         if not line:
1589             nlines += 1
1590         if not line.startswith("#"):
1591             continue
1592
1593         if index == 0 and leaf.type != token.ENDMARKER:
1594             comment_type = token.COMMENT  # simple trailing comment
1595         else:
1596             comment_type = STANDALONE_COMMENT
1597         comment = make_comment(line)
1598         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1599
1600         if comment in {"# fmt: on", "# yapf: enable"}:
1601             raise FormatOn(consumed)
1602
1603         if comment in {"# fmt: off", "# yapf: disable"}:
1604             if comment_type == STANDALONE_COMMENT:
1605                 raise FormatOff(consumed)
1606
1607             prev = preceding_leaf(leaf)
1608             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1609                 raise FormatOff(consumed)
1610
1611         nlines = 0
1612
1613
1614 def make_comment(content: str) -> str:
1615     """Return a consistently formatted comment from the given `content` string.
1616
1617     All comments (except for "##", "#!", "#:") should have a single space between
1618     the hash sign and the content.
1619
1620     If `content` didn't start with a hash sign, one is provided.
1621     """
1622     content = content.rstrip()
1623     if not content:
1624         return "#"
1625
1626     if content[0] == "#":
1627         content = content[1:]
1628     if content and content[0] not in " !:#":
1629         content = " " + content
1630     return "#" + content
1631
1632
1633 def split_line(
1634     line: Line, line_length: int, inner: bool = False, py36: bool = False
1635 ) -> Iterator[Line]:
1636     """Split a `line` into potentially many lines.
1637
1638     They should fit in the allotted `line_length` but might not be able to.
1639     `inner` signifies that there were a pair of brackets somewhere around the
1640     current `line`, possibly transitively. This means we can fallback to splitting
1641     by delimiters if the LHS/RHS don't yield any results.
1642
1643     If `py36` is True, splitting may generate syntax that is only compatible
1644     with Python 3.6 and later.
1645     """
1646     if isinstance(line, UnformattedLines) or line.is_comment:
1647         yield line
1648         return
1649
1650     line_str = str(line).strip("\n")
1651     if (
1652         len(line_str) <= line_length
1653         and "\n" not in line_str  # multiline strings
1654         and not line.contains_standalone_comments()
1655     ):
1656         yield line
1657         return
1658
1659     split_funcs: List[SplitFunc]
1660     if line.is_def:
1661         split_funcs = [left_hand_split]
1662     elif line.inside_brackets:
1663         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1664     else:
1665         split_funcs = [right_hand_split]
1666     for split_func in split_funcs:
1667         # We are accumulating lines in `result` because we might want to abort
1668         # mission and return the original line in the end, or attempt a different
1669         # split altogether.
1670         result: List[Line] = []
1671         try:
1672             for l in split_func(line, py36):
1673                 if str(l).strip("\n") == line_str:
1674                     raise CannotSplit("Split function returned an unchanged result")
1675
1676                 result.extend(
1677                     split_line(l, line_length=line_length, inner=True, py36=py36)
1678                 )
1679         except CannotSplit as cs:
1680             continue
1681
1682         else:
1683             yield from result
1684             break
1685
1686     else:
1687         yield line
1688
1689
1690 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1691     """Split line into many lines, starting with the first matching bracket pair.
1692
1693     Note: this usually looks weird, only use this for function definitions.
1694     Prefer RHS otherwise.
1695     """
1696     head = Line(depth=line.depth)
1697     body = Line(depth=line.depth + 1, inside_brackets=True)
1698     tail = Line(depth=line.depth)
1699     tail_leaves: List[Leaf] = []
1700     body_leaves: List[Leaf] = []
1701     head_leaves: List[Leaf] = []
1702     current_leaves = head_leaves
1703     matching_bracket = None
1704     for leaf in line.leaves:
1705         if (
1706             current_leaves is body_leaves
1707             and leaf.type in CLOSING_BRACKETS
1708             and leaf.opening_bracket is matching_bracket
1709         ):
1710             current_leaves = tail_leaves if body_leaves else head_leaves
1711         current_leaves.append(leaf)
1712         if current_leaves is head_leaves:
1713             if leaf.type in OPENING_BRACKETS:
1714                 matching_bracket = leaf
1715                 current_leaves = body_leaves
1716     # Since body is a new indent level, remove spurious leading whitespace.
1717     if body_leaves:
1718         normalize_prefix(body_leaves[0], inside_brackets=True)
1719     # Build the new lines.
1720     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1721         for leaf in leaves:
1722             result.append(leaf, preformatted=True)
1723             for comment_after in line.comments_after(leaf):
1724                 result.append(comment_after, preformatted=True)
1725     bracket_split_succeeded_or_raise(head, body, tail)
1726     for result in (head, body, tail):
1727         if result:
1728             yield result
1729
1730
1731 def right_hand_split(
1732     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1733 ) -> Iterator[Line]:
1734     """Split line into many lines, starting with the last matching bracket pair."""
1735     head = Line(depth=line.depth)
1736     body = Line(depth=line.depth + 1, inside_brackets=True)
1737     tail = Line(depth=line.depth)
1738     tail_leaves: List[Leaf] = []
1739     body_leaves: List[Leaf] = []
1740     head_leaves: List[Leaf] = []
1741     current_leaves = tail_leaves
1742     opening_bracket = None
1743     closing_bracket = None
1744     for leaf in reversed(line.leaves):
1745         if current_leaves is body_leaves:
1746             if leaf is opening_bracket:
1747                 current_leaves = head_leaves if body_leaves else tail_leaves
1748         current_leaves.append(leaf)
1749         if current_leaves is tail_leaves:
1750             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
1751                 opening_bracket = leaf.opening_bracket
1752                 closing_bracket = leaf
1753                 current_leaves = body_leaves
1754     tail_leaves.reverse()
1755     body_leaves.reverse()
1756     head_leaves.reverse()
1757     # Since body is a new indent level, remove spurious leading whitespace.
1758     if body_leaves:
1759         normalize_prefix(body_leaves[0], inside_brackets=True)
1760     elif not head_leaves:
1761         # No `head` and no `body` means the split failed. `tail` has all content.
1762         raise CannotSplit("No brackets found")
1763
1764     # Build the new lines.
1765     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1766         for leaf in leaves:
1767             result.append(leaf, preformatted=True)
1768             for comment_after in line.comments_after(leaf):
1769                 result.append(comment_after, preformatted=True)
1770     bracket_split_succeeded_or_raise(head, body, tail)
1771     assert opening_bracket and closing_bracket
1772     if (
1773         opening_bracket.type == token.LPAR
1774         and not opening_bracket.value
1775         and closing_bracket.type == token.RPAR
1776         and not closing_bracket.value
1777     ):
1778         # These parens were optional. If there aren't any delimiters or standalone
1779         # comments in the body, they were unnecessary and another split without
1780         # them should be attempted.
1781         if not (
1782             body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
1783         ):
1784             omit = {id(closing_bracket), *omit}
1785             yield from right_hand_split(line, py36=py36, omit=omit)
1786             return
1787
1788     ensure_visible(opening_bracket)
1789     ensure_visible(closing_bracket)
1790     for result in (head, body, tail):
1791         if result:
1792             yield result
1793
1794
1795 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1796     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1797
1798     Do nothing otherwise.
1799
1800     A left- or right-hand split is based on a pair of brackets. Content before
1801     (and including) the opening bracket is left on one line, content inside the
1802     brackets is put on a separate line, and finally content starting with and
1803     following the closing bracket is put on a separate line.
1804
1805     Those are called `head`, `body`, and `tail`, respectively. If the split
1806     produced the same line (all content in `head`) or ended up with an empty `body`
1807     and the `tail` is just the closing bracket, then it's considered failed.
1808     """
1809     tail_len = len(str(tail).strip())
1810     if not body:
1811         if tail_len == 0:
1812             raise CannotSplit("Splitting brackets produced the same line")
1813
1814         elif tail_len < 3:
1815             raise CannotSplit(
1816                 f"Splitting brackets on an empty body to save "
1817                 f"{tail_len} characters is not worth it"
1818             )
1819
1820
1821 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1822     """Normalize prefix of the first leaf in every line returned by `split_func`.
1823
1824     This is a decorator over relevant split functions.
1825     """
1826
1827     @wraps(split_func)
1828     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1829         for l in split_func(line, py36):
1830             normalize_prefix(l.leaves[0], inside_brackets=True)
1831             yield l
1832
1833     return split_wrapper
1834
1835
1836 @dont_increase_indentation
1837 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1838     """Split according to delimiters of the highest priority.
1839
1840     If `py36` is True, the split will add trailing commas also in function
1841     signatures that contain `*` and `**`.
1842     """
1843     try:
1844         last_leaf = line.leaves[-1]
1845     except IndexError:
1846         raise CannotSplit("Line empty")
1847
1848     delimiters = line.bracket_tracker.delimiters
1849     try:
1850         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1851             exclude={id(last_leaf)}
1852         )
1853     except ValueError:
1854         raise CannotSplit("No delimiters found")
1855
1856     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1857     lowest_depth = sys.maxsize
1858     trailing_comma_safe = True
1859
1860     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1861         """Append `leaf` to current line or to new line if appending impossible."""
1862         nonlocal current_line
1863         try:
1864             current_line.append_safe(leaf, preformatted=True)
1865         except ValueError as ve:
1866             yield current_line
1867
1868             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1869             current_line.append(leaf)
1870
1871     for leaf in line.leaves:
1872         yield from append_to_line(leaf)
1873
1874         for comment_after in line.comments_after(leaf):
1875             yield from append_to_line(comment_after)
1876
1877         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1878         if (
1879             leaf.bracket_depth == lowest_depth
1880             and is_vararg(leaf, within=VARARGS_PARENTS)
1881         ):
1882             trailing_comma_safe = trailing_comma_safe and py36
1883         leaf_priority = delimiters.get(id(leaf))
1884         if leaf_priority == delimiter_priority:
1885             yield current_line
1886
1887             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1888     if current_line:
1889         if (
1890             trailing_comma_safe
1891             and delimiter_priority == COMMA_PRIORITY
1892             and current_line.leaves[-1].type != token.COMMA
1893             and current_line.leaves[-1].type != STANDALONE_COMMENT
1894         ):
1895             current_line.append(Leaf(token.COMMA, ","))
1896         yield current_line
1897
1898
1899 @dont_increase_indentation
1900 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1901     """Split standalone comments from the rest of the line."""
1902     if not line.contains_standalone_comments(0):
1903         raise CannotSplit("Line does not have any standalone comments")
1904
1905     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1906
1907     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1908         """Append `leaf` to current line or to new line if appending impossible."""
1909         nonlocal current_line
1910         try:
1911             current_line.append_safe(leaf, preformatted=True)
1912         except ValueError as ve:
1913             yield current_line
1914
1915             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1916             current_line.append(leaf)
1917
1918     for leaf in line.leaves:
1919         yield from append_to_line(leaf)
1920
1921         for comment_after in line.comments_after(leaf):
1922             yield from append_to_line(comment_after)
1923
1924     if current_line:
1925         yield current_line
1926
1927
1928 def is_import(leaf: Leaf) -> bool:
1929     """Return True if the given leaf starts an import statement."""
1930     p = leaf.parent
1931     t = leaf.type
1932     v = leaf.value
1933     return bool(
1934         t == token.NAME
1935         and (
1936             (v == "import" and p and p.type == syms.import_name)
1937             or (v == "from" and p and p.type == syms.import_from)
1938         )
1939     )
1940
1941
1942 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1943     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1944     else.
1945
1946     Note: don't use backslashes for formatting or you'll lose your voting rights.
1947     """
1948     if not inside_brackets:
1949         spl = leaf.prefix.split("#")
1950         if "\\" not in spl[0]:
1951             nl_count = spl[-1].count("\n")
1952             if len(spl) > 1:
1953                 nl_count -= 1
1954             leaf.prefix = "\n" * nl_count
1955             return
1956
1957     leaf.prefix = ""
1958
1959
1960 def normalize_string_quotes(leaf: Leaf) -> None:
1961     """Prefer double quotes but only if it doesn't cause more escaping.
1962
1963     Adds or removes backslashes as appropriate. Doesn't parse and fix
1964     strings nested in f-strings (yet).
1965
1966     Note: Mutates its argument.
1967     """
1968     value = leaf.value.lstrip("furbFURB")
1969     if value[:3] == '"""':
1970         return
1971
1972     elif value[:3] == "'''":
1973         orig_quote = "'''"
1974         new_quote = '"""'
1975     elif value[0] == '"':
1976         orig_quote = '"'
1977         new_quote = "'"
1978     else:
1979         orig_quote = "'"
1980         new_quote = '"'
1981     first_quote_pos = leaf.value.find(orig_quote)
1982     if first_quote_pos == -1:
1983         return  # There's an internal error
1984
1985     prefix = leaf.value[:first_quote_pos]
1986     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
1987     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
1988     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
1989     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1990     if "r" in prefix.casefold():
1991         if unescaped_new_quote.search(body):
1992             # There's at least one unescaped new_quote in this raw string
1993             # so converting is impossible
1994             return
1995
1996         # Do not introduce or remove backslashes in raw strings
1997         new_body = body
1998     else:
1999         # remove unnecessary quotes
2000         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2001         if body != new_body:
2002             # Consider the string without unnecessary quotes as the original
2003             body = new_body
2004             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2005         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2006         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2007     if new_quote == '"""' and new_body[-1] == '"':
2008         # edge case:
2009         new_body = new_body[:-1] + '\\"'
2010     orig_escape_count = body.count("\\")
2011     new_escape_count = new_body.count("\\")
2012     if new_escape_count > orig_escape_count:
2013         return  # Do not introduce more escaping
2014
2015     if new_escape_count == orig_escape_count and orig_quote == '"':
2016         return  # Prefer double quotes
2017
2018     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2019
2020
2021 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2022     """Make existing optional parentheses invisible or create new ones.
2023
2024     Standardizes on visible parentheses for single-element tuples, and keeps
2025     existing visible parentheses for other tuples and generator expressions.
2026     """
2027     check_lpar = False
2028     for child in list(node.children):
2029         if check_lpar:
2030             if child.type == syms.atom:
2031                 if not (
2032                     is_empty_tuple(child)
2033                     or is_one_tuple(child)
2034                     or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
2035                 ):
2036                     first = child.children[0]
2037                     last = child.children[-1]
2038                     if first.type == token.LPAR and last.type == token.RPAR:
2039                         # make parentheses invisible
2040                         first.value = ""  # type: ignore
2041                         last.value = ""  # type: ignore
2042             elif is_one_tuple(child):
2043                 # wrap child in visible parentheses
2044                 lpar = Leaf(token.LPAR, "(")
2045                 rpar = Leaf(token.RPAR, ")")
2046                 index = child.remove() or 0
2047                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2048             else:
2049                 # wrap child in invisible parentheses
2050                 lpar = Leaf(token.LPAR, "")
2051                 rpar = Leaf(token.RPAR, "")
2052                 index = child.remove() or 0
2053                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2054
2055         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2056
2057
2058 def is_empty_tuple(node: LN) -> bool:
2059     """Return True if `node` holds an empty tuple."""
2060     return (
2061         node.type == syms.atom
2062         and len(node.children) == 2
2063         and node.children[0].type == token.LPAR
2064         and node.children[1].type == token.RPAR
2065     )
2066
2067
2068 def is_one_tuple(node: LN) -> bool:
2069     """Return True if `node` holds a tuple with one element, with or without parens."""
2070     if node.type == syms.atom:
2071         if len(node.children) != 3:
2072             return False
2073
2074         lpar, gexp, rpar = node.children
2075         if not (
2076             lpar.type == token.LPAR
2077             and gexp.type == syms.testlist_gexp
2078             and rpar.type == token.RPAR
2079         ):
2080             return False
2081
2082         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2083
2084     return (
2085         node.type in IMPLICIT_TUPLE
2086         and len(node.children) == 2
2087         and node.children[1].type == token.COMMA
2088     )
2089
2090
2091 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2092     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2093
2094     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2095     If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
2096     hand-side extended iterable unpacking (PEP 3132) and additional unpacking
2097     generalizations (PEP 448).
2098     """
2099     if leaf.type not in STARS or not leaf.parent:
2100         return False
2101
2102     p = leaf.parent
2103     if p.type == syms.star_expr:
2104         # Star expressions are also used as assignment targets in extended
2105         # iterable unpacking (PEP 3132).  See what its parent is instead.
2106         if not p.parent:
2107             return False
2108
2109         p = p.parent
2110
2111     return p.type in within
2112
2113
2114 def max_delimiter_priority_in_atom(node: LN) -> int:
2115     if node.type != syms.atom:
2116         return 0
2117
2118     first = node.children[0]
2119     last = node.children[-1]
2120     if not (first.type == token.LPAR and last.type == token.RPAR):
2121         return 0
2122
2123     bt = BracketTracker()
2124     for c in node.children[1:-1]:
2125         if isinstance(c, Leaf):
2126             bt.mark(c)
2127         else:
2128             for leaf in c.leaves():
2129                 bt.mark(leaf)
2130     try:
2131         return bt.max_delimiter_priority()
2132
2133     except ValueError:
2134         return 0
2135
2136
2137 def ensure_visible(leaf: Leaf) -> None:
2138     """Make sure parentheses are visible.
2139
2140     They could be invisible as part of some statements (see
2141     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2142     """
2143     if leaf.type == token.LPAR:
2144         leaf.value = "("
2145     elif leaf.type == token.RPAR:
2146         leaf.value = ")"
2147
2148
2149 def is_python36(node: Node) -> bool:
2150     """Return True if the current file is using Python 3.6+ features.
2151
2152     Currently looking for:
2153     - f-strings; and
2154     - trailing commas after * or ** in function signatures.
2155     """
2156     for n in node.pre_order():
2157         if n.type == token.STRING:
2158             value_head = n.value[:2]  # type: ignore
2159             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2160                 return True
2161
2162         elif (
2163             n.type == syms.typedargslist
2164             and n.children
2165             and n.children[-1].type == token.COMMA
2166         ):
2167             for ch in n.children:
2168                 if ch.type in STARS:
2169                     return True
2170
2171     return False
2172
2173
2174 PYTHON_EXTENSIONS = {".py"}
2175 BLACKLISTED_DIRECTORIES = {
2176     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2177 }
2178
2179
2180 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2181     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2182     and have one of the PYTHON_EXTENSIONS.
2183     """
2184     for child in path.iterdir():
2185         if child.is_dir():
2186             if child.name in BLACKLISTED_DIRECTORIES:
2187                 continue
2188
2189             yield from gen_python_files_in_dir(child)
2190
2191         elif child.suffix in PYTHON_EXTENSIONS:
2192             yield child
2193
2194
2195 @dataclass
2196 class Report:
2197     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2198     check: bool = False
2199     quiet: bool = False
2200     change_count: int = 0
2201     same_count: int = 0
2202     failure_count: int = 0
2203
2204     def done(self, src: Path, changed: bool) -> None:
2205         """Increment the counter for successful reformatting. Write out a message."""
2206         if changed:
2207             reformatted = "would reformat" if self.check else "reformatted"
2208             if not self.quiet:
2209                 out(f"{reformatted} {src}")
2210             self.change_count += 1
2211         else:
2212             if not self.quiet:
2213                 out(f"{src} already well formatted, good job.", bold=False)
2214             self.same_count += 1
2215
2216     def failed(self, src: Path, message: str) -> None:
2217         """Increment the counter for failed reformatting. Write out a message."""
2218         err(f"error: cannot format {src}: {message}")
2219         self.failure_count += 1
2220
2221     @property
2222     def return_code(self) -> int:
2223         """Return the exit code that the app should use.
2224
2225         This considers the current state of changed files and failures:
2226         - if there were any failures, return 123;
2227         - if any files were changed and --check is being used, return 1;
2228         - otherwise return 0.
2229         """
2230         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2231         # 126 we have special returncodes reserved by the shell.
2232         if self.failure_count:
2233             return 123
2234
2235         elif self.change_count and self.check:
2236             return 1
2237
2238         return 0
2239
2240     def __str__(self) -> str:
2241         """Render a color report of the current state.
2242
2243         Use `click.unstyle` to remove colors.
2244         """
2245         if self.check:
2246             reformatted = "would be reformatted"
2247             unchanged = "would be left unchanged"
2248             failed = "would fail to reformat"
2249         else:
2250             reformatted = "reformatted"
2251             unchanged = "left unchanged"
2252             failed = "failed to reformat"
2253         report = []
2254         if self.change_count:
2255             s = "s" if self.change_count > 1 else ""
2256             report.append(
2257                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2258             )
2259         if self.same_count:
2260             s = "s" if self.same_count > 1 else ""
2261             report.append(f"{self.same_count} file{s} {unchanged}")
2262         if self.failure_count:
2263             s = "s" if self.failure_count > 1 else ""
2264             report.append(
2265                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2266             )
2267         return ", ".join(report) + "."
2268
2269
2270 def assert_equivalent(src: str, dst: str) -> None:
2271     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2272
2273     import ast
2274     import traceback
2275
2276     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2277         """Simple visitor generating strings to compare ASTs by content."""
2278         yield f"{'  ' * depth}{node.__class__.__name__}("
2279
2280         for field in sorted(node._fields):
2281             try:
2282                 value = getattr(node, field)
2283             except AttributeError:
2284                 continue
2285
2286             yield f"{'  ' * (depth+1)}{field}="
2287
2288             if isinstance(value, list):
2289                 for item in value:
2290                     if isinstance(item, ast.AST):
2291                         yield from _v(item, depth + 2)
2292
2293             elif isinstance(value, ast.AST):
2294                 yield from _v(value, depth + 2)
2295
2296             else:
2297                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2298
2299         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2300
2301     try:
2302         src_ast = ast.parse(src)
2303     except Exception as exc:
2304         major, minor = sys.version_info[:2]
2305         raise AssertionError(
2306             f"cannot use --safe with this file; failed to parse source file "
2307             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2308             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2309         )
2310
2311     try:
2312         dst_ast = ast.parse(dst)
2313     except Exception as exc:
2314         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2315         raise AssertionError(
2316             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2317             f"Please report a bug on https://github.com/ambv/black/issues.  "
2318             f"This invalid output might be helpful: {log}"
2319         ) from None
2320
2321     src_ast_str = "\n".join(_v(src_ast))
2322     dst_ast_str = "\n".join(_v(dst_ast))
2323     if src_ast_str != dst_ast_str:
2324         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2325         raise AssertionError(
2326             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2327             f"the source.  "
2328             f"Please report a bug on https://github.com/ambv/black/issues.  "
2329             f"This diff might be helpful: {log}"
2330         ) from None
2331
2332
2333 def assert_stable(src: str, dst: str, line_length: int) -> None:
2334     """Raise AssertionError if `dst` reformats differently the second time."""
2335     newdst = format_str(dst, line_length=line_length)
2336     if dst != newdst:
2337         log = dump_to_file(
2338             diff(src, dst, "source", "first pass"),
2339             diff(dst, newdst, "first pass", "second pass"),
2340         )
2341         raise AssertionError(
2342             f"INTERNAL ERROR: Black produced different code on the second pass "
2343             f"of the formatter.  "
2344             f"Please report a bug on https://github.com/ambv/black/issues.  "
2345             f"This diff might be helpful: {log}"
2346         ) from None
2347
2348
2349 def dump_to_file(*output: str) -> str:
2350     """Dump `output` to a temporary file. Return path to the file."""
2351     import tempfile
2352
2353     with tempfile.NamedTemporaryFile(
2354         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2355     ) as f:
2356         for lines in output:
2357             f.write(lines)
2358             if lines and lines[-1] != "\n":
2359                 f.write("\n")
2360     return f.name
2361
2362
2363 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2364     """Return a unified diff string between strings `a` and `b`."""
2365     import difflib
2366
2367     a_lines = [line + "\n" for line in a.split("\n")]
2368     b_lines = [line + "\n" for line in b.split("\n")]
2369     return "".join(
2370         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2371     )
2372
2373
2374 def cancel(tasks: List[asyncio.Task]) -> None:
2375     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2376     err("Aborted!")
2377     for task in tasks:
2378         task.cancel()
2379
2380
2381 def shutdown(loop: BaseEventLoop) -> None:
2382     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2383     try:
2384         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2385         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2386         if not to_cancel:
2387             return
2388
2389         for task in to_cancel:
2390             task.cancel()
2391         loop.run_until_complete(
2392             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2393         )
2394     finally:
2395         # `concurrent.futures.Future` objects cannot be cancelled once they
2396         # are already running. There might be some when the `shutdown()` happened.
2397         # Silence their logger's spew about the event loop being closed.
2398         cf_logger = logging.getLogger("concurrent.futures")
2399         cf_logger.setLevel(logging.CRITICAL)
2400         loop.close()
2401
2402
2403 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2404     """Replace `regex` with `replacement` twice on `original`.
2405
2406     This is used by string normalization to perform replaces on
2407     overlapping matches.
2408     """
2409     return regex.sub(replacement, regex.sub(replacement, original))
2410
2411
2412 if __name__ == "__main__":
2413     main()