]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

57cd9a6b9da63eedabccd18d9f624108b2b3e3ee
[etc/vim.git] / black.py
1 import asyncio
2 import pickle
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from enum import Enum
6 from functools import partial, wraps
7 import keyword
8 import logging
9 from multiprocessing import Manager
10 import os
11 from pathlib import Path
12 import re
13 import tokenize
14 import signal
15 import sys
16 from typing import (
17     Any,
18     Callable,
19     Collection,
20     Dict,
21     Generic,
22     Iterable,
23     Iterator,
24     List,
25     Optional,
26     Pattern,
27     Sequence,
28     Set,
29     Tuple,
30     Type,
31     TypeVar,
32     Union,
33     cast,
34 )
35
36 from appdirs import user_cache_dir
37 from attr import dataclass, Factory
38 import click
39
40 # lib2to3 fork
41 from blib2to3.pytree import Node, Leaf, type_repr
42 from blib2to3 import pygram, pytree
43 from blib2to3.pgen2 import driver, token
44 from blib2to3.pgen2.parse import ParseError
45
46
47 __version__ = "18.5b0"
48 DEFAULT_LINE_LENGTH = 88
49
50 # types
51 syms = pygram.python_symbols
52 FileContent = str
53 Encoding = str
54 Depth = int
55 NodeType = int
56 LeafID = int
57 Priority = int
58 Index = int
59 LN = Union[Leaf, Node]
60 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
61 Timestamp = float
62 FileSize = int
63 CacheInfo = Tuple[Timestamp, FileSize]
64 Cache = Dict[Path, CacheInfo]
65 out = partial(click.secho, bold=True, err=True)
66 err = partial(click.secho, fg="red", err=True)
67
68
69 class NothingChanged(UserWarning):
70     """Raised by :func:`format_file` when reformatted code is the same as source."""
71
72
73 class CannotSplit(Exception):
74     """A readable split that fits the allotted line length is impossible.
75
76     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
77     :func:`delimiter_split`.
78     """
79
80
81 class FormatError(Exception):
82     """Base exception for `# fmt: on` and `# fmt: off` handling.
83
84     It holds the number of bytes of the prefix consumed before the format
85     control comment appeared.
86     """
87
88     def __init__(self, consumed: int) -> None:
89         super().__init__(consumed)
90         self.consumed = consumed
91
92     def trim_prefix(self, leaf: Leaf) -> None:
93         leaf.prefix = leaf.prefix[self.consumed :]
94
95     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
96         """Returns a new Leaf from the consumed part of the prefix."""
97         unformatted_prefix = leaf.prefix[: self.consumed]
98         return Leaf(token.NEWLINE, unformatted_prefix)
99
100
101 class FormatOn(FormatError):
102     """Found a comment like `# fmt: on` in the file."""
103
104
105 class FormatOff(FormatError):
106     """Found a comment like `# fmt: off` in the file."""
107
108
109 class WriteBack(Enum):
110     NO = 0
111     YES = 1
112     DIFF = 2
113
114
115 class Changed(Enum):
116     NO = 0
117     CACHED = 1
118     YES = 2
119
120
121 @click.command()
122 @click.option(
123     "-l",
124     "--line-length",
125     type=int,
126     default=DEFAULT_LINE_LENGTH,
127     help="How many character per line to allow.",
128     show_default=True,
129 )
130 @click.option(
131     "--check",
132     is_flag=True,
133     help=(
134         "Don't write the files back, just return the status.  Return code 0 "
135         "means nothing would change.  Return code 1 means some files would be "
136         "reformatted.  Return code 123 means there was an internal error."
137     ),
138 )
139 @click.option(
140     "--diff",
141     is_flag=True,
142     help="Don't write the files back, just output a diff for each file on stdout.",
143 )
144 @click.option(
145     "--fast/--safe",
146     is_flag=True,
147     help="If --fast given, skip temporary sanity checks. [default: --safe]",
148 )
149 @click.option(
150     "-q",
151     "--quiet",
152     is_flag=True,
153     help=(
154         "Don't emit non-error messages to stderr. Errors are still emitted, "
155         "silence those with 2>/dev/null."
156     ),
157 )
158 @click.version_option(version=__version__)
159 @click.argument(
160     "src",
161     nargs=-1,
162     type=click.Path(
163         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
164     ),
165 )
166 @click.pass_context
167 def main(
168     ctx: click.Context,
169     line_length: int,
170     check: bool,
171     diff: bool,
172     fast: bool,
173     quiet: bool,
174     src: List[str],
175 ) -> None:
176     """The uncompromising code formatter."""
177     sources: List[Path] = []
178     for s in src:
179         p = Path(s)
180         if p.is_dir():
181             sources.extend(gen_python_files_in_dir(p))
182         elif p.is_file():
183             # if a file was explicitly given, we don't care about its extension
184             sources.append(p)
185         elif s == "-":
186             sources.append(Path("-"))
187         else:
188             err(f"invalid path: {s}")
189
190     if check and not diff:
191         write_back = WriteBack.NO
192     elif diff:
193         write_back = WriteBack.DIFF
194     else:
195         write_back = WriteBack.YES
196     report = Report(check=check, quiet=quiet)
197     if len(sources) == 0:
198         out("No paths given. Nothing to do 😴")
199         ctx.exit(0)
200         return
201
202     elif len(sources) == 1:
203         reformat_one(sources[0], line_length, fast, write_back, report)
204     else:
205         loop = asyncio.get_event_loop()
206         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207         try:
208             loop.run_until_complete(
209                 schedule_formatting(
210                     sources, line_length, fast, write_back, report, loop, executor
211                 )
212             )
213         finally:
214             shutdown(loop)
215         if not quiet:
216             out("All done! ✨ 🍰 ✨")
217             click.echo(str(report))
218     ctx.exit(report.return_code)
219
220
221 def reformat_one(
222     src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
223 ) -> None:
224     """Reformat a single file under `src` without spawning child processes.
225
226     If `quiet` is True, non-error messages are not output. `line_length`,
227     `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
228     """
229     try:
230         changed = Changed.NO
231         if not src.is_file() and str(src) == "-":
232             if format_stdin_to_stdout(
233                 line_length=line_length, fast=fast, write_back=write_back
234             ):
235                 changed = Changed.YES
236         else:
237             cache: Cache = {}
238             if write_back != WriteBack.DIFF:
239                 cache = read_cache(line_length)
240                 src = src.resolve()
241                 if src in cache and cache[src] == get_cache_info(src):
242                     changed = Changed.CACHED
243             if changed is not Changed.CACHED and format_file_in_place(
244                 src, line_length=line_length, fast=fast, write_back=write_back
245             ):
246                 changed = Changed.YES
247             if write_back == WriteBack.YES and changed is not Changed.NO:
248                 write_cache(cache, [src], line_length)
249         report.done(src, changed)
250     except Exception as exc:
251         report.failed(src, str(exc))
252
253
254 async def schedule_formatting(
255     sources: List[Path],
256     line_length: int,
257     fast: bool,
258     write_back: WriteBack,
259     report: "Report",
260     loop: BaseEventLoop,
261     executor: Executor,
262 ) -> None:
263     """Run formatting of `sources` in parallel using the provided `executor`.
264
265     (Use ProcessPoolExecutors for actual parallelism.)
266
267     `line_length`, `write_back`, and `fast` options are passed to
268     :func:`format_file_in_place`.
269     """
270     cache: Cache = {}
271     if write_back != WriteBack.DIFF:
272         cache = read_cache(line_length)
273         sources, cached = filter_cached(cache, sources)
274         for src in cached:
275             report.done(src, Changed.CACHED)
276     cancelled = []
277     formatted = []
278     if sources:
279         lock = None
280         if write_back == WriteBack.DIFF:
281             # For diff output, we need locks to ensure we don't interleave output
282             # from different processes.
283             manager = Manager()
284             lock = manager.Lock()
285         tasks = {
286             loop.run_in_executor(
287                 executor, format_file_in_place, src, line_length, fast, write_back, lock
288             ): src
289             for src in sorted(sources)
290         }
291         pending: Iterable[asyncio.Task] = tasks.keys()
292         try:
293             loop.add_signal_handler(signal.SIGINT, cancel, pending)
294             loop.add_signal_handler(signal.SIGTERM, cancel, pending)
295         except NotImplementedError:
296             # There are no good alternatives for these on Windows
297             pass
298         while pending:
299             done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
300             for task in done:
301                 src = tasks.pop(task)
302                 if task.cancelled():
303                     cancelled.append(task)
304                 elif task.exception():
305                     report.failed(src, str(task.exception()))
306                 else:
307                     formatted.append(src)
308                     report.done(src, Changed.YES if task.result() else Changed.NO)
309     if cancelled:
310         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
311     if write_back == WriteBack.YES and formatted:
312         write_cache(cache, formatted, line_length)
313
314
315 def format_file_in_place(
316     src: Path,
317     line_length: int,
318     fast: bool,
319     write_back: WriteBack = WriteBack.NO,
320     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
321 ) -> bool:
322     """Format file under `src` path. Return True if changed.
323
324     If `write_back` is True, write reformatted code back to stdout.
325     `line_length` and `fast` options are passed to :func:`format_file_contents`.
326     """
327     is_pyi = src.suffix == ".pyi"
328
329     with tokenize.open(src) as src_buffer:
330         src_contents = src_buffer.read()
331     try:
332         dst_contents = format_file_contents(
333             src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
334         )
335     except NothingChanged:
336         return False
337
338     if write_back == write_back.YES:
339         with open(src, "w", encoding=src_buffer.encoding) as f:
340             f.write(dst_contents)
341     elif write_back == write_back.DIFF:
342         src_name = f"{src}  (original)"
343         dst_name = f"{src}  (formatted)"
344         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
345         if lock:
346             lock.acquire()
347         try:
348             sys.stdout.write(diff_contents)
349         finally:
350             if lock:
351                 lock.release()
352     return True
353
354
355 def format_stdin_to_stdout(
356     line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
357 ) -> bool:
358     """Format file on stdin. Return True if changed.
359
360     If `write_back` is True, write reformatted code back to stdout.
361     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
362     """
363     src = sys.stdin.read()
364     dst = src
365     try:
366         dst = format_file_contents(src, line_length=line_length, fast=fast)
367         return True
368
369     except NothingChanged:
370         return False
371
372     finally:
373         if write_back == WriteBack.YES:
374             sys.stdout.write(dst)
375         elif write_back == WriteBack.DIFF:
376             src_name = "<stdin>  (original)"
377             dst_name = "<stdin>  (formatted)"
378             sys.stdout.write(diff(src, dst, src_name, dst_name))
379
380
381 def format_file_contents(
382     src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
383 ) -> FileContent:
384     """Reformat contents a file and return new contents.
385
386     If `fast` is False, additionally confirm that the reformatted code is
387     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
388     `line_length` is passed to :func:`format_str`.
389     """
390     if src_contents.strip() == "":
391         raise NothingChanged
392
393     dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
394     if src_contents == dst_contents:
395         raise NothingChanged
396
397     if not fast:
398         assert_equivalent(src_contents, dst_contents)
399         assert_stable(
400             src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
401         )
402     return dst_contents
403
404
405 def format_str(
406     src_contents: str, line_length: int, *, is_pyi: bool = False
407 ) -> FileContent:
408     """Reformat a string and return new contents.
409
410     `line_length` determines how many characters per line are allowed.
411     """
412     src_node = lib2to3_parse(src_contents)
413     dst_contents = ""
414     future_imports = get_future_imports(src_node)
415     elt = EmptyLineTracker(is_pyi=is_pyi)
416     py36 = is_python36(src_node)
417     lines = LineGenerator(
418         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
419     )
420     empty_line = Line()
421     after = 0
422     for current_line in lines.visit(src_node):
423         for _ in range(after):
424             dst_contents += str(empty_line)
425         before, after = elt.maybe_empty_lines(current_line)
426         for _ in range(before):
427             dst_contents += str(empty_line)
428         for line in split_line(current_line, line_length=line_length, py36=py36):
429             dst_contents += str(line)
430     return dst_contents
431
432
433 GRAMMARS = [
434     pygram.python_grammar_no_print_statement_no_exec_statement,
435     pygram.python_grammar_no_print_statement,
436     pygram.python_grammar,
437 ]
438
439
440 def lib2to3_parse(src_txt: str) -> Node:
441     """Given a string with source, return the lib2to3 Node."""
442     grammar = pygram.python_grammar_no_print_statement
443     if src_txt[-1] != "\n":
444         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
445         src_txt += nl
446     for grammar in GRAMMARS:
447         drv = driver.Driver(grammar, pytree.convert)
448         try:
449             result = drv.parse_string(src_txt, True)
450             break
451
452         except ParseError as pe:
453             lineno, column = pe.context[1]
454             lines = src_txt.splitlines()
455             try:
456                 faulty_line = lines[lineno - 1]
457             except IndexError:
458                 faulty_line = "<line number missing in source>"
459             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
460     else:
461         raise exc from None
462
463     if isinstance(result, Leaf):
464         result = Node(syms.file_input, [result])
465     return result
466
467
468 def lib2to3_unparse(node: Node) -> str:
469     """Given a lib2to3 node, return its string representation."""
470     code = str(node)
471     return code
472
473
474 T = TypeVar("T")
475
476
477 class Visitor(Generic[T]):
478     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
479
480     def visit(self, node: LN) -> Iterator[T]:
481         """Main method to visit `node` and its children.
482
483         It tries to find a `visit_*()` method for the given `node.type`, like
484         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
485         If no dedicated `visit_*()` method is found, chooses `visit_default()`
486         instead.
487
488         Then yields objects of type `T` from the selected visitor.
489         """
490         if node.type < 256:
491             name = token.tok_name[node.type]
492         else:
493             name = type_repr(node.type)
494         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
495
496     def visit_default(self, node: LN) -> Iterator[T]:
497         """Default `visit_*()` implementation. Recurses to children of `node`."""
498         if isinstance(node, Node):
499             for child in node.children:
500                 yield from self.visit(child)
501
502
503 @dataclass
504 class DebugVisitor(Visitor[T]):
505     tree_depth: int = 0
506
507     def visit_default(self, node: LN) -> Iterator[T]:
508         indent = " " * (2 * self.tree_depth)
509         if isinstance(node, Node):
510             _type = type_repr(node.type)
511             out(f"{indent}{_type}", fg="yellow")
512             self.tree_depth += 1
513             for child in node.children:
514                 yield from self.visit(child)
515
516             self.tree_depth -= 1
517             out(f"{indent}/{_type}", fg="yellow", bold=False)
518         else:
519             _type = token.tok_name.get(node.type, str(node.type))
520             out(f"{indent}{_type}", fg="blue", nl=False)
521             if node.prefix:
522                 # We don't have to handle prefixes for `Node` objects since
523                 # that delegates to the first child anyway.
524                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
525             out(f" {node.value!r}", fg="blue", bold=False)
526
527     @classmethod
528     def show(cls, code: str) -> None:
529         """Pretty-print the lib2to3 AST of a given string of `code`.
530
531         Convenience method for debugging.
532         """
533         v: DebugVisitor[None] = DebugVisitor()
534         list(v.visit(lib2to3_parse(code)))
535
536
537 KEYWORDS = set(keyword.kwlist)
538 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
539 FLOW_CONTROL = {"return", "raise", "break", "continue"}
540 STATEMENT = {
541     syms.if_stmt,
542     syms.while_stmt,
543     syms.for_stmt,
544     syms.try_stmt,
545     syms.except_clause,
546     syms.with_stmt,
547     syms.funcdef,
548     syms.classdef,
549 }
550 STANDALONE_COMMENT = 153
551 LOGIC_OPERATORS = {"and", "or"}
552 COMPARATORS = {
553     token.LESS,
554     token.GREATER,
555     token.EQEQUAL,
556     token.NOTEQUAL,
557     token.LESSEQUAL,
558     token.GREATEREQUAL,
559 }
560 MATH_OPERATORS = {
561     token.VBAR,
562     token.CIRCUMFLEX,
563     token.AMPER,
564     token.LEFTSHIFT,
565     token.RIGHTSHIFT,
566     token.PLUS,
567     token.MINUS,
568     token.STAR,
569     token.SLASH,
570     token.DOUBLESLASH,
571     token.PERCENT,
572     token.AT,
573     token.TILDE,
574     token.DOUBLESTAR,
575 }
576 STARS = {token.STAR, token.DOUBLESTAR}
577 VARARGS_PARENTS = {
578     syms.arglist,
579     syms.argument,  # double star in arglist
580     syms.trailer,  # single argument to call
581     syms.typedargslist,
582     syms.varargslist,  # lambdas
583 }
584 UNPACKING_PARENTS = {
585     syms.atom,  # single element of a list or set literal
586     syms.dictsetmaker,
587     syms.listmaker,
588     syms.testlist_gexp,
589 }
590 TEST_DESCENDANTS = {
591     syms.test,
592     syms.lambdef,
593     syms.or_test,
594     syms.and_test,
595     syms.not_test,
596     syms.comparison,
597     syms.star_expr,
598     syms.expr,
599     syms.xor_expr,
600     syms.and_expr,
601     syms.shift_expr,
602     syms.arith_expr,
603     syms.trailer,
604     syms.term,
605     syms.power,
606 }
607 ASSIGNMENTS = {
608     "=",
609     "+=",
610     "-=",
611     "*=",
612     "@=",
613     "/=",
614     "%=",
615     "&=",
616     "|=",
617     "^=",
618     "<<=",
619     ">>=",
620     "**=",
621     "//=",
622 }
623 COMPREHENSION_PRIORITY = 20
624 COMMA_PRIORITY = 18
625 TERNARY_PRIORITY = 16
626 LOGIC_PRIORITY = 14
627 STRING_PRIORITY = 12
628 COMPARATOR_PRIORITY = 10
629 MATH_PRIORITIES = {
630     token.VBAR: 9,
631     token.CIRCUMFLEX: 8,
632     token.AMPER: 7,
633     token.LEFTSHIFT: 6,
634     token.RIGHTSHIFT: 6,
635     token.PLUS: 5,
636     token.MINUS: 5,
637     token.STAR: 4,
638     token.SLASH: 4,
639     token.DOUBLESLASH: 4,
640     token.PERCENT: 4,
641     token.AT: 4,
642     token.TILDE: 3,
643     token.DOUBLESTAR: 2,
644 }
645 DOT_PRIORITY = 1
646
647
648 @dataclass
649 class BracketTracker:
650     """Keeps track of brackets on a line."""
651
652     depth: int = 0
653     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
654     delimiters: Dict[LeafID, Priority] = Factory(dict)
655     previous: Optional[Leaf] = None
656     _for_loop_variable: int = 0
657     _lambda_arguments: int = 0
658
659     def mark(self, leaf: Leaf) -> None:
660         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
661
662         All leaves receive an int `bracket_depth` field that stores how deep
663         within brackets a given leaf is. 0 means there are no enclosing brackets
664         that started on this line.
665
666         If a leaf is itself a closing bracket, it receives an `opening_bracket`
667         field that it forms a pair with. This is a one-directional link to
668         avoid reference cycles.
669
670         If a leaf is a delimiter (a token on which Black can split the line if
671         needed) and it's on depth 0, its `id()` is stored in the tracker's
672         `delimiters` field.
673         """
674         if leaf.type == token.COMMENT:
675             return
676
677         self.maybe_decrement_after_for_loop_variable(leaf)
678         self.maybe_decrement_after_lambda_arguments(leaf)
679         if leaf.type in CLOSING_BRACKETS:
680             self.depth -= 1
681             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
682             leaf.opening_bracket = opening_bracket
683         leaf.bracket_depth = self.depth
684         if self.depth == 0:
685             delim = is_split_before_delimiter(leaf, self.previous)
686             if delim and self.previous is not None:
687                 self.delimiters[id(self.previous)] = delim
688             else:
689                 delim = is_split_after_delimiter(leaf, self.previous)
690                 if delim:
691                     self.delimiters[id(leaf)] = delim
692         if leaf.type in OPENING_BRACKETS:
693             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
694             self.depth += 1
695         self.previous = leaf
696         self.maybe_increment_lambda_arguments(leaf)
697         self.maybe_increment_for_loop_variable(leaf)
698
699     def any_open_brackets(self) -> bool:
700         """Return True if there is an yet unmatched open bracket on the line."""
701         return bool(self.bracket_match)
702
703     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
704         """Return the highest priority of a delimiter found on the line.
705
706         Values are consistent with what `is_split_*_delimiter()` return.
707         Raises ValueError on no delimiters.
708         """
709         return max(v for k, v in self.delimiters.items() if k not in exclude)
710
711     def delimiter_count_with_priority(self, priority: int = 0) -> int:
712         """Return the number of delimiters with the given `priority`.
713
714         If no `priority` is passed, defaults to max priority on the line.
715         """
716         if not self.delimiters:
717             return 0
718
719         priority = priority or self.max_delimiter_priority()
720         return sum(1 for p in self.delimiters.values() if p == priority)
721
722     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
723         """In a for loop, or comprehension, the variables are often unpacks.
724
725         To avoid splitting on the comma in this situation, increase the depth of
726         tokens between `for` and `in`.
727         """
728         if leaf.type == token.NAME and leaf.value == "for":
729             self.depth += 1
730             self._for_loop_variable += 1
731             return True
732
733         return False
734
735     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
736         """See `maybe_increment_for_loop_variable` above for explanation."""
737         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
738             self.depth -= 1
739             self._for_loop_variable -= 1
740             return True
741
742         return False
743
744     def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
745         """In a lambda expression, there might be more than one argument.
746
747         To avoid splitting on the comma in this situation, increase the depth of
748         tokens between `lambda` and `:`.
749         """
750         if leaf.type == token.NAME and leaf.value == "lambda":
751             self.depth += 1
752             self._lambda_arguments += 1
753             return True
754
755         return False
756
757     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
758         """See `maybe_increment_lambda_arguments` above for explanation."""
759         if self._lambda_arguments and leaf.type == token.COLON:
760             self.depth -= 1
761             self._lambda_arguments -= 1
762             return True
763
764         return False
765
766     def get_open_lsqb(self) -> Optional[Leaf]:
767         """Return the most recent opening square bracket (if any)."""
768         return self.bracket_match.get((self.depth - 1, token.RSQB))
769
770
771 @dataclass
772 class Line:
773     """Holds leaves and comments. Can be printed with `str(line)`."""
774
775     depth: int = 0
776     leaves: List[Leaf] = Factory(list)
777     comments: List[Tuple[Index, Leaf]] = Factory(list)
778     bracket_tracker: BracketTracker = Factory(BracketTracker)
779     inside_brackets: bool = False
780     should_explode: bool = False
781
782     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
783         """Add a new `leaf` to the end of the line.
784
785         Unless `preformatted` is True, the `leaf` will receive a new consistent
786         whitespace prefix and metadata applied by :class:`BracketTracker`.
787         Trailing commas are maybe removed, unpacked for loop variables are
788         demoted from being delimiters.
789
790         Inline comments are put aside.
791         """
792         has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
793         if not has_value:
794             return
795
796         if token.COLON == leaf.type and self.is_class_paren_empty:
797             del self.leaves[-2:]
798         if self.leaves and not preformatted:
799             # Note: at this point leaf.prefix should be empty except for
800             # imports, for which we only preserve newlines.
801             leaf.prefix += whitespace(
802                 leaf, complex_subscript=self.is_complex_subscript(leaf)
803             )
804         if self.inside_brackets or not preformatted:
805             self.bracket_tracker.mark(leaf)
806             self.maybe_remove_trailing_comma(leaf)
807         if not self.append_comment(leaf):
808             self.leaves.append(leaf)
809
810     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
811         """Like :func:`append()` but disallow invalid standalone comment structure.
812
813         Raises ValueError when any `leaf` is appended after a standalone comment
814         or when a standalone comment is not the first leaf on the line.
815         """
816         if self.bracket_tracker.depth == 0:
817             if self.is_comment:
818                 raise ValueError("cannot append to standalone comments")
819
820             if self.leaves and leaf.type == STANDALONE_COMMENT:
821                 raise ValueError(
822                     "cannot append standalone comments to a populated line"
823                 )
824
825         self.append(leaf, preformatted=preformatted)
826
827     @property
828     def is_comment(self) -> bool:
829         """Is this line a standalone comment?"""
830         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
831
832     @property
833     def is_decorator(self) -> bool:
834         """Is this line a decorator?"""
835         return bool(self) and self.leaves[0].type == token.AT
836
837     @property
838     def is_import(self) -> bool:
839         """Is this an import line?"""
840         return bool(self) and is_import(self.leaves[0])
841
842     @property
843     def is_class(self) -> bool:
844         """Is this line a class definition?"""
845         return (
846             bool(self)
847             and self.leaves[0].type == token.NAME
848             and self.leaves[0].value == "class"
849         )
850
851     @property
852     def is_stub_class(self) -> bool:
853         """Is this line a class definition with a body consisting only of "..."?"""
854         return self.is_class and self.leaves[-3:] == [
855             Leaf(token.DOT, ".") for _ in range(3)
856         ]
857
858     @property
859     def is_def(self) -> bool:
860         """Is this a function definition? (Also returns True for async defs.)"""
861         try:
862             first_leaf = self.leaves[0]
863         except IndexError:
864             return False
865
866         try:
867             second_leaf: Optional[Leaf] = self.leaves[1]
868         except IndexError:
869             second_leaf = None
870         return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
871             first_leaf.type == token.ASYNC
872             and second_leaf is not None
873             and second_leaf.type == token.NAME
874             and second_leaf.value == "def"
875         )
876
877     @property
878     def is_flow_control(self) -> bool:
879         """Is this line a flow control statement?
880
881         Those are `return`, `raise`, `break`, and `continue`.
882         """
883         return (
884             bool(self)
885             and self.leaves[0].type == token.NAME
886             and self.leaves[0].value in FLOW_CONTROL
887         )
888
889     @property
890     def is_yield(self) -> bool:
891         """Is this line a yield statement?"""
892         return (
893             bool(self)
894             and self.leaves[0].type == token.NAME
895             and self.leaves[0].value == "yield"
896         )
897
898     @property
899     def is_class_paren_empty(self) -> bool:
900         """Is this a class with no base classes but using parentheses?
901
902         Those are unnecessary and should be removed.
903         """
904         return (
905             bool(self)
906             and len(self.leaves) == 4
907             and self.is_class
908             and self.leaves[2].type == token.LPAR
909             and self.leaves[2].value == "("
910             and self.leaves[3].type == token.RPAR
911             and self.leaves[3].value == ")"
912         )
913
914     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
915         """If so, needs to be split before emitting."""
916         for leaf in self.leaves:
917             if leaf.type == STANDALONE_COMMENT:
918                 if leaf.bracket_depth <= depth_limit:
919                     return True
920
921         return False
922
923     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
924         """Remove trailing comma if there is one and it's safe."""
925         if not (
926             self.leaves
927             and self.leaves[-1].type == token.COMMA
928             and closing.type in CLOSING_BRACKETS
929         ):
930             return False
931
932         if closing.type == token.RBRACE:
933             self.remove_trailing_comma()
934             return True
935
936         if closing.type == token.RSQB:
937             comma = self.leaves[-1]
938             if comma.parent and comma.parent.type == syms.listmaker:
939                 self.remove_trailing_comma()
940                 return True
941
942         # For parens let's check if it's safe to remove the comma.
943         # Imports are always safe.
944         if self.is_import:
945             self.remove_trailing_comma()
946             return True
947
948         # Otheriwsse, if the trailing one is the only one, we might mistakenly
949         # change a tuple into a different type by removing the comma.
950         depth = closing.bracket_depth + 1
951         commas = 0
952         opening = closing.opening_bracket
953         for _opening_index, leaf in enumerate(self.leaves):
954             if leaf is opening:
955                 break
956
957         else:
958             return False
959
960         for leaf in self.leaves[_opening_index + 1 :]:
961             if leaf is closing:
962                 break
963
964             bracket_depth = leaf.bracket_depth
965             if bracket_depth == depth and leaf.type == token.COMMA:
966                 commas += 1
967                 if leaf.parent and leaf.parent.type == syms.arglist:
968                     commas += 1
969                     break
970
971         if commas > 1:
972             self.remove_trailing_comma()
973             return True
974
975         return False
976
977     def append_comment(self, comment: Leaf) -> bool:
978         """Add an inline or standalone comment to the line."""
979         if (
980             comment.type == STANDALONE_COMMENT
981             and self.bracket_tracker.any_open_brackets()
982         ):
983             comment.prefix = ""
984             return False
985
986         if comment.type != token.COMMENT:
987             return False
988
989         after = len(self.leaves) - 1
990         if after == -1:
991             comment.type = STANDALONE_COMMENT
992             comment.prefix = ""
993             return False
994
995         else:
996             self.comments.append((after, comment))
997             return True
998
999     def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
1000         """Generate comments that should appear directly after `leaf`.
1001
1002         Provide a non-negative leaf `_index` to speed up the function.
1003         """
1004         if _index == -1:
1005             for _index, _leaf in enumerate(self.leaves):
1006                 if leaf is _leaf:
1007                     break
1008
1009             else:
1010                 return
1011
1012         for index, comment_after in self.comments:
1013             if _index == index:
1014                 yield comment_after
1015
1016     def remove_trailing_comma(self) -> None:
1017         """Remove the trailing comma and moves the comments attached to it."""
1018         comma_index = len(self.leaves) - 1
1019         for i in range(len(self.comments)):
1020             comment_index, comment = self.comments[i]
1021             if comment_index == comma_index:
1022                 self.comments[i] = (comma_index - 1, comment)
1023         self.leaves.pop()
1024
1025     def is_complex_subscript(self, leaf: Leaf) -> bool:
1026         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1027         open_lsqb = (
1028             leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1029         )
1030         if open_lsqb is None:
1031             return False
1032
1033         subscript_start = open_lsqb.next_sibling
1034         if (
1035             isinstance(subscript_start, Node)
1036             and subscript_start.type == syms.subscriptlist
1037         ):
1038             subscript_start = child_towards(subscript_start, leaf)
1039         return subscript_start is not None and any(
1040             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1041         )
1042
1043     def __str__(self) -> str:
1044         """Render the line."""
1045         if not self:
1046             return "\n"
1047
1048         indent = "    " * self.depth
1049         leaves = iter(self.leaves)
1050         first = next(leaves)
1051         res = f"{first.prefix}{indent}{first.value}"
1052         for leaf in leaves:
1053             res += str(leaf)
1054         for _, comment in self.comments:
1055             res += str(comment)
1056         return res + "\n"
1057
1058     def __bool__(self) -> bool:
1059         """Return True if the line has leaves or comments."""
1060         return bool(self.leaves or self.comments)
1061
1062
1063 class UnformattedLines(Line):
1064     """Just like :class:`Line` but stores lines which aren't reformatted."""
1065
1066     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1067         """Just add a new `leaf` to the end of the lines.
1068
1069         The `preformatted` argument is ignored.
1070
1071         Keeps track of indentation `depth`, which is useful when the user
1072         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1073         """
1074         try:
1075             list(generate_comments(leaf))
1076         except FormatOn as f_on:
1077             self.leaves.append(f_on.leaf_from_consumed(leaf))
1078             raise
1079
1080         self.leaves.append(leaf)
1081         if leaf.type == token.INDENT:
1082             self.depth += 1
1083         elif leaf.type == token.DEDENT:
1084             self.depth -= 1
1085
1086     def __str__(self) -> str:
1087         """Render unformatted lines from leaves which were added with `append()`.
1088
1089         `depth` is not used for indentation in this case.
1090         """
1091         if not self:
1092             return "\n"
1093
1094         res = ""
1095         for leaf in self.leaves:
1096             res += str(leaf)
1097         return res
1098
1099     def append_comment(self, comment: Leaf) -> bool:
1100         """Not implemented in this class. Raises `NotImplementedError`."""
1101         raise NotImplementedError("Unformatted lines don't store comments separately.")
1102
1103     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1104         """Does nothing and returns False."""
1105         return False
1106
1107     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1108         """Does nothing and returns False."""
1109         return False
1110
1111
1112 @dataclass
1113 class EmptyLineTracker:
1114     """Provides a stateful method that returns the number of potential extra
1115     empty lines needed before and after the currently processed line.
1116
1117     Note: this tracker works on lines that haven't been split yet.  It assumes
1118     the prefix of the first leaf consists of optional newlines.  Those newlines
1119     are consumed by `maybe_empty_lines()` and included in the computation.
1120     """
1121     is_pyi: bool = False
1122     previous_line: Optional[Line] = None
1123     previous_after: int = 0
1124     previous_defs: List[int] = Factory(list)
1125
1126     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1127         """Return the number of extra empty lines before and after the `current_line`.
1128
1129         This is for separating `def`, `async def` and `class` with extra empty
1130         lines (two on module-level), as well as providing an extra empty line
1131         after flow control keywords to make them more prominent.
1132         """
1133         if isinstance(current_line, UnformattedLines):
1134             return 0, 0
1135
1136         before, after = self._maybe_empty_lines(current_line)
1137         before -= self.previous_after
1138         self.previous_after = after
1139         self.previous_line = current_line
1140         return before, after
1141
1142     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1143         max_allowed = 1
1144         if current_line.depth == 0:
1145             max_allowed = 1 if self.is_pyi else 2
1146         if current_line.leaves:
1147             # Consume the first leaf's extra newlines.
1148             first_leaf = current_line.leaves[0]
1149             before = first_leaf.prefix.count("\n")
1150             before = min(before, max_allowed)
1151             first_leaf.prefix = ""
1152         else:
1153             before = 0
1154         depth = current_line.depth
1155         while self.previous_defs and self.previous_defs[-1] >= depth:
1156             self.previous_defs.pop()
1157             if self.is_pyi:
1158                 before = 0 if depth else 1
1159             else:
1160                 before = 1 if depth else 2
1161         is_decorator = current_line.is_decorator
1162         if is_decorator or current_line.is_def or current_line.is_class:
1163             if not is_decorator:
1164                 self.previous_defs.append(depth)
1165             if self.previous_line is None:
1166                 # Don't insert empty lines before the first line in the file.
1167                 return 0, 0
1168
1169             if self.previous_line.is_decorator:
1170                 return 0, 0
1171
1172             if (
1173                 self.previous_line.is_comment
1174                 and self.previous_line.depth == current_line.depth
1175                 and before == 0
1176             ):
1177                 return 0, 0
1178
1179             if self.is_pyi:
1180                 if self.previous_line.depth > current_line.depth:
1181                     newlines = 1
1182                 elif current_line.is_class or self.previous_line.is_class:
1183                     if current_line.is_stub_class and self.previous_line.is_stub_class:
1184                         newlines = 0
1185                     else:
1186                         newlines = 1
1187                 else:
1188                     newlines = 0
1189             else:
1190                 newlines = 2
1191             if current_line.depth and newlines:
1192                 newlines -= 1
1193             return newlines, 0
1194
1195         if (
1196             self.previous_line
1197             and self.previous_line.is_import
1198             and not current_line.is_import
1199             and depth == self.previous_line.depth
1200         ):
1201             return (before or 1), 0
1202
1203         return before, 0
1204
1205
1206 @dataclass
1207 class LineGenerator(Visitor[Line]):
1208     """Generates reformatted Line objects.  Empty lines are not emitted.
1209
1210     Note: destroys the tree it's visiting by mutating prefixes of its leaves
1211     in ways that will no longer stringify to valid Python code on the tree.
1212     """
1213     is_pyi: bool = False
1214     current_line: Line = Factory(Line)
1215     remove_u_prefix: bool = False
1216
1217     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1218         """Generate a line.
1219
1220         If the line is empty, only emit if it makes sense.
1221         If the line is too long, split it first and then generate.
1222
1223         If any lines were generated, set up a new current_line.
1224         """
1225         if not self.current_line:
1226             if self.current_line.__class__ == type:
1227                 self.current_line.depth += indent
1228             else:
1229                 self.current_line = type(depth=self.current_line.depth + indent)
1230             return  # Line is empty, don't emit. Creating a new one unnecessary.
1231
1232         complete_line = self.current_line
1233         self.current_line = type(depth=complete_line.depth + indent)
1234         yield complete_line
1235
1236     def visit(self, node: LN) -> Iterator[Line]:
1237         """Main method to visit `node` and its children.
1238
1239         Yields :class:`Line` objects.
1240         """
1241         if isinstance(self.current_line, UnformattedLines):
1242             # File contained `# fmt: off`
1243             yield from self.visit_unformatted(node)
1244
1245         else:
1246             yield from super().visit(node)
1247
1248     def visit_default(self, node: LN) -> Iterator[Line]:
1249         """Default `visit_*()` implementation. Recurses to children of `node`."""
1250         if isinstance(node, Leaf):
1251             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1252             try:
1253                 for comment in generate_comments(node):
1254                     if any_open_brackets:
1255                         # any comment within brackets is subject to splitting
1256                         self.current_line.append(comment)
1257                     elif comment.type == token.COMMENT:
1258                         # regular trailing comment
1259                         self.current_line.append(comment)
1260                         yield from self.line()
1261
1262                     else:
1263                         # regular standalone comment
1264                         yield from self.line()
1265
1266                         self.current_line.append(comment)
1267                         yield from self.line()
1268
1269             except FormatOff as f_off:
1270                 f_off.trim_prefix(node)
1271                 yield from self.line(type=UnformattedLines)
1272                 yield from self.visit(node)
1273
1274             except FormatOn as f_on:
1275                 # This only happens here if somebody says "fmt: on" multiple
1276                 # times in a row.
1277                 f_on.trim_prefix(node)
1278                 yield from self.visit_default(node)
1279
1280             else:
1281                 normalize_prefix(node, inside_brackets=any_open_brackets)
1282                 if node.type == token.STRING:
1283                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1284                     normalize_string_quotes(node)
1285                 if node.type not in WHITESPACE:
1286                     self.current_line.append(node)
1287         yield from super().visit_default(node)
1288
1289     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1290         """Increase indentation level, maybe yield a line."""
1291         # In blib2to3 INDENT never holds comments.
1292         yield from self.line(+1)
1293         yield from self.visit_default(node)
1294
1295     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1296         """Decrease indentation level, maybe yield a line."""
1297         # The current line might still wait for trailing comments.  At DEDENT time
1298         # there won't be any (they would be prefixes on the preceding NEWLINE).
1299         # Emit the line then.
1300         yield from self.line()
1301
1302         # While DEDENT has no value, its prefix may contain standalone comments
1303         # that belong to the current indentation level.  Get 'em.
1304         yield from self.visit_default(node)
1305
1306         # Finally, emit the dedent.
1307         yield from self.line(-1)
1308
1309     def visit_stmt(
1310         self, node: Node, keywords: Set[str], parens: Set[str]
1311     ) -> Iterator[Line]:
1312         """Visit a statement.
1313
1314         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1315         `def`, `with`, `class`, `assert` and assignments.
1316
1317         The relevant Python language `keywords` for a given statement will be
1318         NAME leaves within it. This methods puts those on a separate line.
1319
1320         `parens` holds a set of string leaf values immediately after which
1321         invisible parens should be put.
1322         """
1323         normalize_invisible_parens(node, parens_after=parens)
1324         for child in node.children:
1325             if child.type == token.NAME and child.value in keywords:  # type: ignore
1326                 yield from self.line()
1327
1328             yield from self.visit(child)
1329
1330     def visit_suite(self, node: Node) -> Iterator[Line]:
1331         """Visit a suite."""
1332         if self.is_pyi and is_stub_suite(node):
1333             yield from self.visit(node.children[2])
1334         else:
1335             yield from self.visit_default(node)
1336
1337     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1338         """Visit a statement without nested statements."""
1339         is_suite_like = node.parent and node.parent.type in STATEMENT
1340         if is_suite_like:
1341             if self.is_pyi and is_stub_body(node):
1342                 yield from self.visit_default(node)
1343             else:
1344                 yield from self.line(+1)
1345                 yield from self.visit_default(node)
1346                 yield from self.line(-1)
1347
1348         else:
1349             if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1350                 yield from self.line()
1351             yield from self.visit_default(node)
1352
1353     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1354         """Visit `async def`, `async for`, `async with`."""
1355         yield from self.line()
1356
1357         children = iter(node.children)
1358         for child in children:
1359             yield from self.visit(child)
1360
1361             if child.type == token.ASYNC:
1362                 break
1363
1364         internal_stmt = next(children)
1365         for child in internal_stmt.children:
1366             yield from self.visit(child)
1367
1368     def visit_decorators(self, node: Node) -> Iterator[Line]:
1369         """Visit decorators."""
1370         for child in node.children:
1371             yield from self.line()
1372             yield from self.visit(child)
1373
1374     def visit_import_from(self, node: Node) -> Iterator[Line]:
1375         """Visit import_from and maybe put invisible parentheses.
1376
1377         This is separate from `visit_stmt` because import statements don't
1378         support arbitrary atoms and thus handling of parentheses is custom.
1379         """
1380         check_lpar = False
1381         for index, child in enumerate(node.children):
1382             if check_lpar:
1383                 if child.type == token.LPAR:
1384                     # make parentheses invisible
1385                     child.value = ""  # type: ignore
1386                     node.children[-1].value = ""  # type: ignore
1387                 else:
1388                     # insert invisible parentheses
1389                     node.insert_child(index, Leaf(token.LPAR, ""))
1390                     node.append_child(Leaf(token.RPAR, ""))
1391                 break
1392
1393             check_lpar = (
1394                 child.type == token.NAME and child.value == "import"  # type: ignore
1395             )
1396
1397         for child in node.children:
1398             yield from self.visit(child)
1399
1400     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1401         """Remove a semicolon and put the other statement on a separate line."""
1402         yield from self.line()
1403
1404     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1405         """End of file. Process outstanding comments and end with a newline."""
1406         yield from self.visit_default(leaf)
1407         yield from self.line()
1408
1409     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1410         """Used when file contained a `# fmt: off`."""
1411         if isinstance(node, Node):
1412             for child in node.children:
1413                 yield from self.visit(child)
1414
1415         else:
1416             try:
1417                 self.current_line.append(node)
1418             except FormatOn as f_on:
1419                 f_on.trim_prefix(node)
1420                 yield from self.line()
1421                 yield from self.visit(node)
1422
1423             if node.type == token.ENDMARKER:
1424                 # somebody decided not to put a final `# fmt: on`
1425                 yield from self.line()
1426
1427     def __attrs_post_init__(self) -> None:
1428         """You are in a twisty little maze of passages."""
1429         v = self.visit_stmt
1430         Ø: Set[str] = set()
1431         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1432         self.visit_if_stmt = partial(
1433             v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1434         )
1435         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1436         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1437         self.visit_try_stmt = partial(
1438             v, keywords={"try", "except", "else", "finally"}, parens=Ø
1439         )
1440         self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1441         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1442         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1443         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1444         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1445         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1446         self.visit_async_funcdef = self.visit_async_stmt
1447         self.visit_decorated = self.visit_decorators
1448
1449
1450 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1451 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1452 OPENING_BRACKETS = set(BRACKET.keys())
1453 CLOSING_BRACKETS = set(BRACKET.values())
1454 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1455 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1456
1457
1458 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
1459     """Return whitespace prefix if needed for the given `leaf`.
1460
1461     `complex_subscript` signals whether the given leaf is part of a subscription
1462     which has non-trivial arguments, like arithmetic expressions or function calls.
1463     """
1464     NO = ""
1465     SPACE = " "
1466     DOUBLESPACE = "  "
1467     t = leaf.type
1468     p = leaf.parent
1469     v = leaf.value
1470     if t in ALWAYS_NO_SPACE:
1471         return NO
1472
1473     if t == token.COMMENT:
1474         return DOUBLESPACE
1475
1476     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1477     if t == token.COLON and p.type not in {
1478         syms.subscript,
1479         syms.subscriptlist,
1480         syms.sliceop,
1481     }:
1482         return NO
1483
1484     prev = leaf.prev_sibling
1485     if not prev:
1486         prevp = preceding_leaf(p)
1487         if not prevp or prevp.type in OPENING_BRACKETS:
1488             return NO
1489
1490         if t == token.COLON:
1491             if prevp.type == token.COLON:
1492                 return NO
1493
1494             elif prevp.type != token.COMMA and not complex_subscript:
1495                 return NO
1496
1497             return SPACE
1498
1499         if prevp.type == token.EQUAL:
1500             if prevp.parent:
1501                 if prevp.parent.type in {
1502                     syms.arglist,
1503                     syms.argument,
1504                     syms.parameters,
1505                     syms.varargslist,
1506                 }:
1507                     return NO
1508
1509                 elif prevp.parent.type == syms.typedargslist:
1510                     # A bit hacky: if the equal sign has whitespace, it means we
1511                     # previously found it's a typed argument.  So, we're using
1512                     # that, too.
1513                     return prevp.prefix
1514
1515         elif prevp.type in STARS:
1516             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1517                 return NO
1518
1519         elif prevp.type == token.COLON:
1520             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1521                 return SPACE if complex_subscript else NO
1522
1523         elif (
1524             prevp.parent
1525             and prevp.parent.type == syms.factor
1526             and prevp.type in MATH_OPERATORS
1527         ):
1528             return NO
1529
1530         elif (
1531             prevp.type == token.RIGHTSHIFT
1532             and prevp.parent
1533             and prevp.parent.type == syms.shift_expr
1534             and prevp.prev_sibling
1535             and prevp.prev_sibling.type == token.NAME
1536             and prevp.prev_sibling.value == "print"  # type: ignore
1537         ):
1538             # Python 2 print chevron
1539             return NO
1540
1541     elif prev.type in OPENING_BRACKETS:
1542         return NO
1543
1544     if p.type in {syms.parameters, syms.arglist}:
1545         # untyped function signatures or calls
1546         if not prev or prev.type != token.COMMA:
1547             return NO
1548
1549     elif p.type == syms.varargslist:
1550         # lambdas
1551         if prev and prev.type != token.COMMA:
1552             return NO
1553
1554     elif p.type == syms.typedargslist:
1555         # typed function signatures
1556         if not prev:
1557             return NO
1558
1559         if t == token.EQUAL:
1560             if prev.type != syms.tname:
1561                 return NO
1562
1563         elif prev.type == token.EQUAL:
1564             # A bit hacky: if the equal sign has whitespace, it means we
1565             # previously found it's a typed argument.  So, we're using that, too.
1566             return prev.prefix
1567
1568         elif prev.type != token.COMMA:
1569             return NO
1570
1571     elif p.type == syms.tname:
1572         # type names
1573         if not prev:
1574             prevp = preceding_leaf(p)
1575             if not prevp or prevp.type != token.COMMA:
1576                 return NO
1577
1578     elif p.type == syms.trailer:
1579         # attributes and calls
1580         if t == token.LPAR or t == token.RPAR:
1581             return NO
1582
1583         if not prev:
1584             if t == token.DOT:
1585                 prevp = preceding_leaf(p)
1586                 if not prevp or prevp.type != token.NUMBER:
1587                     return NO
1588
1589             elif t == token.LSQB:
1590                 return NO
1591
1592         elif prev.type != token.COMMA:
1593             return NO
1594
1595     elif p.type == syms.argument:
1596         # single argument
1597         if t == token.EQUAL:
1598             return NO
1599
1600         if not prev:
1601             prevp = preceding_leaf(p)
1602             if not prevp or prevp.type == token.LPAR:
1603                 return NO
1604
1605         elif prev.type in {token.EQUAL} | STARS:
1606             return NO
1607
1608     elif p.type == syms.decorator:
1609         # decorators
1610         return NO
1611
1612     elif p.type == syms.dotted_name:
1613         if prev:
1614             return NO
1615
1616         prevp = preceding_leaf(p)
1617         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1618             return NO
1619
1620     elif p.type == syms.classdef:
1621         if t == token.LPAR:
1622             return NO
1623
1624         if prev and prev.type == token.LPAR:
1625             return NO
1626
1627     elif p.type in {syms.subscript, syms.sliceop}:
1628         # indexing
1629         if not prev:
1630             assert p.parent is not None, "subscripts are always parented"
1631             if p.parent.type == syms.subscriptlist:
1632                 return SPACE
1633
1634             return NO
1635
1636         elif not complex_subscript:
1637             return NO
1638
1639     elif p.type == syms.atom:
1640         if prev and t == token.DOT:
1641             # dots, but not the first one.
1642             return NO
1643
1644     elif p.type == syms.dictsetmaker:
1645         # dict unpacking
1646         if prev and prev.type == token.DOUBLESTAR:
1647             return NO
1648
1649     elif p.type in {syms.factor, syms.star_expr}:
1650         # unary ops
1651         if not prev:
1652             prevp = preceding_leaf(p)
1653             if not prevp or prevp.type in OPENING_BRACKETS:
1654                 return NO
1655
1656             prevp_parent = prevp.parent
1657             assert prevp_parent is not None
1658             if prevp.type == token.COLON and prevp_parent.type in {
1659                 syms.subscript,
1660                 syms.sliceop,
1661             }:
1662                 return NO
1663
1664             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1665                 return NO
1666
1667         elif t == token.NAME or t == token.NUMBER:
1668             return NO
1669
1670     elif p.type == syms.import_from:
1671         if t == token.DOT:
1672             if prev and prev.type == token.DOT:
1673                 return NO
1674
1675         elif t == token.NAME:
1676             if v == "import":
1677                 return SPACE
1678
1679             if prev and prev.type == token.DOT:
1680                 return NO
1681
1682     elif p.type == syms.sliceop:
1683         return NO
1684
1685     return SPACE
1686
1687
1688 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1689     """Return the first leaf that precedes `node`, if any."""
1690     while node:
1691         res = node.prev_sibling
1692         if res:
1693             if isinstance(res, Leaf):
1694                 return res
1695
1696             try:
1697                 return list(res.leaves())[-1]
1698
1699             except IndexError:
1700                 return None
1701
1702         node = node.parent
1703     return None
1704
1705
1706 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1707     """Return the child of `ancestor` that contains `descendant`."""
1708     node: Optional[LN] = descendant
1709     while node and node.parent != ancestor:
1710         node = node.parent
1711     return node
1712
1713
1714 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1715     """Return the priority of the `leaf` delimiter, given a line break after it.
1716
1717     The delimiter priorities returned here are from those delimiters that would
1718     cause a line break after themselves.
1719
1720     Higher numbers are higher priority.
1721     """
1722     if leaf.type == token.COMMA:
1723         return COMMA_PRIORITY
1724
1725     return 0
1726
1727
1728 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1729     """Return the priority of the `leaf` delimiter, given a line before after it.
1730
1731     The delimiter priorities returned here are from those delimiters that would
1732     cause a line break before themselves.
1733
1734     Higher numbers are higher priority.
1735     """
1736     if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1737         # * and ** might also be MATH_OPERATORS but in this case they are not.
1738         # Don't treat them as a delimiter.
1739         return 0
1740
1741     if (
1742         leaf.type == token.DOT
1743         and leaf.parent
1744         and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1745         and (previous is None or previous.type in CLOSING_BRACKETS)
1746     ):
1747         return DOT_PRIORITY
1748
1749     if (
1750         leaf.type in MATH_OPERATORS
1751         and leaf.parent
1752         and leaf.parent.type not in {syms.factor, syms.star_expr}
1753     ):
1754         return MATH_PRIORITIES[leaf.type]
1755
1756     if leaf.type in COMPARATORS:
1757         return COMPARATOR_PRIORITY
1758
1759     if (
1760         leaf.type == token.STRING
1761         and previous is not None
1762         and previous.type == token.STRING
1763     ):
1764         return STRING_PRIORITY
1765
1766     if leaf.type != token.NAME:
1767         return 0
1768
1769     if (
1770         leaf.value == "for"
1771         and leaf.parent
1772         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1773     ):
1774         return COMPREHENSION_PRIORITY
1775
1776     if (
1777         leaf.value == "if"
1778         and leaf.parent
1779         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1780     ):
1781         return COMPREHENSION_PRIORITY
1782
1783     if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1784         return TERNARY_PRIORITY
1785
1786     if leaf.value == "is":
1787         return COMPARATOR_PRIORITY
1788
1789     if (
1790         leaf.value == "in"
1791         and leaf.parent
1792         and leaf.parent.type in {syms.comp_op, syms.comparison}
1793         and not (
1794             previous is not None
1795             and previous.type == token.NAME
1796             and previous.value == "not"
1797         )
1798     ):
1799         return COMPARATOR_PRIORITY
1800
1801     if (
1802         leaf.value == "not"
1803         and leaf.parent
1804         and leaf.parent.type == syms.comp_op
1805         and not (
1806             previous is not None
1807             and previous.type == token.NAME
1808             and previous.value == "is"
1809         )
1810     ):
1811         return COMPARATOR_PRIORITY
1812
1813     if leaf.value in LOGIC_OPERATORS and leaf.parent:
1814         return LOGIC_PRIORITY
1815
1816     return 0
1817
1818
1819 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1820     """Clean the prefix of the `leaf` and generate comments from it, if any.
1821
1822     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1823     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1824     move because it does away with modifying the grammar to include all the
1825     possible places in which comments can be placed.
1826
1827     The sad consequence for us though is that comments don't "belong" anywhere.
1828     This is why this function generates simple parentless Leaf objects for
1829     comments.  We simply don't know what the correct parent should be.
1830
1831     No matter though, we can live without this.  We really only need to
1832     differentiate between inline and standalone comments.  The latter don't
1833     share the line with any code.
1834
1835     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1836     are emitted with a fake STANDALONE_COMMENT token identifier.
1837     """
1838     p = leaf.prefix
1839     if not p:
1840         return
1841
1842     if "#" not in p:
1843         return
1844
1845     consumed = 0
1846     nlines = 0
1847     for index, line in enumerate(p.split("\n")):
1848         consumed += len(line) + 1  # adding the length of the split '\n'
1849         line = line.lstrip()
1850         if not line:
1851             nlines += 1
1852         if not line.startswith("#"):
1853             continue
1854
1855         if index == 0 and leaf.type != token.ENDMARKER:
1856             comment_type = token.COMMENT  # simple trailing comment
1857         else:
1858             comment_type = STANDALONE_COMMENT
1859         comment = make_comment(line)
1860         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1861
1862         if comment in {"# fmt: on", "# yapf: enable"}:
1863             raise FormatOn(consumed)
1864
1865         if comment in {"# fmt: off", "# yapf: disable"}:
1866             if comment_type == STANDALONE_COMMENT:
1867                 raise FormatOff(consumed)
1868
1869             prev = preceding_leaf(leaf)
1870             if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
1871                 raise FormatOff(consumed)
1872
1873         nlines = 0
1874
1875
1876 def make_comment(content: str) -> str:
1877     """Return a consistently formatted comment from the given `content` string.
1878
1879     All comments (except for "##", "#!", "#:") should have a single space between
1880     the hash sign and the content.
1881
1882     If `content` didn't start with a hash sign, one is provided.
1883     """
1884     content = content.rstrip()
1885     if not content:
1886         return "#"
1887
1888     if content[0] == "#":
1889         content = content[1:]
1890     if content and content[0] not in " !:#":
1891         content = " " + content
1892     return "#" + content
1893
1894
1895 def split_line(
1896     line: Line, line_length: int, inner: bool = False, py36: bool = False
1897 ) -> Iterator[Line]:
1898     """Split a `line` into potentially many lines.
1899
1900     They should fit in the allotted `line_length` but might not be able to.
1901     `inner` signifies that there were a pair of brackets somewhere around the
1902     current `line`, possibly transitively. This means we can fallback to splitting
1903     by delimiters if the LHS/RHS don't yield any results.
1904
1905     If `py36` is True, splitting may generate syntax that is only compatible
1906     with Python 3.6 and later.
1907     """
1908     if isinstance(line, UnformattedLines) or line.is_comment:
1909         yield line
1910         return
1911
1912     line_str = str(line).strip("\n")
1913     if not line.should_explode and is_line_short_enough(
1914         line, line_length=line_length, line_str=line_str
1915     ):
1916         yield line
1917         return
1918
1919     split_funcs: List[SplitFunc]
1920     if line.is_def:
1921         split_funcs = [left_hand_split]
1922     else:
1923
1924         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1925             for omit in generate_trailers_to_omit(line, line_length):
1926                 lines = list(right_hand_split(line, line_length, py36, omit=omit))
1927                 if is_line_short_enough(lines[0], line_length=line_length):
1928                     yield from lines
1929                     return
1930
1931             # All splits failed, best effort split with no omits.
1932             # This mostly happens to multiline strings that are by definition
1933             # reported as not fitting a single line.
1934             yield from right_hand_split(line, py36)
1935
1936         if line.inside_brackets:
1937             split_funcs = [delimiter_split, standalone_comment_split, rhs]
1938         else:
1939             split_funcs = [rhs]
1940     for split_func in split_funcs:
1941         # We are accumulating lines in `result` because we might want to abort
1942         # mission and return the original line in the end, or attempt a different
1943         # split altogether.
1944         result: List[Line] = []
1945         try:
1946             for l in split_func(line, py36):
1947                 if str(l).strip("\n") == line_str:
1948                     raise CannotSplit("Split function returned an unchanged result")
1949
1950                 result.extend(
1951                     split_line(l, line_length=line_length, inner=True, py36=py36)
1952                 )
1953         except CannotSplit as cs:
1954             continue
1955
1956         else:
1957             yield from result
1958             break
1959
1960     else:
1961         yield line
1962
1963
1964 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1965     """Split line into many lines, starting with the first matching bracket pair.
1966
1967     Note: this usually looks weird, only use this for function definitions.
1968     Prefer RHS otherwise.  This is why this function is not symmetrical with
1969     :func:`right_hand_split` which also handles optional parentheses.
1970     """
1971     head = Line(depth=line.depth)
1972     body = Line(depth=line.depth + 1, inside_brackets=True)
1973     tail = Line(depth=line.depth)
1974     tail_leaves: List[Leaf] = []
1975     body_leaves: List[Leaf] = []
1976     head_leaves: List[Leaf] = []
1977     current_leaves = head_leaves
1978     matching_bracket = None
1979     for leaf in line.leaves:
1980         if (
1981             current_leaves is body_leaves
1982             and leaf.type in CLOSING_BRACKETS
1983             and leaf.opening_bracket is matching_bracket
1984         ):
1985             current_leaves = tail_leaves if body_leaves else head_leaves
1986         current_leaves.append(leaf)
1987         if current_leaves is head_leaves:
1988             if leaf.type in OPENING_BRACKETS:
1989                 matching_bracket = leaf
1990                 current_leaves = body_leaves
1991     # Since body is a new indent level, remove spurious leading whitespace.
1992     if body_leaves:
1993         normalize_prefix(body_leaves[0], inside_brackets=True)
1994     # Build the new lines.
1995     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1996         for leaf in leaves:
1997             result.append(leaf, preformatted=True)
1998             for comment_after in line.comments_after(leaf):
1999                 result.append(comment_after, preformatted=True)
2000     bracket_split_succeeded_or_raise(head, body, tail)
2001     for result in (head, body, tail):
2002         if result:
2003             yield result
2004
2005
2006 def right_hand_split(
2007     line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
2008 ) -> Iterator[Line]:
2009     """Split line into many lines, starting with the last matching bracket pair.
2010
2011     If the split was by optional parentheses, attempt splitting without them, too.
2012     `omit` is a collection of closing bracket IDs that shouldn't be considered for
2013     this split.
2014
2015     Note: running this function modifies `bracket_depth` on the leaves of `line`.
2016     """
2017     head = Line(depth=line.depth)
2018     body = Line(depth=line.depth + 1, inside_brackets=True)
2019     tail = Line(depth=line.depth)
2020     tail_leaves: List[Leaf] = []
2021     body_leaves: List[Leaf] = []
2022     head_leaves: List[Leaf] = []
2023     current_leaves = tail_leaves
2024     opening_bracket = None
2025     closing_bracket = None
2026     for leaf in reversed(line.leaves):
2027         if current_leaves is body_leaves:
2028             if leaf is opening_bracket:
2029                 current_leaves = head_leaves if body_leaves else tail_leaves
2030         current_leaves.append(leaf)
2031         if current_leaves is tail_leaves:
2032             if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2033                 opening_bracket = leaf.opening_bracket
2034                 closing_bracket = leaf
2035                 current_leaves = body_leaves
2036     tail_leaves.reverse()
2037     body_leaves.reverse()
2038     head_leaves.reverse()
2039     # Since body is a new indent level, remove spurious leading whitespace.
2040     if body_leaves:
2041         normalize_prefix(body_leaves[0], inside_brackets=True)
2042     if not head_leaves:
2043         # No `head` means the split failed. Either `tail` has all content or
2044         # the matching `opening_bracket` wasn't available on `line` anymore.
2045         raise CannotSplit("No brackets found")
2046
2047     # Build the new lines.
2048     for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2049         for leaf in leaves:
2050             result.append(leaf, preformatted=True)
2051             for comment_after in line.comments_after(leaf):
2052                 result.append(comment_after, preformatted=True)
2053     bracket_split_succeeded_or_raise(head, body, tail)
2054     assert opening_bracket and closing_bracket
2055     if (
2056         # the opening bracket is an optional paren
2057         opening_bracket.type == token.LPAR
2058         and not opening_bracket.value
2059         # the closing bracket is an optional paren
2060         and closing_bracket.type == token.RPAR
2061         and not closing_bracket.value
2062         # there are no standalone comments in the body
2063         and not line.contains_standalone_comments(0)
2064         # and it's not an import (optional parens are the only thing we can split
2065         # on in this case; attempting a split without them is a waste of time)
2066         and not line.is_import
2067     ):
2068         omit = {id(closing_bracket), *omit}
2069         if can_omit_invisible_parens(body, line_length):
2070             try:
2071                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
2072                 return
2073             except CannotSplit:
2074                 pass
2075
2076     ensure_visible(opening_bracket)
2077     ensure_visible(closing_bracket)
2078     body.should_explode = should_explode(body, opening_bracket)
2079     for result in (head, body, tail):
2080         if result:
2081             yield result
2082
2083
2084 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2085     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2086
2087     Do nothing otherwise.
2088
2089     A left- or right-hand split is based on a pair of brackets. Content before
2090     (and including) the opening bracket is left on one line, content inside the
2091     brackets is put on a separate line, and finally content starting with and
2092     following the closing bracket is put on a separate line.
2093
2094     Those are called `head`, `body`, and `tail`, respectively. If the split
2095     produced the same line (all content in `head`) or ended up with an empty `body`
2096     and the `tail` is just the closing bracket, then it's considered failed.
2097     """
2098     tail_len = len(str(tail).strip())
2099     if not body:
2100         if tail_len == 0:
2101             raise CannotSplit("Splitting brackets produced the same line")
2102
2103         elif tail_len < 3:
2104             raise CannotSplit(
2105                 f"Splitting brackets on an empty body to save "
2106                 f"{tail_len} characters is not worth it"
2107             )
2108
2109
2110 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2111     """Normalize prefix of the first leaf in every line returned by `split_func`.
2112
2113     This is a decorator over relevant split functions.
2114     """
2115
2116     @wraps(split_func)
2117     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2118         for l in split_func(line, py36):
2119             normalize_prefix(l.leaves[0], inside_brackets=True)
2120             yield l
2121
2122     return split_wrapper
2123
2124
2125 @dont_increase_indentation
2126 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2127     """Split according to delimiters of the highest priority.
2128
2129     If `py36` is True, the split will add trailing commas also in function
2130     signatures that contain `*` and `**`.
2131     """
2132     try:
2133         last_leaf = line.leaves[-1]
2134     except IndexError:
2135         raise CannotSplit("Line empty")
2136
2137     bt = line.bracket_tracker
2138     try:
2139         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2140     except ValueError:
2141         raise CannotSplit("No delimiters found")
2142
2143     if delimiter_priority == DOT_PRIORITY:
2144         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2145             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2146
2147     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2148     lowest_depth = sys.maxsize
2149     trailing_comma_safe = True
2150
2151     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2152         """Append `leaf` to current line or to new line if appending impossible."""
2153         nonlocal current_line
2154         try:
2155             current_line.append_safe(leaf, preformatted=True)
2156         except ValueError as ve:
2157             yield current_line
2158
2159             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2160             current_line.append(leaf)
2161
2162     for index, leaf in enumerate(line.leaves):
2163         yield from append_to_line(leaf)
2164
2165         for comment_after in line.comments_after(leaf, index):
2166             yield from append_to_line(comment_after)
2167
2168         lowest_depth = min(lowest_depth, leaf.bracket_depth)
2169         if leaf.bracket_depth == lowest_depth and is_vararg(
2170             leaf, within=VARARGS_PARENTS
2171         ):
2172             trailing_comma_safe = trailing_comma_safe and py36
2173         leaf_priority = bt.delimiters.get(id(leaf))
2174         if leaf_priority == delimiter_priority:
2175             yield current_line
2176
2177             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2178     if current_line:
2179         if (
2180             trailing_comma_safe
2181             and delimiter_priority == COMMA_PRIORITY
2182             and current_line.leaves[-1].type != token.COMMA
2183             and current_line.leaves[-1].type != STANDALONE_COMMENT
2184         ):
2185             current_line.append(Leaf(token.COMMA, ","))
2186         yield current_line
2187
2188
2189 @dont_increase_indentation
2190 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2191     """Split standalone comments from the rest of the line."""
2192     if not line.contains_standalone_comments(0):
2193         raise CannotSplit("Line does not have any standalone comments")
2194
2195     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2196
2197     def append_to_line(leaf: Leaf) -> Iterator[Line]:
2198         """Append `leaf` to current line or to new line if appending impossible."""
2199         nonlocal current_line
2200         try:
2201             current_line.append_safe(leaf, preformatted=True)
2202         except ValueError as ve:
2203             yield current_line
2204
2205             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2206             current_line.append(leaf)
2207
2208     for index, leaf in enumerate(line.leaves):
2209         yield from append_to_line(leaf)
2210
2211         for comment_after in line.comments_after(leaf, index):
2212             yield from append_to_line(comment_after)
2213
2214     if current_line:
2215         yield current_line
2216
2217
2218 def is_import(leaf: Leaf) -> bool:
2219     """Return True if the given leaf starts an import statement."""
2220     p = leaf.parent
2221     t = leaf.type
2222     v = leaf.value
2223     return bool(
2224         t == token.NAME
2225         and (
2226             (v == "import" and p and p.type == syms.import_name)
2227             or (v == "from" and p and p.type == syms.import_from)
2228         )
2229     )
2230
2231
2232 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2233     """Leave existing extra newlines if not `inside_brackets`. Remove everything
2234     else.
2235
2236     Note: don't use backslashes for formatting or you'll lose your voting rights.
2237     """
2238     if not inside_brackets:
2239         spl = leaf.prefix.split("#")
2240         if "\\" not in spl[0]:
2241             nl_count = spl[-1].count("\n")
2242             if len(spl) > 1:
2243                 nl_count -= 1
2244             leaf.prefix = "\n" * nl_count
2245             return
2246
2247     leaf.prefix = ""
2248
2249
2250 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2251     """Make all string prefixes lowercase.
2252
2253     If remove_u_prefix is given, also removes any u prefix from the string.
2254
2255     Note: Mutates its argument.
2256     """
2257     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2258     assert match is not None, f"failed to match string {leaf.value!r}"
2259     orig_prefix = match.group(1)
2260     new_prefix = orig_prefix.lower()
2261     if remove_u_prefix:
2262         new_prefix = new_prefix.replace("u", "")
2263     leaf.value = f"{new_prefix}{match.group(2)}"
2264
2265
2266 def normalize_string_quotes(leaf: Leaf) -> None:
2267     """Prefer double quotes but only if it doesn't cause more escaping.
2268
2269     Adds or removes backslashes as appropriate. Doesn't parse and fix
2270     strings nested in f-strings (yet).
2271
2272     Note: Mutates its argument.
2273     """
2274     value = leaf.value.lstrip("furbFURB")
2275     if value[:3] == '"""':
2276         return
2277
2278     elif value[:3] == "'''":
2279         orig_quote = "'''"
2280         new_quote = '"""'
2281     elif value[0] == '"':
2282         orig_quote = '"'
2283         new_quote = "'"
2284     else:
2285         orig_quote = "'"
2286         new_quote = '"'
2287     first_quote_pos = leaf.value.find(orig_quote)
2288     if first_quote_pos == -1:
2289         return  # There's an internal error
2290
2291     prefix = leaf.value[:first_quote_pos]
2292     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2293     escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2294     escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2295     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2296     if "r" in prefix.casefold():
2297         if unescaped_new_quote.search(body):
2298             # There's at least one unescaped new_quote in this raw string
2299             # so converting is impossible
2300             return
2301
2302         # Do not introduce or remove backslashes in raw strings
2303         new_body = body
2304     else:
2305         # remove unnecessary quotes
2306         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2307         if body != new_body:
2308             # Consider the string without unnecessary quotes as the original
2309             body = new_body
2310             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2311         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2312         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2313     if new_quote == '"""' and new_body[-1] == '"':
2314         # edge case:
2315         new_body = new_body[:-1] + '\\"'
2316     orig_escape_count = body.count("\\")
2317     new_escape_count = new_body.count("\\")
2318     if new_escape_count > orig_escape_count:
2319         return  # Do not introduce more escaping
2320
2321     if new_escape_count == orig_escape_count and orig_quote == '"':
2322         return  # Prefer double quotes
2323
2324     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2325
2326
2327 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2328     """Make existing optional parentheses invisible or create new ones.
2329
2330     `parens_after` is a set of string leaf values immeditely after which parens
2331     should be put.
2332
2333     Standardizes on visible parentheses for single-element tuples, and keeps
2334     existing visible parentheses for other tuples and generator expressions.
2335     """
2336     check_lpar = False
2337     for child in list(node.children):
2338         if check_lpar:
2339             if child.type == syms.atom:
2340                 maybe_make_parens_invisible_in_atom(child)
2341             elif is_one_tuple(child):
2342                 # wrap child in visible parentheses
2343                 lpar = Leaf(token.LPAR, "(")
2344                 rpar = Leaf(token.RPAR, ")")
2345                 index = child.remove() or 0
2346                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2347             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2348                 # wrap child in invisible parentheses
2349                 lpar = Leaf(token.LPAR, "")
2350                 rpar = Leaf(token.RPAR, "")
2351                 index = child.remove() or 0
2352                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2353
2354         check_lpar = isinstance(child, Leaf) and child.value in parens_after
2355
2356
2357 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2358     """If it's safe, make the parens in the atom `node` invisible, recusively."""
2359     if (
2360         node.type != syms.atom
2361         or is_empty_tuple(node)
2362         or is_one_tuple(node)
2363         or is_yield(node)
2364         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2365     ):
2366         return False
2367
2368     first = node.children[0]
2369     last = node.children[-1]
2370     if first.type == token.LPAR and last.type == token.RPAR:
2371         # make parentheses invisible
2372         first.value = ""  # type: ignore
2373         last.value = ""  # type: ignore
2374         if len(node.children) > 1:
2375             maybe_make_parens_invisible_in_atom(node.children[1])
2376         return True
2377
2378     return False
2379
2380
2381 def is_empty_tuple(node: LN) -> bool:
2382     """Return True if `node` holds an empty tuple."""
2383     return (
2384         node.type == syms.atom
2385         and len(node.children) == 2
2386         and node.children[0].type == token.LPAR
2387         and node.children[1].type == token.RPAR
2388     )
2389
2390
2391 def is_one_tuple(node: LN) -> bool:
2392     """Return True if `node` holds a tuple with one element, with or without parens."""
2393     if node.type == syms.atom:
2394         if len(node.children) != 3:
2395             return False
2396
2397         lpar, gexp, rpar = node.children
2398         if not (
2399             lpar.type == token.LPAR
2400             and gexp.type == syms.testlist_gexp
2401             and rpar.type == token.RPAR
2402         ):
2403             return False
2404
2405         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2406
2407     return (
2408         node.type in IMPLICIT_TUPLE
2409         and len(node.children) == 2
2410         and node.children[1].type == token.COMMA
2411     )
2412
2413
2414 def is_yield(node: LN) -> bool:
2415     """Return True if `node` holds a `yield` or `yield from` expression."""
2416     if node.type == syms.yield_expr:
2417         return True
2418
2419     if node.type == token.NAME and node.value == "yield":  # type: ignore
2420         return True
2421
2422     if node.type != syms.atom:
2423         return False
2424
2425     if len(node.children) != 3:
2426         return False
2427
2428     lpar, expr, rpar = node.children
2429     if lpar.type == token.LPAR and rpar.type == token.RPAR:
2430         return is_yield(expr)
2431
2432     return False
2433
2434
2435 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2436     """Return True if `leaf` is a star or double star in a vararg or kwarg.
2437
2438     If `within` includes VARARGS_PARENTS, this applies to function signatures.
2439     If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2440     extended iterable unpacking (PEP 3132) and additional unpacking
2441     generalizations (PEP 448).
2442     """
2443     if leaf.type not in STARS or not leaf.parent:
2444         return False
2445
2446     p = leaf.parent
2447     if p.type == syms.star_expr:
2448         # Star expressions are also used as assignment targets in extended
2449         # iterable unpacking (PEP 3132).  See what its parent is instead.
2450         if not p.parent:
2451             return False
2452
2453         p = p.parent
2454
2455     return p.type in within
2456
2457
2458 def is_multiline_string(leaf: Leaf) -> bool:
2459     """Return True if `leaf` is a multiline string that actually spans many lines."""
2460     value = leaf.value.lstrip("furbFURB")
2461     return value[:3] in {'"""', "'''"} and "\n" in value
2462
2463
2464 def is_stub_suite(node: Node) -> bool:
2465     """Return True if `node` is a suite with a stub body."""
2466     if (
2467         len(node.children) != 4
2468         or node.children[0].type != token.NEWLINE
2469         or node.children[1].type != token.INDENT
2470         or node.children[3].type != token.DEDENT
2471     ):
2472         return False
2473
2474     return is_stub_body(node.children[2])
2475
2476
2477 def is_stub_body(node: LN) -> bool:
2478     """Return True if `node` is a simple statement containing an ellipsis."""
2479     if not isinstance(node, Node) or node.type != syms.simple_stmt:
2480         return False
2481
2482     if len(node.children) != 2:
2483         return False
2484
2485     child = node.children[0]
2486     return (
2487         child.type == syms.atom
2488         and len(child.children) == 3
2489         and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2490     )
2491
2492
2493 def max_delimiter_priority_in_atom(node: LN) -> int:
2494     """Return maximum delimiter priority inside `node`.
2495
2496     This is specific to atoms with contents contained in a pair of parentheses.
2497     If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2498     """
2499     if node.type != syms.atom:
2500         return 0
2501
2502     first = node.children[0]
2503     last = node.children[-1]
2504     if not (first.type == token.LPAR and last.type == token.RPAR):
2505         return 0
2506
2507     bt = BracketTracker()
2508     for c in node.children[1:-1]:
2509         if isinstance(c, Leaf):
2510             bt.mark(c)
2511         else:
2512             for leaf in c.leaves():
2513                 bt.mark(leaf)
2514     try:
2515         return bt.max_delimiter_priority()
2516
2517     except ValueError:
2518         return 0
2519
2520
2521 def ensure_visible(leaf: Leaf) -> None:
2522     """Make sure parentheses are visible.
2523
2524     They could be invisible as part of some statements (see
2525     :func:`normalize_invible_parens` and :func:`visit_import_from`).
2526     """
2527     if leaf.type == token.LPAR:
2528         leaf.value = "("
2529     elif leaf.type == token.RPAR:
2530         leaf.value = ")"
2531
2532
2533 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
2534     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
2535     if not (
2536         opening_bracket.parent
2537         and opening_bracket.parent.type in {syms.atom, syms.import_from}
2538         and opening_bracket.value in "[{("
2539     ):
2540         return False
2541
2542     try:
2543         last_leaf = line.leaves[-1]
2544         exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
2545         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
2546     except (IndexError, ValueError):
2547         return False
2548
2549     return max_priority == COMMA_PRIORITY
2550
2551
2552 def is_python36(node: Node) -> bool:
2553     """Return True if the current file is using Python 3.6+ features.
2554
2555     Currently looking for:
2556     - f-strings; and
2557     - trailing commas after * or ** in function signatures and calls.
2558     """
2559     for n in node.pre_order():
2560         if n.type == token.STRING:
2561             value_head = n.value[:2]  # type: ignore
2562             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2563                 return True
2564
2565         elif (
2566             n.type in {syms.typedargslist, syms.arglist}
2567             and n.children
2568             and n.children[-1].type == token.COMMA
2569         ):
2570             for ch in n.children:
2571                 if ch.type in STARS:
2572                     return True
2573
2574                 if ch.type == syms.argument:
2575                     for argch in ch.children:
2576                         if argch.type in STARS:
2577                             return True
2578
2579     return False
2580
2581
2582 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2583     """Generate sets of closing bracket IDs that should be omitted in a RHS.
2584
2585     Brackets can be omitted if the entire trailer up to and including
2586     a preceding closing bracket fits in one line.
2587
2588     Yielded sets are cumulative (contain results of previous yields, too).  First
2589     set is empty.
2590     """
2591
2592     omit: Set[LeafID] = set()
2593     yield omit
2594
2595     length = 4 * line.depth
2596     opening_bracket = None
2597     closing_bracket = None
2598     optional_brackets: Set[LeafID] = set()
2599     inner_brackets: Set[LeafID] = set()
2600     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
2601         length += leaf_length
2602         if length > line_length:
2603             break
2604
2605         optional_brackets.discard(id(leaf))
2606         if opening_bracket:
2607             if leaf is opening_bracket:
2608                 opening_bracket = None
2609             elif leaf.type in CLOSING_BRACKETS:
2610                 inner_brackets.add(id(leaf))
2611         elif leaf.type in CLOSING_BRACKETS:
2612             if not leaf.value:
2613                 optional_brackets.add(id(opening_bracket))
2614                 continue
2615
2616             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2617                 # Empty brackets would fail a split so treat them as "inner"
2618                 # brackets (e.g. only add them to the `omit` set if another
2619                 # pair of brackets was good enough.
2620                 inner_brackets.add(id(leaf))
2621                 continue
2622
2623             opening_bracket = leaf.opening_bracket
2624             if closing_bracket:
2625                 omit.add(id(closing_bracket))
2626                 omit.update(inner_brackets)
2627                 inner_brackets.clear()
2628                 yield omit
2629             closing_bracket = leaf
2630
2631
2632 def get_future_imports(node: Node) -> Set[str]:
2633     """Return a set of __future__ imports in the file."""
2634     imports = set()
2635     for child in node.children:
2636         if child.type != syms.simple_stmt:
2637             break
2638         first_child = child.children[0]
2639         if isinstance(first_child, Leaf):
2640             # Continue looking if we see a docstring; otherwise stop.
2641             if (
2642                 len(child.children) == 2
2643                 and first_child.type == token.STRING
2644                 and child.children[1].type == token.NEWLINE
2645             ):
2646                 continue
2647             else:
2648                 break
2649         elif first_child.type == syms.import_from:
2650             module_name = first_child.children[1]
2651             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2652                 break
2653             for import_from_child in first_child.children[3:]:
2654                 if isinstance(import_from_child, Leaf):
2655                     if import_from_child.type == token.NAME:
2656                         imports.add(import_from_child.value)
2657                 else:
2658                     assert import_from_child.type == syms.import_as_names
2659                     for leaf in import_from_child.children:
2660                         if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2661                             imports.add(leaf.value)
2662         else:
2663             break
2664     return imports
2665
2666
2667 PYTHON_EXTENSIONS = {".py", ".pyi"}
2668 BLACKLISTED_DIRECTORIES = {
2669     "build",
2670     "buck-out",
2671     "dist",
2672     "_build",
2673     ".git",
2674     ".hg",
2675     ".mypy_cache",
2676     ".tox",
2677     ".venv",
2678 }
2679
2680
2681 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2682     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2683     and have one of the PYTHON_EXTENSIONS.
2684     """
2685     for child in path.iterdir():
2686         if child.is_dir():
2687             if child.name in BLACKLISTED_DIRECTORIES:
2688                 continue
2689
2690             yield from gen_python_files_in_dir(child)
2691
2692         elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2693             yield child
2694
2695
2696 @dataclass
2697 class Report:
2698     """Provides a reformatting counter. Can be rendered with `str(report)`."""
2699     check: bool = False
2700     quiet: bool = False
2701     change_count: int = 0
2702     same_count: int = 0
2703     failure_count: int = 0
2704
2705     def done(self, src: Path, changed: Changed) -> None:
2706         """Increment the counter for successful reformatting. Write out a message."""
2707         if changed is Changed.YES:
2708             reformatted = "would reformat" if self.check else "reformatted"
2709             if not self.quiet:
2710                 out(f"{reformatted} {src}")
2711             self.change_count += 1
2712         else:
2713             if not self.quiet:
2714                 if changed is Changed.NO:
2715                     msg = f"{src} already well formatted, good job."
2716                 else:
2717                     msg = f"{src} wasn't modified on disk since last run."
2718                 out(msg, bold=False)
2719             self.same_count += 1
2720
2721     def failed(self, src: Path, message: str) -> None:
2722         """Increment the counter for failed reformatting. Write out a message."""
2723         err(f"error: cannot format {src}: {message}")
2724         self.failure_count += 1
2725
2726     @property
2727     def return_code(self) -> int:
2728         """Return the exit code that the app should use.
2729
2730         This considers the current state of changed files and failures:
2731         - if there were any failures, return 123;
2732         - if any files were changed and --check is being used, return 1;
2733         - otherwise return 0.
2734         """
2735         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2736         # 126 we have special returncodes reserved by the shell.
2737         if self.failure_count:
2738             return 123
2739
2740         elif self.change_count and self.check:
2741             return 1
2742
2743         return 0
2744
2745     def __str__(self) -> str:
2746         """Render a color report of the current state.
2747
2748         Use `click.unstyle` to remove colors.
2749         """
2750         if self.check:
2751             reformatted = "would be reformatted"
2752             unchanged = "would be left unchanged"
2753             failed = "would fail to reformat"
2754         else:
2755             reformatted = "reformatted"
2756             unchanged = "left unchanged"
2757             failed = "failed to reformat"
2758         report = []
2759         if self.change_count:
2760             s = "s" if self.change_count > 1 else ""
2761             report.append(
2762                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2763             )
2764         if self.same_count:
2765             s = "s" if self.same_count > 1 else ""
2766             report.append(f"{self.same_count} file{s} {unchanged}")
2767         if self.failure_count:
2768             s = "s" if self.failure_count > 1 else ""
2769             report.append(
2770                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2771             )
2772         return ", ".join(report) + "."
2773
2774
2775 def assert_equivalent(src: str, dst: str) -> None:
2776     """Raise AssertionError if `src` and `dst` aren't equivalent."""
2777
2778     import ast
2779     import traceback
2780
2781     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2782         """Simple visitor generating strings to compare ASTs by content."""
2783         yield f"{'  ' * depth}{node.__class__.__name__}("
2784
2785         for field in sorted(node._fields):
2786             try:
2787                 value = getattr(node, field)
2788             except AttributeError:
2789                 continue
2790
2791             yield f"{'  ' * (depth+1)}{field}="
2792
2793             if isinstance(value, list):
2794                 for item in value:
2795                     if isinstance(item, ast.AST):
2796                         yield from _v(item, depth + 2)
2797
2798             elif isinstance(value, ast.AST):
2799                 yield from _v(value, depth + 2)
2800
2801             else:
2802                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2803
2804         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2805
2806     try:
2807         src_ast = ast.parse(src)
2808     except Exception as exc:
2809         major, minor = sys.version_info[:2]
2810         raise AssertionError(
2811             f"cannot use --safe with this file; failed to parse source file "
2812             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2813             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2814         )
2815
2816     try:
2817         dst_ast = ast.parse(dst)
2818     except Exception as exc:
2819         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2820         raise AssertionError(
2821             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2822             f"Please report a bug on https://github.com/ambv/black/issues.  "
2823             f"This invalid output might be helpful: {log}"
2824         ) from None
2825
2826     src_ast_str = "\n".join(_v(src_ast))
2827     dst_ast_str = "\n".join(_v(dst_ast))
2828     if src_ast_str != dst_ast_str:
2829         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2830         raise AssertionError(
2831             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2832             f"the source.  "
2833             f"Please report a bug on https://github.com/ambv/black/issues.  "
2834             f"This diff might be helpful: {log}"
2835         ) from None
2836
2837
2838 def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
2839     """Raise AssertionError if `dst` reformats differently the second time."""
2840     newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
2841     if dst != newdst:
2842         log = dump_to_file(
2843             diff(src, dst, "source", "first pass"),
2844             diff(dst, newdst, "first pass", "second pass"),
2845         )
2846         raise AssertionError(
2847             f"INTERNAL ERROR: Black produced different code on the second pass "
2848             f"of the formatter.  "
2849             f"Please report a bug on https://github.com/ambv/black/issues.  "
2850             f"This diff might be helpful: {log}"
2851         ) from None
2852
2853
2854 def dump_to_file(*output: str) -> str:
2855     """Dump `output` to a temporary file. Return path to the file."""
2856     import tempfile
2857
2858     with tempfile.NamedTemporaryFile(
2859         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2860     ) as f:
2861         for lines in output:
2862             f.write(lines)
2863             if lines and lines[-1] != "\n":
2864                 f.write("\n")
2865     return f.name
2866
2867
2868 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2869     """Return a unified diff string between strings `a` and `b`."""
2870     import difflib
2871
2872     a_lines = [line + "\n" for line in a.split("\n")]
2873     b_lines = [line + "\n" for line in b.split("\n")]
2874     return "".join(
2875         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2876     )
2877
2878
2879 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2880     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2881     err("Aborted!")
2882     for task in tasks:
2883         task.cancel()
2884
2885
2886 def shutdown(loop: BaseEventLoop) -> None:
2887     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2888     try:
2889         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2890         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2891         if not to_cancel:
2892             return
2893
2894         for task in to_cancel:
2895             task.cancel()
2896         loop.run_until_complete(
2897             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2898         )
2899     finally:
2900         # `concurrent.futures.Future` objects cannot be cancelled once they
2901         # are already running. There might be some when the `shutdown()` happened.
2902         # Silence their logger's spew about the event loop being closed.
2903         cf_logger = logging.getLogger("concurrent.futures")
2904         cf_logger.setLevel(logging.CRITICAL)
2905         loop.close()
2906
2907
2908 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2909     """Replace `regex` with `replacement` twice on `original`.
2910
2911     This is used by string normalization to perform replaces on
2912     overlapping matches.
2913     """
2914     return regex.sub(replacement, regex.sub(replacement, original))
2915
2916
2917 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2918     """Like `reversed(enumerate(sequence))` if that were possible."""
2919     index = len(sequence) - 1
2920     for element in reversed(sequence):
2921         yield (index, element)
2922         index -= 1
2923
2924
2925 def enumerate_with_length(
2926     line: Line, reversed: bool = False
2927 ) -> Iterator[Tuple[Index, Leaf, int]]:
2928     """Return an enumeration of leaves with their length.
2929
2930     Stops prematurely on multiline strings and standalone comments.
2931     """
2932     op = cast(
2933         Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
2934         enumerate_reversed if reversed else enumerate,
2935     )
2936     for index, leaf in op(line.leaves):
2937         length = len(leaf.prefix) + len(leaf.value)
2938         if "\n" in leaf.value:
2939             return  # Multiline strings, we can't continue.
2940
2941         comment: Optional[Leaf]
2942         for comment in line.comments_after(leaf, index):
2943             if "\n" in comment.prefix:
2944                 return  # Oops, standalone comment!
2945
2946             length += len(comment.value)
2947
2948         yield index, leaf, length
2949
2950
2951 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
2952     """Return True if `line` is no longer than `line_length`.
2953
2954     Uses the provided `line_str` rendering, if any, otherwise computes a new one.
2955     """
2956     if not line_str:
2957         line_str = str(line).strip("\n")
2958     return (
2959         len(line_str) <= line_length
2960         and "\n" not in line_str  # multiline strings
2961         and not line.contains_standalone_comments()
2962     )
2963
2964
2965 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
2966     """Does `line` have a shape safe to reformat without optional parens around it?
2967
2968     Returns True for only a subset of potentially nice looking formattings but
2969     the point is to not return false positives that end up producing lines that
2970     are too long.
2971     """
2972     bt = line.bracket_tracker
2973     if not bt.delimiters:
2974         # Without delimiters the optional parentheses are useless.
2975         return True
2976
2977     max_priority = bt.max_delimiter_priority()
2978     if bt.delimiter_count_with_priority(max_priority) > 1:
2979         # With more than one delimiter of a kind the optional parentheses read better.
2980         return False
2981
2982     if max_priority == DOT_PRIORITY:
2983         # A single stranded method call doesn't require optional parentheses.
2984         return True
2985
2986     assert len(line.leaves) >= 2, "Stranded delimiter"
2987
2988     first = line.leaves[0]
2989     second = line.leaves[1]
2990     penultimate = line.leaves[-2]
2991     last = line.leaves[-1]
2992
2993     # With a single delimiter, omit if the expression starts or ends with
2994     # a bracket.
2995     if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
2996         remainder = False
2997         length = 4 * line.depth
2998         for _index, leaf, leaf_length in enumerate_with_length(line):
2999             if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
3000                 remainder = True
3001             if remainder:
3002                 length += leaf_length
3003                 if length > line_length:
3004                     break
3005
3006                 if leaf.type in OPENING_BRACKETS:
3007                     # There are brackets we can further split on.
3008                     remainder = False
3009
3010         else:
3011             # checked the entire string and line length wasn't exceeded
3012             if len(line.leaves) == _index + 1:
3013                 return True
3014
3015         # Note: we are not returning False here because a line might have *both*
3016         # a leading opening bracket and a trailing closing bracket.  If the
3017         # opening bracket doesn't match our rule, maybe the closing will.
3018
3019     if (
3020         last.type == token.RPAR
3021         or last.type == token.RBRACE
3022         or (
3023             # don't use indexing for omitting optional parentheses;
3024             # it looks weird
3025             last.type == token.RSQB
3026             and last.parent
3027             and last.parent.type != syms.trailer
3028         )
3029     ):
3030         if penultimate.type in OPENING_BRACKETS:
3031             # Empty brackets don't help.
3032             return False
3033
3034         if is_multiline_string(first):
3035             # Additional wrapping of a multiline string in this situation is
3036             # unnecessary.
3037             return True
3038
3039         length = 4 * line.depth
3040         seen_other_brackets = False
3041         for _index, leaf, leaf_length in enumerate_with_length(line):
3042             length += leaf_length
3043             if leaf is last.opening_bracket:
3044                 if seen_other_brackets or length <= line_length:
3045                     return True
3046
3047             elif leaf.type in OPENING_BRACKETS:
3048                 # There are brackets we can further split on.
3049                 seen_other_brackets = True
3050
3051     return False
3052
3053
3054 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
3055
3056
3057 def get_cache_file(line_length: int) -> Path:
3058     return CACHE_DIR / f"cache.{line_length}.pickle"
3059
3060
3061 def read_cache(line_length: int) -> Cache:
3062     """Read the cache if it exists and is well formed.
3063
3064     If it is not well formed, the call to write_cache later should resolve the issue.
3065     """
3066     cache_file = get_cache_file(line_length)
3067     if not cache_file.exists():
3068         return {}
3069
3070     with cache_file.open("rb") as fobj:
3071         try:
3072             cache: Cache = pickle.load(fobj)
3073         except pickle.UnpicklingError:
3074             return {}
3075
3076     return cache
3077
3078
3079 def get_cache_info(path: Path) -> CacheInfo:
3080     """Return the information used to check if a file is already formatted or not."""
3081     stat = path.stat()
3082     return stat.st_mtime, stat.st_size
3083
3084
3085 def filter_cached(
3086     cache: Cache, sources: Iterable[Path]
3087 ) -> Tuple[List[Path], List[Path]]:
3088     """Split a list of paths into two.
3089
3090     The first list contains paths of files that modified on disk or are not in the
3091     cache. The other list contains paths to non-modified files.
3092     """
3093     todo, done = [], []
3094     for src in sources:
3095         src = src.resolve()
3096         if cache.get(src) != get_cache_info(src):
3097             todo.append(src)
3098         else:
3099             done.append(src)
3100     return todo, done
3101
3102
3103 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
3104     """Update the cache file."""
3105     cache_file = get_cache_file(line_length)
3106     try:
3107         if not CACHE_DIR.exists():
3108             CACHE_DIR.mkdir(parents=True)
3109         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
3110         with cache_file.open("wb") as fobj:
3111             pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3112     except OSError:
3113         pass
3114
3115
3116 if __name__ == "__main__":
3117     main()