]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Ignore `# fmt: off` as inline comment
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from enum import Enum
7 from functools import partial, wraps
8 import keyword
9 import logging
10 from multiprocessing import Manager
11 import os
12 from pathlib import Path
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Dict,
20     Generic,
21     Iterable,
22     Iterator,
23     List,
24     Optional,
25     Set,
26     Tuple,
27     Type,
28     TypeVar,
29     Union,
30 )
31
32 from attr import dataclass, Factory
33 import click
34
35 # lib2to3 fork
36 from blib2to3.pytree import Node, Leaf, type_repr
37 from blib2to3 import pygram, pytree
38 from blib2to3.pgen2 import driver, token
39 from blib2to3.pgen2.parse import ParseError
40
41 __version__ = "18.3a4"
42 DEFAULT_LINE_LENGTH = 88
43 # types
44 syms = pygram.python_symbols
45 FileContent = str
46 Encoding = str
47 Depth = int
48 NodeType = int
49 LeafID = int
50 Priority = int
51 Index = int
52 LN = Union[Leaf, Node]
53 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
54 out = partial(click.secho, bold=True, err=True)
55 err = partial(click.secho, fg="red", err=True)
56
57
58 class NothingChanged(UserWarning):
59     """Raised by :func:`format_file` when reformatted code is the same as source."""
60
61
62 class CannotSplit(Exception):
63     """A readable split that fits the allotted line length is impossible.
64
65     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
66     :func:`delimiter_split`.
67     """
68
69
70 class FormatError(Exception):
71     """Base exception for `# fmt: on` and `# fmt: off` handling.
72
73     It holds the number of bytes of the prefix consumed before the format
74     control comment appeared.
75     """
76
77     def __init__(self, consumed: int) -> None:
78         super().__init__(consumed)
79         self.consumed = consumed
80
81     def trim_prefix(self, leaf: Leaf) -> None:
82         leaf.prefix = leaf.prefix[self.consumed:]
83
84     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
85         """Returns a new Leaf from the consumed part of the prefix."""
86         unformatted_prefix = leaf.prefix[:self.consumed]
87         return Leaf(token.NEWLINE, unformatted_prefix)
88
89
90 class FormatOn(FormatError):
91     """Found a comment like `# fmt: on` in the file."""
92
93
94 class FormatOff(FormatError):
95     """Found a comment like `# fmt: off` in the file."""
96
97
98 class WriteBack(Enum):
99     NO = 0
100     YES = 1
101     DIFF = 2
102
103
104 @click.command()
105 @click.option(
106     "-l",
107     "--line-length",
108     type=int,
109     default=DEFAULT_LINE_LENGTH,
110     help="How many character per line to allow.",
111     show_default=True,
112 )
113 @click.option(
114     "--check",
115     is_flag=True,
116     help=(
117         "Don't write the files back, just return the status.  Return code 0 "
118         "means nothing would change.  Return code 1 means some files would be "
119         "reformatted.  Return code 123 means there was an internal error."
120     ),
121 )
122 @click.option(
123     "--diff",
124     is_flag=True,
125     help="Don't write the files back, just output a diff for each file on stdout.",
126 )
127 @click.option(
128     "--fast/--safe",
129     is_flag=True,
130     help="If --fast given, skip temporary sanity checks. [default: --safe]",
131 )
132 @click.version_option(version=__version__)
133 @click.argument(
134     "src",
135     nargs=-1,
136     type=click.Path(
137         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
138     ),
139 )
140 @click.pass_context
141 def main(
142     ctx: click.Context,
143     line_length: int,
144     check: bool,
145     diff: bool,
146     fast: bool,
147     src: List[str],
148 ) -> None:
149     """The uncompromising code formatter."""
150     sources: List[Path] = []
151     for s in src:
152         p = Path(s)
153         if p.is_dir():
154             sources.extend(gen_python_files_in_dir(p))
155         elif p.is_file():
156             # if a file was explicitly given, we don't care about its extension
157             sources.append(p)
158         elif s == "-":
159             sources.append(Path("-"))
160         else:
161             err(f"invalid path: {s}")
162     if check and diff:
163         exc = click.ClickException("Options --check and --diff are mutually exclusive")
164         exc.exit_code = 2
165         raise exc
166
167     if check:
168         write_back = WriteBack.NO
169     elif diff:
170         write_back = WriteBack.DIFF
171     else:
172         write_back = WriteBack.YES
173     if len(sources) == 0:
174         ctx.exit(0)
175     elif len(sources) == 1:
176         p = sources[0]
177         report = Report(check=check)
178         try:
179             if not p.is_file() and str(p) == "-":
180                 changed = format_stdin_to_stdout(
181                     line_length=line_length, fast=fast, write_back=write_back
182                 )
183             else:
184                 changed = format_file_in_place(
185                     p, line_length=line_length, fast=fast, write_back=write_back
186                 )
187             report.done(p, changed)
188         except Exception as exc:
189             report.failed(p, str(exc))
190         ctx.exit(report.return_code)
191     else:
192         loop = asyncio.get_event_loop()
193         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
194         return_code = 1
195         try:
196             return_code = loop.run_until_complete(
197                 schedule_formatting(
198                     sources, line_length, write_back, fast, loop, executor
199                 )
200             )
201         finally:
202             shutdown(loop)
203             ctx.exit(return_code)
204
205
206 async def schedule_formatting(
207     sources: List[Path],
208     line_length: int,
209     write_back: WriteBack,
210     fast: bool,
211     loop: BaseEventLoop,
212     executor: Executor,
213 ) -> int:
214     """Run formatting of `sources` in parallel using the provided `executor`.
215
216     (Use ProcessPoolExecutors for actual parallelism.)
217
218     `line_length`, `write_back`, and `fast` options are passed to
219     :func:`format_file_in_place`.
220     """
221     lock = None
222     if write_back == WriteBack.DIFF:
223         # For diff output, we need locks to ensure we don't interleave output
224         # from different processes.
225         manager = Manager()
226         lock = manager.Lock()
227     tasks = {
228         src: loop.run_in_executor(
229             executor, format_file_in_place, src, line_length, fast, write_back, lock
230         )
231         for src in sources
232     }
233     _task_values = list(tasks.values())
234     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
235     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
236     await asyncio.wait(tasks.values())
237     cancelled = []
238     report = Report(check=not write_back)
239     for src, task in tasks.items():
240         if not task.done():
241             report.failed(src, "timed out, cancelling")
242             task.cancel()
243             cancelled.append(task)
244         elif task.cancelled():
245             cancelled.append(task)
246         elif task.exception():
247             report.failed(src, str(task.exception()))
248         else:
249             report.done(src, task.result())
250     if cancelled:
251         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
252     else:
253         out("All done! ✨ 🍰 ✨")
254     click.echo(str(report))
255     return report.return_code
256
257
258 def format_file_in_place(
259     src: Path,
260     line_length: int,
261     fast: bool,
262     write_back: WriteBack = WriteBack.NO,
263     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
264 ) -> bool:
265     """Format file under `src` path. Return True if changed.
266
267     If `write_back` is True, write reformatted code back to stdout.
268     `line_length` and `fast` options are passed to :func:`format_file_contents`.
269     """
270     with tokenize.open(src) as src_buffer:
271         src_contents = src_buffer.read()
272     try:
273         dst_contents = format_file_contents(
274             src_contents, line_length=line_length, fast=fast
275         )
276     except NothingChanged:
277         return False
278
279     if write_back == write_back.YES:
280         with open(src, "w", encoding=src_buffer.encoding) as f:
281             f.write(dst_contents)
282     elif write_back == write_back.DIFF:
283         src_name = f"{src.name}  (original)"
284         dst_name = f"{src.name}  (formatted)"
285         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
286         if lock:
287             lock.acquire()
288         try:
289             sys.stdout.write(diff_contents)
290         finally:
291             if lock:
292                 lock.release()
293     return True
294
295
296 def format_stdin_to_stdout(
297     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
298 ) -> bool:
299     """Format file on stdin. Return True if changed.
300
301     If `write_back` is True, write reformatted code back to stdout.
302     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
303     """
304     src = sys.stdin.read()
305     try:
306         dst = format_file_contents(src, line_length=line_length, fast=fast)
307         return True
308
309     except NothingChanged:
310         dst = src
311         return False
312
313     finally:
314         if write_back == WriteBack.YES:
315             sys.stdout.write(dst)
316         elif write_back == WriteBack.DIFF:
317             src_name = "<stdin>  (original)"
318             dst_name = "<stdin>  (formatted)"
319             sys.stdout.write(diff(src, dst, src_name, dst_name))
320
321
322 def format_file_contents(
323     src_contents: str, line_length: int, fast: bool
324 ) -> FileContent:
325     """Reformat contents a file and return new contents.
326
327     If `fast` is False, additionally confirm that the reformatted code is
328     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
329     `line_length` is passed to :func:`format_str`.
330     """
331     if src_contents.strip() == "":
332         raise NothingChanged
333
334     dst_contents = format_str(src_contents, line_length=line_length)
335     if src_contents == dst_contents:
336         raise NothingChanged
337
338     if not fast:
339         assert_equivalent(src_contents, dst_contents)
340         assert_stable(src_contents, dst_contents, line_length=line_length)
341     return dst_contents
342
343
344 def format_str(src_contents: str, line_length: int) -> FileContent:
345     """Reformat a string and return new contents.
346
347     `line_length` determines how many characters per line are allowed.
348     """
349     src_node = lib2to3_parse(src_contents)
350     dst_contents = ""
351     lines = LineGenerator()
352     elt = EmptyLineTracker()
353     py36 = is_python36(src_node)
354     empty_line = Line()
355     after = 0
356     for current_line in lines.visit(src_node):
357         for _ in range(after):
358             dst_contents += str(empty_line)
359         before, after = elt.maybe_empty_lines(current_line)
360         for _ in range(before):
361             dst_contents += str(empty_line)
362         for line in split_line(current_line, line_length=line_length, py36=py36):
363             dst_contents += str(line)
364     return dst_contents
365
366
367 GRAMMARS = [
368     pygram.python_grammar_no_print_statement_no_exec_statement,
369     pygram.python_grammar_no_print_statement,
370     pygram.python_grammar_no_exec_statement,
371     pygram.python_grammar,
372 ]
373
374
375 def lib2to3_parse(src_txt: str) -> Node:
376     """Given a string with source, return the lib2to3 Node."""
377     grammar = pygram.python_grammar_no_print_statement
378     if src_txt[-1] != "\n":
379         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
380         src_txt += nl
381     for grammar in GRAMMARS:
382         drv = driver.Driver(grammar, pytree.convert)
383         try:
384             result = drv.parse_string(src_txt, True)
385             break
386
387         except ParseError as pe:
388             lineno, column = pe.context[1]
389             lines = src_txt.splitlines()
390             try:
391                 faulty_line = lines[lineno - 1]
392             except IndexError:
393                 faulty_line = "<line number missing in source>"
394             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
395     else:
396         raise exc from None
397
398     if isinstance(result, Leaf):
399         result = Node(syms.file_input, [result])
400     return result
401
402
403 def lib2to3_unparse(node: Node) -> str:
404     """Given a lib2to3 node, return its string representation."""
405     code = str(node)
406     return code
407
408
409 T = TypeVar("T")
410
411
412 class Visitor(Generic[T]):
413     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
414
415     def visit(self, node: LN) -> Iterator[T]:
416         """Main method to visit `node` and its children.
417
418         It tries to find a `visit_*()` method for the given `node.type`, like
419         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
420         If no dedicated `visit_*()` method is found, chooses `visit_default()`
421         instead.
422
423         Then yields objects of type `T` from the selected visitor.
424         """
425         if node.type < 256:
426             name = token.tok_name[node.type]
427         else:
428             name = type_repr(node.type)
429         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
430
431     def visit_default(self, node: LN) -> Iterator[T]:
432         """Default `visit_*()` implementation. Recurses to children of `node`."""
433         if isinstance(node, Node):
434             for child in node.children:
435                 yield from self.visit(child)
436
437
438 @dataclass
439 class DebugVisitor(Visitor[T]):
440     tree_depth: int = 0
441
442     def visit_default(self, node: LN) -> Iterator[T]:
443         indent = " " * (2 * self.tree_depth)
444         if isinstance(node, Node):
445             _type = type_repr(node.type)
446             out(f"{indent}{_type}", fg="yellow")
447             self.tree_depth += 1
448             for child in node.children:
449                 yield from self.visit(child)
450
451             self.tree_depth -= 1
452             out(f"{indent}/{_type}", fg="yellow", bold=False)
453         else:
454             _type = token.tok_name.get(node.type, str(node.type))
455             out(f"{indent}{_type}", fg="blue", nl=False)
456             if node.prefix:
457                 # We don't have to handle prefixes for `Node` objects since
458                 # that delegates to the first child anyway.
459                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
460             out(f" {node.value!r}", fg="blue", bold=False)
461
462     @classmethod
463     def show(cls, code: str) -> None:
464         """Pretty-print the lib2to3 AST of a given string of `code`.
465
466         Convenience method for debugging.
467         """
468         v: DebugVisitor[None] = DebugVisitor()
469         list(v.visit(lib2to3_parse(code)))
470
471
472 KEYWORDS = set(keyword.kwlist)
473 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
474 FLOW_CONTROL = {"return", "raise", "break", "continue"}
475 STATEMENT = {
476     syms.if_stmt,
477     syms.while_stmt,
478     syms.for_stmt,
479     syms.try_stmt,
480     syms.except_clause,
481     syms.with_stmt,
482     syms.funcdef,
483     syms.classdef,
484 }
485 STANDALONE_COMMENT = 153
486 LOGIC_OPERATORS = {"and", "or"}
487 COMPARATORS = {
488     token.LESS,
489     token.GREATER,
490     token.EQEQUAL,
491     token.NOTEQUAL,
492     token.LESSEQUAL,
493     token.GREATEREQUAL,
494 }
495 MATH_OPERATORS = {
496     token.PLUS,
497     token.MINUS,
498     token.STAR,
499     token.SLASH,
500     token.VBAR,
501     token.AMPER,
502     token.PERCENT,
503     token.CIRCUMFLEX,
504     token.TILDE,
505     token.LEFTSHIFT,
506     token.RIGHTSHIFT,
507     token.DOUBLESTAR,
508     token.DOUBLESLASH,
509 }
510 VARARGS = {token.STAR, token.DOUBLESTAR}
511 COMPREHENSION_PRIORITY = 20
512 COMMA_PRIORITY = 10
513 LOGIC_PRIORITY = 5
514 STRING_PRIORITY = 4
515 COMPARATOR_PRIORITY = 3
516 MATH_PRIORITY = 1
517
518
519 @dataclass
520 class BracketTracker:
521     """Keeps track of brackets on a line."""
522
523     depth: int = 0
524     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
525     delimiters: Dict[LeafID, Priority] = Factory(dict)
526     previous: Optional[Leaf] = None
527
528     def mark(self, leaf: Leaf) -> None:
529         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
530
531         All leaves receive an int `bracket_depth` field that stores how deep
532         within brackets a given leaf is. 0 means there are no enclosing brackets
533         that started on this line.
534
535         If a leaf is itself a closing bracket, it receives an `opening_bracket`
536         field that it forms a pair with. This is a one-directional link to
537         avoid reference cycles.
538
539         If a leaf is a delimiter (a token on which Black can split the line if
540         needed) and it's on depth 0, its `id()` is stored in the tracker's
541         `delimiters` field.
542         """
543         if leaf.type == token.COMMENT:
544             return
545
546         if leaf.type in CLOSING_BRACKETS:
547             self.depth -= 1
548             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
549             leaf.opening_bracket = opening_bracket
550         leaf.bracket_depth = self.depth
551         if self.depth == 0:
552             after_delim = is_split_after_delimiter(leaf, self.previous)
553             before_delim = is_split_before_delimiter(leaf, self.previous)
554             if after_delim > before_delim:
555                 self.delimiters[id(leaf)] = after_delim
556             elif before_delim > after_delim and self.previous is not None:
557                 self.delimiters[id(self.previous)] = before_delim
558         if leaf.type in OPENING_BRACKETS:
559             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
560             self.depth += 1
561         self.previous = leaf
562
563     def any_open_brackets(self) -> bool:
564         """Return True if there is an yet unmatched open bracket on the line."""
565         return bool(self.bracket_match)
566
567     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
568         """Return the highest priority of a delimiter found on the line.
569
570         Values are consistent with what `is_delimiter()` returns.
571         """
572         return max(v for k, v in self.delimiters.items() if k not in exclude)
573
574
575 @dataclass
576 class Line:
577     """Holds leaves and comments. Can be printed with `str(line)`."""
578
579     depth: int = 0
580     leaves: List[Leaf] = Factory(list)
581     comments: List[Tuple[Index, Leaf]] = Factory(list)
582     bracket_tracker: BracketTracker = Factory(BracketTracker)
583     inside_brackets: bool = False
584     has_for: bool = False
585     _for_loop_variable: bool = False
586
587     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
588         """Add a new `leaf` to the end of the line.
589
590         Unless `preformatted` is True, the `leaf` will receive a new consistent
591         whitespace prefix and metadata applied by :class:`BracketTracker`.
592         Trailing commas are maybe removed, unpacked for loop variables are
593         demoted from being delimiters.
594
595         Inline comments are put aside.
596         """
597         has_value = leaf.value.strip()
598         if not has_value:
599             return
600
601         if self.leaves and not preformatted:
602             # Note: at this point leaf.prefix should be empty except for
603             # imports, for which we only preserve newlines.
604             leaf.prefix += whitespace(leaf)
605         if self.inside_brackets or not preformatted:
606             self.maybe_decrement_after_for_loop_variable(leaf)
607             self.bracket_tracker.mark(leaf)
608             self.maybe_remove_trailing_comma(leaf)
609             self.maybe_increment_for_loop_variable(leaf)
610
611         if not self.append_comment(leaf):
612             self.leaves.append(leaf)
613
614     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
615         """Like :func:`append()` but disallow invalid standalone comment structure.
616
617         Raises ValueError when any `leaf` is appended after a standalone comment
618         or when a standalone comment is not the first leaf on the line.
619         """
620         if self.bracket_tracker.depth == 0:
621             if self.is_comment:
622                 raise ValueError("cannot append to standalone comments")
623
624             if self.leaves and leaf.type == STANDALONE_COMMENT:
625                 raise ValueError(
626                     "cannot append standalone comments to a populated line"
627                 )
628
629         self.append(leaf, preformatted=preformatted)
630
631     @property
632     def is_comment(self) -> bool:
633         """Is this line a standalone comment?"""
634         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
635
636     @property
637     def is_decorator(self) -> bool:
638         """Is this line a decorator?"""
639         return bool(self) and self.leaves[0].type == token.AT
640
641     @property
642     def is_import(self) -> bool:
643         """Is this an import line?"""
644         return bool(self) and is_import(self.leaves[0])
645
646     @property
647     def is_class(self) -> bool:
648         """Is this line a class definition?"""
649         return (
650             bool(self)
651             and self.leaves[0].type == token.NAME
652             and self.leaves[0].value == "class"
653         )
654
655     @property
656     def is_def(self) -> bool:
657         """Is this a function definition? (Also returns True for async defs.)"""
658         try:
659             first_leaf = self.leaves[0]
660         except IndexError:
661             return False
662
663         try:
664             second_leaf: Optional[Leaf] = self.leaves[1]
665         except IndexError:
666             second_leaf = None
667         return (
668             (first_leaf.type == token.NAME and first_leaf.value == "def")
669             or (
670                 first_leaf.type == token.ASYNC
671                 and second_leaf is not None
672                 and second_leaf.type == token.NAME
673                 and second_leaf.value == "def"
674             )
675         )
676
677     @property
678     def is_flow_control(self) -> bool:
679         """Is this line a flow control statement?
680
681         Those are `return`, `raise`, `break`, and `continue`.
682         """
683         return (
684             bool(self)
685             and self.leaves[0].type == token.NAME
686             and self.leaves[0].value in FLOW_CONTROL
687         )
688
689     @property
690     def is_yield(self) -> bool:
691         """Is this line a yield statement?"""
692         return (
693             bool(self)
694             and self.leaves[0].type == token.NAME
695             and self.leaves[0].value == "yield"
696         )
697
698     @property
699     def contains_standalone_comments(self) -> bool:
700         """If so, needs to be split before emitting."""
701         for leaf in self.leaves:
702             if leaf.type == STANDALONE_COMMENT:
703                 return True
704
705         return False
706
707     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
708         """Remove trailing comma if there is one and it's safe."""
709         if not (
710             self.leaves
711             and self.leaves[-1].type == token.COMMA
712             and closing.type in CLOSING_BRACKETS
713         ):
714             return False
715
716         if closing.type == token.RBRACE:
717             self.remove_trailing_comma()
718             return True
719
720         if closing.type == token.RSQB:
721             comma = self.leaves[-1]
722             if comma.parent and comma.parent.type == syms.listmaker:
723                 self.remove_trailing_comma()
724                 return True
725
726         # For parens let's check if it's safe to remove the comma.  If the
727         # trailing one is the only one, we might mistakenly change a tuple
728         # into a different type by removing the comma.
729         depth = closing.bracket_depth + 1
730         commas = 0
731         opening = closing.opening_bracket
732         for _opening_index, leaf in enumerate(self.leaves):
733             if leaf is opening:
734                 break
735
736         else:
737             return False
738
739         for leaf in self.leaves[_opening_index + 1:]:
740             if leaf is closing:
741                 break
742
743             bracket_depth = leaf.bracket_depth
744             if bracket_depth == depth and leaf.type == token.COMMA:
745                 commas += 1
746                 if leaf.parent and leaf.parent.type == syms.arglist:
747                     commas += 1
748                     break
749
750         if commas > 1:
751             self.remove_trailing_comma()
752             return True
753
754         return False
755
756     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
757         """In a for loop, or comprehension, the variables are often unpacks.
758
759         To avoid splitting on the comma in this situation, increase the depth of
760         tokens between `for` and `in`.
761         """
762         if leaf.type == token.NAME and leaf.value == "for":
763             self.has_for = True
764             self.bracket_tracker.depth += 1
765             self._for_loop_variable = True
766             return True
767
768         return False
769
770     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
771         """See `maybe_increment_for_loop_variable` above for explanation."""
772         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
773             self.bracket_tracker.depth -= 1
774             self._for_loop_variable = False
775             return True
776
777         return False
778
779     def append_comment(self, comment: Leaf) -> bool:
780         """Add an inline or standalone comment to the line."""
781         if (
782             comment.type == STANDALONE_COMMENT
783             and self.bracket_tracker.any_open_brackets()
784         ):
785             comment.prefix = ""
786             return False
787
788         if comment.type != token.COMMENT:
789             return False
790
791         after = len(self.leaves) - 1
792         if after == -1:
793             comment.type = STANDALONE_COMMENT
794             comment.prefix = ""
795             return False
796
797         else:
798             self.comments.append((after, comment))
799             return True
800
801     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
802         """Generate comments that should appear directly after `leaf`."""
803         for _leaf_index, _leaf in enumerate(self.leaves):
804             if leaf is _leaf:
805                 break
806
807         else:
808             return
809
810         for index, comment_after in self.comments:
811             if _leaf_index == index:
812                 yield comment_after
813
814     def remove_trailing_comma(self) -> None:
815         """Remove the trailing comma and moves the comments attached to it."""
816         comma_index = len(self.leaves) - 1
817         for i in range(len(self.comments)):
818             comment_index, comment = self.comments[i]
819             if comment_index == comma_index:
820                 self.comments[i] = (comma_index - 1, comment)
821         self.leaves.pop()
822
823     def __str__(self) -> str:
824         """Render the line."""
825         if not self:
826             return "\n"
827
828         indent = "    " * self.depth
829         leaves = iter(self.leaves)
830         first = next(leaves)
831         res = f"{first.prefix}{indent}{first.value}"
832         for leaf in leaves:
833             res += str(leaf)
834         for _, comment in self.comments:
835             res += str(comment)
836         return res + "\n"
837
838     def __bool__(self) -> bool:
839         """Return True if the line has leaves or comments."""
840         return bool(self.leaves or self.comments)
841
842
843 class UnformattedLines(Line):
844     """Just like :class:`Line` but stores lines which aren't reformatted."""
845
846     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
847         """Just add a new `leaf` to the end of the lines.
848
849         The `preformatted` argument is ignored.
850
851         Keeps track of indentation `depth`, which is useful when the user
852         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
853         """
854         try:
855             list(generate_comments(leaf))
856         except FormatOn as f_on:
857             self.leaves.append(f_on.leaf_from_consumed(leaf))
858             raise
859
860         self.leaves.append(leaf)
861         if leaf.type == token.INDENT:
862             self.depth += 1
863         elif leaf.type == token.DEDENT:
864             self.depth -= 1
865
866     def __str__(self) -> str:
867         """Render unformatted lines from leaves which were added with `append()`.
868
869         `depth` is not used for indentation in this case.
870         """
871         if not self:
872             return "\n"
873
874         res = ""
875         for leaf in self.leaves:
876             res += str(leaf)
877         return res
878
879     def append_comment(self, comment: Leaf) -> bool:
880         """Not implemented in this class. Raises `NotImplementedError`."""
881         raise NotImplementedError("Unformatted lines don't store comments separately.")
882
883     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
884         """Does nothing and returns False."""
885         return False
886
887     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
888         """Does nothing and returns False."""
889         return False
890
891
892 @dataclass
893 class EmptyLineTracker:
894     """Provides a stateful method that returns the number of potential extra
895     empty lines needed before and after the currently processed line.
896
897     Note: this tracker works on lines that haven't been split yet.  It assumes
898     the prefix of the first leaf consists of optional newlines.  Those newlines
899     are consumed by `maybe_empty_lines()` and included in the computation.
900     """
901     previous_line: Optional[Line] = None
902     previous_after: int = 0
903     previous_defs: List[int] = Factory(list)
904
905     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
906         """Return the number of extra empty lines before and after the `current_line`.
907
908         This is for separating `def`, `async def` and `class` with extra empty
909         lines (two on module-level), as well as providing an extra empty line
910         after flow control keywords to make them more prominent.
911         """
912         if isinstance(current_line, UnformattedLines):
913             return 0, 0
914
915         before, after = self._maybe_empty_lines(current_line)
916         before -= self.previous_after
917         self.previous_after = after
918         self.previous_line = current_line
919         return before, after
920
921     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
922         max_allowed = 1
923         if current_line.depth == 0:
924             max_allowed = 2
925         if current_line.leaves:
926             # Consume the first leaf's extra newlines.
927             first_leaf = current_line.leaves[0]
928             before = first_leaf.prefix.count("\n")
929             before = min(before, max_allowed)
930             first_leaf.prefix = ""
931         else:
932             before = 0
933         depth = current_line.depth
934         while self.previous_defs and self.previous_defs[-1] >= depth:
935             self.previous_defs.pop()
936             before = 1 if depth else 2
937         is_decorator = current_line.is_decorator
938         if is_decorator or current_line.is_def or current_line.is_class:
939             if not is_decorator:
940                 self.previous_defs.append(depth)
941             if self.previous_line is None:
942                 # Don't insert empty lines before the first line in the file.
943                 return 0, 0
944
945             if self.previous_line and self.previous_line.is_decorator:
946                 # Don't insert empty lines between decorators.
947                 return 0, 0
948
949             newlines = 2
950             if current_line.depth:
951                 newlines -= 1
952             return newlines, 0
953
954         if current_line.is_flow_control:
955             return before, 1
956
957         if (
958             self.previous_line
959             and self.previous_line.is_import
960             and not current_line.is_import
961             and depth == self.previous_line.depth
962         ):
963             return (before or 1), 0
964
965         if (
966             self.previous_line
967             and self.previous_line.is_yield
968             and (not current_line.is_yield or depth != self.previous_line.depth)
969         ):
970             return (before or 1), 0
971
972         return before, 0
973
974
975 @dataclass
976 class LineGenerator(Visitor[Line]):
977     """Generates reformatted Line objects.  Empty lines are not emitted.
978
979     Note: destroys the tree it's visiting by mutating prefixes of its leaves
980     in ways that will no longer stringify to valid Python code on the tree.
981     """
982     current_line: Line = Factory(Line)
983
984     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
985         """Generate a line.
986
987         If the line is empty, only emit if it makes sense.
988         If the line is too long, split it first and then generate.
989
990         If any lines were generated, set up a new current_line.
991         """
992         if not self.current_line:
993             if self.current_line.__class__ == type:
994                 self.current_line.depth += indent
995             else:
996                 self.current_line = type(depth=self.current_line.depth + indent)
997             return  # Line is empty, don't emit. Creating a new one unnecessary.
998
999         complete_line = self.current_line
1000         self.current_line = type(depth=complete_line.depth + indent)
1001         yield complete_line
1002
1003     def visit(self, node: LN) -> Iterator[Line]:
1004         """Main method to visit `node` and its children.
1005
1006         Yields :class:`Line` objects.
1007         """
1008         if isinstance(self.current_line, UnformattedLines):
1009             # File contained `# fmt: off`
1010             yield from self.visit_unformatted(node)
1011
1012         else:
1013             yield from super().visit(node)
1014
1015     def visit_default(self, node: LN) -> Iterator[Line]:
1016         """Default `visit_*()` implementation. Recurses to children of `node`."""
1017         if isinstance(node, Leaf):
1018             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1019             try:
1020                 for comment in generate_comments(node):
1021                     if any_open_brackets:
1022                         # any comment within brackets is subject to splitting
1023                         self.current_line.append(comment)
1024                     elif comment.type == token.COMMENT:
1025                         # regular trailing comment
1026                         self.current_line.append(comment)
1027                         yield from self.line()
1028
1029                     else:
1030                         # regular standalone comment
1031                         yield from self.line()
1032
1033                         self.current_line.append(comment)
1034                         yield from self.line()
1035
1036             except FormatOff as f_off:
1037                 f_off.trim_prefix(node)
1038                 yield from self.line(type=UnformattedLines)
1039                 yield from self.visit(node)
1040
1041             except FormatOn as f_on:
1042                 # This only happens here if somebody says "fmt: on" multiple
1043                 # times in a row.
1044                 f_on.trim_prefix(node)
1045                 yield from self.visit_default(node)
1046
1047             else:
1048                 normalize_prefix(node, inside_brackets=any_open_brackets)
1049                 if node.type == token.STRING:
1050                     normalize_string_quotes(node)
1051                 if node.type not in WHITESPACE:
1052                     self.current_line.append(node)
1053         yield from super().visit_default(node)
1054
1055     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1056         """Increase indentation level, maybe yield a line."""
1057         # In blib2to3 INDENT never holds comments.
1058         yield from self.line(+1)
1059         yield from self.visit_default(node)
1060
1061     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1062         """Decrease indentation level, maybe yield a line."""
1063         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1064         yield from self.line(-1)
1065
1066     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1067         """Visit a statement.
1068
1069         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1070         `def`, `with`, and `class`.
1071
1072         The relevant Python language `keywords` for a given statement will be NAME
1073         leaves within it. This methods puts those on a separate line.
1074         """
1075         for child in node.children:
1076             if child.type == token.NAME and child.value in keywords:  # type: ignore
1077                 yield from self.line()
1078
1079             yield from self.visit(child)
1080
1081     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1082         """Visit a statement without nested statements."""
1083         is_suite_like = node.parent and node.parent.type in STATEMENT
1084         if is_suite_like:
1085             yield from self.line(+1)
1086             yield from self.visit_default(node)
1087             yield from self.line(-1)
1088
1089         else:
1090             yield from self.line()
1091             yield from self.visit_default(node)
1092
1093     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1094         """Visit `async def`, `async for`, `async with`."""
1095         yield from self.line()
1096
1097         children = iter(node.children)
1098         for child in children:
1099             yield from self.visit(child)
1100
1101             if child.type == token.ASYNC:
1102                 break
1103
1104         internal_stmt = next(children)
1105         for child in internal_stmt.children:
1106             yield from self.visit(child)
1107
1108     def visit_decorators(self, node: Node) -> Iterator[Line]:
1109         """Visit decorators."""
1110         for child in node.children:
1111             yield from self.line()
1112             yield from self.visit(child)
1113
1114     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1115         """Remove a semicolon and put the other statement on a separate line."""
1116         yield from self.line()
1117
1118     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1119         """End of file. Process outstanding comments and end with a newline."""
1120         yield from self.visit_default(leaf)
1121         yield from self.line()
1122
1123     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1124         """Used when file contained a `# fmt: off`."""
1125         if isinstance(node, Node):
1126             for child in node.children:
1127                 yield from self.visit(child)
1128
1129         else:
1130             try:
1131                 self.current_line.append(node)
1132             except FormatOn as f_on:
1133                 f_on.trim_prefix(node)
1134                 yield from self.line()
1135                 yield from self.visit(node)
1136
1137             if node.type == token.ENDMARKER:
1138                 # somebody decided not to put a final `# fmt: on`
1139                 yield from self.line()
1140
1141     def __attrs_post_init__(self) -> None:
1142         """You are in a twisty little maze of passages."""
1143         v = self.visit_stmt
1144         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
1145         self.visit_while_stmt = partial(v, keywords={"while", "else"})
1146         self.visit_for_stmt = partial(v, keywords={"for", "else"})
1147         self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
1148         self.visit_except_clause = partial(v, keywords={"except"})
1149         self.visit_funcdef = partial(v, keywords={"def"})
1150         self.visit_with_stmt = partial(v, keywords={"with"})
1151         self.visit_classdef = partial(v, keywords={"class"})
1152         self.visit_async_funcdef = self.visit_async_stmt
1153         self.visit_decorated = self.visit_decorators
1154
1155
1156 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1157 OPENING_BRACKETS = set(BRACKET.keys())
1158 CLOSING_BRACKETS = set(BRACKET.values())
1159 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1160 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1161
1162
1163 def whitespace(leaf: Leaf) -> str:  # noqa C901
1164     """Return whitespace prefix if needed for the given `leaf`."""
1165     NO = ""
1166     SPACE = " "
1167     DOUBLESPACE = "  "
1168     t = leaf.type
1169     p = leaf.parent
1170     v = leaf.value
1171     if t in ALWAYS_NO_SPACE:
1172         return NO
1173
1174     if t == token.COMMENT:
1175         return DOUBLESPACE
1176
1177     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1178     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1179         return NO
1180
1181     prev = leaf.prev_sibling
1182     if not prev:
1183         prevp = preceding_leaf(p)
1184         if not prevp or prevp.type in OPENING_BRACKETS:
1185             return NO
1186
1187         if t == token.COLON:
1188             return SPACE if prevp.type == token.COMMA else NO
1189
1190         if prevp.type == token.EQUAL:
1191             if prevp.parent:
1192                 if prevp.parent.type in {
1193                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1194                 }:
1195                     return NO
1196
1197                 elif prevp.parent.type == syms.typedargslist:
1198                     # A bit hacky: if the equal sign has whitespace, it means we
1199                     # previously found it's a typed argument.  So, we're using
1200                     # that, too.
1201                     return prevp.prefix
1202
1203         elif prevp.type == token.DOUBLESTAR:
1204             if prevp.parent and prevp.parent.type in {
1205                 syms.arglist,
1206                 syms.argument,
1207                 syms.dictsetmaker,
1208                 syms.parameters,
1209                 syms.typedargslist,
1210                 syms.varargslist,
1211             }:
1212                 return NO
1213
1214         elif prevp.type == token.COLON:
1215             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1216                 return NO
1217
1218         elif (
1219             prevp.parent
1220             and prevp.parent.type in {syms.factor, syms.star_expr}
1221             and prevp.type in MATH_OPERATORS
1222         ):
1223             return NO
1224
1225         elif (
1226             prevp.type == token.RIGHTSHIFT
1227             and prevp.parent
1228             and prevp.parent.type == syms.shift_expr
1229             and prevp.prev_sibling
1230             and prevp.prev_sibling.type == token.NAME
1231             and prevp.prev_sibling.value == "print"  # type: ignore
1232         ):
1233             # Python 2 print chevron
1234             return NO
1235
1236     elif prev.type in OPENING_BRACKETS:
1237         return NO
1238
1239     if p.type in {syms.parameters, syms.arglist}:
1240         # untyped function signatures or calls
1241         if t == token.RPAR:
1242             return NO
1243
1244         if not prev or prev.type != token.COMMA:
1245             return NO
1246
1247     elif p.type == syms.varargslist:
1248         # lambdas
1249         if t == token.RPAR:
1250             return NO
1251
1252         if prev and prev.type != token.COMMA:
1253             return NO
1254
1255     elif p.type == syms.typedargslist:
1256         # typed function signatures
1257         if not prev:
1258             return NO
1259
1260         if t == token.EQUAL:
1261             if prev.type != syms.tname:
1262                 return NO
1263
1264         elif prev.type == token.EQUAL:
1265             # A bit hacky: if the equal sign has whitespace, it means we
1266             # previously found it's a typed argument.  So, we're using that, too.
1267             return prev.prefix
1268
1269         elif prev.type != token.COMMA:
1270             return NO
1271
1272     elif p.type == syms.tname:
1273         # type names
1274         if not prev:
1275             prevp = preceding_leaf(p)
1276             if not prevp or prevp.type != token.COMMA:
1277                 return NO
1278
1279     elif p.type == syms.trailer:
1280         # attributes and calls
1281         if t == token.LPAR or t == token.RPAR:
1282             return NO
1283
1284         if not prev:
1285             if t == token.DOT:
1286                 prevp = preceding_leaf(p)
1287                 if not prevp or prevp.type != token.NUMBER:
1288                     return NO
1289
1290             elif t == token.LSQB:
1291                 return NO
1292
1293         elif prev.type != token.COMMA:
1294             return NO
1295
1296     elif p.type == syms.argument:
1297         # single argument
1298         if t == token.EQUAL:
1299             return NO
1300
1301         if not prev:
1302             prevp = preceding_leaf(p)
1303             if not prevp or prevp.type == token.LPAR:
1304                 return NO
1305
1306         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1307             return NO
1308
1309     elif p.type == syms.decorator:
1310         # decorators
1311         return NO
1312
1313     elif p.type == syms.dotted_name:
1314         if prev:
1315             return NO
1316
1317         prevp = preceding_leaf(p)
1318         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1319             return NO
1320
1321     elif p.type == syms.classdef:
1322         if t == token.LPAR:
1323             return NO
1324
1325         if prev and prev.type == token.LPAR:
1326             return NO
1327
1328     elif p.type == syms.subscript:
1329         # indexing
1330         if not prev:
1331             assert p.parent is not None, "subscripts are always parented"
1332             if p.parent.type == syms.subscriptlist:
1333                 return SPACE
1334
1335             return NO
1336
1337         else:
1338             return NO
1339
1340     elif p.type == syms.atom:
1341         if prev and t == token.DOT:
1342             # dots, but not the first one.
1343             return NO
1344
1345     elif (
1346         p.type == syms.listmaker
1347         or p.type == syms.testlist_gexp
1348         or p.type == syms.subscriptlist
1349     ):
1350         # list interior, including unpacking
1351         if not prev:
1352             return NO
1353
1354     elif p.type == syms.dictsetmaker:
1355         # dict and set interior, including unpacking
1356         if not prev:
1357             return NO
1358
1359         if prev.type == token.DOUBLESTAR:
1360             return NO
1361
1362     elif p.type in {syms.factor, syms.star_expr}:
1363         # unary ops
1364         if not prev:
1365             prevp = preceding_leaf(p)
1366             if not prevp or prevp.type in OPENING_BRACKETS:
1367                 return NO
1368
1369             prevp_parent = prevp.parent
1370             assert prevp_parent is not None
1371             if prevp.type == token.COLON and prevp_parent.type in {
1372                 syms.subscript, syms.sliceop
1373             }:
1374                 return NO
1375
1376             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1377                 return NO
1378
1379         elif t == token.NAME or t == token.NUMBER:
1380             return NO
1381
1382     elif p.type == syms.import_from:
1383         if t == token.DOT:
1384             if prev and prev.type == token.DOT:
1385                 return NO
1386
1387         elif t == token.NAME:
1388             if v == "import":
1389                 return SPACE
1390
1391             if prev and prev.type == token.DOT:
1392                 return NO
1393
1394     elif p.type == syms.sliceop:
1395         return NO
1396
1397     return SPACE
1398
1399
1400 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1401     """Return the first leaf that precedes `node`, if any."""
1402     while node:
1403         res = node.prev_sibling
1404         if res:
1405             if isinstance(res, Leaf):
1406                 return res
1407
1408             try:
1409                 return list(res.leaves())[-1]
1410
1411             except IndexError:
1412                 return None
1413
1414         node = node.parent
1415     return None
1416
1417
1418 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1419     """Return the priority of the `leaf` delimiter, given a line break after it.
1420
1421     The delimiter priorities returned here are from those delimiters that would
1422     cause a line break after themselves.
1423
1424     Higher numbers are higher priority.
1425     """
1426     if leaf.type == token.COMMA:
1427         return COMMA_PRIORITY
1428
1429     if (
1430         leaf.type in VARARGS
1431         and leaf.parent
1432         and leaf.parent.type in {syms.argument, syms.typedargslist}
1433     ):
1434         return MATH_PRIORITY
1435
1436     return 0
1437
1438
1439 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1440     """Return the priority of the `leaf` delimiter, given a line before after it.
1441
1442     The delimiter priorities returned here are from those delimiters that would
1443     cause a line break before themselves.
1444
1445     Higher numbers are higher priority.
1446     """
1447     if (
1448         leaf.type in MATH_OPERATORS
1449         and leaf.parent
1450         and leaf.parent.type not in {syms.factor, syms.star_expr}
1451     ):
1452         return MATH_PRIORITY
1453
1454     if leaf.type in COMPARATORS:
1455         return COMPARATOR_PRIORITY
1456
1457     if (
1458         leaf.type == token.STRING
1459         and previous is not None
1460         and previous.type == token.STRING
1461     ):
1462         return STRING_PRIORITY
1463
1464     if (
1465         leaf.type == token.NAME
1466         and leaf.value == "for"
1467         and leaf.parent
1468         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1469     ):
1470         return COMPREHENSION_PRIORITY
1471
1472     if (
1473         leaf.type == token.NAME
1474         and leaf.value == "if"
1475         and leaf.parent
1476         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1477     ):
1478         return COMPREHENSION_PRIORITY
1479
1480     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1481         return LOGIC_PRIORITY
1482
1483     return 0
1484
1485
1486 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1487     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1488
1489     Higher numbers are higher priority.
1490     """
1491     return max(
1492         is_split_before_delimiter(leaf, previous),
1493         is_split_after_delimiter(leaf, previous),
1494     )
1495
1496
1497 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1498     """Clean the prefix of the `leaf` and generate comments from it, if any.
1499
1500     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1501     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1502     move because it does away with modifying the grammar to include all the
1503     possible places in which comments can be placed.
1504
1505     The sad consequence for us though is that comments don't "belong" anywhere.
1506     This is why this function generates simple parentless Leaf objects for
1507     comments.  We simply don't know what the correct parent should be.
1508
1509     No matter though, we can live without this.  We really only need to
1510     differentiate between inline and standalone comments.  The latter don't
1511     share the line with any code.
1512
1513     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1514     are emitted with a fake STANDALONE_COMMENT token identifier.
1515     """
1516     p = leaf.prefix
1517     if not p:
1518         return
1519
1520     if "#" not in p:
1521         return
1522
1523     consumed = 0
1524     nlines = 0
1525     for index, line in enumerate(p.split("\n")):
1526         consumed += len(line) + 1  # adding the length of the split '\n'
1527         line = line.lstrip()
1528         if not line:
1529             nlines += 1
1530         if not line.startswith("#"):
1531             continue
1532
1533         if index == 0 and leaf.type != token.ENDMARKER:
1534             comment_type = token.COMMENT  # simple trailing comment
1535         else:
1536             comment_type = STANDALONE_COMMENT
1537         comment = make_comment(line)
1538         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1539
1540         if comment in {"# fmt: on", "# yapf: enable"}:
1541             raise FormatOn(consumed)
1542
1543         if comment in {"# fmt: off", "# yapf: disable"}:
1544             if comment_type == STANDALONE_COMMENT:
1545                 raise FormatOff(consumed)
1546
1547             prev = preceding_leaf(leaf)
1548             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1549                 raise FormatOff(consumed)
1550
1551         nlines = 0
1552
1553
1554 def make_comment(content: str) -> str:
1555     """Return a consistently formatted comment from the given `content` string.
1556
1557     All comments (except for "##", "#!", "#:") should have a single space between
1558     the hash sign and the content.
1559
1560     If `content` didn't start with a hash sign, one is provided.
1561     """
1562     content = content.rstrip()
1563     if not content:
1564         return "#"
1565
1566     if content[0] == "#":
1567         content = content[1:]
1568     if content and content[0] not in " !:#":
1569         content = " " + content
1570     return "#" + content
1571
1572
1573 def split_line(
1574     line: Line, line_length: int, inner: bool = False, py36: bool = False
1575 ) -> Iterator[Line]:
1576     """Split a `line` into potentially many lines.
1577
1578     They should fit in the allotted `line_length` but might not be able to.
1579     `inner` signifies that there were a pair of brackets somewhere around the
1580     current `line`, possibly transitively. This means we can fallback to splitting
1581     by delimiters if the LHS/RHS don't yield any results.
1582
1583     If `py36` is True, splitting may generate syntax that is only compatible
1584     with Python 3.6 and later.
1585     """
1586     if isinstance(line, UnformattedLines) or line.is_comment:
1587         yield line
1588         return
1589
1590     line_str = str(line).strip("\n")
1591     if (
1592         len(line_str) <= line_length
1593         and "\n" not in line_str  # multiline strings
1594         and not line.contains_standalone_comments
1595     ):
1596         yield line
1597         return
1598
1599     split_funcs: List[SplitFunc]
1600     if line.is_def:
1601         split_funcs = [left_hand_split]
1602     elif line.inside_brackets:
1603         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1604     else:
1605         split_funcs = [right_hand_split]
1606     for split_func in split_funcs:
1607         # We are accumulating lines in `result` because we might want to abort
1608         # mission and return the original line in the end, or attempt a different
1609         # split altogether.
1610         result: List[Line] = []
1611         try:
1612             for l in split_func(line, py36):
1613                 if str(l).strip("\n") == line_str:
1614                     raise CannotSplit("Split function returned an unchanged result")
1615
1616                 result.extend(
1617                     split_line(l, line_length=line_length, inner=True, py36=py36)
1618                 )
1619         except CannotSplit as cs:
1620             continue
1621
1622         else:
1623             yield from result
1624             break
1625
1626     else:
1627         yield line
1628
1629
1630 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1631     """Split line into many lines, starting with the first matching bracket pair.
1632
1633     Note: this usually looks weird, only use this for function definitions.
1634     Prefer RHS otherwise.
1635     """
1636     head = Line(depth=line.depth)
1637     body = Line(depth=line.depth + 1, inside_brackets=True)
1638     tail = Line(depth=line.depth)
1639     tail_leaves: List[Leaf] = []
1640     body_leaves: List[Leaf] = []
1641     head_leaves: List[Leaf] = []
1642     current_leaves = head_leaves
1643     matching_bracket = None
1644     for leaf in line.leaves:
1645         if (
1646             current_leaves is body_leaves
1647             and leaf.type in CLOSING_BRACKETS
1648             and leaf.opening_bracket is matching_bracket
1649         ):
1650             current_leaves = tail_leaves if body_leaves else head_leaves
1651         current_leaves.append(leaf)
1652         if current_leaves is head_leaves:
1653             if leaf.type in OPENING_BRACKETS:
1654                 matching_bracket = leaf
1655                 current_leaves = body_leaves
1656     # Since body is a new indent level, remove spurious leading whitespace.
1657     if body_leaves:
1658         normalize_prefix(body_leaves[0], inside_brackets=True)
1659     # Build the new lines.
1660     for result, leaves in (
1661         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1662     ):
1663         for leaf in leaves:
1664             result.append(leaf, preformatted=True)
1665             for comment_after in line.comments_after(leaf):
1666                 result.append(comment_after, preformatted=True)
1667     bracket_split_succeeded_or_raise(head, body, tail)
1668     for result in (head, body, tail):
1669         if result:
1670             yield result
1671
1672
1673 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1674     """Split line into many lines, starting with the last matching bracket pair."""
1675     head = Line(depth=line.depth)
1676     body = Line(depth=line.depth + 1, inside_brackets=True)
1677     tail = Line(depth=line.depth)
1678     tail_leaves: List[Leaf] = []
1679     body_leaves: List[Leaf] = []
1680     head_leaves: List[Leaf] = []
1681     current_leaves = tail_leaves
1682     opening_bracket = None
1683     for leaf in reversed(line.leaves):
1684         if current_leaves is body_leaves:
1685             if leaf is opening_bracket:
1686                 current_leaves = head_leaves if body_leaves else tail_leaves
1687         current_leaves.append(leaf)
1688         if current_leaves is tail_leaves:
1689             if leaf.type in CLOSING_BRACKETS:
1690                 opening_bracket = leaf.opening_bracket
1691                 current_leaves = body_leaves
1692     tail_leaves.reverse()
1693     body_leaves.reverse()
1694     head_leaves.reverse()
1695     # Since body is a new indent level, remove spurious leading whitespace.
1696     if body_leaves:
1697         normalize_prefix(body_leaves[0], inside_brackets=True)
1698     # Build the new lines.
1699     for result, leaves in (
1700         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1701     ):
1702         for leaf in leaves:
1703             result.append(leaf, preformatted=True)
1704             for comment_after in line.comments_after(leaf):
1705                 result.append(comment_after, preformatted=True)
1706     bracket_split_succeeded_or_raise(head, body, tail)
1707     for result in (head, body, tail):
1708         if result:
1709             yield result
1710
1711
1712 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1713     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1714
1715     Do nothing otherwise.
1716
1717     A left- or right-hand split is based on a pair of brackets. Content before
1718     (and including) the opening bracket is left on one line, content inside the
1719     brackets is put on a separate line, and finally content starting with and
1720     following the closing bracket is put on a separate line.
1721
1722     Those are called `head`, `body`, and `tail`, respectively. If the split
1723     produced the same line (all content in `head`) or ended up with an empty `body`
1724     and the `tail` is just the closing bracket, then it's considered failed.
1725     """
1726     tail_len = len(str(tail).strip())
1727     if not body:
1728         if tail_len == 0:
1729             raise CannotSplit("Splitting brackets produced the same line")
1730
1731         elif tail_len < 3:
1732             raise CannotSplit(
1733                 f"Splitting brackets on an empty body to save "
1734                 f"{tail_len} characters is not worth it"
1735             )
1736
1737
1738 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1739     """Normalize prefix of the first leaf in every line returned by `split_func`.
1740
1741     This is a decorator over relevant split functions.
1742     """
1743
1744     @wraps(split_func)
1745     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1746         for l in split_func(line, py36):
1747             normalize_prefix(l.leaves[0], inside_brackets=True)
1748             yield l
1749
1750     return split_wrapper
1751
1752
1753 @dont_increase_indentation
1754 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1755     """Split according to delimiters of the highest priority.
1756
1757     If `py36` is True, the split will add trailing commas also in function
1758     signatures that contain `*` and `**`.
1759     """
1760     try:
1761         last_leaf = line.leaves[-1]
1762     except IndexError:
1763         raise CannotSplit("Line empty")
1764
1765     delimiters = line.bracket_tracker.delimiters
1766     try:
1767         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1768             exclude={id(last_leaf)}
1769         )
1770     except ValueError:
1771         raise CannotSplit("No delimiters found")
1772
1773     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1774     lowest_depth = sys.maxsize
1775     trailing_comma_safe = True
1776
1777     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1778         """Append `leaf` to current line or to new line if appending impossible."""
1779         nonlocal current_line
1780         try:
1781             current_line.append_safe(leaf, preformatted=True)
1782         except ValueError as ve:
1783             yield current_line
1784
1785             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1786             current_line.append(leaf)
1787
1788     for leaf in line.leaves:
1789         yield from append_to_line(leaf)
1790
1791         for comment_after in line.comments_after(leaf):
1792             yield from append_to_line(comment_after)
1793
1794         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1795         if (
1796             leaf.bracket_depth == lowest_depth
1797             and leaf.type == token.STAR
1798             or leaf.type == token.DOUBLESTAR
1799         ):
1800             trailing_comma_safe = trailing_comma_safe and py36
1801         leaf_priority = delimiters.get(id(leaf))
1802         if leaf_priority == delimiter_priority:
1803             yield current_line
1804
1805             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1806     if current_line:
1807         if (
1808             trailing_comma_safe
1809             and delimiter_priority == COMMA_PRIORITY
1810             and current_line.leaves[-1].type != token.COMMA
1811             and current_line.leaves[-1].type != STANDALONE_COMMENT
1812         ):
1813             current_line.append(Leaf(token.COMMA, ","))
1814         yield current_line
1815
1816
1817 @dont_increase_indentation
1818 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1819     """Split standalone comments from the rest of the line."""
1820     for leaf in line.leaves:
1821         if leaf.type == STANDALONE_COMMENT:
1822             if leaf.bracket_depth == 0:
1823                 break
1824
1825     else:
1826         raise CannotSplit("Line does not have any standalone comments")
1827
1828     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1829
1830     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1831         """Append `leaf` to current line or to new line if appending impossible."""
1832         nonlocal current_line
1833         try:
1834             current_line.append_safe(leaf, preformatted=True)
1835         except ValueError as ve:
1836             yield current_line
1837
1838             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1839             current_line.append(leaf)
1840
1841     for leaf in line.leaves:
1842         yield from append_to_line(leaf)
1843
1844         for comment_after in line.comments_after(leaf):
1845             yield from append_to_line(comment_after)
1846
1847     if current_line:
1848         yield current_line
1849
1850
1851 def is_import(leaf: Leaf) -> bool:
1852     """Return True if the given leaf starts an import statement."""
1853     p = leaf.parent
1854     t = leaf.type
1855     v = leaf.value
1856     return bool(
1857         t == token.NAME
1858         and (
1859             (v == "import" and p and p.type == syms.import_name)
1860             or (v == "from" and p and p.type == syms.import_from)
1861         )
1862     )
1863
1864
1865 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1866     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1867     else.
1868
1869     Note: don't use backslashes for formatting or you'll lose your voting rights.
1870     """
1871     if not inside_brackets:
1872         spl = leaf.prefix.split("#")
1873         if "\\" not in spl[0]:
1874             nl_count = spl[-1].count("\n")
1875             if len(spl) > 1:
1876                 nl_count -= 1
1877             leaf.prefix = "\n" * nl_count
1878             return
1879
1880     leaf.prefix = ""
1881
1882
1883 def normalize_string_quotes(leaf: Leaf) -> None:
1884     """Prefer double quotes but only if it doesn't cause more escaping.
1885
1886     Adds or removes backslashes as appropriate. Doesn't parse and fix
1887     strings nested in f-strings (yet).
1888
1889     Note: Mutates its argument.
1890     """
1891     value = leaf.value.lstrip("furbFURB")
1892     if value[:3] == '"""':
1893         return
1894
1895     elif value[:3] == "'''":
1896         orig_quote = "'''"
1897         new_quote = '"""'
1898     elif value[0] == '"':
1899         orig_quote = '"'
1900         new_quote = "'"
1901     else:
1902         orig_quote = "'"
1903         new_quote = '"'
1904     first_quote_pos = leaf.value.find(orig_quote)
1905     if first_quote_pos == -1:
1906         return  # There's an internal error
1907
1908     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1909     new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
1910         new_quote, f"\\{new_quote}"
1911     )
1912     if new_quote == '"""' and new_body[-1] == '"':
1913         # edge case:
1914         new_body = new_body[:-1] + '\\"'
1915     orig_escape_count = body.count("\\")
1916     new_escape_count = new_body.count("\\")
1917     if new_escape_count > orig_escape_count:
1918         return  # Do not introduce more escaping
1919
1920     if new_escape_count == orig_escape_count and orig_quote == '"':
1921         return  # Prefer double quotes
1922
1923     prefix = leaf.value[:first_quote_pos]
1924     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
1925
1926
1927 def is_python36(node: Node) -> bool:
1928     """Return True if the current file is using Python 3.6+ features.
1929
1930     Currently looking for:
1931     - f-strings; and
1932     - trailing commas after * or ** in function signatures.
1933     """
1934     for n in node.pre_order():
1935         if n.type == token.STRING:
1936             value_head = n.value[:2]  # type: ignore
1937             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1938                 return True
1939
1940         elif (
1941             n.type == syms.typedargslist
1942             and n.children
1943             and n.children[-1].type == token.COMMA
1944         ):
1945             for ch in n.children:
1946                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1947                     return True
1948
1949     return False
1950
1951
1952 PYTHON_EXTENSIONS = {".py"}
1953 BLACKLISTED_DIRECTORIES = {
1954     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
1955 }
1956
1957
1958 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1959     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1960     and have one of the PYTHON_EXTENSIONS.
1961     """
1962     for child in path.iterdir():
1963         if child.is_dir():
1964             if child.name in BLACKLISTED_DIRECTORIES:
1965                 continue
1966
1967             yield from gen_python_files_in_dir(child)
1968
1969         elif child.suffix in PYTHON_EXTENSIONS:
1970             yield child
1971
1972
1973 @dataclass
1974 class Report:
1975     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1976     check: bool = False
1977     change_count: int = 0
1978     same_count: int = 0
1979     failure_count: int = 0
1980
1981     def done(self, src: Path, changed: bool) -> None:
1982         """Increment the counter for successful reformatting. Write out a message."""
1983         if changed:
1984             reformatted = "would reformat" if self.check else "reformatted"
1985             out(f"{reformatted} {src}")
1986             self.change_count += 1
1987         else:
1988             out(f"{src} already well formatted, good job.", bold=False)
1989             self.same_count += 1
1990
1991     def failed(self, src: Path, message: str) -> None:
1992         """Increment the counter for failed reformatting. Write out a message."""
1993         err(f"error: cannot format {src}: {message}")
1994         self.failure_count += 1
1995
1996     @property
1997     def return_code(self) -> int:
1998         """Return the exit code that the app should use.
1999
2000         This considers the current state of changed files and failures:
2001         - if there were any failures, return 123;
2002         - if any files were changed and --check is being used, return 1;
2003         - otherwise return 0.
2004         """
2005         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2006         # 126 we have special returncodes reserved by the shell.
2007         if self.failure_count:
2008             return 123
2009
2010         elif self.change_count and self.check:
2011             return 1
2012
2013         return 0
2014
2015     def __str__(self) -> str:
2016         """Render a color report of the current state.
2017
2018         Use `click.unstyle` to remove colors.
2019         """
2020         if self.check:
2021             reformatted = "would be reformatted"
2022             unchanged = "would be left unchanged"
2023             failed = "would fail to reformat"
2024         else:
2025             reformatted = "reformatted"
2026             unchanged = "left unchanged"
2027             failed = "failed to reformat"
2028         report = []
2029         if self.change_count:
2030             s = "s" if self.change_count > 1 else ""
2031             report.append(
2032                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2033             )
2034         if self.same_count:
2035             s = "s" if self.same_count > 1 else ""
2036             report.append(f"{self.same_count} file{s} {unchanged}")
2037         if self.failure_count:
2038             s = "s" if self.failure_count > 1 else ""
2039             report.append(
2040                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2041             )
2042         return ", ".join(report) + "."
2043
2044
2045 def assert_equivalent(src: str, dst: str) -> None:
2046     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2047
2048     import ast
2049     import traceback
2050
2051     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2052         """Simple visitor generating strings to compare ASTs by content."""
2053         yield f"{'  ' * depth}{node.__class__.__name__}("
2054
2055         for field in sorted(node._fields):
2056             try:
2057                 value = getattr(node, field)
2058             except AttributeError:
2059                 continue
2060
2061             yield f"{'  ' * (depth+1)}{field}="
2062
2063             if isinstance(value, list):
2064                 for item in value:
2065                     if isinstance(item, ast.AST):
2066                         yield from _v(item, depth + 2)
2067
2068             elif isinstance(value, ast.AST):
2069                 yield from _v(value, depth + 2)
2070
2071             else:
2072                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2073
2074         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2075
2076     try:
2077         src_ast = ast.parse(src)
2078     except Exception as exc:
2079         major, minor = sys.version_info[:2]
2080         raise AssertionError(
2081             f"cannot use --safe with this file; failed to parse source file "
2082             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2083             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2084         )
2085
2086     try:
2087         dst_ast = ast.parse(dst)
2088     except Exception as exc:
2089         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2090         raise AssertionError(
2091             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2092             f"Please report a bug on https://github.com/ambv/black/issues.  "
2093             f"This invalid output might be helpful: {log}"
2094         ) from None
2095
2096     src_ast_str = "\n".join(_v(src_ast))
2097     dst_ast_str = "\n".join(_v(dst_ast))
2098     if src_ast_str != dst_ast_str:
2099         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2100         raise AssertionError(
2101             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2102             f"the source.  "
2103             f"Please report a bug on https://github.com/ambv/black/issues.  "
2104             f"This diff might be helpful: {log}"
2105         ) from None
2106
2107
2108 def assert_stable(src: str, dst: str, line_length: int) -> None:
2109     """Raise AssertionError if `dst` reformats differently the second time."""
2110     newdst = format_str(dst, line_length=line_length)
2111     if dst != newdst:
2112         log = dump_to_file(
2113             diff(src, dst, "source", "first pass"),
2114             diff(dst, newdst, "first pass", "second pass"),
2115         )
2116         raise AssertionError(
2117             f"INTERNAL ERROR: Black produced different code on the second pass "
2118             f"of the formatter.  "
2119             f"Please report a bug on https://github.com/ambv/black/issues.  "
2120             f"This diff might be helpful: {log}"
2121         ) from None
2122
2123
2124 def dump_to_file(*output: str) -> str:
2125     """Dump `output` to a temporary file. Return path to the file."""
2126     import tempfile
2127
2128     with tempfile.NamedTemporaryFile(
2129         mode="w", prefix="blk_", suffix=".log", delete=False
2130     ) as f:
2131         for lines in output:
2132             f.write(lines)
2133             if lines and lines[-1] != "\n":
2134                 f.write("\n")
2135     return f.name
2136
2137
2138 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2139     """Return a unified diff string between strings `a` and `b`."""
2140     import difflib
2141
2142     a_lines = [line + "\n" for line in a.split("\n")]
2143     b_lines = [line + "\n" for line in b.split("\n")]
2144     return "".join(
2145         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2146     )
2147
2148
2149 def cancel(tasks: List[asyncio.Task]) -> None:
2150     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2151     err("Aborted!")
2152     for task in tasks:
2153         task.cancel()
2154
2155
2156 def shutdown(loop: BaseEventLoop) -> None:
2157     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2158     try:
2159         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2160         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2161         if not to_cancel:
2162             return
2163
2164         for task in to_cancel:
2165             task.cancel()
2166         loop.run_until_complete(
2167             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2168         )
2169     finally:
2170         # `concurrent.futures.Future` objects cannot be cancelled once they
2171         # are already running. There might be some when the `shutdown()` happened.
2172         # Silence their logger's spew about the event loop being closed.
2173         cf_logger = logging.getLogger("concurrent.futures")
2174         cf_logger.setLevel(logging.CRITICAL)
2175         loop.close()
2176
2177
2178 if __name__ == "__main__":
2179     main()