All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
4 from asyncio.base_events import BaseEventLoop
5 from concurrent.futures import Executor, ProcessPoolExecutor
7 from functools import partial, wraps
10 from multiprocessing import Manager
12 from pathlib import Path
32 from attr import dataclass, Factory
36 from blib2to3.pytree import Node, Leaf, type_repr
37 from blib2to3 import pygram, pytree
38 from blib2to3.pgen2 import driver, token
39 from blib2to3.pgen2.parse import ParseError
41 __version__ = "18.4a0"
42 DEFAULT_LINE_LENGTH = 88
44 syms = pygram.python_symbols
52 LN = Union[Leaf, Node]
53 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
54 out = partial(click.secho, bold=True, err=True)
55 err = partial(click.secho, fg="red", err=True)
58 class NothingChanged(UserWarning):
59 """Raised by :func:`format_file` when reformatted code is the same as source."""
62 class CannotSplit(Exception):
63 """A readable split that fits the allotted line length is impossible.
65 Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
66 :func:`delimiter_split`.
70 class FormatError(Exception):
71 """Base exception for `# fmt: on` and `# fmt: off` handling.
73 It holds the number of bytes of the prefix consumed before the format
74 control comment appeared.
77 def __init__(self, consumed: int) -> None:
78 super().__init__(consumed)
79 self.consumed = consumed
81 def trim_prefix(self, leaf: Leaf) -> None:
82 leaf.prefix = leaf.prefix[self.consumed:]
84 def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
85 """Returns a new Leaf from the consumed part of the prefix."""
86 unformatted_prefix = leaf.prefix[:self.consumed]
87 return Leaf(token.NEWLINE, unformatted_prefix)
90 class FormatOn(FormatError):
91 """Found a comment like `# fmt: on` in the file."""
94 class FormatOff(FormatError):
95 """Found a comment like `# fmt: off` in the file."""
98 class WriteBack(Enum):
109 default=DEFAULT_LINE_LENGTH,
110 help="How many character per line to allow.",
117 "Don't write the files back, just return the status. Return code 0 "
118 "means nothing would change. Return code 1 means some files would be "
119 "reformatted. Return code 123 means there was an internal error."
125 help="Don't write the files back, just output a diff for each file on stdout.",
130 help="If --fast given, skip temporary sanity checks. [default: --safe]",
137 "Don't emit non-error messages to stderr. Errors are still emitted, "
138 "silence those with 2>/dev/null."
141 @click.version_option(version=__version__)
146 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
159 """The uncompromising code formatter."""
160 sources: List[Path] = []
164 sources.extend(gen_python_files_in_dir(p))
166 # if a file was explicitly given, we don't care about its extension
169 sources.append(Path("-"))
171 err(f"invalid path: {s}")
173 exc = click.ClickException("Options --check and --diff are mutually exclusive")
178 write_back = WriteBack.NO
180 write_back = WriteBack.DIFF
182 write_back = WriteBack.YES
183 if len(sources) == 0:
185 elif len(sources) == 1:
187 report = Report(check=check, quiet=quiet)
189 if not p.is_file() and str(p) == "-":
190 changed = format_stdin_to_stdout(
191 line_length=line_length, fast=fast, write_back=write_back
194 changed = format_file_in_place(
195 p, line_length=line_length, fast=fast, write_back=write_back
197 report.done(p, changed)
198 except Exception as exc:
199 report.failed(p, str(exc))
200 ctx.exit(report.return_code)
202 loop = asyncio.get_event_loop()
203 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
206 return_code = loop.run_until_complete(
208 sources, line_length, write_back, fast, quiet, loop, executor
213 ctx.exit(return_code)
216 async def schedule_formatting(
219 write_back: WriteBack,
225 """Run formatting of `sources` in parallel using the provided `executor`.
227 (Use ProcessPoolExecutors for actual parallelism.)
229 `line_length`, `write_back`, and `fast` options are passed to
230 :func:`format_file_in_place`.
233 if write_back == WriteBack.DIFF:
234 # For diff output, we need locks to ensure we don't interleave output
235 # from different processes.
237 lock = manager.Lock()
239 src: loop.run_in_executor(
240 executor, format_file_in_place, src, line_length, fast, write_back, lock
244 _task_values = list(tasks.values())
245 loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
246 loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
247 await asyncio.wait(tasks.values())
249 report = Report(check=not write_back, quiet=quiet)
250 for src, task in tasks.items():
252 report.failed(src, "timed out, cancelling")
254 cancelled.append(task)
255 elif task.cancelled():
256 cancelled.append(task)
257 elif task.exception():
258 report.failed(src, str(task.exception()))
260 report.done(src, task.result())
262 await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
264 out("All done! ✨ 🍰 ✨")
266 click.echo(str(report))
267 return report.return_code
270 def format_file_in_place(
274 write_back: WriteBack = WriteBack.NO,
275 lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
277 """Format file under `src` path. Return True if changed.
279 If `write_back` is True, write reformatted code back to stdout.
280 `line_length` and `fast` options are passed to :func:`format_file_contents`.
282 with tokenize.open(src) as src_buffer:
283 src_contents = src_buffer.read()
285 dst_contents = format_file_contents(
286 src_contents, line_length=line_length, fast=fast
288 except NothingChanged:
291 if write_back == write_back.YES:
292 with open(src, "w", encoding=src_buffer.encoding) as f:
293 f.write(dst_contents)
294 elif write_back == write_back.DIFF:
295 src_name = f"{src.name} (original)"
296 dst_name = f"{src.name} (formatted)"
297 diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
301 sys.stdout.write(diff_contents)
308 def format_stdin_to_stdout(
309 line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
311 """Format file on stdin. Return True if changed.
313 If `write_back` is True, write reformatted code back to stdout.
314 `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
316 src = sys.stdin.read()
318 dst = format_file_contents(src, line_length=line_length, fast=fast)
321 except NothingChanged:
326 if write_back == WriteBack.YES:
327 sys.stdout.write(dst)
328 elif write_back == WriteBack.DIFF:
329 src_name = "<stdin> (original)"
330 dst_name = "<stdin> (formatted)"
331 sys.stdout.write(diff(src, dst, src_name, dst_name))
334 def format_file_contents(
335 src_contents: str, line_length: int, fast: bool
337 """Reformat contents a file and return new contents.
339 If `fast` is False, additionally confirm that the reformatted code is
340 valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
341 `line_length` is passed to :func:`format_str`.
343 if src_contents.strip() == "":
346 dst_contents = format_str(src_contents, line_length=line_length)
347 if src_contents == dst_contents:
351 assert_equivalent(src_contents, dst_contents)
352 assert_stable(src_contents, dst_contents, line_length=line_length)
356 def format_str(src_contents: str, line_length: int) -> FileContent:
357 """Reformat a string and return new contents.
359 `line_length` determines how many characters per line are allowed.
361 src_node = lib2to3_parse(src_contents)
363 lines = LineGenerator()
364 elt = EmptyLineTracker()
365 py36 = is_python36(src_node)
368 for current_line in lines.visit(src_node):
369 for _ in range(after):
370 dst_contents += str(empty_line)
371 before, after = elt.maybe_empty_lines(current_line)
372 for _ in range(before):
373 dst_contents += str(empty_line)
374 for line in split_line(current_line, line_length=line_length, py36=py36):
375 dst_contents += str(line)
380 pygram.python_grammar_no_print_statement_no_exec_statement,
381 pygram.python_grammar_no_print_statement,
382 pygram.python_grammar_no_exec_statement,
383 pygram.python_grammar,
387 def lib2to3_parse(src_txt: str) -> Node:
388 """Given a string with source, return the lib2to3 Node."""
389 grammar = pygram.python_grammar_no_print_statement
390 if src_txt[-1] != "\n":
391 nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
393 for grammar in GRAMMARS:
394 drv = driver.Driver(grammar, pytree.convert)
396 result = drv.parse_string(src_txt, True)
399 except ParseError as pe:
400 lineno, column = pe.context[1]
401 lines = src_txt.splitlines()
403 faulty_line = lines[lineno - 1]
405 faulty_line = "<line number missing in source>"
406 exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
410 if isinstance(result, Leaf):
411 result = Node(syms.file_input, [result])
415 def lib2to3_unparse(node: Node) -> str:
416 """Given a lib2to3 node, return its string representation."""
424 class Visitor(Generic[T]):
425 """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
427 def visit(self, node: LN) -> Iterator[T]:
428 """Main method to visit `node` and its children.
430 It tries to find a `visit_*()` method for the given `node.type`, like
431 `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
432 If no dedicated `visit_*()` method is found, chooses `visit_default()`
435 Then yields objects of type `T` from the selected visitor.
438 name = token.tok_name[node.type]
440 name = type_repr(node.type)
441 yield from getattr(self, f"visit_{name}", self.visit_default)(node)
443 def visit_default(self, node: LN) -> Iterator[T]:
444 """Default `visit_*()` implementation. Recurses to children of `node`."""
445 if isinstance(node, Node):
446 for child in node.children:
447 yield from self.visit(child)
451 class DebugVisitor(Visitor[T]):
454 def visit_default(self, node: LN) -> Iterator[T]:
455 indent = " " * (2 * self.tree_depth)
456 if isinstance(node, Node):
457 _type = type_repr(node.type)
458 out(f"{indent}{_type}", fg="yellow")
460 for child in node.children:
461 yield from self.visit(child)
464 out(f"{indent}/{_type}", fg="yellow", bold=False)
466 _type = token.tok_name.get(node.type, str(node.type))
467 out(f"{indent}{_type}", fg="blue", nl=False)
469 # We don't have to handle prefixes for `Node` objects since
470 # that delegates to the first child anyway.
471 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
472 out(f" {node.value!r}", fg="blue", bold=False)
475 def show(cls, code: str) -> None:
476 """Pretty-print the lib2to3 AST of a given string of `code`.
478 Convenience method for debugging.
480 v: DebugVisitor[None] = DebugVisitor()
481 list(v.visit(lib2to3_parse(code)))
484 KEYWORDS = set(keyword.kwlist)
485 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
486 FLOW_CONTROL = {"return", "raise", "break", "continue"}
497 STANDALONE_COMMENT = 153
498 LOGIC_OPERATORS = {"and", "or"}
522 VARARGS = {token.STAR, token.DOUBLESTAR}
523 COMPREHENSION_PRIORITY = 20
527 COMPARATOR_PRIORITY = 3
532 class BracketTracker:
533 """Keeps track of brackets on a line."""
536 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
537 delimiters: Dict[LeafID, Priority] = Factory(dict)
538 previous: Optional[Leaf] = None
540 def mark(self, leaf: Leaf) -> None:
541 """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
543 All leaves receive an int `bracket_depth` field that stores how deep
544 within brackets a given leaf is. 0 means there are no enclosing brackets
545 that started on this line.
547 If a leaf is itself a closing bracket, it receives an `opening_bracket`
548 field that it forms a pair with. This is a one-directional link to
549 avoid reference cycles.
551 If a leaf is a delimiter (a token on which Black can split the line if
552 needed) and it's on depth 0, its `id()` is stored in the tracker's
555 if leaf.type == token.COMMENT:
558 if leaf.type in CLOSING_BRACKETS:
560 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
561 leaf.opening_bracket = opening_bracket
562 leaf.bracket_depth = self.depth
564 after_delim = is_split_after_delimiter(leaf, self.previous)
565 before_delim = is_split_before_delimiter(leaf, self.previous)
566 if after_delim > before_delim:
567 self.delimiters[id(leaf)] = after_delim
568 elif before_delim > after_delim and self.previous is not None:
569 self.delimiters[id(self.previous)] = before_delim
570 if leaf.type in OPENING_BRACKETS:
571 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
575 def any_open_brackets(self) -> bool:
576 """Return True if there is an yet unmatched open bracket on the line."""
577 return bool(self.bracket_match)
579 def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
580 """Return the highest priority of a delimiter found on the line.
582 Values are consistent with what `is_delimiter()` returns.
584 return max(v for k, v in self.delimiters.items() if k not in exclude)
589 """Holds leaves and comments. Can be printed with `str(line)`."""
592 leaves: List[Leaf] = Factory(list)
593 comments: List[Tuple[Index, Leaf]] = Factory(list)
594 bracket_tracker: BracketTracker = Factory(BracketTracker)
595 inside_brackets: bool = False
596 has_for: bool = False
597 _for_loop_variable: bool = False
599 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
600 """Add a new `leaf` to the end of the line.
602 Unless `preformatted` is True, the `leaf` will receive a new consistent
603 whitespace prefix and metadata applied by :class:`BracketTracker`.
604 Trailing commas are maybe removed, unpacked for loop variables are
605 demoted from being delimiters.
607 Inline comments are put aside.
609 has_value = leaf.value.strip()
613 if self.leaves and not preformatted:
614 # Note: at this point leaf.prefix should be empty except for
615 # imports, for which we only preserve newlines.
616 leaf.prefix += whitespace(leaf)
617 if self.inside_brackets or not preformatted:
618 self.maybe_decrement_after_for_loop_variable(leaf)
619 self.bracket_tracker.mark(leaf)
620 self.maybe_remove_trailing_comma(leaf)
621 self.maybe_increment_for_loop_variable(leaf)
623 if not self.append_comment(leaf):
624 self.leaves.append(leaf)
626 def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
627 """Like :func:`append()` but disallow invalid standalone comment structure.
629 Raises ValueError when any `leaf` is appended after a standalone comment
630 or when a standalone comment is not the first leaf on the line.
632 if self.bracket_tracker.depth == 0:
634 raise ValueError("cannot append to standalone comments")
636 if self.leaves and leaf.type == STANDALONE_COMMENT:
638 "cannot append standalone comments to a populated line"
641 self.append(leaf, preformatted=preformatted)
644 def is_comment(self) -> bool:
645 """Is this line a standalone comment?"""
646 return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
649 def is_decorator(self) -> bool:
650 """Is this line a decorator?"""
651 return bool(self) and self.leaves[0].type == token.AT
654 def is_import(self) -> bool:
655 """Is this an import line?"""
656 return bool(self) and is_import(self.leaves[0])
659 def is_class(self) -> bool:
660 """Is this line a class definition?"""
663 and self.leaves[0].type == token.NAME
664 and self.leaves[0].value == "class"
668 def is_def(self) -> bool:
669 """Is this a function definition? (Also returns True for async defs.)"""
671 first_leaf = self.leaves[0]
676 second_leaf: Optional[Leaf] = self.leaves[1]
680 (first_leaf.type == token.NAME and first_leaf.value == "def")
682 first_leaf.type == token.ASYNC
683 and second_leaf is not None
684 and second_leaf.type == token.NAME
685 and second_leaf.value == "def"
690 def is_flow_control(self) -> bool:
691 """Is this line a flow control statement?
693 Those are `return`, `raise`, `break`, and `continue`.
697 and self.leaves[0].type == token.NAME
698 and self.leaves[0].value in FLOW_CONTROL
702 def is_yield(self) -> bool:
703 """Is this line a yield statement?"""
706 and self.leaves[0].type == token.NAME
707 and self.leaves[0].value == "yield"
711 def contains_standalone_comments(self) -> bool:
712 """If so, needs to be split before emitting."""
713 for leaf in self.leaves:
714 if leaf.type == STANDALONE_COMMENT:
719 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
720 """Remove trailing comma if there is one and it's safe."""
723 and self.leaves[-1].type == token.COMMA
724 and closing.type in CLOSING_BRACKETS
728 if closing.type == token.RBRACE:
729 self.remove_trailing_comma()
732 if closing.type == token.RSQB:
733 comma = self.leaves[-1]
734 if comma.parent and comma.parent.type == syms.listmaker:
735 self.remove_trailing_comma()
738 # For parens let's check if it's safe to remove the comma. If the
739 # trailing one is the only one, we might mistakenly change a tuple
740 # into a different type by removing the comma.
741 depth = closing.bracket_depth + 1
743 opening = closing.opening_bracket
744 for _opening_index, leaf in enumerate(self.leaves):
751 for leaf in self.leaves[_opening_index + 1:]:
755 bracket_depth = leaf.bracket_depth
756 if bracket_depth == depth and leaf.type == token.COMMA:
758 if leaf.parent and leaf.parent.type == syms.arglist:
763 self.remove_trailing_comma()
768 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
769 """In a for loop, or comprehension, the variables are often unpacks.
771 To avoid splitting on the comma in this situation, increase the depth of
772 tokens between `for` and `in`.
774 if leaf.type == token.NAME and leaf.value == "for":
776 self.bracket_tracker.depth += 1
777 self._for_loop_variable = True
782 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
783 """See `maybe_increment_for_loop_variable` above for explanation."""
784 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
785 self.bracket_tracker.depth -= 1
786 self._for_loop_variable = False
791 def append_comment(self, comment: Leaf) -> bool:
792 """Add an inline or standalone comment to the line."""
794 comment.type == STANDALONE_COMMENT
795 and self.bracket_tracker.any_open_brackets()
800 if comment.type != token.COMMENT:
803 after = len(self.leaves) - 1
805 comment.type = STANDALONE_COMMENT
810 self.comments.append((after, comment))
813 def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
814 """Generate comments that should appear directly after `leaf`."""
815 for _leaf_index, _leaf in enumerate(self.leaves):
822 for index, comment_after in self.comments:
823 if _leaf_index == index:
826 def remove_trailing_comma(self) -> None:
827 """Remove the trailing comma and moves the comments attached to it."""
828 comma_index = len(self.leaves) - 1
829 for i in range(len(self.comments)):
830 comment_index, comment = self.comments[i]
831 if comment_index == comma_index:
832 self.comments[i] = (comma_index - 1, comment)
835 def __str__(self) -> str:
836 """Render the line."""
840 indent = " " * self.depth
841 leaves = iter(self.leaves)
843 res = f"{first.prefix}{indent}{first.value}"
846 for _, comment in self.comments:
850 def __bool__(self) -> bool:
851 """Return True if the line has leaves or comments."""
852 return bool(self.leaves or self.comments)
855 class UnformattedLines(Line):
856 """Just like :class:`Line` but stores lines which aren't reformatted."""
858 def append(self, leaf: Leaf, preformatted: bool = True) -> None:
859 """Just add a new `leaf` to the end of the lines.
861 The `preformatted` argument is ignored.
863 Keeps track of indentation `depth`, which is useful when the user
864 says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
867 list(generate_comments(leaf))
868 except FormatOn as f_on:
869 self.leaves.append(f_on.leaf_from_consumed(leaf))
872 self.leaves.append(leaf)
873 if leaf.type == token.INDENT:
875 elif leaf.type == token.DEDENT:
878 def __str__(self) -> str:
879 """Render unformatted lines from leaves which were added with `append()`.
881 `depth` is not used for indentation in this case.
887 for leaf in self.leaves:
891 def append_comment(self, comment: Leaf) -> bool:
892 """Not implemented in this class. Raises `NotImplementedError`."""
893 raise NotImplementedError("Unformatted lines don't store comments separately.")
895 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
896 """Does nothing and returns False."""
899 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
900 """Does nothing and returns False."""
905 class EmptyLineTracker:
906 """Provides a stateful method that returns the number of potential extra
907 empty lines needed before and after the currently processed line.
909 Note: this tracker works on lines that haven't been split yet. It assumes
910 the prefix of the first leaf consists of optional newlines. Those newlines
911 are consumed by `maybe_empty_lines()` and included in the computation.
913 previous_line: Optional[Line] = None
914 previous_after: int = 0
915 previous_defs: List[int] = Factory(list)
917 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
918 """Return the number of extra empty lines before and after the `current_line`.
920 This is for separating `def`, `async def` and `class` with extra empty
921 lines (two on module-level), as well as providing an extra empty line
922 after flow control keywords to make them more prominent.
924 if isinstance(current_line, UnformattedLines):
927 before, after = self._maybe_empty_lines(current_line)
928 before -= self.previous_after
929 self.previous_after = after
930 self.previous_line = current_line
933 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
935 if current_line.depth == 0:
937 if current_line.leaves:
938 # Consume the first leaf's extra newlines.
939 first_leaf = current_line.leaves[0]
940 before = first_leaf.prefix.count("\n")
941 before = min(before, max_allowed)
942 first_leaf.prefix = ""
945 depth = current_line.depth
946 while self.previous_defs and self.previous_defs[-1] >= depth:
947 self.previous_defs.pop()
948 before = 1 if depth else 2
949 is_decorator = current_line.is_decorator
950 if is_decorator or current_line.is_def or current_line.is_class:
952 self.previous_defs.append(depth)
953 if self.previous_line is None:
954 # Don't insert empty lines before the first line in the file.
957 if self.previous_line and self.previous_line.is_decorator:
958 # Don't insert empty lines between decorators.
962 if current_line.depth:
966 if current_line.is_flow_control:
971 and self.previous_line.is_import
972 and not current_line.is_import
973 and depth == self.previous_line.depth
975 return (before or 1), 0
979 and self.previous_line.is_yield
980 and (not current_line.is_yield or depth != self.previous_line.depth)
982 return (before or 1), 0
988 class LineGenerator(Visitor[Line]):
989 """Generates reformatted Line objects. Empty lines are not emitted.
991 Note: destroys the tree it's visiting by mutating prefixes of its leaves
992 in ways that will no longer stringify to valid Python code on the tree.
994 current_line: Line = Factory(Line)
996 def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
999 If the line is empty, only emit if it makes sense.
1000 If the line is too long, split it first and then generate.
1002 If any lines were generated, set up a new current_line.
1004 if not self.current_line:
1005 if self.current_line.__class__ == type:
1006 self.current_line.depth += indent
1008 self.current_line = type(depth=self.current_line.depth + indent)
1009 return # Line is empty, don't emit. Creating a new one unnecessary.
1011 complete_line = self.current_line
1012 self.current_line = type(depth=complete_line.depth + indent)
1015 def visit(self, node: LN) -> Iterator[Line]:
1016 """Main method to visit `node` and its children.
1018 Yields :class:`Line` objects.
1020 if isinstance(self.current_line, UnformattedLines):
1021 # File contained `# fmt: off`
1022 yield from self.visit_unformatted(node)
1025 yield from super().visit(node)
1027 def visit_default(self, node: LN) -> Iterator[Line]:
1028 """Default `visit_*()` implementation. Recurses to children of `node`."""
1029 if isinstance(node, Leaf):
1030 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1032 for comment in generate_comments(node):
1033 if any_open_brackets:
1034 # any comment within brackets is subject to splitting
1035 self.current_line.append(comment)
1036 elif comment.type == token.COMMENT:
1037 # regular trailing comment
1038 self.current_line.append(comment)
1039 yield from self.line()
1042 # regular standalone comment
1043 yield from self.line()
1045 self.current_line.append(comment)
1046 yield from self.line()
1048 except FormatOff as f_off:
1049 f_off.trim_prefix(node)
1050 yield from self.line(type=UnformattedLines)
1051 yield from self.visit(node)
1053 except FormatOn as f_on:
1054 # This only happens here if somebody says "fmt: on" multiple
1056 f_on.trim_prefix(node)
1057 yield from self.visit_default(node)
1060 normalize_prefix(node, inside_brackets=any_open_brackets)
1061 if node.type == token.STRING:
1062 normalize_string_quotes(node)
1063 if node.type not in WHITESPACE:
1064 self.current_line.append(node)
1065 yield from super().visit_default(node)
1067 def visit_INDENT(self, node: Node) -> Iterator[Line]:
1068 """Increase indentation level, maybe yield a line."""
1069 # In blib2to3 INDENT never holds comments.
1070 yield from self.line(+1)
1071 yield from self.visit_default(node)
1073 def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1074 """Decrease indentation level, maybe yield a line."""
1075 # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
1076 yield from self.line(-1)
1078 def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
1079 """Visit a statement.
1081 This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1082 `def`, `with`, and `class`.
1084 The relevant Python language `keywords` for a given statement will be NAME
1085 leaves within it. This methods puts those on a separate line.
1087 for child in node.children:
1088 if child.type == token.NAME and child.value in keywords: # type: ignore
1089 yield from self.line()
1091 yield from self.visit(child)
1093 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1094 """Visit a statement without nested statements."""
1095 is_suite_like = node.parent and node.parent.type in STATEMENT
1097 yield from self.line(+1)
1098 yield from self.visit_default(node)
1099 yield from self.line(-1)
1102 yield from self.line()
1103 yield from self.visit_default(node)
1105 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1106 """Visit `async def`, `async for`, `async with`."""
1107 yield from self.line()
1109 children = iter(node.children)
1110 for child in children:
1111 yield from self.visit(child)
1113 if child.type == token.ASYNC:
1116 internal_stmt = next(children)
1117 for child in internal_stmt.children:
1118 yield from self.visit(child)
1120 def visit_decorators(self, node: Node) -> Iterator[Line]:
1121 """Visit decorators."""
1122 for child in node.children:
1123 yield from self.line()
1124 yield from self.visit(child)
1126 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1127 """Remove a semicolon and put the other statement on a separate line."""
1128 yield from self.line()
1130 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1131 """End of file. Process outstanding comments and end with a newline."""
1132 yield from self.visit_default(leaf)
1133 yield from self.line()
1135 def visit_unformatted(self, node: LN) -> Iterator[Line]:
1136 """Used when file contained a `# fmt: off`."""
1137 if isinstance(node, Node):
1138 for child in node.children:
1139 yield from self.visit(child)
1143 self.current_line.append(node)
1144 except FormatOn as f_on:
1145 f_on.trim_prefix(node)
1146 yield from self.line()
1147 yield from self.visit(node)
1149 if node.type == token.ENDMARKER:
1150 # somebody decided not to put a final `# fmt: on`
1151 yield from self.line()
1153 def __attrs_post_init__(self) -> None:
1154 """You are in a twisty little maze of passages."""
1156 self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"})
1157 self.visit_while_stmt = partial(v, keywords={"while", "else"})
1158 self.visit_for_stmt = partial(v, keywords={"for", "else"})
1159 self.visit_try_stmt = partial(v, keywords={"try", "except", "else", "finally"})
1160 self.visit_except_clause = partial(v, keywords={"except"})
1161 self.visit_funcdef = partial(v, keywords={"def"})
1162 self.visit_with_stmt = partial(v, keywords={"with"})
1163 self.visit_classdef = partial(v, keywords={"class"})
1164 self.visit_async_funcdef = self.visit_async_stmt
1165 self.visit_decorated = self.visit_decorators
1168 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1169 OPENING_BRACKETS = set(BRACKET.keys())
1170 CLOSING_BRACKETS = set(BRACKET.values())
1171 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1172 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1175 def whitespace(leaf: Leaf) -> str: # noqa C901
1176 """Return whitespace prefix if needed for the given `leaf`."""
1183 if t in ALWAYS_NO_SPACE:
1186 if t == token.COMMENT:
1189 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1190 if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
1193 prev = leaf.prev_sibling
1195 prevp = preceding_leaf(p)
1196 if not prevp or prevp.type in OPENING_BRACKETS:
1199 if t == token.COLON:
1200 return SPACE if prevp.type == token.COMMA else NO
1202 if prevp.type == token.EQUAL:
1204 if prevp.parent.type in {
1205 syms.arglist, syms.argument, syms.parameters, syms.varargslist
1209 elif prevp.parent.type == syms.typedargslist:
1210 # A bit hacky: if the equal sign has whitespace, it means we
1211 # previously found it's a typed argument. So, we're using
1215 elif prevp.type == token.DOUBLESTAR:
1216 if prevp.parent and prevp.parent.type in {
1226 elif prevp.type == token.COLON:
1227 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1232 and prevp.parent.type in {syms.factor, syms.star_expr}
1233 and prevp.type in MATH_OPERATORS
1238 prevp.type == token.RIGHTSHIFT
1240 and prevp.parent.type == syms.shift_expr
1241 and prevp.prev_sibling
1242 and prevp.prev_sibling.type == token.NAME
1243 and prevp.prev_sibling.value == "print" # type: ignore
1245 # Python 2 print chevron
1248 elif prev.type in OPENING_BRACKETS:
1251 if p.type in {syms.parameters, syms.arglist}:
1252 # untyped function signatures or calls
1256 if not prev or prev.type != token.COMMA:
1259 elif p.type == syms.varargslist:
1264 if prev and prev.type != token.COMMA:
1267 elif p.type == syms.typedargslist:
1268 # typed function signatures
1272 if t == token.EQUAL:
1273 if prev.type != syms.tname:
1276 elif prev.type == token.EQUAL:
1277 # A bit hacky: if the equal sign has whitespace, it means we
1278 # previously found it's a typed argument. So, we're using that, too.
1281 elif prev.type != token.COMMA:
1284 elif p.type == syms.tname:
1287 prevp = preceding_leaf(p)
1288 if not prevp or prevp.type != token.COMMA:
1291 elif p.type == syms.trailer:
1292 # attributes and calls
1293 if t == token.LPAR or t == token.RPAR:
1298 prevp = preceding_leaf(p)
1299 if not prevp or prevp.type != token.NUMBER:
1302 elif t == token.LSQB:
1305 elif prev.type != token.COMMA:
1308 elif p.type == syms.argument:
1310 if t == token.EQUAL:
1314 prevp = preceding_leaf(p)
1315 if not prevp or prevp.type == token.LPAR:
1318 elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
1321 elif p.type == syms.decorator:
1325 elif p.type == syms.dotted_name:
1329 prevp = preceding_leaf(p)
1330 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1333 elif p.type == syms.classdef:
1337 if prev and prev.type == token.LPAR:
1340 elif p.type == syms.subscript:
1343 assert p.parent is not None, "subscripts are always parented"
1344 if p.parent.type == syms.subscriptlist:
1352 elif p.type == syms.atom:
1353 if prev and t == token.DOT:
1354 # dots, but not the first one.
1358 p.type == syms.listmaker
1359 or p.type == syms.testlist_gexp
1360 or p.type == syms.subscriptlist
1362 # list interior, including unpacking
1366 elif p.type == syms.dictsetmaker:
1367 # dict and set interior, including unpacking
1371 if prev.type == token.DOUBLESTAR:
1374 elif p.type in {syms.factor, syms.star_expr}:
1377 prevp = preceding_leaf(p)
1378 if not prevp or prevp.type in OPENING_BRACKETS:
1381 prevp_parent = prevp.parent
1382 assert prevp_parent is not None
1383 if prevp.type == token.COLON and prevp_parent.type in {
1384 syms.subscript, syms.sliceop
1388 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1391 elif t == token.NAME or t == token.NUMBER:
1394 elif p.type == syms.import_from:
1396 if prev and prev.type == token.DOT:
1399 elif t == token.NAME:
1403 if prev and prev.type == token.DOT:
1406 elif p.type == syms.sliceop:
1412 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1413 """Return the first leaf that precedes `node`, if any."""
1415 res = node.prev_sibling
1417 if isinstance(res, Leaf):
1421 return list(res.leaves())[-1]
1430 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1431 """Return the priority of the `leaf` delimiter, given a line break after it.
1433 The delimiter priorities returned here are from those delimiters that would
1434 cause a line break after themselves.
1436 Higher numbers are higher priority.
1438 if leaf.type == token.COMMA:
1439 return COMMA_PRIORITY
1442 leaf.type in VARARGS
1444 and leaf.parent.type in {syms.argument, syms.typedargslist}
1446 return MATH_PRIORITY
1451 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1452 """Return the priority of the `leaf` delimiter, given a line before after it.
1454 The delimiter priorities returned here are from those delimiters that would
1455 cause a line break before themselves.
1457 Higher numbers are higher priority.
1460 leaf.type in MATH_OPERATORS
1462 and leaf.parent.type not in {syms.factor, syms.star_expr}
1464 return MATH_PRIORITY
1466 if leaf.type in COMPARATORS:
1467 return COMPARATOR_PRIORITY
1470 leaf.type == token.STRING
1471 and previous is not None
1472 and previous.type == token.STRING
1474 return STRING_PRIORITY
1477 leaf.type == token.NAME
1478 and leaf.value == "for"
1480 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1482 return COMPREHENSION_PRIORITY
1485 leaf.type == token.NAME
1486 and leaf.value == "if"
1488 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1490 return COMPREHENSION_PRIORITY
1492 if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
1493 return LOGIC_PRIORITY
1498 def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1499 """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
1501 Higher numbers are higher priority.
1504 is_split_before_delimiter(leaf, previous),
1505 is_split_after_delimiter(leaf, previous),
1509 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1510 """Clean the prefix of the `leaf` and generate comments from it, if any.
1512 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1513 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1514 move because it does away with modifying the grammar to include all the
1515 possible places in which comments can be placed.
1517 The sad consequence for us though is that comments don't "belong" anywhere.
1518 This is why this function generates simple parentless Leaf objects for
1519 comments. We simply don't know what the correct parent should be.
1521 No matter though, we can live without this. We really only need to
1522 differentiate between inline and standalone comments. The latter don't
1523 share the line with any code.
1525 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1526 are emitted with a fake STANDALONE_COMMENT token identifier.
1537 for index, line in enumerate(p.split("\n")):
1538 consumed += len(line) + 1 # adding the length of the split '\n'
1539 line = line.lstrip()
1542 if not line.startswith("#"):
1545 if index == 0 and leaf.type != token.ENDMARKER:
1546 comment_type = token.COMMENT # simple trailing comment
1548 comment_type = STANDALONE_COMMENT
1549 comment = make_comment(line)
1550 yield Leaf(comment_type, comment, prefix="\n" * nlines)
1552 if comment in {"# fmt: on", "# yapf: enable"}:
1553 raise FormatOn(consumed)
1555 if comment in {"# fmt: off", "# yapf: disable"}:
1556 if comment_type == STANDALONE_COMMENT:
1557 raise FormatOff(consumed)
1559 prev = preceding_leaf(leaf)
1560 if not prev or prev.type in WHITESPACE: # standalone comment in disguise
1561 raise FormatOff(consumed)
1566 def make_comment(content: str) -> str:
1567 """Return a consistently formatted comment from the given `content` string.
1569 All comments (except for "##", "#!", "#:") should have a single space between
1570 the hash sign and the content.
1572 If `content` didn't start with a hash sign, one is provided.
1574 content = content.rstrip()
1578 if content[0] == "#":
1579 content = content[1:]
1580 if content and content[0] not in " !:#":
1581 content = " " + content
1582 return "#" + content
1586 line: Line, line_length: int, inner: bool = False, py36: bool = False
1587 ) -> Iterator[Line]:
1588 """Split a `line` into potentially many lines.
1590 They should fit in the allotted `line_length` but might not be able to.
1591 `inner` signifies that there were a pair of brackets somewhere around the
1592 current `line`, possibly transitively. This means we can fallback to splitting
1593 by delimiters if the LHS/RHS don't yield any results.
1595 If `py36` is True, splitting may generate syntax that is only compatible
1596 with Python 3.6 and later.
1598 if isinstance(line, UnformattedLines) or line.is_comment:
1602 line_str = str(line).strip("\n")
1604 len(line_str) <= line_length
1605 and "\n" not in line_str # multiline strings
1606 and not line.contains_standalone_comments
1611 split_funcs: List[SplitFunc]
1613 split_funcs = [left_hand_split]
1614 elif line.inside_brackets:
1615 split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
1617 split_funcs = [right_hand_split]
1618 for split_func in split_funcs:
1619 # We are accumulating lines in `result` because we might want to abort
1620 # mission and return the original line in the end, or attempt a different
1622 result: List[Line] = []
1624 for l in split_func(line, py36):
1625 if str(l).strip("\n") == line_str:
1626 raise CannotSplit("Split function returned an unchanged result")
1629 split_line(l, line_length=line_length, inner=True, py36=py36)
1631 except CannotSplit as cs:
1642 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1643 """Split line into many lines, starting with the first matching bracket pair.
1645 Note: this usually looks weird, only use this for function definitions.
1646 Prefer RHS otherwise.
1648 head = Line(depth=line.depth)
1649 body = Line(depth=line.depth + 1, inside_brackets=True)
1650 tail = Line(depth=line.depth)
1651 tail_leaves: List[Leaf] = []
1652 body_leaves: List[Leaf] = []
1653 head_leaves: List[Leaf] = []
1654 current_leaves = head_leaves
1655 matching_bracket = None
1656 for leaf in line.leaves:
1658 current_leaves is body_leaves
1659 and leaf.type in CLOSING_BRACKETS
1660 and leaf.opening_bracket is matching_bracket
1662 current_leaves = tail_leaves if body_leaves else head_leaves
1663 current_leaves.append(leaf)
1664 if current_leaves is head_leaves:
1665 if leaf.type in OPENING_BRACKETS:
1666 matching_bracket = leaf
1667 current_leaves = body_leaves
1668 # Since body is a new indent level, remove spurious leading whitespace.
1670 normalize_prefix(body_leaves[0], inside_brackets=True)
1671 # Build the new lines.
1672 for result, leaves in (
1673 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1676 result.append(leaf, preformatted=True)
1677 for comment_after in line.comments_after(leaf):
1678 result.append(comment_after, preformatted=True)
1679 bracket_split_succeeded_or_raise(head, body, tail)
1680 for result in (head, body, tail):
1685 def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1686 """Split line into many lines, starting with the last matching bracket pair."""
1687 head = Line(depth=line.depth)
1688 body = Line(depth=line.depth + 1, inside_brackets=True)
1689 tail = Line(depth=line.depth)
1690 tail_leaves: List[Leaf] = []
1691 body_leaves: List[Leaf] = []
1692 head_leaves: List[Leaf] = []
1693 current_leaves = tail_leaves
1694 opening_bracket = None
1695 for leaf in reversed(line.leaves):
1696 if current_leaves is body_leaves:
1697 if leaf is opening_bracket:
1698 current_leaves = head_leaves if body_leaves else tail_leaves
1699 current_leaves.append(leaf)
1700 if current_leaves is tail_leaves:
1701 if leaf.type in CLOSING_BRACKETS:
1702 opening_bracket = leaf.opening_bracket
1703 current_leaves = body_leaves
1704 tail_leaves.reverse()
1705 body_leaves.reverse()
1706 head_leaves.reverse()
1707 # Since body is a new indent level, remove spurious leading whitespace.
1709 normalize_prefix(body_leaves[0], inside_brackets=True)
1710 # Build the new lines.
1711 for result, leaves in (
1712 (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
1715 result.append(leaf, preformatted=True)
1716 for comment_after in line.comments_after(leaf):
1717 result.append(comment_after, preformatted=True)
1718 bracket_split_succeeded_or_raise(head, body, tail)
1719 for result in (head, body, tail):
1724 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
1725 """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
1727 Do nothing otherwise.
1729 A left- or right-hand split is based on a pair of brackets. Content before
1730 (and including) the opening bracket is left on one line, content inside the
1731 brackets is put on a separate line, and finally content starting with and
1732 following the closing bracket is put on a separate line.
1734 Those are called `head`, `body`, and `tail`, respectively. If the split
1735 produced the same line (all content in `head`) or ended up with an empty `body`
1736 and the `tail` is just the closing bracket, then it's considered failed.
1738 tail_len = len(str(tail).strip())
1741 raise CannotSplit("Splitting brackets produced the same line")
1745 f"Splitting brackets on an empty body to save "
1746 f"{tail_len} characters is not worth it"
1750 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
1751 """Normalize prefix of the first leaf in every line returned by `split_func`.
1753 This is a decorator over relevant split functions.
1757 def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
1758 for l in split_func(line, py36):
1759 normalize_prefix(l.leaves[0], inside_brackets=True)
1762 return split_wrapper
1765 @dont_increase_indentation
1766 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
1767 """Split according to delimiters of the highest priority.
1769 If `py36` is True, the split will add trailing commas also in function
1770 signatures that contain `*` and `**`.
1773 last_leaf = line.leaves[-1]
1775 raise CannotSplit("Line empty")
1777 delimiters = line.bracket_tracker.delimiters
1779 delimiter_priority = line.bracket_tracker.max_delimiter_priority(
1780 exclude={id(last_leaf)}
1783 raise CannotSplit("No delimiters found")
1785 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1786 lowest_depth = sys.maxsize
1787 trailing_comma_safe = True
1789 def append_to_line(leaf: Leaf) -> Iterator[Line]:
1790 """Append `leaf` to current line or to new line if appending impossible."""
1791 nonlocal current_line
1793 current_line.append_safe(leaf, preformatted=True)
1794 except ValueError as ve:
1797 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1798 current_line.append(leaf)
1800 for leaf in line.leaves:
1801 yield from append_to_line(leaf)
1803 for comment_after in line.comments_after(leaf):
1804 yield from append_to_line(comment_after)
1806 lowest_depth = min(lowest_depth, leaf.bracket_depth)
1808 leaf.bracket_depth == lowest_depth
1809 and leaf.type == token.STAR
1810 or leaf.type == token.DOUBLESTAR
1812 trailing_comma_safe = trailing_comma_safe and py36
1813 leaf_priority = delimiters.get(id(leaf))
1814 if leaf_priority == delimiter_priority:
1817 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1821 and delimiter_priority == COMMA_PRIORITY
1822 and current_line.leaves[-1].type != token.COMMA
1823 and current_line.leaves[-1].type != STANDALONE_COMMENT
1825 current_line.append(Leaf(token.COMMA, ","))
1829 @dont_increase_indentation
1830 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
1831 """Split standalone comments from the rest of the line."""
1832 for leaf in line.leaves:
1833 if leaf.type == STANDALONE_COMMENT:
1834 if leaf.bracket_depth == 0:
1838 raise CannotSplit("Line does not have any standalone comments")
1840 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1842 def append_to_line(leaf: Leaf) -> Iterator[Line]:
1843 """Append `leaf` to current line or to new line if appending impossible."""
1844 nonlocal current_line
1846 current_line.append_safe(leaf, preformatted=True)
1847 except ValueError as ve:
1850 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
1851 current_line.append(leaf)
1853 for leaf in line.leaves:
1854 yield from append_to_line(leaf)
1856 for comment_after in line.comments_after(leaf):
1857 yield from append_to_line(comment_after)
1863 def is_import(leaf: Leaf) -> bool:
1864 """Return True if the given leaf starts an import statement."""
1871 (v == "import" and p and p.type == syms.import_name)
1872 or (v == "from" and p and p.type == syms.import_from)
1877 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
1878 """Leave existing extra newlines if not `inside_brackets`. Remove everything
1881 Note: don't use backslashes for formatting or you'll lose your voting rights.
1883 if not inside_brackets:
1884 spl = leaf.prefix.split("#")
1885 if "\\" not in spl[0]:
1886 nl_count = spl[-1].count("\n")
1889 leaf.prefix = "\n" * nl_count
1895 def normalize_string_quotes(leaf: Leaf) -> None:
1896 """Prefer double quotes but only if it doesn't cause more escaping.
1898 Adds or removes backslashes as appropriate. Doesn't parse and fix
1899 strings nested in f-strings (yet).
1901 Note: Mutates its argument.
1903 value = leaf.value.lstrip("furbFURB")
1904 if value[:3] == '"""':
1907 elif value[:3] == "'''":
1910 elif value[0] == '"':
1916 first_quote_pos = leaf.value.find(orig_quote)
1917 if first_quote_pos == -1:
1918 return # There's an internal error
1920 body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
1921 new_body = body.replace(f"\\{orig_quote}", orig_quote).replace(
1922 new_quote, f"\\{new_quote}"
1924 if new_quote == '"""' and new_body[-1] == '"':
1926 new_body = new_body[:-1] + '\\"'
1927 orig_escape_count = body.count("\\")
1928 new_escape_count = new_body.count("\\")
1929 if new_escape_count > orig_escape_count:
1930 return # Do not introduce more escaping
1932 if new_escape_count == orig_escape_count and orig_quote == '"':
1933 return # Prefer double quotes
1935 prefix = leaf.value[:first_quote_pos]
1936 leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
1939 def is_python36(node: Node) -> bool:
1940 """Return True if the current file is using Python 3.6+ features.
1942 Currently looking for:
1944 - trailing commas after * or ** in function signatures.
1946 for n in node.pre_order():
1947 if n.type == token.STRING:
1948 value_head = n.value[:2] # type: ignore
1949 if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
1953 n.type == syms.typedargslist
1955 and n.children[-1].type == token.COMMA
1957 for ch in n.children:
1958 if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
1964 PYTHON_EXTENSIONS = {".py"}
1965 BLACKLISTED_DIRECTORIES = {
1966 "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
1970 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
1971 """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
1972 and have one of the PYTHON_EXTENSIONS.
1974 for child in path.iterdir():
1976 if child.name in BLACKLISTED_DIRECTORIES:
1979 yield from gen_python_files_in_dir(child)
1981 elif child.suffix in PYTHON_EXTENSIONS:
1987 """Provides a reformatting counter. Can be rendered with `str(report)`."""
1990 change_count: int = 0
1992 failure_count: int = 0
1994 def done(self, src: Path, changed: bool) -> None:
1995 """Increment the counter for successful reformatting. Write out a message."""
1997 reformatted = "would reformat" if self.check else "reformatted"
1999 out(f"{reformatted} {src}")
2000 self.change_count += 1
2003 out(f"{src} already well formatted, good job.", bold=False)
2004 self.same_count += 1
2006 def failed(self, src: Path, message: str) -> None:
2007 """Increment the counter for failed reformatting. Write out a message."""
2008 err(f"error: cannot format {src}: {message}")
2009 self.failure_count += 1
2012 def return_code(self) -> int:
2013 """Return the exit code that the app should use.
2015 This considers the current state of changed files and failures:
2016 - if there were any failures, return 123;
2017 - if any files were changed and --check is being used, return 1;
2018 - otherwise return 0.
2020 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2021 # 126 we have special returncodes reserved by the shell.
2022 if self.failure_count:
2025 elif self.change_count and self.check:
2030 def __str__(self) -> str:
2031 """Render a color report of the current state.
2033 Use `click.unstyle` to remove colors.
2036 reformatted = "would be reformatted"
2037 unchanged = "would be left unchanged"
2038 failed = "would fail to reformat"
2040 reformatted = "reformatted"
2041 unchanged = "left unchanged"
2042 failed = "failed to reformat"
2044 if self.change_count:
2045 s = "s" if self.change_count > 1 else ""
2047 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2050 s = "s" if self.same_count > 1 else ""
2051 report.append(f"{self.same_count} file{s} {unchanged}")
2052 if self.failure_count:
2053 s = "s" if self.failure_count > 1 else ""
2055 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2057 return ", ".join(report) + "."
2060 def assert_equivalent(src: str, dst: str) -> None:
2061 """Raise AssertionError if `src` and `dst` aren't equivalent."""
2066 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2067 """Simple visitor generating strings to compare ASTs by content."""
2068 yield f"{' ' * depth}{node.__class__.__name__}("
2070 for field in sorted(node._fields):
2072 value = getattr(node, field)
2073 except AttributeError:
2076 yield f"{' ' * (depth+1)}{field}="
2078 if isinstance(value, list):
2080 if isinstance(item, ast.AST):
2081 yield from _v(item, depth + 2)
2083 elif isinstance(value, ast.AST):
2084 yield from _v(value, depth + 2)
2087 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
2089 yield f"{' ' * depth}) # /{node.__class__.__name__}"
2092 src_ast = ast.parse(src)
2093 except Exception as exc:
2094 major, minor = sys.version_info[:2]
2095 raise AssertionError(
2096 f"cannot use --safe with this file; failed to parse source file "
2097 f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2098 f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2102 dst_ast = ast.parse(dst)
2103 except Exception as exc:
2104 log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2105 raise AssertionError(
2106 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2107 f"Please report a bug on https://github.com/ambv/black/issues. "
2108 f"This invalid output might be helpful: {log}"
2111 src_ast_str = "\n".join(_v(src_ast))
2112 dst_ast_str = "\n".join(_v(dst_ast))
2113 if src_ast_str != dst_ast_str:
2114 log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2115 raise AssertionError(
2116 f"INTERNAL ERROR: Black produced code that is not equivalent to "
2118 f"Please report a bug on https://github.com/ambv/black/issues. "
2119 f"This diff might be helpful: {log}"
2123 def assert_stable(src: str, dst: str, line_length: int) -> None:
2124 """Raise AssertionError if `dst` reformats differently the second time."""
2125 newdst = format_str(dst, line_length=line_length)
2128 diff(src, dst, "source", "first pass"),
2129 diff(dst, newdst, "first pass", "second pass"),
2131 raise AssertionError(
2132 f"INTERNAL ERROR: Black produced different code on the second pass "
2133 f"of the formatter. "
2134 f"Please report a bug on https://github.com/ambv/black/issues. "
2135 f"This diff might be helpful: {log}"
2139 def dump_to_file(*output: str) -> str:
2140 """Dump `output` to a temporary file. Return path to the file."""
2143 with tempfile.NamedTemporaryFile(
2144 mode="w", prefix="blk_", suffix=".log", delete=False
2146 for lines in output:
2148 if lines and lines[-1] != "\n":
2153 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2154 """Return a unified diff string between strings `a` and `b`."""
2157 a_lines = [line + "\n" for line in a.split("\n")]
2158 b_lines = [line + "\n" for line in b.split("\n")]
2160 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2164 def cancel(tasks: List[asyncio.Task]) -> None:
2165 """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2171 def shutdown(loop: BaseEventLoop) -> None:
2172 """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2174 # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2175 to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2179 for task in to_cancel:
2181 loop.run_until_complete(
2182 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2185 # `concurrent.futures.Future` objects cannot be cancelled once they
2186 # are already running. There might be some when the `shutdown()` happened.
2187 # Silence their logger's spew about the event loop being closed.
2188 cf_logger = logging.getLogger("concurrent.futures")
2189 cf_logger.setLevel(logging.CRITICAL)
2193 if __name__ == "__main__":