]> git.madduck.net Git - etc/vim.git/blob - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

3bb83daef28c396cd7670503d7ad21c7d20336c2
[etc/vim.git] / black.py
1 #!/usr/bin/env python3
2
3 import asyncio
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial, wraps
7 import keyword
8 import logging
9 import os
10 from pathlib import Path
11 import tokenize
12 import signal
13 import sys
14 from typing import (
15     Callable,
16     Dict,
17     Generic,
18     Iterable,
19     Iterator,
20     List,
21     Optional,
22     Set,
23     Tuple,
24     Type,
25     TypeVar,
26     Union,
27 )
28
29 from attr import dataclass, Factory
30 import click
31
32 # lib2to3 fork
33 from blib2to3.pytree import Node, Leaf, type_repr
34 from blib2to3 import pygram, pytree
35 from blib2to3.pgen2 import driver, token
36 from blib2to3.pgen2.parse import ParseError
37
38 __version__ = "18.3a4"
39 DEFAULT_LINE_LENGTH = 88
40 # types
41 syms = pygram.python_symbols
42 FileContent = str
43 Encoding = str
44 Depth = int
45 NodeType = int
46 LeafID = int
47 Priority = int
48 Index = int
49 LN = Union[Leaf, Node]
50 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
51 out = partial(click.secho, bold=True, err=True)
52 err = partial(click.secho, fg="red", err=True)
53
54
55 class NothingChanged(UserWarning):
56     """Raised by :func:`format_file` when reformatted code is the same as source."""
57
58
59 class CannotSplit(Exception):
60     """A readable split that fits the allotted line length is impossible.
61
62     Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
63     :func:`delimiter_split`.
64     """
65
66
67 class FormatError(Exception):
68     """Base exception for `# fmt: on` and `# fmt: off` handling.
69
70     It holds the number of bytes of the prefix consumed before the format
71     control comment appeared.
72     """
73
74     def __init__(self, consumed: int) -> None:
75         super().__init__(consumed)
76         self.consumed = consumed
77
78     def trim_prefix(self, leaf: Leaf) -> None:
79         leaf.prefix = leaf.prefix[self.consumed:]
80
81     def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
82         """Returns a new Leaf from the consumed part of the prefix."""
83         unformatted_prefix = leaf.prefix[:self.consumed]
84         return Leaf(token.NEWLINE, unformatted_prefix)
85
86
87 class FormatOn(FormatError):
88     """Found a comment like `# fmt: on` in the file."""
89
90
91 class FormatOff(FormatError):
92     """Found a comment like `# fmt: off` in the file."""
93
94
95 @click.command()
96 @click.option(
97     "-l",
98     "--line-length",
99     type=int,
100     default=DEFAULT_LINE_LENGTH,
101     help="How many character per line to allow.",
102     show_default=True,
103 )
104 @click.option(
105     "--check",
106     is_flag=True,
107     help=(
108         "Don't write back the files, just return the status.  Return code 0 "
109         "means nothing would change.  Return code 1 means some files would be "
110         "reformatted.  Return code 123 means there was an internal error."
111     ),
112 )
113 @click.option(
114     "--fast/--safe",
115     is_flag=True,
116     help="If --fast given, skip temporary sanity checks. [default: --safe]",
117 )
118 @click.version_option(version=__version__)
119 @click.argument(
120     "src",
121     nargs=-1,
122     type=click.Path(
123         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
124     ),
125 )
126 @click.pass_context
127 def main(
128     ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
129 ) -> None:
130     """The uncompromising code formatter."""
131     sources: List[Path] = []
132     for s in src:
133         p = Path(s)
134         if p.is_dir():
135             sources.extend(gen_python_files_in_dir(p))
136         elif p.is_file():
137             # if a file was explicitly given, we don't care about its extension
138             sources.append(p)
139         elif s == "-":
140             sources.append(Path("-"))
141         else:
142             err(f"invalid path: {s}")
143     if len(sources) == 0:
144         ctx.exit(0)
145     elif len(sources) == 1:
146         p = sources[0]
147         report = Report(check=check)
148         try:
149             if not p.is_file() and str(p) == "-":
150                 changed = format_stdin_to_stdout(
151                     line_length=line_length, fast=fast, write_back=not check
152                 )
153             else:
154                 changed = format_file_in_place(
155                     p, line_length=line_length, fast=fast, write_back=not check
156                 )
157             report.done(p, changed)
158         except Exception as exc:
159             report.failed(p, str(exc))
160         ctx.exit(report.return_code)
161     else:
162         loop = asyncio.get_event_loop()
163         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
164         return_code = 1
165         try:
166             return_code = loop.run_until_complete(
167                 schedule_formatting(
168                     sources, line_length, not check, fast, loop, executor
169                 )
170             )
171         finally:
172             shutdown(loop)
173             ctx.exit(return_code)
174
175
176 async def schedule_formatting(
177     sources: List[Path],
178     line_length: int,
179     write_back: bool,
180     fast: bool,
181     loop: BaseEventLoop,
182     executor: Executor,
183 ) -> int:
184     """Run formatting of `sources` in parallel using the provided `executor`.
185
186     (Use ProcessPoolExecutors for actual parallelism.)
187
188     `line_length`, `write_back`, and `fast` options are passed to
189     :func:`format_file_in_place`.
190     """
191     tasks = {
192         src: loop.run_in_executor(
193             executor, format_file_in_place, src, line_length, fast, write_back
194         )
195         for src in sources
196     }
197     _task_values = list(tasks.values())
198     loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
199     loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
200     await asyncio.wait(tasks.values())
201     cancelled = []
202     report = Report(check=not write_back)
203     for src, task in tasks.items():
204         if not task.done():
205             report.failed(src, "timed out, cancelling")
206             task.cancel()
207             cancelled.append(task)
208         elif task.cancelled():
209             cancelled.append(task)
210         elif task.exception():
211             report.failed(src, str(task.exception()))
212         else:
213             report.done(src, task.result())
214     if cancelled:
215         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
216     else:
217         out("All done! ✨ 🍰 ✨")
218     click.echo(str(report))
219     return report.return_code
220
221
222 def format_file_in_place(
223     src: Path, line_length: int, fast: bool, write_back: bool = False
224 ) -> bool:
225     """Format file under `src` path. Return True if changed.
226
227     If `write_back` is True, write reformatted code back to stdout.
228     `line_length` and `fast` options are passed to :func:`format_file_contents`.
229     """
230     with tokenize.open(src) as src_buffer:
231         src_contents = src_buffer.read()
232     try:
233         contents = format_file_contents(
234             src_contents, line_length=line_length, fast=fast
235         )
236     except NothingChanged:
237         return False
238
239     if write_back:
240         with open(src, "w", encoding=src_buffer.encoding) as f:
241             f.write(contents)
242     return True
243
244
245 def format_stdin_to_stdout(
246     line_length: int, fast: bool, write_back: bool = False
247 ) -> bool:
248     """Format file on stdin. Return True if changed.
249
250     If `write_back` is True, write reformatted code back to stdout.
251     `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
252     """
253     contents = sys.stdin.read()
254     try:
255         contents = format_file_contents(contents, line_length=line_length, fast=fast)
256         return True
257
258     except NothingChanged:
259         return False
260
261     finally:
262         if write_back:
263             sys.stdout.write(contents)
264
265
266 def format_file_contents(
267     src_contents: str, line_length: int, fast: bool
268 ) -> FileContent:
269     """Reformat contents a file and return new contents.
270
271     If `fast` is False, additionally confirm that the reformatted code is
272     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
273     `line_length` is passed to :func:`format_str`.
274     """
275     if src_contents.strip() == "":
276         raise NothingChanged
277
278     dst_contents = format_str(src_contents, line_length=line_length)
279     if src_contents == dst_contents:
280         raise NothingChanged
281
282     if not fast:
283         assert_equivalent(src_contents, dst_contents)
284         assert_stable(src_contents, dst_contents, line_length=line_length)
285     return dst_contents
286
287
288 def format_str(src_contents: str, line_length: int) -> FileContent:
289     """Reformat a string and return new contents.
290
291     `line_length` determines how many characters per line are allowed.
292     """
293     src_node = lib2to3_parse(src_contents)
294     dst_contents = ""
295     lines = LineGenerator()
296     elt = EmptyLineTracker()
297     py36 = is_python36(src_node)
298     empty_line = Line()
299     after = 0
300     for current_line in lines.visit(src_node):
301         for _ in range(after):
302             dst_contents += str(empty_line)
303         before, after = elt.maybe_empty_lines(current_line)
304         for _ in range(before):
305             dst_contents += str(empty_line)
306         for line in split_line(current_line, line_length=line_length, py36=py36):
307             dst_contents += str(line)
308     return dst_contents
309
310
311 GRAMMARS = [
312     pygram.python_grammar_no_print_statement_no_exec_statement,
313     pygram.python_grammar_no_print_statement,
314     pygram.python_grammar_no_exec_statement,
315     pygram.python_grammar,
316 ]
317
318
319 def lib2to3_parse(src_txt: str) -> Node:
320     """Given a string with source, return the lib2to3 Node."""
321     grammar = pygram.python_grammar_no_print_statement
322     if src_txt[-1] != "\n":
323         nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
324         src_txt += nl
325     for grammar in GRAMMARS:
326         drv = driver.Driver(grammar, pytree.convert)
327         try:
328             result = drv.parse_string(src_txt, True)
329             break
330
331         except ParseError as pe:
332             lineno, column = pe.context[1]
333             lines = src_txt.splitlines()
334             try:
335                 faulty_line = lines[lineno - 1]
336             except IndexError:
337                 faulty_line = "<line number missing in source>"
338             exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
339     else:
340         raise exc from None
341
342     if isinstance(result, Leaf):
343         result = Node(syms.file_input, [result])
344     return result
345
346
347 def lib2to3_unparse(node: Node) -> str:
348     """Given a lib2to3 node, return its string representation."""
349     code = str(node)
350     return code
351
352
353 T = TypeVar("T")
354
355
356 class Visitor(Generic[T]):
357     """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
358
359     def visit(self, node: LN) -> Iterator[T]:
360         """Main method to visit `node` and its children.
361
362         It tries to find a `visit_*()` method for the given `node.type`, like
363         `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
364         If no dedicated `visit_*()` method is found, chooses `visit_default()`
365         instead.
366
367         Then yields objects of type `T` from the selected visitor.
368         """
369         if node.type < 256:
370             name = token.tok_name[node.type]
371         else:
372             name = type_repr(node.type)
373         yield from getattr(self, f"visit_{name}", self.visit_default)(node)
374
375     def visit_default(self, node: LN) -> Iterator[T]:
376         """Default `visit_*()` implementation. Recurses to children of `node`."""
377         if isinstance(node, Node):
378             for child in node.children:
379                 yield from self.visit(child)
380
381
382 @dataclass
383 class DebugVisitor(Visitor[T]):
384     tree_depth: int = 0
385
386     def visit_default(self, node: LN) -> Iterator[T]:
387         indent = " " * (2 * self.tree_depth)
388         if isinstance(node, Node):
389             _type = type_repr(node.type)
390             out(f"{indent}{_type}", fg="yellow")
391             self.tree_depth += 1
392             for child in node.children:
393                 yield from self.visit(child)
394
395             self.tree_depth -= 1
396             out(f"{indent}/{_type}", fg="yellow", bold=False)
397         else:
398             _type = token.tok_name.get(node.type, str(node.type))
399             out(f"{indent}{_type}", fg="blue", nl=False)
400             if node.prefix:
401                 # We don't have to handle prefixes for `Node` objects since
402                 # that delegates to the first child anyway.
403                 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
404             out(f" {node.value!r}", fg="blue", bold=False)
405
406     @classmethod
407     def show(cls, code: str) -> None:
408         """Pretty-print the lib2to3 AST of a given string of `code`.
409
410         Convenience method for debugging.
411         """
412         v: DebugVisitor[None] = DebugVisitor()
413         list(v.visit(lib2to3_parse(code)))
414
415
416 KEYWORDS = set(keyword.kwlist)
417 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
418 FLOW_CONTROL = {"return", "raise", "break", "continue"}
419 STATEMENT = {
420     syms.if_stmt,
421     syms.while_stmt,
422     syms.for_stmt,
423     syms.try_stmt,
424     syms.except_clause,
425     syms.with_stmt,
426     syms.funcdef,
427     syms.classdef,
428 }
429 STANDALONE_COMMENT = 153
430 LOGIC_OPERATORS = {"and", "or"}
431 COMPARATORS = {
432     token.LESS,
433     token.GREATER,
434     token.EQEQUAL,
435     token.NOTEQUAL,
436     token.LESSEQUAL,
437     token.GREATEREQUAL,
438 }
439 MATH_OPERATORS = {
440     token.PLUS,
441     token.MINUS,
442     token.STAR,
443     token.SLASH,
444     token.VBAR,
445     token.AMPER,
446     token.PERCENT,
447     token.CIRCUMFLEX,
448     token.TILDE,
449     token.LEFTSHIFT,
450     token.RIGHTSHIFT,
451     token.DOUBLESTAR,
452     token.DOUBLESLASH,
453 }
454 VARARGS = {token.STAR, token.DOUBLESTAR}
455 COMPREHENSION_PRIORITY = 20
456 COMMA_PRIORITY = 10
457 LOGIC_PRIORITY = 5
458 STRING_PRIORITY = 4
459 COMPARATOR_PRIORITY = 3
460 MATH_PRIORITY = 1
461
462
463 @dataclass
464 class BracketTracker:
465     """Keeps track of brackets on a line."""
466
467     depth: int = 0
468     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
469     delimiters: Dict[LeafID, Priority] = Factory(dict)
470     previous: Optional[Leaf] = None
471
472     def mark(self, leaf: Leaf) -> None:
473         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
474
475         All leaves receive an int `bracket_depth` field that stores how deep
476         within brackets a given leaf is. 0 means there are no enclosing brackets
477         that started on this line.
478
479         If a leaf is itself a closing bracket, it receives an `opening_bracket`
480         field that it forms a pair with. This is a one-directional link to
481         avoid reference cycles.
482
483         If a leaf is a delimiter (a token on which Black can split the line if
484         needed) and it's on depth 0, its `id()` is stored in the tracker's
485         `delimiters` field.
486         """
487         if leaf.type == token.COMMENT:
488             return
489
490         if leaf.type in CLOSING_BRACKETS:
491             self.depth -= 1
492             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
493             leaf.opening_bracket = opening_bracket
494         leaf.bracket_depth = self.depth
495         if self.depth == 0:
496             after_delim = is_split_after_delimiter(leaf, self.previous)
497             before_delim = is_split_before_delimiter(leaf, self.previous)
498             if after_delim > before_delim:
499                 self.delimiters[id(leaf)] = after_delim
500             elif before_delim > after_delim and self.previous is not None:
501                 self.delimiters[id(self.previous)] = before_delim
502         if leaf.type in OPENING_BRACKETS:
503             self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
504             self.depth += 1
505         self.previous = leaf
506
507     def any_open_brackets(self) -> bool:
508         """Return True if there is an yet unmatched open bracket on the line."""
509         return bool(self.bracket_match)
510
511     def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
512         """Return the highest priority of a delimiter found on the line.
513
514         Values are consistent with what `is_delimiter()` returns.
515         """
516         return max(v for k, v in self.delimiters.items() if k not in exclude)
517
518
519 @dataclass
520 class Line:
521     """Holds leaves and comments. Can be printed with `str(line)`."""
522
523     depth: int = 0
524     leaves: List[Leaf] = Factory(list)
525     comments: List[Tuple[Index, Leaf]] = Factory(list)
526     bracket_tracker: BracketTracker = Factory(BracketTracker)
527     inside_brackets: bool = False
528     has_for: bool = False
529     _for_loop_variable: bool = False
530
531     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
532         """Add a new `leaf` to the end of the line.
533
534         Unless `preformatted` is True, the `leaf` will receive a new consistent
535         whitespace prefix and metadata applied by :class:`BracketTracker`.
536         Trailing commas are maybe removed, unpacked for loop variables are
537         demoted from being delimiters.
538
539         Inline comments are put aside.
540         """
541         has_value = leaf.value.strip()
542         if not has_value:
543             return
544
545         if self.leaves and not preformatted:
546             # Note: at this point leaf.prefix should be empty except for
547             # imports, for which we only preserve newlines.
548             leaf.prefix += whitespace(leaf)
549         if self.inside_brackets or not preformatted:
550             self.maybe_decrement_after_for_loop_variable(leaf)
551             self.bracket_tracker.mark(leaf)
552             self.maybe_remove_trailing_comma(leaf)
553             self.maybe_increment_for_loop_variable(leaf)
554
555         if not self.append_comment(leaf):
556             self.leaves.append(leaf)
557
558     def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
559         """Like :func:`append()` but disallow invalid standalone comment structure.
560
561         Raises ValueError when any `leaf` is appended after a standalone comment
562         or when a standalone comment is not the first leaf on the line.
563         """
564         if self.bracket_tracker.depth == 0:
565             if self.is_comment:
566                 raise ValueError("cannot append to standalone comments")
567
568             if self.leaves and leaf.type == STANDALONE_COMMENT:
569                 raise ValueError(
570                     "cannot append standalone comments to a populated line"
571                 )
572
573         self.append(leaf, preformatted=preformatted)
574
575     @property
576     def is_comment(self) -> bool:
577         """Is this line a standalone comment?"""
578         return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
579
580     @property
581     def is_decorator(self) -> bool:
582         """Is this line a decorator?"""
583         return bool(self) and self.leaves[0].type == token.AT
584
585     @property
586     def is_import(self) -> bool:
587         """Is this an import line?"""
588         return bool(self) and is_import(self.leaves[0])
589
590     @property
591     def is_class(self) -> bool:
592         """Is this line a class definition?"""
593         return (
594             bool(self)
595             and self.leaves[0].type == token.NAME
596             and self.leaves[0].value == "class"
597         )
598
599     @property
600     def is_def(self) -> bool:
601         """Is this a function definition? (Also returns True for async defs.)"""
602         try:
603             first_leaf = self.leaves[0]
604         except IndexError:
605             return False
606
607         try:
608             second_leaf: Optional[Leaf] = self.leaves[1]
609         except IndexError:
610             second_leaf = None
611         return (
612             (first_leaf.type == token.NAME and first_leaf.value == "def")
613             or (
614                 first_leaf.type == token.ASYNC
615                 and second_leaf is not None
616                 and second_leaf.type == token.NAME
617                 and second_leaf.value == "def"
618             )
619         )
620
621     @property
622     def is_flow_control(self) -> bool:
623         """Is this line a flow control statement?
624
625         Those are `return`, `raise`, `break`, and `continue`.
626         """
627         return (
628             bool(self)
629             and self.leaves[0].type == token.NAME
630             and self.leaves[0].value in FLOW_CONTROL
631         )
632
633     @property
634     def is_yield(self) -> bool:
635         """Is this line a yield statement?"""
636         return (
637             bool(self)
638             and self.leaves[0].type == token.NAME
639             and self.leaves[0].value == "yield"
640         )
641
642     @property
643     def contains_standalone_comments(self) -> bool:
644         """If so, needs to be split before emitting."""
645         for leaf in self.leaves:
646             if leaf.type == STANDALONE_COMMENT:
647                 return True
648
649         return False
650
651     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
652         """Remove trailing comma if there is one and it's safe."""
653         if not (
654             self.leaves
655             and self.leaves[-1].type == token.COMMA
656             and closing.type in CLOSING_BRACKETS
657         ):
658             return False
659
660         if closing.type == token.RBRACE:
661             self.remove_trailing_comma()
662             return True
663
664         if closing.type == token.RSQB:
665             comma = self.leaves[-1]
666             if comma.parent and comma.parent.type == syms.listmaker:
667                 self.remove_trailing_comma()
668                 return True
669
670         # For parens let's check if it's safe to remove the comma.  If the
671         # trailing one is the only one, we might mistakenly change a tuple
672         # into a different type by removing the comma.
673         depth = closing.bracket_depth + 1
674         commas = 0
675         opening = closing.opening_bracket
676         for _opening_index, leaf in enumerate(self.leaves):
677             if leaf is opening:
678                 break
679
680         else:
681             return False
682
683         for leaf in self.leaves[_opening_index + 1:]:
684             if leaf is closing:
685                 break
686
687             bracket_depth = leaf.bracket_depth
688             if bracket_depth == depth and leaf.type == token.COMMA:
689                 commas += 1
690                 if leaf.parent and leaf.parent.type == syms.arglist:
691                     commas += 1
692                     break
693
694         if commas > 1:
695             self.remove_trailing_comma()
696             return True
697
698         return False
699
700     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
701         """In a for loop, or comprehension, the variables are often unpacks.
702
703         To avoid splitting on the comma in this situation, increase the depth of
704         tokens between `for` and `in`.
705         """
706         if leaf.type == token.NAME and leaf.value == "for":
707             self.has_for = True
708             self.bracket_tracker.depth += 1
709             self._for_loop_variable = True
710             return True
711
712         return False
713
714     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
715         """See `maybe_increment_for_loop_variable` above for explanation."""
716         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
717             self.bracket_tracker.depth -= 1
718             self._for_loop_variable = False
719             return True
720
721         return False
722
723     def append_comment(self, comment: Leaf) -> bool:
724         """Add an inline or standalone comment to the line."""
725         if (
726             comment.type == STANDALONE_COMMENT
727             and self.bracket_tracker.any_open_brackets()
728         ):
729             comment.prefix = ""
730             return False
731
732         if comment.type != token.COMMENT:
733             return False
734
735         after = len(self.leaves) - 1
736         if after == -1:
737             comment.type = STANDALONE_COMMENT
738             comment.prefix = ""
739             return False
740
741         else:
742             self.comments.append((after, comment))
743             return True
744
745     def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
746         """Generate comments that should appear directly after `leaf`."""
747         for _leaf_index, _leaf in enumerate(self.leaves):
748             if leaf is _leaf:
749                 break
750
751         else:
752             return
753
754         for index, comment_after in self.comments:
755             if _leaf_index == index:
756                 yield comment_after
757
758     def remove_trailing_comma(self) -> None:
759         """Remove the trailing comma and moves the comments attached to it."""
760         comma_index = len(self.leaves) - 1
761         for i in range(len(self.comments)):
762             comment_index, comment = self.comments[i]
763             if comment_index == comma_index:
764                 self.comments[i] = (comma_index - 1, comment)
765         self.leaves.pop()
766
767     def __str__(self) -> str:
768         """Render the line."""
769         if not self:
770             return "\n"
771
772         indent = "    " * self.depth
773         leaves = iter(self.leaves)
774         first = next(leaves)
775         res = f"{first.prefix}{indent}{first.value}"
776         for leaf in leaves:
777             res += str(leaf)
778         for _, comment in self.comments:
779             res += str(comment)
780         return res + "\n"
781
782     def __bool__(self) -> bool:
783         """Return True if the line has leaves or comments."""
784         return bool(self.leaves or self.comments)
785
786
787 class UnformattedLines(Line):
788     """Just like :class:`Line` but stores lines which aren't reformatted."""
789
790     def append(self, leaf: Leaf, preformatted: bool = True) -> None:
791         """Just add a new `leaf` to the end of the lines.
792
793         The `preformatted` argument is ignored.
794
795         Keeps track of indentation `depth`, which is useful when the user
796         says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
797         """
798         try:
799             list(generate_comments(leaf))
800         except FormatOn as f_on:
801             self.leaves.append(f_on.leaf_from_consumed(leaf))
802             raise
803
804         self.leaves.append(leaf)
805         if leaf.type == token.INDENT:
806             self.depth += 1
807         elif leaf.type == token.DEDENT:
808             self.depth -= 1
809
810     def __str__(self) -> str:
811         """Render unformatted lines from leaves which were added with `append()`.
812
813         `depth` is not used for indentation in this case.
814         """
815         if not self:
816             return "\n"
817
818         res = ""
819         for leaf in self.leaves:
820             res += str(leaf)
821         return res
822
823     def append_comment(self, comment: Leaf) -> bool:
824         """Not implemented in this class. Raises `NotImplementedError`."""
825         raise NotImplementedError("Unformatted lines don't store comments separately.")
826
827     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
828         """Does nothing and returns False."""
829         return False
830
831     def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
832         """Does nothing and returns False."""
833         return False
834
835
836 @dataclass
837 class EmptyLineTracker:
838     """Provides a stateful method that returns the number of potential extra
839     empty lines needed before and after the currently processed line.
840
841     Note: this tracker works on lines that haven't been split yet.  It assumes
842     the prefix of the first leaf consists of optional newlines.  Those newlines
843     are consumed by `maybe_empty_lines()` and included in the computation.
844     """
845     previous_line: Optional[Line] = None
846     previous_after: int = 0
847     previous_defs: List[int] = Factory(list)
848
849     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
850         """Return the number of extra empty lines before and after the `current_line`.
851
852         This is for separating `def`, `async def` and `class` with extra empty
853         lines (two on module-level), as well as providing an extra empty line
854         after flow control keywords to make them more prominent.
855         """
856         if isinstance(current_line, UnformattedLines):
857             return 0, 0
858
859         before, after = self._maybe_empty_lines(current_line)
860         before -= self.previous_after
861         self.previous_after = after
862         self.previous_line = current_line
863         return before, after
864
865     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
866         max_allowed = 1
867         if current_line.depth == 0:
868             max_allowed = 2
869         if current_line.leaves:
870             # Consume the first leaf's extra newlines.
871             first_leaf = current_line.leaves[0]
872             before = first_leaf.prefix.count("\n")
873             before = min(before, max_allowed)
874             first_leaf.prefix = ""
875         else:
876             before = 0
877         depth = current_line.depth
878         while self.previous_defs and self.previous_defs[-1] >= depth:
879             self.previous_defs.pop()
880             before = 1 if depth else 2
881         is_decorator = current_line.is_decorator
882         if is_decorator or current_line.is_def or current_line.is_class:
883             if not is_decorator:
884                 self.previous_defs.append(depth)
885             if self.previous_line is None:
886                 # Don't insert empty lines before the first line in the file.
887                 return 0, 0
888
889             if self.previous_line and self.previous_line.is_decorator:
890                 # Don't insert empty lines between decorators.
891                 return 0, 0
892
893             newlines = 2
894             if current_line.depth:
895                 newlines -= 1
896             return newlines, 0
897
898         if current_line.is_flow_control:
899             return before, 1
900
901         if (
902             self.previous_line
903             and self.previous_line.is_import
904             and not current_line.is_import
905             and depth == self.previous_line.depth
906         ):
907             return (before or 1), 0
908
909         if (
910             self.previous_line
911             and self.previous_line.is_yield
912             and (not current_line.is_yield or depth != self.previous_line.depth)
913         ):
914             return (before or 1), 0
915
916         return before, 0
917
918
919 @dataclass
920 class LineGenerator(Visitor[Line]):
921     """Generates reformatted Line objects.  Empty lines are not emitted.
922
923     Note: destroys the tree it's visiting by mutating prefixes of its leaves
924     in ways that will no longer stringify to valid Python code on the tree.
925     """
926     current_line: Line = Factory(Line)
927
928     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
929         """Generate a line.
930
931         If the line is empty, only emit if it makes sense.
932         If the line is too long, split it first and then generate.
933
934         If any lines were generated, set up a new current_line.
935         """
936         if not self.current_line:
937             if self.current_line.__class__ == type:
938                 self.current_line.depth += indent
939             else:
940                 self.current_line = type(depth=self.current_line.depth + indent)
941             return  # Line is empty, don't emit. Creating a new one unnecessary.
942
943         complete_line = self.current_line
944         self.current_line = type(depth=complete_line.depth + indent)
945         yield complete_line
946
947     def visit(self, node: LN) -> Iterator[Line]:
948         """Main method to visit `node` and its children.
949
950         Yields :class:`Line` objects.
951         """
952         if isinstance(self.current_line, UnformattedLines):
953             # File contained `# fmt: off`
954             yield from self.visit_unformatted(node)
955
956         else:
957             yield from super().visit(node)
958
959     def visit_default(self, node: LN) -> Iterator[Line]:
960         """Default `visit_*()` implementation. Recurses to children of `node`."""
961         if isinstance(node, Leaf):
962             any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
963             try:
964                 for comment in generate_comments(node):
965                     if any_open_brackets:
966                         # any comment within brackets is subject to splitting
967                         self.current_line.append(comment)
968                     elif comment.type == token.COMMENT:
969                         # regular trailing comment
970                         self.current_line.append(comment)
971                         yield from self.line()
972
973                     else:
974                         # regular standalone comment
975                         yield from self.line()
976
977                         self.current_line.append(comment)
978                         yield from self.line()
979
980             except FormatOff as f_off:
981                 f_off.trim_prefix(node)
982                 yield from self.line(type=UnformattedLines)
983                 yield from self.visit(node)
984
985             except FormatOn as f_on:
986                 # This only happens here if somebody says "fmt: on" multiple
987                 # times in a row.
988                 f_on.trim_prefix(node)
989                 yield from self.visit_default(node)
990
991             else:
992                 normalize_prefix(node, inside_brackets=any_open_brackets)
993                 if node.type == token.STRING:
994                     normalize_string_quotes(node)
995                 if node.type not in WHITESPACE:
996                     self.current_line.append(node)
997         yield from super().visit_default(node)
998
999     def visit_INDENT(self, node: Node) -> Iterator[Line]:
1000         """Increase indentation level, maybe yield a line."""
1001         # In blib2to3 INDENT never holds comments.
1002         yield from self.line(+1)
1003         yield from self.visit_default(node)
1004
1005     def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1006         """Decrease indentation level, maybe yield a line."""
1007         # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1008         yield from self.line(-1)
1009
1010     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1011         """Visit a statement.
1012
1013         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1014         `def`, `with`, and `class`.
1015
1016         The relevant Python language `keywords` for a given statement will be NAME
1017         leaves within it. This methods puts those on a separate line.
1018         """
1019         for child in node.children:
1020             if child.type == token.NAME and child.value in keywords:  # type: ignore
1021                 yield from self.line()
1022
1023             yield from self.visit(child)
1024
1025     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1026         """Visit a statement without nested statements."""
1027         is_suite_like = node.parent and node.parent.type in STATEMENT
1028         if is_suite_like:
1029             yield from self.line(+1)
1030             yield from self.visit_default(node)
1031             yield from self.line(-1)
1032
1033         else:
1034             yield from self.line()
1035             yield from self.visit_default(node)
1036
1037     def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1038         """Visit `async def`, `async for`, `async with`."""
1039         yield from self.line()
1040
1041         children = iter(node.children)
1042         for child in children:
1043             yield from self.visit(child)
1044
1045             if child.type == token.ASYNC:
1046                 break
1047
1048         internal_stmt = next(children)
1049         for child in internal_stmt.children:
1050             yield from self.visit(child)
1051
1052     def visit_decorators(self, node: Node) -> Iterator[Line]:
1053         """Visit decorators."""
1054         for child in node.children:
1055             yield from self.line()
1056             yield from self.visit(child)
1057
1058     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1059         """Remove a semicolon and put the other statement on a separate line."""
1060         yield from self.line()
1061
1062     def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1063         """End of file. Process outstanding comments and end with a newline."""
1064         yield from self.visit_default(leaf)
1065         yield from self.line()
1066
1067     def visit_unformatted(self, node: LN) -> Iterator[Line]:
1068         """Used when file contained a `# fmt: off`."""
1069         if isinstance(node, Node):
1070             for child in node.children:
1071                 yield from self.visit(child)
1072
1073         else:
1074             try:
1075                 self.current_line.append(node)
1076             except FormatOn as f_on:
1077                 f_on.trim_prefix(node)
1078                 yield from self.line()
1079                 yield from self.visit(node)
1080
1081     def __attrs_post_init__(self) -> None:
1082         """You are in a twisty little maze of passages."""
1083         v = self.visit_stmt
1084         self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
1085         self.visit_while_stmt = partial(v, keywords={"while", "else"})
1086         self.visit_for_stmt = partial(v, keywords={"for", "else"})
1087         self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
1088         self.visit_except_clause = partial(v, keywords={"except"})
1089         self.visit_funcdef = partial(v, keywords={"def"})
1090         self.visit_with_stmt = partial(v, keywords={"with"})
1091         self.visit_classdef = partial(v, keywords={"class"})
1092         self.visit_async_funcdef = self.visit_async_stmt
1093         self.visit_decorated = self.visit_decorators
1094
1095
1096 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1097 OPENING_BRACKETS = set(BRACKET.keys())
1098 CLOSING_BRACKETS = set(BRACKET.values())
1099 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1100 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1101
1102
1103 def whitespace(leaf: Leaf) -> str:  # noqa C901
1104     """Return whitespace prefix if needed for the given `leaf`."""
1105     NO = ""
1106     SPACE = " "
1107     DOUBLESPACE = "  "
1108     t = leaf.type
1109     p = leaf.parent
1110     v = leaf.value
1111     if t in ALWAYS_NO_SPACE:
1112         return NO
1113
1114     if t == token.COMMENT:
1115         return DOUBLESPACE
1116
1117     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1118     if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1119         return NO
1120
1121     prev = leaf.prev_sibling
1122     if not prev:
1123         prevp = preceding_leaf(p)
1124         if not prevp or prevp.type in OPENING_BRACKETS:
1125             return NO
1126
1127         if t == token.COLON:
1128             return SPACE if prevp.type == token.COMMA else NO
1129
1130         if prevp.type == token.EQUAL:
1131             if prevp.parent:
1132                 if prevp.parent.type in {
1133                     syms.arglist, syms.argument, syms.parameters, syms.varargslist
1134                 }:
1135                     return NO
1136
1137                 elif prevp.parent.type == syms.typedargslist:
1138                     # A bit hacky: if the equal sign has whitespace, it means we
1139                     # previously found it's a typed argument.  So, we're using
1140                     # that, too.
1141                     return prevp.prefix
1142
1143         elif prevp.type == token.DOUBLESTAR:
1144             if prevp.parent and prevp.parent.type in {
1145                 syms.arglist,
1146                 syms.argument,
1147                 syms.dictsetmaker,
1148                 syms.parameters,
1149                 syms.typedargslist,
1150                 syms.varargslist,
1151             }:
1152                 return NO
1153
1154         elif prevp.type == token.COLON:
1155             if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1156                 return NO
1157
1158         elif (
1159             prevp.parent
1160             and prevp.parent.type in {syms.factor, syms.star_expr}
1161             and prevp.type in MATH_OPERATORS
1162         ):
1163             return NO
1164
1165         elif (
1166             prevp.type == token.RIGHTSHIFT
1167             and prevp.parent
1168             and prevp.parent.type == syms.shift_expr
1169             and prevp.prev_sibling
1170             and prevp.prev_sibling.type == token.NAME
1171             and prevp.prev_sibling.value == "print"  # type: ignore
1172         ):
1173             # Python 2 print chevron
1174             return NO
1175
1176     elif prev.type in OPENING_BRACKETS:
1177         return NO
1178
1179     if p.type in {syms.parameters, syms.arglist}:
1180         # untyped function signatures or calls
1181         if t == token.RPAR:
1182             return NO
1183
1184         if not prev or prev.type != token.COMMA:
1185             return NO
1186
1187     elif p.type == syms.varargslist:
1188         # lambdas
1189         if t == token.RPAR:
1190             return NO
1191
1192         if prev and prev.type != token.COMMA:
1193             return NO
1194
1195     elif p.type == syms.typedargslist:
1196         # typed function signatures
1197         if not prev:
1198             return NO
1199
1200         if t == token.EQUAL:
1201             if prev.type != syms.tname:
1202                 return NO
1203
1204         elif prev.type == token.EQUAL:
1205             # A bit hacky: if the equal sign has whitespace, it means we
1206             # previously found it's a typed argument.  So, we're using that, too.
1207             return prev.prefix
1208
1209         elif prev.type != token.COMMA:
1210             return NO
1211
1212     elif p.type == syms.tname:
1213         # type names
1214         if not prev:
1215             prevp = preceding_leaf(p)
1216             if not prevp or prevp.type != token.COMMA:
1217                 return NO
1218
1219     elif p.type == syms.trailer:
1220         # attributes and calls
1221         if t == token.LPAR or t == token.RPAR:
1222             return NO
1223
1224         if not prev:
1225             if t == token.DOT:
1226                 prevp = preceding_leaf(p)
1227                 if not prevp or prevp.type != token.NUMBER:
1228                     return NO
1229
1230             elif t == token.LSQB:
1231                 return NO
1232
1233         elif prev.type != token.COMMA:
1234             return NO
1235
1236     elif p.type == syms.argument:
1237         # single argument
1238         if t == token.EQUAL:
1239             return NO
1240
1241         if not prev:
1242             prevp = preceding_leaf(p)
1243             if not prevp or prevp.type == token.LPAR:
1244                 return NO
1245
1246         elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1247             return NO
1248
1249     elif p.type == syms.decorator:
1250         # decorators
1251         return NO
1252
1253     elif p.type == syms.dotted_name:
1254         if prev:
1255             return NO
1256
1257         prevp = preceding_leaf(p)
1258         if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1259             return NO
1260
1261     elif p.type == syms.classdef:
1262         if t == token.LPAR:
1263             return NO
1264
1265         if prev and prev.type == token.LPAR:
1266             return NO
1267
1268     elif p.type == syms.subscript:
1269         # indexing
1270         if not prev:
1271             assert p.parent is not None, "subscripts are always parented"
1272             if p.parent.type == syms.subscriptlist:
1273                 return SPACE
1274
1275             return NO
1276
1277         else:
1278             return NO
1279
1280     elif p.type == syms.atom:
1281         if prev and t == token.DOT:
1282             # dots, but not the first one.
1283             return NO
1284
1285     elif (
1286         p.type == syms.listmaker
1287         or p.type == syms.testlist_gexp
1288         or p.type == syms.subscriptlist
1289     ):
1290         # list interior, including unpacking
1291         if not prev:
1292             return NO
1293
1294     elif p.type == syms.dictsetmaker:
1295         # dict and set interior, including unpacking
1296         if not prev:
1297             return NO
1298
1299         if prev.type == token.DOUBLESTAR:
1300             return NO
1301
1302     elif p.type in {syms.factor, syms.star_expr}:
1303         # unary ops
1304         if not prev:
1305             prevp = preceding_leaf(p)
1306             if not prevp or prevp.type in OPENING_BRACKETS:
1307                 return NO
1308
1309             prevp_parent = prevp.parent
1310             assert prevp_parent is not None
1311             if prevp.type == token.COLON and prevp_parent.type in {
1312                 syms.subscript, syms.sliceop
1313             }:
1314                 return NO
1315
1316             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1317                 return NO
1318
1319         elif t == token.NAME or t == token.NUMBER:
1320             return NO
1321
1322     elif p.type == syms.import_from:
1323         if t == token.DOT:
1324             if prev and prev.type == token.DOT:
1325                 return NO
1326
1327         elif t == token.NAME:
1328             if v == "import":
1329                 return SPACE
1330
1331             if prev and prev.type == token.DOT:
1332                 return NO
1333
1334     elif p.type == syms.sliceop:
1335         return NO
1336
1337     return SPACE
1338
1339
1340 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1341     """Return the first leaf that precedes `node`, if any."""
1342     while node:
1343         res = node.prev_sibling
1344         if res:
1345             if isinstance(res, Leaf):
1346                 return res
1347
1348             try:
1349                 return list(res.leaves())[-1]
1350
1351             except IndexError:
1352                 return None
1353
1354         node = node.parent
1355     return None
1356
1357
1358 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1359     """Return the priority of the `leaf` delimiter, given a line break after it.
1360
1361     The delimiter priorities returned here are from those delimiters that would
1362     cause a line break after themselves.
1363
1364     Higher numbers are higher priority.
1365     """
1366     if leaf.type == token.COMMA:
1367         return COMMA_PRIORITY
1368
1369     if (
1370         leaf.type in VARARGS
1371         and leaf.parent
1372         and leaf.parent.type in {syms.argument, syms.typedargslist}
1373     ):
1374         return MATH_PRIORITY
1375
1376     return 0
1377
1378
1379 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1380     """Return the priority of the `leaf` delimiter, given a line before after it.
1381
1382     The delimiter priorities returned here are from those delimiters that would
1383     cause a line break before themselves.
1384
1385     Higher numbers are higher priority.
1386     """
1387     if (
1388         leaf.type in MATH_OPERATORS
1389         and leaf.parent
1390         and leaf.parent.type not in {syms.factor, syms.star_expr}
1391     ):
1392         return MATH_PRIORITY
1393
1394     if leaf.type in COMPARATORS:
1395         return COMPARATOR_PRIORITY
1396
1397     if (
1398         leaf.type == token.STRING
1399         and previous is not None
1400         and previous.type == token.STRING
1401     ):
1402         return STRING_PRIORITY
1403
1404     if (
1405         leaf.type == token.NAME
1406         and leaf.value == "for"
1407         and leaf.parent
1408         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1409     ):
1410         return COMPREHENSION_PRIORITY
1411
1412     if (
1413         leaf.type == token.NAME
1414         and leaf.value == "if"
1415         and leaf.parent
1416         and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1417     ):
1418         return COMPREHENSION_PRIORITY
1419
1420     if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1421         return LOGIC_PRIORITY
1422
1423     return 0
1424
1425
1426 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1427     """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1428
1429     Higher numbers are higher priority.
1430     """
1431     return max(
1432         is_split_before_delimiter(leaf, previous),
1433         is_split_after_delimiter(leaf, previous),
1434     )
1435
1436
1437 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1438     """Clean the prefix of the `leaf` and generate comments from it, if any.
1439
1440     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
1441     in `pgen2/driver.py:Driver.parse_tokens()`.  This was a brilliant implementation
1442     move because it does away with modifying the grammar to include all the
1443     possible places in which comments can be placed.
1444
1445     The sad consequence for us though is that comments don't "belong" anywhere.
1446     This is why this function generates simple parentless Leaf objects for
1447     comments.  We simply don't know what the correct parent should be.
1448
1449     No matter though, we can live without this.  We really only need to
1450     differentiate between inline and standalone comments.  The latter don't
1451     share the line with any code.
1452
1453     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
1454     are emitted with a fake STANDALONE_COMMENT token identifier.
1455     """
1456     p = leaf.prefix
1457     if not p:
1458         return
1459
1460     if "#" not in p:
1461         return
1462
1463     consumed = 0
1464     nlines = 0
1465     for index, line in enumerate(p.split("\n")):
1466         consumed += len(line) + 1  # adding the length of the split '\n'
1467         line = line.lstrip()
1468         if not line:
1469             nlines += 1
1470         if not line.startswith("#"):
1471             continue
1472
1473         if index == 0 and leaf.type != token.ENDMARKER:
1474             comment_type = token.COMMENT  # simple trailing comment
1475         else:
1476             comment_type = STANDALONE_COMMENT
1477         comment = make_comment(line)
1478         yield Leaf(comment_type, comment, prefix="\n" * nlines)
1479
1480         if comment in {"# fmt: on", "# yapf: enable"}:
1481             raise FormatOn(consumed)
1482
1483         if comment in {"# fmt: off", "# yapf: disable"}:
1484             raise FormatOff(consumed)
1485
1486         nlines = 0
1487
1488
1489 def make_comment(content: str) -> str:
1490     """Return a consistently formatted comment from the given `content` string.
1491
1492     All comments (except for "##", "#!", "#:") should have a single space between
1493     the hash sign and the content.
1494
1495     If `content` didn't start with a hash sign, one is provided.
1496     """
1497     content = content.rstrip()
1498     if not content:
1499         return "#"
1500
1501     if content[0] == "#":
1502         content = content[1:]
1503     if content and content[0] not in " !:#":
1504         content = " " + content
1505     return "#" + content
1506
1507
1508 def split_line(
1509     line: Line, line_length: int, inner: bool = False, py36: bool = False
1510 ) -> Iterator[Line]:
1511     """Split a `line` into potentially many lines.
1512
1513     They should fit in the allotted `line_length` but might not be able to.
1514     `inner` signifies that there were a pair of brackets somewhere around the
1515     current `line`, possibly transitively. This means we can fallback to splitting
1516     by delimiters if the LHS/RHS don't yield any results.
1517
1518     If `py36` is True, splitting may generate syntax that is only compatible
1519     with Python 3.6 and later.
1520     """
1521     if isinstance(line, UnformattedLines) or line.is_comment:
1522         yield line
1523         return
1524
1525     line_str = str(line).strip("\n")
1526     if (
1527         len(line_str) <= line_length
1528         and "\n" not in line_str  # multiline strings
1529         and not line.contains_standalone_comments
1530     ):
1531         yield line
1532         return
1533
1534     split_funcs: List[SplitFunc]
1535     if line.is_def:
1536         split_funcs = [left_hand_split]
1537     elif line.inside_brackets:
1538         split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1539     else:
1540         split_funcs = [right_hand_split]
1541     for split_func in split_funcs:
1542         # We are accumulating lines in `result` because we might want to abort
1543         # mission and return the original line in the end, or attempt a different
1544         # split altogether.
1545         result: List[Line] = []
1546         try:
1547             for l in split_func(line, py36):
1548                 if str(l).strip("\n") == line_str:
1549                     raise CannotSplit("Split function returned an unchanged result")
1550
1551                 result.extend(
1552                     split_line(l, line_length=line_length, inner=True, py36=py36)
1553                 )
1554         except CannotSplit as cs:
1555             continue
1556
1557         else:
1558             yield from result
1559             break
1560
1561     else:
1562         yield line
1563
1564
1565 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1566     """Split line into many lines, starting with the first matching bracket pair.
1567
1568     Note: this usually looks weird, only use this for function definitions.
1569     Prefer RHS otherwise.
1570     """
1571     head = Line(depth=line.depth)
1572     body = Line(depth=line.depth + 1, inside_brackets=True)
1573     tail = Line(depth=line.depth)
1574     tail_leaves: List[Leaf] = []
1575     body_leaves: List[Leaf] = []
1576     head_leaves: List[Leaf] = []
1577     current_leaves = head_leaves
1578     matching_bracket = None
1579     for leaf in line.leaves:
1580         if (
1581             current_leaves is body_leaves
1582             and leaf.type in CLOSING_BRACKETS
1583             and leaf.opening_bracket is matching_bracket
1584         ):
1585             current_leaves = tail_leaves if body_leaves else head_leaves
1586         current_leaves.append(leaf)
1587         if current_leaves is head_leaves:
1588             if leaf.type in OPENING_BRACKETS:
1589                 matching_bracket = leaf
1590                 current_leaves = body_leaves
1591     # Since body is a new indent level, remove spurious leading whitespace.
1592     if body_leaves:
1593         normalize_prefix(body_leaves[0], inside_brackets=True)
1594     # Build the new lines.
1595     for result, leaves in (
1596         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1597     ):
1598         for leaf in leaves:
1599             result.append(leaf, preformatted=True)
1600             for comment_after in line.comments_after(leaf):
1601                 result.append(comment_after, preformatted=True)
1602     bracket_split_succeeded_or_raise(head, body, tail)
1603     for result in (head, body, tail):
1604         if result:
1605             yield result
1606
1607
1608 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1609     """Split line into many lines, starting with the last matching bracket pair."""
1610     head = Line(depth=line.depth)
1611     body = Line(depth=line.depth + 1, inside_brackets=True)
1612     tail = Line(depth=line.depth)
1613     tail_leaves: List[Leaf] = []
1614     body_leaves: List[Leaf] = []
1615     head_leaves: List[Leaf] = []
1616     current_leaves = tail_leaves
1617     opening_bracket = None
1618     for leaf in reversed(line.leaves):
1619         if current_leaves is body_leaves:
1620             if leaf is opening_bracket:
1621                 current_leaves = head_leaves if body_leaves else tail_leaves
1622         current_leaves.append(leaf)
1623         if current_leaves is tail_leaves:
1624             if leaf.type in CLOSING_BRACKETS:
1625                 opening_bracket = leaf.opening_bracket
1626                 current_leaves = body_leaves
1627     tail_leaves.reverse()
1628     body_leaves.reverse()
1629     head_leaves.reverse()
1630     # Since body is a new indent level, remove spurious leading whitespace.
1631     if body_leaves:
1632         normalize_prefix(body_leaves[0], inside_brackets=True)
1633     # Build the new lines.
1634     for result, leaves in (
1635         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1636     ):
1637         for leaf in leaves:
1638             result.append(leaf, preformatted=True)
1639             for comment_after in line.comments_after(leaf):
1640                 result.append(comment_after, preformatted=True)
1641     bracket_split_succeeded_or_raise(head, body, tail)
1642     for result in (head, body, tail):
1643         if result:
1644             yield result
1645
1646
1647 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1648     """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1649
1650     Do nothing otherwise.
1651
1652     A left- or right-hand split is based on a pair of brackets. Content before
1653     (and including) the opening bracket is left on one line, content inside the
1654     brackets is put on a separate line, and finally content starting with and
1655     following the closing bracket is put on a separate line.
1656
1657     Those are called `head`, `body`, and `tail`, respectively. If the split
1658     produced the same line (all content in `head`) or ended up with an empty `body`
1659     and the `tail` is just the closing bracket, then it's considered failed.
1660     """
1661     tail_len = len(str(tail).strip())
1662     if not body:
1663         if tail_len == 0:
1664             raise CannotSplit("Splitting brackets produced the same line")
1665
1666         elif tail_len < 3:
1667             raise CannotSplit(
1668                 f"Splitting brackets on an empty body to save "
1669                 f"{tail_len} characters is not worth it"
1670             )
1671
1672
1673 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1674     """Normalize prefix of the first leaf in every line returned by `split_func`.
1675
1676     This is a decorator over relevant split functions.
1677     """
1678
1679     @wraps(split_func)
1680     def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1681         for l in split_func(line, py36):
1682             normalize_prefix(l.leaves[0], inside_brackets=True)
1683             yield l
1684
1685     return split_wrapper
1686
1687
1688 @dont_increase_indentation
1689 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1690     """Split according to delimiters of the highest priority.
1691
1692     If `py36` is True, the split will add trailing commas also in function
1693     signatures that contain `*` and `**`.
1694     """
1695     try:
1696         last_leaf = line.leaves[-1]
1697     except IndexError:
1698         raise CannotSplit("Line empty")
1699
1700     delimiters = line.bracket_tracker.delimiters
1701     try:
1702         delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1703             exclude={id(last_leaf)}
1704         )
1705     except ValueError:
1706         raise CannotSplit("No delimiters found")
1707
1708     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1709     lowest_depth = sys.maxsize
1710     trailing_comma_safe = True
1711
1712     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1713         """Append `leaf` to current line or to new line if appending impossible."""
1714         nonlocal current_line
1715         try:
1716             current_line.append_safe(leaf, preformatted=True)
1717         except ValueError as ve:
1718             yield current_line
1719
1720             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1721             current_line.append(leaf)
1722
1723     for leaf in line.leaves:
1724         yield from append_to_line(leaf)
1725
1726         for comment_after in line.comments_after(leaf):
1727             yield from append_to_line(comment_after)
1728
1729         lowest_depth = min(lowest_depth, leaf.bracket_depth)
1730         if (
1731             leaf.bracket_depth == lowest_depth
1732             and leaf.type == token.STAR
1733             or leaf.type == token.DOUBLESTAR
1734         ):
1735             trailing_comma_safe = trailing_comma_safe and py36
1736         leaf_priority = delimiters.get(id(leaf))
1737         if leaf_priority == delimiter_priority:
1738             yield current_line
1739
1740             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1741     if current_line:
1742         if (
1743             delimiter_priority == COMMA_PRIORITY
1744             and current_line.leaves[-1].type != token.COMMA
1745             and trailing_comma_safe
1746         ):
1747             current_line.append(Leaf(token.COMMA, ","))
1748         yield current_line
1749
1750
1751 @dont_increase_indentation
1752 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1753     """Split standalone comments from the rest of the line."""
1754     for leaf in line.leaves:
1755         if leaf.type == STANDALONE_COMMENT:
1756             if leaf.bracket_depth == 0:
1757                 break
1758
1759     else:
1760         raise CannotSplit("Line does not have any standalone comments")
1761
1762     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1763
1764     def append_to_line(leaf: Leaf) -> Iterator[Line]:
1765         """Append `leaf` to current line or to new line if appending impossible."""
1766         nonlocal current_line
1767         try:
1768             current_line.append_safe(leaf, preformatted=True)
1769         except ValueError as ve:
1770             yield current_line
1771
1772             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1773             current_line.append(leaf)
1774
1775     for leaf in line.leaves:
1776         yield from append_to_line(leaf)
1777
1778         for comment_after in line.comments_after(leaf):
1779             yield from append_to_line(comment_after)
1780
1781     if current_line:
1782         yield current_line
1783
1784
1785 def is_import(leaf: Leaf) -> bool:
1786     """Return True if the given leaf starts an import statement."""
1787     p = leaf.parent
1788     t = leaf.type
1789     v = leaf.value
1790     return bool(
1791         t == token.NAME
1792         and (
1793             (v == "import" and p and p.type == syms.import_name)
1794             or (v == "from" and p and p.type == syms.import_from)
1795         )
1796     )
1797
1798
1799 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1800     """Leave existing extra newlines if not `inside_brackets`. Remove everything
1801     else.
1802
1803     Note: don't use backslashes for formatting or you'll lose your voting rights.
1804     """
1805     if not inside_brackets:
1806         spl = leaf.prefix.split("#")
1807         if "\\" not in spl[0]:
1808             nl_count = spl[-1].count("\n")
1809             if len(spl) > 1:
1810                 nl_count -= 1
1811             leaf.prefix = "\n" * nl_count
1812             return
1813
1814     leaf.prefix = ""
1815
1816
1817 def normalize_string_quotes(leaf: Leaf) -> None:
1818     """Prefer double quotes but only if it doesn't cause more escaping.
1819
1820     Adds or removes backslashes as appropriate. Doesn't parse and fix
1821     strings nested in f-strings (yet).
1822
1823     Note: Mutates its argument.
1824     """
1825     value = leaf.value.lstrip("furbFURB")
1826     if value[:3] == '"""':
1827         return
1828
1829     elif value[:3] == "'''":
1830         orig_quote = "'''"
1831         new_quote = '"""'
1832     elif value[0] == '"':
1833         orig_quote = '"'
1834         new_quote = "'"
1835     else:
1836         orig_quote = "'"
1837         new_quote = '"'
1838     first_quote_pos = leaf.value.find(orig_quote)
1839     if first_quote_pos == -1:
1840         return  # There's an internal error
1841
1842     body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1843     new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
1844         new_quote, f"\\{new_quote}"
1845     )
1846     if new_quote == '"""' and new_body[-1] == '"':
1847         # edge case:
1848         new_body = new_body[:-1] + '\\"'
1849     orig_escape_count = body.count("\\")
1850     new_escape_count = new_body.count("\\")
1851     if new_escape_count > orig_escape_count:
1852         return  # Do not introduce more escaping
1853
1854     if new_escape_count == orig_escape_count and orig_quote == '"':
1855         return  # Prefer double quotes
1856
1857     prefix = leaf.value[:first_quote_pos]
1858     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
1859
1860
1861 def is_python36(node: Node) -> bool:
1862     """Return True if the current file is using Python 3.6+ features.
1863
1864     Currently looking for:
1865     - f-strings; and
1866     - trailing commas after * or ** in function signatures.
1867     """
1868     for n in node.pre_order():
1869         if n.type == token.STRING:
1870             value_head = n.value[:2]  # type: ignore
1871             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1872                 return True
1873
1874         elif (
1875             n.type == syms.typedargslist
1876             and n.children
1877             and n.children[-1].type == token.COMMA
1878         ):
1879             for ch in n.children:
1880                 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1881                     return True
1882
1883     return False
1884
1885
1886 PYTHON_EXTENSIONS = {".py"}
1887 BLACKLISTED_DIRECTORIES = {
1888     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
1889 }
1890
1891
1892 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1893     """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1894     and have one of the PYTHON_EXTENSIONS.
1895     """
1896     for child in path.iterdir():
1897         if child.is_dir():
1898             if child.name in BLACKLISTED_DIRECTORIES:
1899                 continue
1900
1901             yield from gen_python_files_in_dir(child)
1902
1903         elif child.suffix in PYTHON_EXTENSIONS:
1904             yield child
1905
1906
1907 @dataclass
1908 class Report:
1909     """Provides a reformatting counter. Can be rendered with `str(report)`."""
1910     check: bool = False
1911     change_count: int = 0
1912     same_count: int = 0
1913     failure_count: int = 0
1914
1915     def done(self, src: Path, changed: bool) -> None:
1916         """Increment the counter for successful reformatting. Write out a message."""
1917         if changed:
1918             reformatted = "would reformat" if self.check else "reformatted"
1919             out(f"{reformatted} {src}")
1920             self.change_count += 1
1921         else:
1922             out(f"{src} already well formatted, good job.", bold=False)
1923             self.same_count += 1
1924
1925     def failed(self, src: Path, message: str) -> None:
1926         """Increment the counter for failed reformatting. Write out a message."""
1927         err(f"error: cannot format {src}: {message}")
1928         self.failure_count += 1
1929
1930     @property
1931     def return_code(self) -> int:
1932         """Return the exit code that the app should use.
1933
1934         This considers the current state of changed files and failures:
1935         - if there were any failures, return 123;
1936         - if any files were changed and --check is being used, return 1;
1937         - otherwise return 0.
1938         """
1939         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
1940         # 126 we have special returncodes reserved by the shell.
1941         if self.failure_count:
1942             return 123
1943
1944         elif self.change_count and self.check:
1945             return 1
1946
1947         return 0
1948
1949     def __str__(self) -> str:
1950         """Render a color report of the current state.
1951
1952         Use `click.unstyle` to remove colors.
1953         """
1954         if self.check:
1955             reformatted = "would be reformatted"
1956             unchanged = "would be left unchanged"
1957             failed = "would fail to reformat"
1958         else:
1959             reformatted = "reformatted"
1960             unchanged = "left unchanged"
1961             failed = "failed to reformat"
1962         report = []
1963         if self.change_count:
1964             s = "s" if self.change_count > 1 else ""
1965             report.append(
1966                 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
1967             )
1968         if self.same_count:
1969             s = "s" if self.same_count > 1 else ""
1970             report.append(f"{self.same_count} file{s} {unchanged}")
1971         if self.failure_count:
1972             s = "s" if self.failure_count > 1 else ""
1973             report.append(
1974                 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
1975             )
1976         return ", ".join(report) + "."
1977
1978
1979 def assert_equivalent(src: str, dst: str) -> None:
1980     """Raise AssertionError if `src` and `dst` aren't equivalent."""
1981
1982     import ast
1983     import traceback
1984
1985     def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
1986         """Simple visitor generating strings to compare ASTs by content."""
1987         yield f"{'  ' * depth}{node.__class__.__name__}("
1988
1989         for field in sorted(node._fields):
1990             try:
1991                 value = getattr(node, field)
1992             except AttributeError:
1993                 continue
1994
1995             yield f"{'  ' * (depth+1)}{field}="
1996
1997             if isinstance(value, list):
1998                 for item in value:
1999                     if isinstance(item, ast.AST):
2000                         yield from _v(item, depth + 2)
2001
2002             elif isinstance(value, ast.AST):
2003                 yield from _v(value, depth + 2)
2004
2005             else:
2006                 yield f"{'  ' * (depth+2)}{value!r},  # {value.__class__.__name__}"
2007
2008         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
2009
2010     try:
2011         src_ast = ast.parse(src)
2012     except Exception as exc:
2013         major, minor = sys.version_info[:2]
2014         raise AssertionError(
2015             f"cannot use --safe with this file; failed to parse source file "
2016             f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2017             f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2018         )
2019
2020     try:
2021         dst_ast = ast.parse(dst)
2022     except Exception as exc:
2023         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2024         raise AssertionError(
2025             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2026             f"Please report a bug on https://github.com/ambv/black/issues.  "
2027             f"This invalid output might be helpful: {log}"
2028         ) from None
2029
2030     src_ast_str = "\n".join(_v(src_ast))
2031     dst_ast_str = "\n".join(_v(dst_ast))
2032     if src_ast_str != dst_ast_str:
2033         log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2034         raise AssertionError(
2035             f"INTERNAL ERROR: Black produced code that is not equivalent to "
2036             f"the source.  "
2037             f"Please report a bug on https://github.com/ambv/black/issues.  "
2038             f"This diff might be helpful: {log}"
2039         ) from None
2040
2041
2042 def assert_stable(src: str, dst: str, line_length: int) -> None:
2043     """Raise AssertionError if `dst` reformats differently the second time."""
2044     newdst = format_str(dst, line_length=line_length)
2045     if dst != newdst:
2046         log = dump_to_file(
2047             diff(src, dst, "source", "first pass"),
2048             diff(dst, newdst, "first pass", "second pass"),
2049         )
2050         raise AssertionError(
2051             f"INTERNAL ERROR: Black produced different code on the second pass "
2052             f"of the formatter.  "
2053             f"Please report a bug on https://github.com/ambv/black/issues.  "
2054             f"This diff might be helpful: {log}"
2055         ) from None
2056
2057
2058 def dump_to_file(*output: str) -> str:
2059     """Dump `output` to a temporary file. Return path to the file."""
2060     import tempfile
2061
2062     with tempfile.NamedTemporaryFile(
2063         mode="w", prefix="blk_", suffix=".log", delete=False
2064     ) as f:
2065         for lines in output:
2066             f.write(lines)
2067             f.write("\n")
2068     return f.name
2069
2070
2071 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2072     """Return a unified diff string between strings `a` and `b`."""
2073     import difflib
2074
2075     a_lines = [line + "\n" for line in a.split("\n")]
2076     b_lines = [line + "\n" for line in b.split("\n")]
2077     return "".join(
2078         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2079     )
2080
2081
2082 def cancel(tasks: List[asyncio.Task]) -> None:
2083     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2084     err("Aborted!")
2085     for task in tasks:
2086         task.cancel()
2087
2088
2089 def shutdown(loop: BaseEventLoop) -> None:
2090     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2091     try:
2092         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2093         to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2094         if not to_cancel:
2095             return
2096
2097         for task in to_cancel:
2098             task.cancel()
2099         loop.run_until_complete(
2100             asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2101         )
2102     finally:
2103         # `concurrent.futures.Future` objects cannot be cancelled once they
2104         # are already running. There might be some when the `shutdown()` happened.
2105         # Silence their logger's spew about the event loop being closed.
2106         cf_logger = logging.getLogger("concurrent.futures")
2107         cf_logger.setLevel(logging.CRITICAL)
2108         loop.close()
2109
2110
2111 if __name__ == "__main__":
2112     main()