All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
7 from functools import partial, wraps
10 from multiprocessing import Manager
12 from pathlib import Path
32 from attr import dataclass, Factory
36 from blib2to3.pytree import Node, Leaf, type_repr
37 from blib2to3 import pygram, pytree
38 from blib2to3.pgen2 import driver, token
39 from blib2to3.pgen2.parse import ParseError
41 __version__ = "18.4a0"
42 DEFAULT_LINE_LENGTH = 88
44 syms = pygram.python_symbols
52 LN = Union[Leaf, Node]
53 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
54 out = partial(click.secho, bold=True, err=True)
55 err = partial(click.secho, fg="red", err=True)
58 class NothingChanged(UserWarning):
59 """Raised by :func:`format_file` when reformatted code is the same as source."""
62 class CannotSplit(Exception):
63 """A readable split that fits the allotted line length is impossible.
65 Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
66 :func:`delimiter_split`.
70 class FormatError(Exception):
71 """Base exception for `# fmt: on` and `# fmt: off` handling.
73 It holds the number of bytes of the prefix consumed before the format
74 control comment appeared.
77 def __init__(self, consumed: int) -> None:
78 super().__init__(consumed)
79 self.consumed = consumed
81 def trim_prefix(self, leaf: Leaf) -> None:
82 leaf.prefix = leaf.prefix[self.consumed:]
84 def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
85 """Returns a new Leaf from the consumed part of the prefix."""
86 unformatted_prefix = leaf.prefix[:self.consumed]
87 return Leaf(token.NEWLINE, unformatted_prefix)
90 class FormatOn(FormatError):
91 """Found a comment like `# fmt: on` in the file."""
94 class FormatOff(FormatError):
95 """Found a comment like `# fmt: off` in the file."""
98 class WriteBack(Enum):
109 default=DEFAULT_LINE_LENGTH,
110 help="How many character per line to allow.",
117 "Don't write the files back, just return the status. Return code 0 "
118 "means nothing would change. Return code 1 means some files would be "
119 "reformatted. Return code 123 means there was an internal error."
125 help="Don't write the files back, just output a diff for each file on stdout.",
130 help="If --fast given, skip temporary sanity checks. [default: --safe]",
132 @click.version_option(version=__version__)
137 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
149 """The uncompromising code formatter."""
150 sources: List[Path] = []
154 sources.extend(gen_python_files_in_dir(p))
156 # if a file was explicitly given, we don't care about its extension
159 sources.append(Path("-"))
161 err(f"invalid path: {s}")
163 exc = click.ClickException("Options --check and --diff are mutually exclusive")
168 write_back = WriteBack.NO
170 write_back = WriteBack.DIFF
172 write_back = WriteBack.YES
173 if len(sources) == 0:
175 elif len(sources) == 1:
177 report = Report(check=check)
179 if not p.is_file() and str(p) == "-":
180 changed = format_stdin_to_stdout(
181 line_length=line_length, fast=fast, write_back=write_back
184 changed = format_file_in_place(
185 p, line_length=line_length, fast=fast, write_back=write_back
187 report.done(p, changed)
188 except Exception as exc:
189 report.failed(p, str(exc))
190 ctx.exit(report.return_code)
192 loop = asyncio.get_event_loop()
193 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
196 return_code = loop.run_until_complete(
198 sources, line_length, write_back, fast, loop, executor
203 ctx.exit(return_code)
206 async def schedule_formatting(
209 write_back: WriteBack,
214 """Run formatting of `sources` in parallel using the provided `executor`.
216 (Use ProcessPoolExecutors for actual parallelism.)
218 `line_length`, `write_back`, and `fast` options are passed to
219 :func:`format_file_in_place`.
222 if write_back == WriteBack.DIFF:
223 # For diff output, we need locks to ensure we don't interleave output
224 # from different processes.
226 lock = manager.Lock()
228 src: loop.run_in_executor(
229 executor, format_file_in_place, src, line_length, fast, write_back, lock
233 _task_values = list(tasks.values())
234 loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
235 loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
236 await asyncio.wait(tasks.values())
238 report = Report(check=not write_back)
239 for src, task in tasks.items():
241 report.failed(src, "timed out, cancelling")
243 cancelled.append(task)
244 elif task.cancelled():
245 cancelled.append(task)
246 elif task.exception():
247 report.failed(src, str(task.exception()))
249 report.done(src, task.result())
251 await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
253 out("All done! ✨ 🍰 ✨")
254 click.echo(str(report))
255 return report.return_code
258 def format_file_in_place(
262 write_back: WriteBack = WriteBack.NO,
263 lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
265 """Format file under `src` path. Return True if changed.
267 If `write_back` is True, write reformatted code back to stdout.
268 `line_length` and `fast` options are passed to :func:`format_file_contents`.
270 with tokenize.open(src) as src_buffer:
271 src_contents = src_buffer.read()
273 dst_contents = format_file_contents(
274 src_contents, line_length=line_length, fast=fast
276 except NothingChanged:
279 if write_back == write_back.YES:
280 with open(src, "w", encoding=src_buffer.encoding) as f:
281 f.write(dst_contents)
282 elif write_back == write_back.DIFF:
283 src_name = f"{src.name} (original)"
284 dst_name = f"{src.name} (formatted)"
285 diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
289 sys.stdout.write(diff_contents)
296 def format_stdin_to_stdout(
297 line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
299 """Format file on stdin. Return True if changed.
301 If `write_back` is True, write reformatted code back to stdout.
302 `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
304 src = sys.stdin.read()
306 dst = format_file_contents(src, line_length=line_length, fast=fast)
309 except NothingChanged:
314 if write_back == WriteBack.YES:
315 sys.stdout.write(dst)
316 elif write_back == WriteBack.DIFF:
317 src_name = "<stdin> (original)"
318 dst_name = "<stdin> (formatted)"
319 sys.stdout.write(diff(src, dst, src_name, dst_name))
322 def format_file_contents(
323 src_contents: str, line_length: int, fast: bool
325 """Reformat contents a file and return new contents.
327 If `fast` is False, additionally confirm that the reformatted code is
328 valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
329 `line_length` is passed to :func:`format_str`.
331 if src_contents.strip() == "":
334 dst_contents = format_str(src_contents, line_length=line_length)
335 if src_contents == dst_contents:
339 assert_equivalent(src_contents, dst_contents)
340 assert_stable(src_contents, dst_contents, line_length=line_length)
344 def format_str(src_contents: str, line_length: int) -> FileContent:
345 """Reformat a string and return new contents.
347 `line_length` determines how many characters per line are allowed.
349 src_node = lib2to3_parse(src_contents)
351 lines = LineGenerator()
352 elt = EmptyLineTracker()
353 py36 = is_python36(src_node)
356 for current_line in lines.visit(src_node):
357 for _ in range(after):
358 dst_contents += str(empty_line)
359 before, after = elt.maybe_empty_lines(current_line)
360 for _ in range(before):
361 dst_contents += str(empty_line)
362 for line in split_line(current_line, line_length=line_length, py36=py36):
363 dst_contents += str(line)
368 pygram.python_grammar_no_print_statement_no_exec_statement,
369 pygram.python_grammar_no_print_statement,
370 pygram.python_grammar_no_exec_statement,
371 pygram.python_grammar,
375 def lib2to3_parse(src_txt: str) -> Node:
376 """Given a string with source, return the lib2to3 Node."""
377 grammar = pygram.python_grammar_no_print_statement
378 if src_txt[-1] != "\n":
379 nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
381 for grammar in GRAMMARS:
382 drv = driver.Driver(grammar, pytree.convert)
384 result = drv.parse_string(src_txt, True)
387 except ParseError as pe:
388 lineno, column = pe.context[1]
389 lines = src_txt.splitlines()
391 faulty_line = lines[lineno - 1]
393 faulty_line = "<line number missing in source>"
394 exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
398 if isinstance(result, Leaf):
399 result = Node(syms.file_input, [result])
403 def lib2to3_unparse(node: Node) -> str:
404 """Given a lib2to3 node, return its string representation."""
412 class Visitor(Generic[T]):
413 """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
415 def visit(self, node: LN) -> Iterator[T]:
416 """Main method to visit `node` and its children.
418 It tries to find a `visit_*()` method for the given `node.type`, like
419 `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
420 If no dedicated `visit_*()` method is found, chooses `visit_default()`
423 Then yields objects of type `T` from the selected visitor.
426 name = token.tok_name[node.type]
428 name = type_repr(node.type)
429 yield from getattr(self, f"visit_{name}", self.visit_default)(node)
431 def visit_default(self, node: LN) -> Iterator[T]:
432 """Default `visit_*()` implementation. Recurses to children of `node`."""
433 if isinstance(node, Node):
434 for child in node.children:
435 yield from self.visit(child)
439 class DebugVisitor(Visitor[T]):
442 def visit_default(self, node: LN) -> Iterator[T]:
443 indent = " " * (2 * self.tree_depth)
444 if isinstance(node, Node):
445 _type = type_repr(node.type)
446 out(f"{indent}{_type}", fg="yellow")
448 for child in node.children:
449 yield from self.visit(child)
452 out(f"{indent}/{_type}", fg="yellow", bold=False)
454 _type = token.tok_name.get(node.type, str(node.type))
455 out(f"{indent}{_type}", fg="blue", nl=False)
457 # We don't have to handle prefixes for `Node` objects since
458 # that delegates to the first child anyway.
459 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
460 out(f" {node.value!r}", fg="blue", bold=False)
463 def show(cls, code: str) -> None:
464 """Pretty-print the lib2to3 AST of a given string of `code`.
466 Convenience method for debugging.
468 v: DebugVisitor[None] = DebugVisitor()
469 list(v.visit(lib2to3_parse(code)))
472 KEYWORDS = set(keyword.kwlist)
473 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
474 FLOW_CONTROL = {"return", "raise", "break", "continue"}
485 STANDALONE_COMMENT = 153
486 LOGIC_OPERATORS = {"and", "or"}
510 VARARGS = {token.STAR, token.DOUBLESTAR}
511 COMPREHENSION_PRIORITY = 20
515 COMPARATOR_PRIORITY = 3
520 class BracketTracker:
521 """Keeps track of brackets on a line."""
524 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
525 delimiters: Dict[LeafID, Priority] = Factory(dict)
526 previous: Optional[Leaf] = None
528 def mark(self, leaf: Leaf) -> None:
529 """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
531 All leaves receive an int `bracket_depth` field that stores how deep
532 within brackets a given leaf is. 0 means there are no enclosing brackets
533 that started on this line.
535 If a leaf is itself a closing bracket, it receives an `opening_bracket`
536 field that it forms a pair with. This is a one-directional link to
537 avoid reference cycles.
539 If a leaf is a delimiter (a token on which Black can split the line if
540 needed) and it's on depth 0, its `id()` is stored in the tracker's
543 if leaf.type == token.COMMENT:
546 if leaf.type in CLOSING_BRACKETS:
548 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
549 leaf.opening_bracket = opening_bracket
550 leaf.bracket_depth = self.depth
552 after_delim = is_split_after_delimiter(leaf, self.previous)
553 before_delim = is_split_before_delimiter(leaf, self.previous)
554 if after_delim > before_delim:
555 self.delimiters[id(leaf)] = after_delim
556 elif before_delim > after_delim and self.previous is not None:
557 self.delimiters[id(self.previous)] = before_delim
558 if leaf.type in OPENING_BRACKETS:
559 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
563 def any_open_brackets(self) -> bool:
564 """Return True if there is an yet unmatched open bracket on the line."""
565 return bool(self.bracket_match)
567 def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
568 """Return the highest priority of a delimiter found on the line.
570 Values are consistent with what `is_delimiter()` returns.
572 return max(v for k, v in self.delimiters.items() if k not in exclude)
577 """Holds leaves and comments. Can be printed with `str(line)`."""
580 leaves: List[Leaf] = Factory(list)
581 comments: List[Tuple[Index, Leaf]] = Factory(list)
582 bracket_tracker: BracketTracker = Factory(BracketTracker)
583 inside_brackets: bool = False
584 has_for: bool = False
585 _for_loop_variable: bool = False
587 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
588 """Add a new `leaf` to the end of the line.
590 Unless `preformatted` is True, the `leaf` will receive a new consistent
591 whitespace prefix and metadata applied by :class:`BracketTracker`.
592 Trailing commas are maybe removed, unpacked for loop variables are
593 demoted from being delimiters.
595 Inline comments are put aside.
597 has_value = leaf.value.strip()
601 if self.leaves and not preformatted:
602 # Note: at this point leaf.prefix should be empty except for
603 # imports, for which we only preserve newlines.
604 leaf.prefix += whitespace(leaf)
605 if self.inside_brackets or not preformatted:
606 self.maybe_decrement_after_for_loop_variable(leaf)
607 self.bracket_tracker.mark(leaf)
608 self.maybe_remove_trailing_comma(leaf)
609 self.maybe_increment_for_loop_variable(leaf)
611 if not self.append_comment(leaf):
612 self.leaves.append(leaf)
614 def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
615 """Like :func:`append()` but disallow invalid standalone comment structure.
617 Raises ValueError when any `leaf` is appended after a standalone comment
618 or when a standalone comment is not the first leaf on the line.
620 if self.bracket_tracker.depth == 0:
622 raise ValueError("cannot append to standalone comments")
624 if self.leaves and leaf.type == STANDALONE_COMMENT:
626 "cannot append standalone comments to a populated line"
629 self.append(leaf, preformatted=preformatted)
632 def is_comment(self) -> bool:
633 """Is this line a standalone comment?"""
634 return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
637 def is_decorator(self) -> bool:
638 """Is this line a decorator?"""
639 return bool(self) and self.leaves[0].type == token.AT
642 def is_import(self) -> bool:
643 """Is this an import line?"""
644 return bool(self) and is_import(self.leaves[0])
647 def is_class(self) -> bool:
648 """Is this line a class definition?"""
651 and self.leaves[0].type == token.NAME
652 and self.leaves[0].value == "class"
656 def is_def(self) -> bool:
657 """Is this a function definition? (Also returns True for async defs.)"""
659 first_leaf = self.leaves[0]
664 second_leaf: Optional[Leaf] = self.leaves[1]
668 (first_leaf.type == token.NAME and first_leaf.value == "def")
670 first_leaf.type == token.ASYNC
671 and second_leaf is not None
672 and second_leaf.type == token.NAME
673 and second_leaf.value == "def"
678 def is_flow_control(self) -> bool:
679 """Is this line a flow control statement?
681 Those are `return`, `raise`, `break`, and `continue`.
685 and self.leaves[0].type == token.NAME
686 and self.leaves[0].value in FLOW_CONTROL
690 def is_yield(self) -> bool:
691 """Is this line a yield statement?"""
694 and self.leaves[0].type == token.NAME
695 and self.leaves[0].value == "yield"
699 def contains_standalone_comments(self) -> bool:
700 """If so, needs to be split before emitting."""
701 for leaf in self.leaves:
702 if leaf.type == STANDALONE_COMMENT:
707 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
708 """Remove trailing comma if there is one and it's safe."""
711 and self.leaves[-1].type == token.COMMA
712 and closing.type in CLOSING_BRACKETS
716 if closing.type == token.RBRACE:
717 self.remove_trailing_comma()
720 if closing.type == token.RSQB:
721 comma = self.leaves[-1]
722 if comma.parent and comma.parent.type == syms.listmaker:
723 self.remove_trailing_comma()
726 # For parens let's check if it's safe to remove the comma. If the
727 # trailing one is the only one, we might mistakenly change a tuple
728 # into a different type by removing the comma.
729 depth = closing.bracket_depth + 1
731 opening = closing.opening_bracket
732 for _opening_index, leaf in enumerate(self.leaves):
739 for leaf in self.leaves[_opening_index + 1:]:
743 bracket_depth = leaf.bracket_depth
744 if bracket_depth == depth and leaf.type == token.COMMA:
746 if leaf.parent and leaf.parent.type == syms.arglist:
751 self.remove_trailing_comma()
756 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
757 """In a for loop, or comprehension, the variables are often unpacks.
759 To avoid splitting on the comma in this situation, increase the depth of
760 tokens between `for` and `in`.
762 if leaf.type == token.NAME and leaf.value == "for":
764 self.bracket_tracker.depth += 1
765 self._for_loop_variable = True
770 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
771 """See `maybe_increment_for_loop_variable` above for explanation."""
772 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
773 self.bracket_tracker.depth -= 1
774 self._for_loop_variable = False
779 def append_comment(self, comment: Leaf) -> bool:
780 """Add an inline or standalone comment to the line."""
782 comment.type == STANDALONE_COMMENT
783 and self.bracket_tracker.any_open_brackets()
788 if comment.type != token.COMMENT:
791 after = len(self.leaves) - 1
793 comment.type = STANDALONE_COMMENT
798 self.comments.append((after, comment))
801 def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
802 """Generate comments that should appear directly after `leaf`."""
803 for _leaf_index, _leaf in enumerate(self.leaves):
810 for index, comment_after in self.comments:
811 if _leaf_index == index:
814 def remove_trailing_comma(self) -> None:
815 """Remove the trailing comma and moves the comments attached to it."""
816 comma_index = len(self.leaves) - 1
817 for i in range(len(self.comments)):
818 comment_index, comment = self.comments[i]
819 if comment_index == comma_index:
820 self.comments[i] = (comma_index - 1, comment)
823 def __str__(self) -> str:
824 """Render the line."""
828 indent = " " * self.depth
829 leaves = iter(self.leaves)
831 res = f"{first.prefix}{indent}{first.value}"
834 for _, comment in self.comments:
838 def __bool__(self) -> bool:
839 """Return True if the line has leaves or comments."""
840 return bool(self.leaves or self.comments)
843 class UnformattedLines(Line):
844 """Just like :class:`Line` but stores lines which aren't reformatted."""
846 def append(self, leaf: Leaf, preformatted: bool = True) -> None:
847 """Just add a new `leaf` to the end of the lines.
849 The `preformatted` argument is ignored.
851 Keeps track of indentation `depth`, which is useful when the user
852 says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
855 list(generate_comments(leaf))
856 except FormatOn as f_on:
857 self.leaves.append(f_on.leaf_from_consumed(leaf))
860 self.leaves.append(leaf)
861 if leaf.type == token.INDENT:
863 elif leaf.type == token.DEDENT:
866 def __str__(self) -> str:
867 """Render unformatted lines from leaves which were added with `append()`.
869 `depth` is not used for indentation in this case.
875 for leaf in self.leaves:
879 def append_comment(self, comment: Leaf) -> bool:
880 """Not implemented in this class. Raises `NotImplementedError`."""
881 raise NotImplementedError("Unformatted lines don't store comments separately.")
883 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
884 """Does nothing and returns False."""
887 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
888 """Does nothing and returns False."""
893 class EmptyLineTracker:
894 """Provides a stateful method that returns the number of potential extra
895 empty lines needed before and after the currently processed line.
897 Note: this tracker works on lines that haven't been split yet. It assumes
898 the prefix of the first leaf consists of optional newlines. Those newlines
899 are consumed by `maybe_empty_lines()` and included in the computation.
901 previous_line: Optional[Line] = None
902 previous_after: int = 0
903 previous_defs: List[int] = Factory(list)
905 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
906 """Return the number of extra empty lines before and after the `current_line`.
908 This is for separating `def`, `async def` and `class` with extra empty
909 lines (two on module-level), as well as providing an extra empty line
910 after flow control keywords to make them more prominent.
912 if isinstance(current_line, UnformattedLines):
915 before, after = self._maybe_empty_lines(current_line)
916 before -= self.previous_after
917 self.previous_after = after
918 self.previous_line = current_line
921 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
923 if current_line.depth == 0:
925 if current_line.leaves:
926 # Consume the first leaf's extra newlines.
927 first_leaf = current_line.leaves[0]
928 before = first_leaf.prefix.count("\n")
929 before = min(before, max_allowed)
930 first_leaf.prefix = ""
933 depth = current_line.depth
934 while self.previous_defs and self.previous_defs[-1] >= depth:
935 self.previous_defs.pop()
936 before = 1 if depth else 2
937 is_decorator = current_line.is_decorator
938 if is_decorator or current_line.is_def or current_line.is_class:
940 self.previous_defs.append(depth)
941 if self.previous_line is None:
942 # Don't insert empty lines before the first line in the file.
945 if self.previous_line and self.previous_line.is_decorator:
946 # Don't insert empty lines between decorators.
950 if current_line.depth:
954 if current_line.is_flow_control:
959 and self.previous_line.is_import
960 and not current_line.is_import
961 and depth == self.previous_line.depth
963 return (before or 1), 0
967 and self.previous_line.is_yield
968 and (not current_line.is_yield or depth != self.previous_line.depth)
970 return (before or 1), 0
976 class LineGenerator(Visitor[Line]):
977 """Generates reformatted Line objects. Empty lines are not emitted.
979 Note: destroys the tree it's visiting by mutating prefixes of its leaves
980 in ways that will no longer stringify to valid Python code on the tree.
982 current_line: Line = Factory(Line)
984 def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
987 If the line is empty, only emit if it makes sense.
988 If the line is too long, split it first and then generate.
990 If any lines were generated, set up a new current_line.
992 if not self.current_line:
993 if self.current_line.__class__ == type:
994 self.current_line.depth += indent
996 self.current_line = type(depth=self.current_line.depth + indent)
997 return # Line is empty, don't emit. Creating a new one unnecessary.
999 complete_line = self.current_line
1000 self.current_line = type(depth=complete_line.depth + indent)
1003 def visit(self, node: LN) -> Iterator[Line]:
1004 """Main method to visit `node` and its children.
1006 Yields :class:`Line` objects.
1008 if isinstance(self.current_line, UnformattedLines):
1009 # File contained `# fmt: off`
1010 yield from self.visit_unformatted(node)
1013 yield from super().visit(node)
1015 def visit_default(self, node: LN) -> Iterator[Line]:
1016 """Default `visit_*()` implementation. Recurses to children of `node`."""
1017 if isinstance(node, Leaf):
1018 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1020 for comment in generate_comments(node):
1021 if any_open_brackets:
1022 # any comment within brackets is subject to splitting
1023 self.current_line.append(comment)
1024 elif comment.type == token.COMMENT:
1025 # regular trailing comment
1026 self.current_line.append(comment)
1027 yield from self.line()
1030 # regular standalone comment
1031 yield from self.line()
1033 self.current_line.append(comment)
1034 yield from self.line()
1036 except FormatOff as f_off:
1037 f_off.trim_prefix(node)
1038 yield from self.line(type=UnformattedLines)
1039 yield from self.visit(node)
1041 except FormatOn as f_on:
1042 # This only happens here if somebody says "fmt: on" multiple
1044 f_on.trim_prefix(node)
1045 yield from self.visit_default(node)
1048 normalize_prefix(node, inside_brackets=any_open_brackets)
1049 if node.type == token.STRING:
1050 normalize_string_quotes(node)
1051 if node.type not in WHITESPACE:
1052 self.current_line.append(node)
1053 yield from super().visit_default(node)
1055 def visit_INDENT(self, node: Node) -> Iterator[Line]:
1056 """Increase indentation level, maybe yield a line."""
1057 # In blib2to3 INDENT never holds comments.
1058 yield from self.line(+1)
1059 yield from self.visit_default(node)
1061 def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1062 """Decrease indentation level, maybe yield a line."""
1063 # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1064 yield from self.line(-1)
1066 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1067 """Visit a statement.
1069 This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1070 `def`, `with`, and `class`.
1072 The relevant Python language `keywords` for a given statement will be NAME
1073 leaves within it. This methods puts those on a separate line.
1075 for child in node.children:
1076 if child.type == token.NAME and child.value in keywords: # type: ignore
1077 yield from self.line()
1079 yield from self.visit(child)
1081 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1082 """Visit a statement without nested statements."""
1083 is_suite_like = node.parent and node.parent.type in STATEMENT
1085 yield from self.line(+1)
1086 yield from self.visit_default(node)
1087 yield from self.line(-1)
1090 yield from self.line()
1091 yield from self.visit_default(node)
1093 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1094 """Visit `async def`, `async for`, `async with`."""
1095 yield from self.line()
1097 children = iter(node.children)
1098 for child in children:
1099 yield from self.visit(child)
1101 if child.type == token.ASYNC:
1104 internal_stmt = next(children)
1105 for child in internal_stmt.children:
1106 yield from self.visit(child)
1108 def visit_decorators(self, node: Node) -> Iterator[Line]:
1109 """Visit decorators."""
1110 for child in node.children:
1111 yield from self.line()
1112 yield from self.visit(child)
1114 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1115 """Remove a semicolon and put the other statement on a separate line."""
1116 yield from self.line()
1118 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1119 """End of file. Process outstanding comments and end with a newline."""
1120 yield from self.visit_default(leaf)
1121 yield from self.line()
1123 def visit_unformatted(self, node: LN) -> Iterator[Line]:
1124 """Used when file contained a `# fmt: off`."""
1125 if isinstance(node, Node):
1126 for child in node.children:
1127 yield from self.visit(child)
1131 self.current_line.append(node)
1132 except FormatOn as f_on:
1133 f_on.trim_prefix(node)
1134 yield from self.line()
1135 yield from self.visit(node)
1137 if node.type == token.ENDMARKER:
1138 # somebody decided not to put a final `# fmt: on`
1139 yield from self.line()
1141 def __attrs_post_init__(self) -> None:
1142 """You are in a twisty little maze of passages."""
1144 self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
1145 self.visit_while_stmt = partial(v, keywords={"while", "else"})
1146 self.visit_for_stmt = partial(v, keywords={"for", "else"})
1147 self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
1148 self.visit_except_clause = partial(v, keywords={"except"})
1149 self.visit_funcdef = partial(v, keywords={"def"})
1150 self.visit_with_stmt = partial(v, keywords={"with"})
1151 self.visit_classdef = partial(v, keywords={"class"})
1152 self.visit_async_funcdef = self.visit_async_stmt
1153 self.visit_decorated = self.visit_decorators
1156 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1157 OPENING_BRACKETS = set(BRACKET.keys())
1158 CLOSING_BRACKETS = set(BRACKET.values())
1159 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1160 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1163 def whitespace(leaf: Leaf) -> str: # noqa C901
1164 """Return whitespace prefix if needed for the given `leaf`."""
1171 if t in ALWAYS_NO_SPACE:
1174 if t == token.COMMENT:
1177 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1178 if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1181 prev = leaf.prev_sibling
1183 prevp = preceding_leaf(p)
1184 if not prevp or prevp.type in OPENING_BRACKETS:
1187 if t == token.COLON:
1188 return SPACE if prevp.type == token.COMMA else NO
1190 if prevp.type == token.EQUAL:
1192 if prevp.parent.type in {
1193 syms.arglist, syms.argument, syms.parameters, syms.varargslist
1197 elif prevp.parent.type == syms.typedargslist:
1198 # A bit hacky: if the equal sign has whitespace, it means we
1199 # previously found it's a typed argument. So, we're using
1203 elif prevp.type == token.DOUBLESTAR:
1204 if prevp.parent and prevp.parent.type in {
1214 elif prevp.type == token.COLON:
1215 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1220 and prevp.parent.type in {syms.factor, syms.star_expr}
1221 and prevp.type in MATH_OPERATORS
1226 prevp.type == token.RIGHTSHIFT
1228 and prevp.parent.type == syms.shift_expr
1229 and prevp.prev_sibling
1230 and prevp.prev_sibling.type == token.NAME
1231 and prevp.prev_sibling.value == "print" # type: ignore
1233 # Python 2 print chevron
1236 elif prev.type in OPENING_BRACKETS:
1239 if p.type in {syms.parameters, syms.arglist}:
1240 # untyped function signatures or calls
1244 if not prev or prev.type != token.COMMA:
1247 elif p.type == syms.varargslist:
1252 if prev and prev.type != token.COMMA:
1255 elif p.type == syms.typedargslist:
1256 # typed function signatures
1260 if t == token.EQUAL:
1261 if prev.type != syms.tname:
1264 elif prev.type == token.EQUAL:
1265 # A bit hacky: if the equal sign has whitespace, it means we
1266 # previously found it's a typed argument. So, we're using that, too.
1269 elif prev.type != token.COMMA:
1272 elif p.type == syms.tname:
1275 prevp = preceding_leaf(p)
1276 if not prevp or prevp.type != token.COMMA:
1279 elif p.type == syms.trailer:
1280 # attributes and calls
1281 if t == token.LPAR or t == token.RPAR:
1286 prevp = preceding_leaf(p)
1287 if not prevp or prevp.type != token.NUMBER:
1290 elif t == token.LSQB:
1293 elif prev.type != token.COMMA:
1296 elif p.type == syms.argument:
1298 if t == token.EQUAL:
1302 prevp = preceding_leaf(p)
1303 if not prevp or prevp.type == token.LPAR:
1306 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1309 elif p.type == syms.decorator:
1313 elif p.type == syms.dotted_name:
1317 prevp = preceding_leaf(p)
1318 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1321 elif p.type == syms.classdef:
1325 if prev and prev.type == token.LPAR:
1328 elif p.type == syms.subscript:
1331 assert p.parent is not None, "subscripts are always parented"
1332 if p.parent.type == syms.subscriptlist:
1340 elif p.type == syms.atom:
1341 if prev and t == token.DOT:
1342 # dots, but not the first one.
1346 p.type == syms.listmaker
1347 or p.type == syms.testlist_gexp
1348 or p.type == syms.subscriptlist
1350 # list interior, including unpacking
1354 elif p.type == syms.dictsetmaker:
1355 # dict and set interior, including unpacking
1359 if prev.type == token.DOUBLESTAR:
1362 elif p.type in {syms.factor, syms.star_expr}:
1365 prevp = preceding_leaf(p)
1366 if not prevp or prevp.type in OPENING_BRACKETS:
1369 prevp_parent = prevp.parent
1370 assert prevp_parent is not None
1371 if prevp.type == token.COLON and prevp_parent.type in {
1372 syms.subscript, syms.sliceop
1376 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1379 elif t == token.NAME or t == token.NUMBER:
1382 elif p.type == syms.import_from:
1384 if prev and prev.type == token.DOT:
1387 elif t == token.NAME:
1391 if prev and prev.type == token.DOT:
1394 elif p.type == syms.sliceop:
1400 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1401 """Return the first leaf that precedes `node`, if any."""
1403 res = node.prev_sibling
1405 if isinstance(res, Leaf):
1409 return list(res.leaves())[-1]
1418 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1419 """Return the priority of the `leaf` delimiter, given a line break after it.
1421 The delimiter priorities returned here are from those delimiters that would
1422 cause a line break after themselves.
1424 Higher numbers are higher priority.
1426 if leaf.type == token.COMMA:
1427 return COMMA_PRIORITY
1430 leaf.type in VARARGS
1432 and leaf.parent.type in {syms.argument, syms.typedargslist}
1434 return MATH_PRIORITY
1439 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1440 """Return the priority of the `leaf` delimiter, given a line before after it.
1442 The delimiter priorities returned here are from those delimiters that would
1443 cause a line break before themselves.
1445 Higher numbers are higher priority.
1448 leaf.type in MATH_OPERATORS
1450 and leaf.parent.type not in {syms.factor, syms.star_expr}
1452 return MATH_PRIORITY
1454 if leaf.type in COMPARATORS:
1455 return COMPARATOR_PRIORITY
1458 leaf.type == token.STRING
1459 and previous is not None
1460 and previous.type == token.STRING
1462 return STRING_PRIORITY
1465 leaf.type == token.NAME
1466 and leaf.value == "for"
1468 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1470 return COMPREHENSION_PRIORITY
1473 leaf.type == token.NAME
1474 and leaf.value == "if"
1476 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1478 return COMPREHENSION_PRIORITY
1480 if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1481 return LOGIC_PRIORITY
1486 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1487 """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1489 Higher numbers are higher priority.
1492 is_split_before_delimiter(leaf, previous),
1493 is_split_after_delimiter(leaf, previous),
1497 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1498 """Clean the prefix of the `leaf` and generate comments from it, if any.
1500 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1501 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1502 move because it does away with modifying the grammar to include all the
1503 possible places in which comments can be placed.
1505 The sad consequence for us though is that comments don't "belong" anywhere.
1506 This is why this function generates simple parentless Leaf objects for
1507 comments. We simply don't know what the correct parent should be.
1509 No matter though, we can live without this. We really only need to
1510 differentiate between inline and standalone comments. The latter don't
1511 share the line with any code.
1513 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1514 are emitted with a fake STANDALONE_COMMENT token identifier.
1525 for index, line in enumerate(p.split("\n")):
1526 consumed += len(line) + 1 # adding the length of the split '\n'
1527 line = line.lstrip()
1530 if not line.startswith("#"):
1533 if index == 0 and leaf.type != token.ENDMARKER:
1534 comment_type = token.COMMENT # simple trailing comment
1536 comment_type = STANDALONE_COMMENT
1537 comment = make_comment(line)
1538 yield Leaf(comment_type, comment, prefix="\n" * nlines)
1540 if comment in {"# fmt: on", "# yapf: enable"}:
1541 raise FormatOn(consumed)
1543 if comment in {"# fmt: off", "# yapf: disable"}:
1544 if comment_type == STANDALONE_COMMENT:
1545 raise FormatOff(consumed)
1547 prev = preceding_leaf(leaf)
1548 if not prev or prev.type in WHITESPACE: # standalone comment in disguise
1549 raise FormatOff(consumed)
1554 def make_comment(content: str) -> str:
1555 """Return a consistently formatted comment from the given `content` string.
1557 All comments (except for "##", "#!", "#:") should have a single space between
1558 the hash sign and the content.
1560 If `content` didn't start with a hash sign, one is provided.
1562 content = content.rstrip()
1566 if content[0] == "#":
1567 content = content[1:]
1568 if content and content[0] not in " !:#":
1569 content = " " + content
1570 return "#" + content
1574 line: Line, line_length: int, inner: bool = False, py36: bool = False
1575 ) -> Iterator[Line]:
1576 """Split a `line` into potentially many lines.
1578 They should fit in the allotted `line_length` but might not be able to.
1579 `inner` signifies that there were a pair of brackets somewhere around the
1580 current `line`, possibly transitively. This means we can fallback to splitting
1581 by delimiters if the LHS/RHS don't yield any results.
1583 If `py36` is True, splitting may generate syntax that is only compatible
1584 with Python 3.6 and later.
1586 if isinstance(line, UnformattedLines) or line.is_comment:
1590 line_str = str(line).strip("\n")
1592 len(line_str) <= line_length
1593 and "\n" not in line_str # multiline strings
1594 and not line.contains_standalone_comments
1599 split_funcs: List[SplitFunc]
1601 split_funcs = [left_hand_split]
1602 elif line.inside_brackets:
1603 split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1605 split_funcs = [right_hand_split]
1606 for split_func in split_funcs:
1607 # We are accumulating lines in `result` because we might want to abort
1608 # mission and return the original line in the end, or attempt a different
1610 result: List[Line] = []
1612 for l in split_func(line, py36):
1613 if str(l).strip("\n") == line_str:
1614 raise CannotSplit("Split function returned an unchanged result")
1617 split_line(l, line_length=line_length, inner=True, py36=py36)
1619 except CannotSplit as cs:
1630 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1631 """Split line into many lines, starting with the first matching bracket pair.
1633 Note: this usually looks weird, only use this for function definitions.
1634 Prefer RHS otherwise.
1636 head = Line(depth=line.depth)
1637 body = Line(depth=line.depth + 1, inside_brackets=True)
1638 tail = Line(depth=line.depth)
1639 tail_leaves: List[Leaf] = []
1640 body_leaves: List[Leaf] = []
1641 head_leaves: List[Leaf] = []
1642 current_leaves = head_leaves
1643 matching_bracket = None
1644 for leaf in line.leaves:
1646 current_leaves is body_leaves
1647 and leaf.type in CLOSING_BRACKETS
1648 and leaf.opening_bracket is matching_bracket
1650 current_leaves = tail_leaves if body_leaves else head_leaves
1651 current_leaves.append(leaf)
1652 if current_leaves is head_leaves:
1653 if leaf.type in OPENING_BRACKETS:
1654 matching_bracket = leaf
1655 current_leaves = body_leaves
1656 # Since body is a new indent level, remove spurious leading whitespace.
1658 normalize_prefix(body_leaves[0], inside_brackets=True)
1659 # Build the new lines.
1660 for result, leaves in (
1661 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1664 result.append(leaf, preformatted=True)
1665 for comment_after in line.comments_after(leaf):
1666 result.append(comment_after, preformatted=True)
1667 bracket_split_succeeded_or_raise(head, body, tail)
1668 for result in (head, body, tail):
1673 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1674 """Split line into many lines, starting with the last matching bracket pair."""
1675 head = Line(depth=line.depth)
1676 body = Line(depth=line.depth + 1, inside_brackets=True)
1677 tail = Line(depth=line.depth)
1678 tail_leaves: List[Leaf] = []
1679 body_leaves: List[Leaf] = []
1680 head_leaves: List[Leaf] = []
1681 current_leaves = tail_leaves
1682 opening_bracket = None
1683 for leaf in reversed(line.leaves):
1684 if current_leaves is body_leaves:
1685 if leaf is opening_bracket:
1686 current_leaves = head_leaves if body_leaves else tail_leaves
1687 current_leaves.append(leaf)
1688 if current_leaves is tail_leaves:
1689 if leaf.type in CLOSING_BRACKETS:
1690 opening_bracket = leaf.opening_bracket
1691 current_leaves = body_leaves
1692 tail_leaves.reverse()
1693 body_leaves.reverse()
1694 head_leaves.reverse()
1695 # Since body is a new indent level, remove spurious leading whitespace.
1697 normalize_prefix(body_leaves[0], inside_brackets=True)
1698 # Build the new lines.
1699 for result, leaves in (
1700 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1703 result.append(leaf, preformatted=True)
1704 for comment_after in line.comments_after(leaf):
1705 result.append(comment_after, preformatted=True)
1706 bracket_split_succeeded_or_raise(head, body, tail)
1707 for result in (head, body, tail):
1712 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1713 """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1715 Do nothing otherwise.
1717 A left- or right-hand split is based on a pair of brackets. Content before
1718 (and including) the opening bracket is left on one line, content inside the
1719 brackets is put on a separate line, and finally content starting with and
1720 following the closing bracket is put on a separate line.
1722 Those are called `head`, `body`, and `tail`, respectively. If the split
1723 produced the same line (all content in `head`) or ended up with an empty `body`
1724 and the `tail` is just the closing bracket, then it's considered failed.
1726 tail_len = len(str(tail).strip())
1729 raise CannotSplit("Splitting brackets produced the same line")
1733 f"Splitting brackets on an empty body to save "
1734 f"{tail_len} characters is not worth it"
1738 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1739 """Normalize prefix of the first leaf in every line returned by `split_func`.
1741 This is a decorator over relevant split functions.
1745 def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1746 for l in split_func(line, py36):
1747 normalize_prefix(l.leaves[0], inside_brackets=True)
1750 return split_wrapper
1753 @dont_increase_indentation
1754 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1755 """Split according to delimiters of the highest priority.
1757 If `py36` is True, the split will add trailing commas also in function
1758 signatures that contain `*` and `**`.
1761 last_leaf = line.leaves[-1]
1763 raise CannotSplit("Line empty")
1765 delimiters = line.bracket_tracker.delimiters
1767 delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1768 exclude={id(last_leaf)}
1771 raise CannotSplit("No delimiters found")
1773 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1774 lowest_depth = sys.maxsize
1775 trailing_comma_safe = True
1777 def append_to_line(leaf: Leaf) -> Iterator[Line]:
1778 """Append `leaf` to current line or to new line if appending impossible."""
1779 nonlocal current_line
1781 current_line.append_safe(leaf, preformatted=True)
1782 except ValueError as ve:
1785 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1786 current_line.append(leaf)
1788 for leaf in line.leaves:
1789 yield from append_to_line(leaf)
1791 for comment_after in line.comments_after(leaf):
1792 yield from append_to_line(comment_after)
1794 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1796 leaf.bracket_depth == lowest_depth
1797 and leaf.type == token.STAR
1798 or leaf.type == token.DOUBLESTAR
1800 trailing_comma_safe = trailing_comma_safe and py36
1801 leaf_priority = delimiters.get(id(leaf))
1802 if leaf_priority == delimiter_priority:
1805 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1809 and delimiter_priority == COMMA_PRIORITY
1810 and current_line.leaves[-1].type != token.COMMA
1811 and current_line.leaves[-1].type != STANDALONE_COMMENT
1813 current_line.append(Leaf(token.COMMA, ","))
1817 @dont_increase_indentation
1818 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1819 """Split standalone comments from the rest of the line."""
1820 for leaf in line.leaves:
1821 if leaf.type == STANDALONE_COMMENT:
1822 if leaf.bracket_depth == 0:
1826 raise CannotSplit("Line does not have any standalone comments")
1828 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1830 def append_to_line(leaf: Leaf) -> Iterator[Line]:
1831 """Append `leaf` to current line or to new line if appending impossible."""
1832 nonlocal current_line
1834 current_line.append_safe(leaf, preformatted=True)
1835 except ValueError as ve:
1838 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1839 current_line.append(leaf)
1841 for leaf in line.leaves:
1842 yield from append_to_line(leaf)
1844 for comment_after in line.comments_after(leaf):
1845 yield from append_to_line(comment_after)
1851 def is_import(leaf: Leaf) -> bool:
1852 """Return True if the given leaf starts an import statement."""
1859 (v == "import" and p and p.type == syms.import_name)
1860 or (v == "from" and p and p.type == syms.import_from)
1865 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1866 """Leave existing extra newlines if not `inside_brackets`. Remove everything
1869 Note: don't use backslashes for formatting or you'll lose your voting rights.
1871 if not inside_brackets:
1872 spl = leaf.prefix.split("#")
1873 if "\\" not in spl[0]:
1874 nl_count = spl[-1].count("\n")
1877 leaf.prefix = "\n" * nl_count
1883 def normalize_string_quotes(leaf: Leaf) -> None:
1884 """Prefer double quotes but only if it doesn't cause more escaping.
1886 Adds or removes backslashes as appropriate. Doesn't parse and fix
1887 strings nested in f-strings (yet).
1889 Note: Mutates its argument.
1891 value = leaf.value.lstrip("furbFURB")
1892 if value[:3] == '"""':
1895 elif value[:3] == "'''":
1898 elif value[0] == '"':
1904 first_quote_pos = leaf.value.find(orig_quote)
1905 if first_quote_pos == -1:
1906 return # There's an internal error
1908 body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1909 new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
1910 new_quote, f"\\{new_quote}"
1912 if new_quote == '"""' and new_body[-1] == '"':
1914 new_body = new_body[:-1] + '\\"'
1915 orig_escape_count = body.count("\\")
1916 new_escape_count = new_body.count("\\")
1917 if new_escape_count > orig_escape_count:
1918 return # Do not introduce more escaping
1920 if new_escape_count == orig_escape_count and orig_quote == '"':
1921 return # Prefer double quotes
1923 prefix = leaf.value[:first_quote_pos]
1924 leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
1927 def is_python36(node: Node) -> bool:
1928 """Return True if the current file is using Python 3.6+ features.
1930 Currently looking for:
1932 - trailing commas after * or ** in function signatures.
1934 for n in node.pre_order():
1935 if n.type == token.STRING:
1936 value_head = n.value[:2] # type: ignore
1937 if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1941 n.type == syms.typedargslist
1943 and n.children[-1].type == token.COMMA
1945 for ch in n.children:
1946 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1952 PYTHON_EXTENSIONS = {".py"}
1953 BLACKLISTED_DIRECTORIES = {
1954 "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
1958 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1959 """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1960 and have one of the PYTHON_EXTENSIONS.
1962 for child in path.iterdir():
1964 if child.name in BLACKLISTED_DIRECTORIES:
1967 yield from gen_python_files_in_dir(child)
1969 elif child.suffix in PYTHON_EXTENSIONS:
1975 """Provides a reformatting counter. Can be rendered with `str(report)`."""
1977 change_count: int = 0
1979 failure_count: int = 0
1981 def done(self, src: Path, changed: bool) -> None:
1982 """Increment the counter for successful reformatting. Write out a message."""
1984 reformatted = "would reformat" if self.check else "reformatted"
1985 out(f"{reformatted} {src}")
1986 self.change_count += 1
1988 out(f"{src} already well formatted, good job.", bold=False)
1989 self.same_count += 1
1991 def failed(self, src: Path, message: str) -> None:
1992 """Increment the counter for failed reformatting. Write out a message."""
1993 err(f"error: cannot format {src}: {message}")
1994 self.failure_count += 1
1997 def return_code(self) -> int:
1998 """Return the exit code that the app should use.
2000 This considers the current state of changed files and failures:
2001 - if there were any failures, return 123;
2002 - if any files were changed and --check is being used, return 1;
2003 - otherwise return 0.
2005 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2006 # 126 we have special returncodes reserved by the shell.
2007 if self.failure_count:
2010 elif self.change_count and self.check:
2015 def __str__(self) -> str:
2016 """Render a color report of the current state.
2018 Use `click.unstyle` to remove colors.
2021 reformatted = "would be reformatted"
2022 unchanged = "would be left unchanged"
2023 failed = "would fail to reformat"
2025 reformatted = "reformatted"
2026 unchanged = "left unchanged"
2027 failed = "failed to reformat"
2029 if self.change_count:
2030 s = "s" if self.change_count > 1 else ""
2032 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2035 s = "s" if self.same_count > 1 else ""
2036 report.append(f"{self.same_count} file{s} {unchanged}")
2037 if self.failure_count:
2038 s = "s" if self.failure_count > 1 else ""
2040 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2042 return ", ".join(report) + "."
2045 def assert_equivalent(src: str, dst: str) -> None:
2046 """Raise AssertionError if `src` and `dst` aren't equivalent."""
2051 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2052 """Simple visitor generating strings to compare ASTs by content."""
2053 yield f"{' ' * depth}{node.__class__.__name__}("
2055 for field in sorted(node._fields):
2057 value = getattr(node, field)
2058 except AttributeError:
2061 yield f"{' ' * (depth+1)}{field}="
2063 if isinstance(value, list):
2065 if isinstance(item, ast.AST):
2066 yield from _v(item, depth + 2)
2068 elif isinstance(value, ast.AST):
2069 yield from _v(value, depth + 2)
2072 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
2074 yield f"{' ' * depth}) # /{node.__class__.__name__}"
2077 src_ast = ast.parse(src)
2078 except Exception as exc:
2079 major, minor = sys.version_info[:2]
2080 raise AssertionError(
2081 f"cannot use --safe with this file; failed to parse source file "
2082 f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2083 f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2087 dst_ast = ast.parse(dst)
2088 except Exception as exc:
2089 log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2090 raise AssertionError(
2091 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2092 f"Please report a bug on https://github.com/ambv/black/issues. "
2093 f"This invalid output might be helpful: {log}"
2096 src_ast_str = "\n".join(_v(src_ast))
2097 dst_ast_str = "\n".join(_v(dst_ast))
2098 if src_ast_str != dst_ast_str:
2099 log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2100 raise AssertionError(
2101 f"INTERNAL ERROR: Black produced code that is not equivalent to "
2103 f"Please report a bug on https://github.com/ambv/black/issues. "
2104 f"This diff might be helpful: {log}"
2108 def assert_stable(src: str, dst: str, line_length: int) -> None:
2109 """Raise AssertionError if `dst` reformats differently the second time."""
2110 newdst = format_str(dst, line_length=line_length)
2113 diff(src, dst, "source", "first pass"),
2114 diff(dst, newdst, "first pass", "second pass"),
2116 raise AssertionError(
2117 f"INTERNAL ERROR: Black produced different code on the second pass "
2118 f"of the formatter. "
2119 f"Please report a bug on https://github.com/ambv/black/issues. "
2120 f"This diff might be helpful: {log}"
2124 def dump_to_file(*output: str) -> str:
2125 """Dump `output` to a temporary file. Return path to the file."""
2128 with tempfile.NamedTemporaryFile(
2129 mode="w", prefix="blk_", suffix=".log", delete=False
2131 for lines in output:
2133 if lines and lines[-1] != "\n":
2138 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2139 """Return a unified diff string between strings `a` and `b`."""
2142 a_lines = [line + "\n" for line in a.split("\n")]
2143 b_lines = [line + "\n" for line in b.split("\n")]
2145 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2149 def cancel(tasks: List[asyncio.Task]) -> None:
2150 """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2156 def shutdown(loop: BaseEventLoop) -> None:
2157 """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2159 # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2160 to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2164 for task in to_cancel:
2166 loop.run_until_complete(
2167 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2170 # `concurrent.futures.Future` objects cannot be cancelled once they
2171 # are already running. There might be some when the `shutdown()` happened.
2172 # Silence their logger's spew about the event loop being closed.
2173 cf_logger = logging.getLogger("concurrent.futures")
2174 cf_logger.setLevel(logging.CRITICAL)
2178 if __name__ == "__main__":