All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from asyncio.base_events import BaseEventLoop
4 from concurrent.futures import Executor, ProcessPoolExecutor
6 from functools import partial, wraps
9 from multiprocessing import Manager
11 from pathlib import Path
35 from appdirs import user_cache_dir
36 from attr import dataclass, Factory
40 from blib2to3.pytree import Node, Leaf, type_repr
41 from blib2to3 import pygram, pytree
42 from blib2to3.pgen2 import driver, token
43 from blib2to3.pgen2.parse import ParseError
46 __version__ = "18.4a6"
47 DEFAULT_LINE_LENGTH = 88
50 syms = pygram.python_symbols
58 LN = Union[Leaf, Node]
59 SplitFunc = Callable[["Line", bool], Iterator["Line"]]
62 CacheInfo = Tuple[Timestamp, FileSize]
63 Cache = Dict[Path, CacheInfo]
64 out = partial(click.secho, bold=True, err=True)
65 err = partial(click.secho, fg="red", err=True)
68 class NothingChanged(UserWarning):
69 """Raised by :func:`format_file` when reformatted code is the same as source."""
72 class CannotSplit(Exception):
73 """A readable split that fits the allotted line length is impossible.
75 Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
76 :func:`delimiter_split`.
80 class FormatError(Exception):
81 """Base exception for `# fmt: on` and `# fmt: off` handling.
83 It holds the number of bytes of the prefix consumed before the format
84 control comment appeared.
87 def __init__(self, consumed: int) -> None:
88 super().__init__(consumed)
89 self.consumed = consumed
91 def trim_prefix(self, leaf: Leaf) -> None:
92 leaf.prefix = leaf.prefix[self.consumed :]
94 def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
95 """Returns a new Leaf from the consumed part of the prefix."""
96 unformatted_prefix = leaf.prefix[: self.consumed]
97 return Leaf(token.NEWLINE, unformatted_prefix)
100 class FormatOn(FormatError):
101 """Found a comment like `# fmt: on` in the file."""
104 class FormatOff(FormatError):
105 """Found a comment like `# fmt: off` in the file."""
108 class WriteBack(Enum):
125 default=DEFAULT_LINE_LENGTH,
126 help="How many character per line to allow.",
133 "Don't write the files back, just return the status. Return code 0 "
134 "means nothing would change. Return code 1 means some files would be "
135 "reformatted. Return code 123 means there was an internal error."
141 help="Don't write the files back, just output a diff for each file on stdout.",
146 help="If --fast given, skip temporary sanity checks. [default: --safe]",
153 "Don't emit non-error messages to stderr. Errors are still emitted, "
154 "silence those with 2>/dev/null."
157 @click.version_option(version=__version__)
162 exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
175 """The uncompromising code formatter."""
176 sources: List[Path] = []
180 sources.extend(gen_python_files_in_dir(p))
182 # if a file was explicitly given, we don't care about its extension
185 sources.append(Path("-"))
187 err(f"invalid path: {s}")
189 if check and not diff:
190 write_back = WriteBack.NO
192 write_back = WriteBack.DIFF
194 write_back = WriteBack.YES
195 report = Report(check=check, quiet=quiet)
196 if len(sources) == 0:
197 out("No paths given. Nothing to do 😴")
201 elif len(sources) == 1:
202 reformat_one(sources[0], line_length, fast, write_back, report)
204 loop = asyncio.get_event_loop()
205 executor = ProcessPoolExecutor(max_workers=os.cpu_count())
207 loop.run_until_complete(
209 sources, line_length, fast, write_back, report, loop, executor
215 out("All done! ✨ 🍰 ✨")
216 click.echo(str(report))
217 ctx.exit(report.return_code)
221 src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
223 """Reformat a single file under `src` without spawning child processes.
225 If `quiet` is True, non-error messages are not output. `line_length`,
226 `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
230 if not src.is_file() and str(src) == "-":
231 if format_stdin_to_stdout(
232 line_length=line_length, fast=fast, write_back=write_back
234 changed = Changed.YES
237 if write_back != WriteBack.DIFF:
238 cache = read_cache(line_length)
240 if src in cache and cache[src] == get_cache_info(src):
241 changed = Changed.CACHED
242 if changed is not Changed.CACHED and format_file_in_place(
243 src, line_length=line_length, fast=fast, write_back=write_back
245 changed = Changed.YES
246 if write_back == WriteBack.YES and changed is not Changed.NO:
247 write_cache(cache, [src], line_length)
248 report.done(src, changed)
249 except Exception as exc:
250 report.failed(src, str(exc))
253 async def schedule_formatting(
257 write_back: WriteBack,
262 """Run formatting of `sources` in parallel using the provided `executor`.
264 (Use ProcessPoolExecutors for actual parallelism.)
266 `line_length`, `write_back`, and `fast` options are passed to
267 :func:`format_file_in_place`.
270 if write_back != WriteBack.DIFF:
271 cache = read_cache(line_length)
272 sources, cached = filter_cached(cache, sources)
274 report.done(src, Changed.CACHED)
279 if write_back == WriteBack.DIFF:
280 # For diff output, we need locks to ensure we don't interleave output
281 # from different processes.
283 lock = manager.Lock()
285 loop.run_in_executor(
286 executor, format_file_in_place, src, line_length, fast, write_back, lock
288 for src in sorted(sources)
290 pending: Iterable[asyncio.Task] = tasks.keys()
292 loop.add_signal_handler(signal.SIGINT, cancel, pending)
293 loop.add_signal_handler(signal.SIGTERM, cancel, pending)
294 except NotImplementedError:
295 # There are no good alternatives for these on Windows
298 done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
300 src = tasks.pop(task)
302 cancelled.append(task)
303 elif task.exception():
304 report.failed(src, str(task.exception()))
306 formatted.append(src)
307 report.done(src, Changed.YES if task.result() else Changed.NO)
309 await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
310 if write_back == WriteBack.YES and formatted:
311 write_cache(cache, formatted, line_length)
314 def format_file_in_place(
318 write_back: WriteBack = WriteBack.NO,
319 lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
321 """Format file under `src` path. Return True if changed.
323 If `write_back` is True, write reformatted code back to stdout.
324 `line_length` and `fast` options are passed to :func:`format_file_contents`.
326 is_pyi = src.suffix == ".pyi"
328 with tokenize.open(src) as src_buffer:
329 src_contents = src_buffer.read()
331 dst_contents = format_file_contents(
332 src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
334 except NothingChanged:
337 if write_back == write_back.YES:
338 with open(src, "w", encoding=src_buffer.encoding) as f:
339 f.write(dst_contents)
340 elif write_back == write_back.DIFF:
341 src_name = f"{src} (original)"
342 dst_name = f"{src} (formatted)"
343 diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
347 sys.stdout.write(diff_contents)
354 def format_stdin_to_stdout(
355 line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
357 """Format file on stdin. Return True if changed.
359 If `write_back` is True, write reformatted code back to stdout.
360 `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
362 src = sys.stdin.read()
365 dst = format_file_contents(src, line_length=line_length, fast=fast)
368 except NothingChanged:
372 if write_back == WriteBack.YES:
373 sys.stdout.write(dst)
374 elif write_back == WriteBack.DIFF:
375 src_name = "<stdin> (original)"
376 dst_name = "<stdin> (formatted)"
377 sys.stdout.write(diff(src, dst, src_name, dst_name))
380 def format_file_contents(
381 src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
383 """Reformat contents a file and return new contents.
385 If `fast` is False, additionally confirm that the reformatted code is
386 valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
387 `line_length` is passed to :func:`format_str`.
389 if src_contents.strip() == "":
392 dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
393 if src_contents == dst_contents:
397 assert_equivalent(src_contents, dst_contents)
399 src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
405 src_contents: str, line_length: int, *, is_pyi: bool = False
407 """Reformat a string and return new contents.
409 `line_length` determines how many characters per line are allowed.
411 src_node = lib2to3_parse(src_contents)
413 future_imports = get_future_imports(src_node)
414 elt = EmptyLineTracker(is_pyi=is_pyi)
415 py36 = is_python36(src_node)
416 lines = LineGenerator(
417 remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
421 for current_line in lines.visit(src_node):
422 for _ in range(after):
423 dst_contents += str(empty_line)
424 before, after = elt.maybe_empty_lines(current_line)
425 for _ in range(before):
426 dst_contents += str(empty_line)
427 for line in split_line(current_line, line_length=line_length, py36=py36):
428 dst_contents += str(line)
433 pygram.python_grammar_no_print_statement_no_exec_statement,
434 pygram.python_grammar_no_print_statement,
435 pygram.python_grammar,
439 def lib2to3_parse(src_txt: str) -> Node:
440 """Given a string with source, return the lib2to3 Node."""
441 grammar = pygram.python_grammar_no_print_statement
442 if src_txt[-1] != "\n":
443 nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
445 for grammar in GRAMMARS:
446 drv = driver.Driver(grammar, pytree.convert)
448 result = drv.parse_string(src_txt, True)
451 except ParseError as pe:
452 lineno, column = pe.context[1]
453 lines = src_txt.splitlines()
455 faulty_line = lines[lineno - 1]
457 faulty_line = "<line number missing in source>"
458 exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
462 if isinstance(result, Leaf):
463 result = Node(syms.file_input, [result])
467 def lib2to3_unparse(node: Node) -> str:
468 """Given a lib2to3 node, return its string representation."""
476 class Visitor(Generic[T]):
477 """Basic lib2to3 visitor that yields things of type `T` on `visit()`."""
479 def visit(self, node: LN) -> Iterator[T]:
480 """Main method to visit `node` and its children.
482 It tries to find a `visit_*()` method for the given `node.type`, like
483 `visit_simple_stmt` for Node objects or `visit_INDENT` for Leaf objects.
484 If no dedicated `visit_*()` method is found, chooses `visit_default()`
487 Then yields objects of type `T` from the selected visitor.
490 name = token.tok_name[node.type]
492 name = type_repr(node.type)
493 yield from getattr(self, f"visit_{name}", self.visit_default)(node)
495 def visit_default(self, node: LN) -> Iterator[T]:
496 """Default `visit_*()` implementation. Recurses to children of `node`."""
497 if isinstance(node, Node):
498 for child in node.children:
499 yield from self.visit(child)
503 class DebugVisitor(Visitor[T]):
506 def visit_default(self, node: LN) -> Iterator[T]:
507 indent = " " * (2 * self.tree_depth)
508 if isinstance(node, Node):
509 _type = type_repr(node.type)
510 out(f"{indent}{_type}", fg="yellow")
512 for child in node.children:
513 yield from self.visit(child)
516 out(f"{indent}/{_type}", fg="yellow", bold=False)
518 _type = token.tok_name.get(node.type, str(node.type))
519 out(f"{indent}{_type}", fg="blue", nl=False)
521 # We don't have to handle prefixes for `Node` objects since
522 # that delegates to the first child anyway.
523 out(f" {node.prefix!r}", fg="green", bold=False, nl=False)
524 out(f" {node.value!r}", fg="blue", bold=False)
527 def show(cls, code: str) -> None:
528 """Pretty-print the lib2to3 AST of a given string of `code`.
530 Convenience method for debugging.
532 v: DebugVisitor[None] = DebugVisitor()
533 list(v.visit(lib2to3_parse(code)))
536 KEYWORDS = set(keyword.kwlist)
537 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
538 FLOW_CONTROL = {"return", "raise", "break", "continue"}
549 STANDALONE_COMMENT = 153
550 LOGIC_OPERATORS = {"and", "or"}
575 STARS = {token.STAR, token.DOUBLESTAR}
578 syms.argument, # double star in arglist
579 syms.trailer, # single argument to call
581 syms.varargslist, # lambdas
583 UNPACKING_PARENTS = {
584 syms.atom, # single element of a list or set literal
622 COMPREHENSION_PRIORITY = 20
624 TERNARY_PRIORITY = 16
627 COMPARATOR_PRIORITY = 10
638 token.DOUBLESLASH: 4,
648 class BracketTracker:
649 """Keeps track of brackets on a line."""
652 bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
653 delimiters: Dict[LeafID, Priority] = Factory(dict)
654 previous: Optional[Leaf] = None
655 _for_loop_variable: int = 0
656 _lambda_arguments: int = 0
658 def mark(self, leaf: Leaf) -> None:
659 """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
661 All leaves receive an int `bracket_depth` field that stores how deep
662 within brackets a given leaf is. 0 means there are no enclosing brackets
663 that started on this line.
665 If a leaf is itself a closing bracket, it receives an `opening_bracket`
666 field that it forms a pair with. This is a one-directional link to
667 avoid reference cycles.
669 If a leaf is a delimiter (a token on which Black can split the line if
670 needed) and it's on depth 0, its `id()` is stored in the tracker's
673 if leaf.type == token.COMMENT:
676 self.maybe_decrement_after_for_loop_variable(leaf)
677 self.maybe_decrement_after_lambda_arguments(leaf)
678 if leaf.type in CLOSING_BRACKETS:
680 opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
681 leaf.opening_bracket = opening_bracket
682 leaf.bracket_depth = self.depth
684 delim = is_split_before_delimiter(leaf, self.previous)
685 if delim and self.previous is not None:
686 self.delimiters[id(self.previous)] = delim
688 delim = is_split_after_delimiter(leaf, self.previous)
690 self.delimiters[id(leaf)] = delim
691 if leaf.type in OPENING_BRACKETS:
692 self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
695 self.maybe_increment_lambda_arguments(leaf)
696 self.maybe_increment_for_loop_variable(leaf)
698 def any_open_brackets(self) -> bool:
699 """Return True if there is an yet unmatched open bracket on the line."""
700 return bool(self.bracket_match)
702 def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
703 """Return the highest priority of a delimiter found on the line.
705 Values are consistent with what `is_split_*_delimiter()` return.
706 Raises ValueError on no delimiters.
708 return max(v for k, v in self.delimiters.items() if k not in exclude)
710 def delimiter_count_with_priority(self, priority: int = 0) -> int:
711 """Return the number of delimiters with the given `priority`.
713 If no `priority` is passed, defaults to max priority on the line.
715 if not self.delimiters:
718 priority = priority or self.max_delimiter_priority()
719 return sum(1 for p in self.delimiters.values() if p == priority)
721 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
722 """In a for loop, or comprehension, the variables are often unpacks.
724 To avoid splitting on the comma in this situation, increase the depth of
725 tokens between `for` and `in`.
727 if leaf.type == token.NAME and leaf.value == "for":
729 self._for_loop_variable += 1
734 def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
735 """See `maybe_increment_for_loop_variable` above for explanation."""
736 if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
738 self._for_loop_variable -= 1
743 def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
744 """In a lambda expression, there might be more than one argument.
746 To avoid splitting on the comma in this situation, increase the depth of
747 tokens between `lambda` and `:`.
749 if leaf.type == token.NAME and leaf.value == "lambda":
751 self._lambda_arguments += 1
756 def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
757 """See `maybe_increment_lambda_arguments` above for explanation."""
758 if self._lambda_arguments and leaf.type == token.COLON:
760 self._lambda_arguments -= 1
765 def get_open_lsqb(self) -> Optional[Leaf]:
766 """Return the most recent opening square bracket (if any)."""
767 return self.bracket_match.get((self.depth - 1, token.RSQB))
772 """Holds leaves and comments. Can be printed with `str(line)`."""
775 leaves: List[Leaf] = Factory(list)
776 comments: List[Tuple[Index, Leaf]] = Factory(list)
777 bracket_tracker: BracketTracker = Factory(BracketTracker)
778 inside_brackets: bool = False
780 def append(self, leaf: Leaf, preformatted: bool = False) -> None:
781 """Add a new `leaf` to the end of the line.
783 Unless `preformatted` is True, the `leaf` will receive a new consistent
784 whitespace prefix and metadata applied by :class:`BracketTracker`.
785 Trailing commas are maybe removed, unpacked for loop variables are
786 demoted from being delimiters.
788 Inline comments are put aside.
790 has_value = leaf.type in BRACKETS or bool(leaf.value.strip())
794 if token.COLON == leaf.type and self.is_class_paren_empty:
796 if self.leaves and not preformatted:
797 # Note: at this point leaf.prefix should be empty except for
798 # imports, for which we only preserve newlines.
799 leaf.prefix += whitespace(
800 leaf, complex_subscript=self.is_complex_subscript(leaf)
802 if self.inside_brackets or not preformatted:
803 self.bracket_tracker.mark(leaf)
804 self.maybe_remove_trailing_comma(leaf)
805 if not self.append_comment(leaf):
806 self.leaves.append(leaf)
808 def append_safe(self, leaf: Leaf, preformatted: bool = False) -> None:
809 """Like :func:`append()` but disallow invalid standalone comment structure.
811 Raises ValueError when any `leaf` is appended after a standalone comment
812 or when a standalone comment is not the first leaf on the line.
814 if self.bracket_tracker.depth == 0:
816 raise ValueError("cannot append to standalone comments")
818 if self.leaves and leaf.type == STANDALONE_COMMENT:
820 "cannot append standalone comments to a populated line"
823 self.append(leaf, preformatted=preformatted)
826 def is_comment(self) -> bool:
827 """Is this line a standalone comment?"""
828 return len(self.leaves) == 1 and self.leaves[0].type == STANDALONE_COMMENT
831 def is_decorator(self) -> bool:
832 """Is this line a decorator?"""
833 return bool(self) and self.leaves[0].type == token.AT
836 def is_import(self) -> bool:
837 """Is this an import line?"""
838 return bool(self) and is_import(self.leaves[0])
841 def is_class(self) -> bool:
842 """Is this line a class definition?"""
845 and self.leaves[0].type == token.NAME
846 and self.leaves[0].value == "class"
850 def is_stub_class(self) -> bool:
851 """Is this line a class definition with a body consisting only of "..."?"""
852 return self.is_class and self.leaves[-3:] == [
853 Leaf(token.DOT, ".") for _ in range(3)
857 def is_def(self) -> bool:
858 """Is this a function definition? (Also returns True for async defs.)"""
860 first_leaf = self.leaves[0]
865 second_leaf: Optional[Leaf] = self.leaves[1]
868 return (first_leaf.type == token.NAME and first_leaf.value == "def") or (
869 first_leaf.type == token.ASYNC
870 and second_leaf is not None
871 and second_leaf.type == token.NAME
872 and second_leaf.value == "def"
876 def is_flow_control(self) -> bool:
877 """Is this line a flow control statement?
879 Those are `return`, `raise`, `break`, and `continue`.
883 and self.leaves[0].type == token.NAME
884 and self.leaves[0].value in FLOW_CONTROL
888 def is_yield(self) -> bool:
889 """Is this line a yield statement?"""
892 and self.leaves[0].type == token.NAME
893 and self.leaves[0].value == "yield"
897 def is_class_paren_empty(self) -> bool:
898 """Is this a class with no base classes but using parentheses?
900 Those are unnecessary and should be removed.
904 and len(self.leaves) == 4
906 and self.leaves[2].type == token.LPAR
907 and self.leaves[2].value == "("
908 and self.leaves[3].type == token.RPAR
909 and self.leaves[3].value == ")"
912 def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
913 """If so, needs to be split before emitting."""
914 for leaf in self.leaves:
915 if leaf.type == STANDALONE_COMMENT:
916 if leaf.bracket_depth <= depth_limit:
921 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
922 """Remove trailing comma if there is one and it's safe."""
925 and self.leaves[-1].type == token.COMMA
926 and closing.type in CLOSING_BRACKETS
930 if closing.type == token.RBRACE:
931 self.remove_trailing_comma()
934 if closing.type == token.RSQB:
935 comma = self.leaves[-1]
936 if comma.parent and comma.parent.type == syms.listmaker:
937 self.remove_trailing_comma()
940 # For parens let's check if it's safe to remove the comma.
941 # Imports are always safe.
943 self.remove_trailing_comma()
946 # Otheriwsse, if the trailing one is the only one, we might mistakenly
947 # change a tuple into a different type by removing the comma.
948 depth = closing.bracket_depth + 1
950 opening = closing.opening_bracket
951 for _opening_index, leaf in enumerate(self.leaves):
958 for leaf in self.leaves[_opening_index + 1 :]:
962 bracket_depth = leaf.bracket_depth
963 if bracket_depth == depth and leaf.type == token.COMMA:
965 if leaf.parent and leaf.parent.type == syms.arglist:
970 self.remove_trailing_comma()
975 def append_comment(self, comment: Leaf) -> bool:
976 """Add an inline or standalone comment to the line."""
978 comment.type == STANDALONE_COMMENT
979 and self.bracket_tracker.any_open_brackets()
984 if comment.type != token.COMMENT:
987 after = len(self.leaves) - 1
989 comment.type = STANDALONE_COMMENT
994 self.comments.append((after, comment))
997 def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
998 """Generate comments that should appear directly after `leaf`.
1000 Provide a non-negative leaf `_index` to speed up the function.
1003 for _index, _leaf in enumerate(self.leaves):
1010 for index, comment_after in self.comments:
1014 def remove_trailing_comma(self) -> None:
1015 """Remove the trailing comma and moves the comments attached to it."""
1016 comma_index = len(self.leaves) - 1
1017 for i in range(len(self.comments)):
1018 comment_index, comment = self.comments[i]
1019 if comment_index == comma_index:
1020 self.comments[i] = (comma_index - 1, comment)
1023 def is_complex_subscript(self, leaf: Leaf) -> bool:
1024 """Return True iff `leaf` is part of a slice with non-trivial exprs."""
1026 leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
1028 if open_lsqb is None:
1031 subscript_start = open_lsqb.next_sibling
1033 isinstance(subscript_start, Node)
1034 and subscript_start.type == syms.subscriptlist
1036 subscript_start = child_towards(subscript_start, leaf)
1037 return subscript_start is not None and any(
1038 n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
1041 def __str__(self) -> str:
1042 """Render the line."""
1046 indent = " " * self.depth
1047 leaves = iter(self.leaves)
1048 first = next(leaves)
1049 res = f"{first.prefix}{indent}{first.value}"
1052 for _, comment in self.comments:
1056 def __bool__(self) -> bool:
1057 """Return True if the line has leaves or comments."""
1058 return bool(self.leaves or self.comments)
1061 class UnformattedLines(Line):
1062 """Just like :class:`Line` but stores lines which aren't reformatted."""
1064 def append(self, leaf: Leaf, preformatted: bool = True) -> None:
1065 """Just add a new `leaf` to the end of the lines.
1067 The `preformatted` argument is ignored.
1069 Keeps track of indentation `depth`, which is useful when the user
1070 says `# fmt: on`. Otherwise, doesn't do anything with the `leaf`.
1073 list(generate_comments(leaf))
1074 except FormatOn as f_on:
1075 self.leaves.append(f_on.leaf_from_consumed(leaf))
1078 self.leaves.append(leaf)
1079 if leaf.type == token.INDENT:
1081 elif leaf.type == token.DEDENT:
1084 def __str__(self) -> str:
1085 """Render unformatted lines from leaves which were added with `append()`.
1087 `depth` is not used for indentation in this case.
1093 for leaf in self.leaves:
1097 def append_comment(self, comment: Leaf) -> bool:
1098 """Not implemented in this class. Raises `NotImplementedError`."""
1099 raise NotImplementedError("Unformatted lines don't store comments separately.")
1101 def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
1102 """Does nothing and returns False."""
1105 def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
1106 """Does nothing and returns False."""
1111 class EmptyLineTracker:
1112 """Provides a stateful method that returns the number of potential extra
1113 empty lines needed before and after the currently processed line.
1115 Note: this tracker works on lines that haven't been split yet. It assumes
1116 the prefix of the first leaf consists of optional newlines. Those newlines
1117 are consumed by `maybe_empty_lines()` and included in the computation.
1119 is_pyi: bool = False
1120 previous_line: Optional[Line] = None
1121 previous_after: int = 0
1122 previous_defs: List[int] = Factory(list)
1124 def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1125 """Return the number of extra empty lines before and after the `current_line`.
1127 This is for separating `def`, `async def` and `class` with extra empty
1128 lines (two on module-level), as well as providing an extra empty line
1129 after flow control keywords to make them more prominent.
1131 if isinstance(current_line, UnformattedLines):
1134 before, after = self._maybe_empty_lines(current_line)
1135 before -= self.previous_after
1136 self.previous_after = after
1137 self.previous_line = current_line
1138 return before, after
1140 def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
1142 if current_line.depth == 0:
1143 max_allowed = 1 if self.is_pyi else 2
1144 if current_line.leaves:
1145 # Consume the first leaf's extra newlines.
1146 first_leaf = current_line.leaves[0]
1147 before = first_leaf.prefix.count("\n")
1148 before = min(before, max_allowed)
1149 first_leaf.prefix = ""
1152 depth = current_line.depth
1153 while self.previous_defs and self.previous_defs[-1] >= depth:
1154 self.previous_defs.pop()
1156 before = 0 if depth else 1
1158 before = 1 if depth else 2
1159 is_decorator = current_line.is_decorator
1160 if is_decorator or current_line.is_def or current_line.is_class:
1161 if not is_decorator:
1162 self.previous_defs.append(depth)
1163 if self.previous_line is None:
1164 # Don't insert empty lines before the first line in the file.
1167 if self.previous_line.is_decorator:
1171 self.previous_line.is_comment
1172 and self.previous_line.depth == current_line.depth
1178 if self.previous_line.depth > current_line.depth:
1180 elif current_line.is_class or self.previous_line.is_class:
1181 if current_line.is_stub_class and self.previous_line.is_stub_class:
1189 if current_line.depth and newlines:
1195 and self.previous_line.is_import
1196 and not current_line.is_import
1197 and depth == self.previous_line.depth
1199 return (before or 1), 0
1205 class LineGenerator(Visitor[Line]):
1206 """Generates reformatted Line objects. Empty lines are not emitted.
1208 Note: destroys the tree it's visiting by mutating prefixes of its leaves
1209 in ways that will no longer stringify to valid Python code on the tree.
1211 is_pyi: bool = False
1212 current_line: Line = Factory(Line)
1213 remove_u_prefix: bool = False
1215 def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
1218 If the line is empty, only emit if it makes sense.
1219 If the line is too long, split it first and then generate.
1221 If any lines were generated, set up a new current_line.
1223 if not self.current_line:
1224 if self.current_line.__class__ == type:
1225 self.current_line.depth += indent
1227 self.current_line = type(depth=self.current_line.depth + indent)
1228 return # Line is empty, don't emit. Creating a new one unnecessary.
1230 complete_line = self.current_line
1231 self.current_line = type(depth=complete_line.depth + indent)
1234 def visit(self, node: LN) -> Iterator[Line]:
1235 """Main method to visit `node` and its children.
1237 Yields :class:`Line` objects.
1239 if isinstance(self.current_line, UnformattedLines):
1240 # File contained `# fmt: off`
1241 yield from self.visit_unformatted(node)
1244 yield from super().visit(node)
1246 def visit_default(self, node: LN) -> Iterator[Line]:
1247 """Default `visit_*()` implementation. Recurses to children of `node`."""
1248 if isinstance(node, Leaf):
1249 any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
1251 for comment in generate_comments(node):
1252 if any_open_brackets:
1253 # any comment within brackets is subject to splitting
1254 self.current_line.append(comment)
1255 elif comment.type == token.COMMENT:
1256 # regular trailing comment
1257 self.current_line.append(comment)
1258 yield from self.line()
1261 # regular standalone comment
1262 yield from self.line()
1264 self.current_line.append(comment)
1265 yield from self.line()
1267 except FormatOff as f_off:
1268 f_off.trim_prefix(node)
1269 yield from self.line(type=UnformattedLines)
1270 yield from self.visit(node)
1272 except FormatOn as f_on:
1273 # This only happens here if somebody says "fmt: on" multiple
1275 f_on.trim_prefix(node)
1276 yield from self.visit_default(node)
1279 normalize_prefix(node, inside_brackets=any_open_brackets)
1280 if node.type == token.STRING:
1281 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
1282 normalize_string_quotes(node)
1283 if node.type not in WHITESPACE:
1284 self.current_line.append(node)
1285 yield from super().visit_default(node)
1287 def visit_INDENT(self, node: Node) -> Iterator[Line]:
1288 """Increase indentation level, maybe yield a line."""
1289 # In blib2to3 INDENT never holds comments.
1290 yield from self.line(+1)
1291 yield from self.visit_default(node)
1293 def visit_DEDENT(self, node: Node) -> Iterator[Line]:
1294 """Decrease indentation level, maybe yield a line."""
1295 # The current line might still wait for trailing comments. At DEDENT time
1296 # there won't be any (they would be prefixes on the preceding NEWLINE).
1297 # Emit the line then.
1298 yield from self.line()
1300 # While DEDENT has no value, its prefix may contain standalone comments
1301 # that belong to the current indentation level. Get 'em.
1302 yield from self.visit_default(node)
1304 # Finally, emit the dedent.
1305 yield from self.line(-1)
1308 self, node: Node, keywords: Set[str], parens: Set[str]
1309 ) -> Iterator[Line]:
1310 """Visit a statement.
1312 This implementation is shared for `if`, `while`, `for`, `try`, `except`,
1313 `def`, `with`, `class`, `assert` and assignments.
1315 The relevant Python language `keywords` for a given statement will be
1316 NAME leaves within it. This methods puts those on a separate line.
1318 `parens` holds a set of string leaf values immediately after which
1319 invisible parens should be put.
1321 normalize_invisible_parens(node, parens_after=parens)
1322 for child in node.children:
1323 if child.type == token.NAME and child.value in keywords: # type: ignore
1324 yield from self.line()
1326 yield from self.visit(child)
1328 def visit_suite(self, node: Node) -> Iterator[Line]:
1329 """Visit a suite."""
1330 if self.is_pyi and is_stub_suite(node):
1331 yield from self.visit(node.children[2])
1333 yield from self.visit_default(node)
1335 def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
1336 """Visit a statement without nested statements."""
1337 is_suite_like = node.parent and node.parent.type in STATEMENT
1339 if self.is_pyi and is_stub_body(node):
1340 yield from self.visit_default(node)
1342 yield from self.line(+1)
1343 yield from self.visit_default(node)
1344 yield from self.line(-1)
1347 if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
1348 yield from self.line()
1349 yield from self.visit_default(node)
1351 def visit_async_stmt(self, node: Node) -> Iterator[Line]:
1352 """Visit `async def`, `async for`, `async with`."""
1353 yield from self.line()
1355 children = iter(node.children)
1356 for child in children:
1357 yield from self.visit(child)
1359 if child.type == token.ASYNC:
1362 internal_stmt = next(children)
1363 for child in internal_stmt.children:
1364 yield from self.visit(child)
1366 def visit_decorators(self, node: Node) -> Iterator[Line]:
1367 """Visit decorators."""
1368 for child in node.children:
1369 yield from self.line()
1370 yield from self.visit(child)
1372 def visit_import_from(self, node: Node) -> Iterator[Line]:
1373 """Visit import_from and maybe put invisible parentheses.
1375 This is separate from `visit_stmt` because import statements don't
1376 support arbitrary atoms and thus handling of parentheses is custom.
1379 for index, child in enumerate(node.children):
1381 if child.type == token.LPAR:
1382 # make parentheses invisible
1383 child.value = "" # type: ignore
1384 node.children[-1].value = "" # type: ignore
1386 # insert invisible parentheses
1387 node.insert_child(index, Leaf(token.LPAR, ""))
1388 node.append_child(Leaf(token.RPAR, ""))
1392 child.type == token.NAME and child.value == "import" # type: ignore
1395 for child in node.children:
1396 yield from self.visit(child)
1398 def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
1399 """Remove a semicolon and put the other statement on a separate line."""
1400 yield from self.line()
1402 def visit_ENDMARKER(self, leaf: Leaf) -> Iterator[Line]:
1403 """End of file. Process outstanding comments and end with a newline."""
1404 yield from self.visit_default(leaf)
1405 yield from self.line()
1407 def visit_unformatted(self, node: LN) -> Iterator[Line]:
1408 """Used when file contained a `# fmt: off`."""
1409 if isinstance(node, Node):
1410 for child in node.children:
1411 yield from self.visit(child)
1415 self.current_line.append(node)
1416 except FormatOn as f_on:
1417 f_on.trim_prefix(node)
1418 yield from self.line()
1419 yield from self.visit(node)
1421 if node.type == token.ENDMARKER:
1422 # somebody decided not to put a final `# fmt: on`
1423 yield from self.line()
1425 def __attrs_post_init__(self) -> None:
1426 """You are in a twisty little maze of passages."""
1429 self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
1430 self.visit_if_stmt = partial(
1431 v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
1433 self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
1434 self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
1435 self.visit_try_stmt = partial(
1436 v, keywords={"try", "except", "else", "finally"}, parens=Ø
1438 self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø)
1439 self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
1440 self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
1441 self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
1442 self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
1443 self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
1444 self.visit_async_funcdef = self.visit_async_stmt
1445 self.visit_decorated = self.visit_decorators
1448 IMPLICIT_TUPLE = {syms.testlist, syms.testlist_star_expr, syms.exprlist}
1449 BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.RBRACE}
1450 OPENING_BRACKETS = set(BRACKET.keys())
1451 CLOSING_BRACKETS = set(BRACKET.values())
1452 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
1453 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
1456 def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901
1457 """Return whitespace prefix if needed for the given `leaf`.
1459 `complex_subscript` signals whether the given leaf is part of a subscription
1460 which has non-trivial arguments, like arithmetic expressions or function calls.
1468 if t in ALWAYS_NO_SPACE:
1471 if t == token.COMMENT:
1474 assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
1475 if t == token.COLON and p.type not in {
1476 syms.subscript, syms.subscriptlist, syms.sliceop
1480 prev = leaf.prev_sibling
1482 prevp = preceding_leaf(p)
1483 if not prevp or prevp.type in OPENING_BRACKETS:
1486 if t == token.COLON:
1487 if prevp.type == token.COLON:
1490 elif prevp.type != token.COMMA and not complex_subscript:
1495 if prevp.type == token.EQUAL:
1497 if prevp.parent.type in {
1498 syms.arglist, syms.argument, syms.parameters, syms.varargslist
1502 elif prevp.parent.type == syms.typedargslist:
1503 # A bit hacky: if the equal sign has whitespace, it means we
1504 # previously found it's a typed argument. So, we're using
1508 elif prevp.type in STARS:
1509 if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1512 elif prevp.type == token.COLON:
1513 if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
1514 return SPACE if complex_subscript else NO
1518 and prevp.parent.type == syms.factor
1519 and prevp.type in MATH_OPERATORS
1524 prevp.type == token.RIGHTSHIFT
1526 and prevp.parent.type == syms.shift_expr
1527 and prevp.prev_sibling
1528 and prevp.prev_sibling.type == token.NAME
1529 and prevp.prev_sibling.value == "print" # type: ignore
1531 # Python 2 print chevron
1534 elif prev.type in OPENING_BRACKETS:
1537 if p.type in {syms.parameters, syms.arglist}:
1538 # untyped function signatures or calls
1539 if not prev or prev.type != token.COMMA:
1542 elif p.type == syms.varargslist:
1544 if prev and prev.type != token.COMMA:
1547 elif p.type == syms.typedargslist:
1548 # typed function signatures
1552 if t == token.EQUAL:
1553 if prev.type != syms.tname:
1556 elif prev.type == token.EQUAL:
1557 # A bit hacky: if the equal sign has whitespace, it means we
1558 # previously found it's a typed argument. So, we're using that, too.
1561 elif prev.type != token.COMMA:
1564 elif p.type == syms.tname:
1567 prevp = preceding_leaf(p)
1568 if not prevp or prevp.type != token.COMMA:
1571 elif p.type == syms.trailer:
1572 # attributes and calls
1573 if t == token.LPAR or t == token.RPAR:
1578 prevp = preceding_leaf(p)
1579 if not prevp or prevp.type != token.NUMBER:
1582 elif t == token.LSQB:
1585 elif prev.type != token.COMMA:
1588 elif p.type == syms.argument:
1590 if t == token.EQUAL:
1594 prevp = preceding_leaf(p)
1595 if not prevp or prevp.type == token.LPAR:
1598 elif prev.type in {token.EQUAL} | STARS:
1601 elif p.type == syms.decorator:
1605 elif p.type == syms.dotted_name:
1609 prevp = preceding_leaf(p)
1610 if not prevp or prevp.type == token.AT or prevp.type == token.DOT:
1613 elif p.type == syms.classdef:
1617 if prev and prev.type == token.LPAR:
1620 elif p.type in {syms.subscript, syms.sliceop}:
1623 assert p.parent is not None, "subscripts are always parented"
1624 if p.parent.type == syms.subscriptlist:
1629 elif not complex_subscript:
1632 elif p.type == syms.atom:
1633 if prev and t == token.DOT:
1634 # dots, but not the first one.
1637 elif p.type == syms.dictsetmaker:
1639 if prev and prev.type == token.DOUBLESTAR:
1642 elif p.type in {syms.factor, syms.star_expr}:
1645 prevp = preceding_leaf(p)
1646 if not prevp or prevp.type in OPENING_BRACKETS:
1649 prevp_parent = prevp.parent
1650 assert prevp_parent is not None
1651 if prevp.type == token.COLON and prevp_parent.type in {
1652 syms.subscript, syms.sliceop
1656 elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
1659 elif t == token.NAME or t == token.NUMBER:
1662 elif p.type == syms.import_from:
1664 if prev and prev.type == token.DOT:
1667 elif t == token.NAME:
1671 if prev and prev.type == token.DOT:
1674 elif p.type == syms.sliceop:
1680 def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]:
1681 """Return the first leaf that precedes `node`, if any."""
1683 res = node.prev_sibling
1685 if isinstance(res, Leaf):
1689 return list(res.leaves())[-1]
1698 def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
1699 """Return the child of `ancestor` that contains `descendant`."""
1700 node: Optional[LN] = descendant
1701 while node and node.parent != ancestor:
1706 def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1707 """Return the priority of the `leaf` delimiter, given a line break after it.
1709 The delimiter priorities returned here are from those delimiters that would
1710 cause a line break after themselves.
1712 Higher numbers are higher priority.
1714 if leaf.type == token.COMMA:
1715 return COMMA_PRIORITY
1720 def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
1721 """Return the priority of the `leaf` delimiter, given a line before after it.
1723 The delimiter priorities returned here are from those delimiters that would
1724 cause a line break before themselves.
1726 Higher numbers are higher priority.
1728 if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
1729 # * and ** might also be MATH_OPERATORS but in this case they are not.
1730 # Don't treat them as a delimiter.
1734 leaf.type == token.DOT
1736 and leaf.parent.type not in {syms.import_from, syms.dotted_name}
1737 and (previous is None or previous.type != token.NAME)
1742 leaf.type in MATH_OPERATORS
1744 and leaf.parent.type not in {syms.factor, syms.star_expr}
1746 return MATH_PRIORITIES[leaf.type]
1748 if leaf.type in COMPARATORS:
1749 return COMPARATOR_PRIORITY
1752 leaf.type == token.STRING
1753 and previous is not None
1754 and previous.type == token.STRING
1756 return STRING_PRIORITY
1758 if leaf.type != token.NAME:
1764 and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
1766 return COMPREHENSION_PRIORITY
1771 and leaf.parent.type in {syms.comp_if, syms.old_comp_if}
1773 return COMPREHENSION_PRIORITY
1775 if leaf.value in {"if", "else"} and leaf.parent and leaf.parent.type == syms.test:
1776 return TERNARY_PRIORITY
1778 if leaf.value == "is":
1779 return COMPARATOR_PRIORITY
1784 and leaf.parent.type in {syms.comp_op, syms.comparison}
1786 previous is not None
1787 and previous.type == token.NAME
1788 and previous.value == "not"
1791 return COMPARATOR_PRIORITY
1796 and leaf.parent.type == syms.comp_op
1798 previous is not None
1799 and previous.type == token.NAME
1800 and previous.value == "is"
1803 return COMPARATOR_PRIORITY
1805 if leaf.value in LOGIC_OPERATORS and leaf.parent:
1806 return LOGIC_PRIORITY
1811 def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
1812 """Clean the prefix of the `leaf` and generate comments from it, if any.
1814 Comments in lib2to3 are shoved into the whitespace prefix. This happens
1815 in `pgen2/driver.py:Driver.parse_tokens()`. This was a brilliant implementation
1816 move because it does away with modifying the grammar to include all the
1817 possible places in which comments can be placed.
1819 The sad consequence for us though is that comments don't "belong" anywhere.
1820 This is why this function generates simple parentless Leaf objects for
1821 comments. We simply don't know what the correct parent should be.
1823 No matter though, we can live without this. We really only need to
1824 differentiate between inline and standalone comments. The latter don't
1825 share the line with any code.
1827 Inline comments are emitted as regular token.COMMENT leaves. Standalone
1828 are emitted with a fake STANDALONE_COMMENT token identifier.
1839 for index, line in enumerate(p.split("\n")):
1840 consumed += len(line) + 1 # adding the length of the split '\n'
1841 line = line.lstrip()
1844 if not line.startswith("#"):
1847 if index == 0 and leaf.type != token.ENDMARKER:
1848 comment_type = token.COMMENT # simple trailing comment
1850 comment_type = STANDALONE_COMMENT
1851 comment = make_comment(line)
1852 yield Leaf(comment_type, comment, prefix="\n" * nlines)
1854 if comment in {"# fmt: on", "# yapf: enable"}:
1855 raise FormatOn(consumed)
1857 if comment in {"# fmt: off", "# yapf: disable"}:
1858 if comment_type == STANDALONE_COMMENT:
1859 raise FormatOff(consumed)
1861 prev = preceding_leaf(leaf)
1862 if not prev or prev.type in WHITESPACE: # standalone comment in disguise
1863 raise FormatOff(consumed)
1868 def make_comment(content: str) -> str:
1869 """Return a consistently formatted comment from the given `content` string.
1871 All comments (except for "##", "#!", "#:") should have a single space between
1872 the hash sign and the content.
1874 If `content` didn't start with a hash sign, one is provided.
1876 content = content.rstrip()
1880 if content[0] == "#":
1881 content = content[1:]
1882 if content and content[0] not in " !:#":
1883 content = " " + content
1884 return "#" + content
1888 line: Line, line_length: int, inner: bool = False, py36: bool = False
1889 ) -> Iterator[Line]:
1890 """Split a `line` into potentially many lines.
1892 They should fit in the allotted `line_length` but might not be able to.
1893 `inner` signifies that there were a pair of brackets somewhere around the
1894 current `line`, possibly transitively. This means we can fallback to splitting
1895 by delimiters if the LHS/RHS don't yield any results.
1897 If `py36` is True, splitting may generate syntax that is only compatible
1898 with Python 3.6 and later.
1900 if isinstance(line, UnformattedLines) or line.is_comment:
1904 line_str = str(line).strip("\n")
1905 if is_line_short_enough(line, line_length=line_length, line_str=line_str):
1909 split_funcs: List[SplitFunc]
1911 split_funcs = [left_hand_split]
1912 elif line.is_import:
1913 split_funcs = [explode_split]
1916 def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
1917 for omit in generate_trailers_to_omit(line, line_length):
1918 lines = list(right_hand_split(line, py36, omit=omit))
1919 if is_line_short_enough(lines[0], line_length=line_length):
1923 # All splits failed, best effort split with no omits.
1924 yield from right_hand_split(line, py36)
1926 if line.inside_brackets:
1927 split_funcs = [delimiter_split, standalone_comment_split, rhs]
1930 for split_func in split_funcs:
1931 # We are accumulating lines in `result` because we might want to abort
1932 # mission and return the original line in the end, or attempt a different
1934 result: List[Line] = []
1936 for l in split_func(line, py36):
1937 if str(l).strip("\n") == line_str:
1938 raise CannotSplit("Split function returned an unchanged result")
1941 split_line(l, line_length=line_length, inner=True, py36=py36)
1943 except CannotSplit as cs:
1954 def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
1955 """Split line into many lines, starting with the first matching bracket pair.
1957 Note: this usually looks weird, only use this for function definitions.
1958 Prefer RHS otherwise. This is why this function is not symmetrical with
1959 :func:`right_hand_split` which also handles optional parentheses.
1961 head = Line(depth=line.depth)
1962 body = Line(depth=line.depth + 1, inside_brackets=True)
1963 tail = Line(depth=line.depth)
1964 tail_leaves: List[Leaf] = []
1965 body_leaves: List[Leaf] = []
1966 head_leaves: List[Leaf] = []
1967 current_leaves = head_leaves
1968 matching_bracket = None
1969 for leaf in line.leaves:
1971 current_leaves is body_leaves
1972 and leaf.type in CLOSING_BRACKETS
1973 and leaf.opening_bracket is matching_bracket
1975 current_leaves = tail_leaves if body_leaves else head_leaves
1976 current_leaves.append(leaf)
1977 if current_leaves is head_leaves:
1978 if leaf.type in OPENING_BRACKETS:
1979 matching_bracket = leaf
1980 current_leaves = body_leaves
1981 # Since body is a new indent level, remove spurious leading whitespace.
1983 normalize_prefix(body_leaves[0], inside_brackets=True)
1984 # Build the new lines.
1985 for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
1987 result.append(leaf, preformatted=True)
1988 for comment_after in line.comments_after(leaf):
1989 result.append(comment_after, preformatted=True)
1990 bracket_split_succeeded_or_raise(head, body, tail)
1991 for result in (head, body, tail):
1996 def right_hand_split(
1997 line: Line, py36: bool = False, omit: Collection[LeafID] = ()
1998 ) -> Iterator[Line]:
1999 """Split line into many lines, starting with the last matching bracket pair.
2001 If the split was by optional parentheses, attempt splitting without them, too.
2002 `omit` is a collection of closing bracket IDs that shouldn't be considered for
2005 Note: running this function modifies `bracket_depth` on the leaves of `line`.
2007 head = Line(depth=line.depth)
2008 body = Line(depth=line.depth + 1, inside_brackets=True)
2009 tail = Line(depth=line.depth)
2010 tail_leaves: List[Leaf] = []
2011 body_leaves: List[Leaf] = []
2012 head_leaves: List[Leaf] = []
2013 current_leaves = tail_leaves
2014 opening_bracket = None
2015 closing_bracket = None
2016 for leaf in reversed(line.leaves):
2017 if current_leaves is body_leaves:
2018 if leaf is opening_bracket:
2019 current_leaves = head_leaves if body_leaves else tail_leaves
2020 current_leaves.append(leaf)
2021 if current_leaves is tail_leaves:
2022 if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit:
2023 opening_bracket = leaf.opening_bracket
2024 closing_bracket = leaf
2025 current_leaves = body_leaves
2026 tail_leaves.reverse()
2027 body_leaves.reverse()
2028 head_leaves.reverse()
2029 # Since body is a new indent level, remove spurious leading whitespace.
2031 normalize_prefix(body_leaves[0], inside_brackets=True)
2033 # No `head` means the split failed. Either `tail` has all content or
2034 # the matching `opening_bracket` wasn't available on `line` anymore.
2035 raise CannotSplit("No brackets found")
2037 # Build the new lines.
2038 for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
2040 result.append(leaf, preformatted=True)
2041 for comment_after in line.comments_after(leaf):
2042 result.append(comment_after, preformatted=True)
2043 bracket_split_succeeded_or_raise(head, body, tail)
2044 assert opening_bracket and closing_bracket
2046 # the opening bracket is an optional paren
2047 opening_bracket.type == token.LPAR
2048 and not opening_bracket.value
2049 # the closing bracket is an optional paren
2050 and closing_bracket.type == token.RPAR
2051 and not closing_bracket.value
2052 # there are no standalone comments in the body
2053 and not line.contains_standalone_comments(0)
2054 # and it's not an import (optional parens are the only thing we can split
2055 # on in this case; attempting a split without them is a waste of time)
2056 and not line.is_import
2058 omit = {id(closing_bracket), *omit}
2059 delimiter_count = body.bracket_tracker.delimiter_count_with_priority()
2061 delimiter_count == 0
2062 or delimiter_count == 1
2064 body.leaves[0].type in OPENING_BRACKETS
2065 or body.leaves[-1].type in CLOSING_BRACKETS
2069 yield from right_hand_split(line, py36=py36, omit=omit)
2074 ensure_visible(opening_bracket)
2075 ensure_visible(closing_bracket)
2076 for result in (head, body, tail):
2081 def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
2082 """Raise :exc:`CannotSplit` if the last left- or right-hand split failed.
2084 Do nothing otherwise.
2086 A left- or right-hand split is based on a pair of brackets. Content before
2087 (and including) the opening bracket is left on one line, content inside the
2088 brackets is put on a separate line, and finally content starting with and
2089 following the closing bracket is put on a separate line.
2091 Those are called `head`, `body`, and `tail`, respectively. If the split
2092 produced the same line (all content in `head`) or ended up with an empty `body`
2093 and the `tail` is just the closing bracket, then it's considered failed.
2095 tail_len = len(str(tail).strip())
2098 raise CannotSplit("Splitting brackets produced the same line")
2102 f"Splitting brackets on an empty body to save "
2103 f"{tail_len} characters is not worth it"
2107 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
2108 """Normalize prefix of the first leaf in every line returned by `split_func`.
2110 This is a decorator over relevant split functions.
2114 def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
2115 for l in split_func(line, py36):
2116 normalize_prefix(l.leaves[0], inside_brackets=True)
2119 return split_wrapper
2122 @dont_increase_indentation
2123 def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
2124 """Split according to delimiters of the highest priority.
2126 If `py36` is True, the split will add trailing commas also in function
2127 signatures that contain `*` and `**`.
2130 last_leaf = line.leaves[-1]
2132 raise CannotSplit("Line empty")
2134 bt = line.bracket_tracker
2136 delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
2138 raise CannotSplit("No delimiters found")
2140 if delimiter_priority == DOT_PRIORITY:
2141 if bt.delimiter_count_with_priority(delimiter_priority) == 1:
2142 raise CannotSplit("Splitting a single attribute from its owner looks wrong")
2144 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2145 lowest_depth = sys.maxsize
2146 trailing_comma_safe = True
2148 def append_to_line(leaf: Leaf) -> Iterator[Line]:
2149 """Append `leaf` to current line or to new line if appending impossible."""
2150 nonlocal current_line
2152 current_line.append_safe(leaf, preformatted=True)
2153 except ValueError as ve:
2156 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2157 current_line.append(leaf)
2159 for index, leaf in enumerate(line.leaves):
2160 yield from append_to_line(leaf)
2162 for comment_after in line.comments_after(leaf, index):
2163 yield from append_to_line(comment_after)
2165 lowest_depth = min(lowest_depth, leaf.bracket_depth)
2166 if leaf.bracket_depth == lowest_depth and is_vararg(
2167 leaf, within=VARARGS_PARENTS
2169 trailing_comma_safe = trailing_comma_safe and py36
2170 leaf_priority = bt.delimiters.get(id(leaf))
2171 if leaf_priority == delimiter_priority:
2174 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2178 and delimiter_priority == COMMA_PRIORITY
2179 and current_line.leaves[-1].type != token.COMMA
2180 and current_line.leaves[-1].type != STANDALONE_COMMENT
2182 current_line.append(Leaf(token.COMMA, ","))
2186 @dont_increase_indentation
2187 def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
2188 """Split standalone comments from the rest of the line."""
2189 if not line.contains_standalone_comments(0):
2190 raise CannotSplit("Line does not have any standalone comments")
2192 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2194 def append_to_line(leaf: Leaf) -> Iterator[Line]:
2195 """Append `leaf` to current line or to new line if appending impossible."""
2196 nonlocal current_line
2198 current_line.append_safe(leaf, preformatted=True)
2199 except ValueError as ve:
2202 current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
2203 current_line.append(leaf)
2205 for index, leaf in enumerate(line.leaves):
2206 yield from append_to_line(leaf)
2208 for comment_after in line.comments_after(leaf, index):
2209 yield from append_to_line(comment_after)
2216 line: Line, py36: bool = False, omit: Collection[LeafID] = ()
2217 ) -> Iterator[Line]:
2218 """Split by rightmost bracket and immediately split contents by a delimiter."""
2219 new_lines = list(right_hand_split(line, py36, omit))
2220 if len(new_lines) != 3:
2221 yield from new_lines
2227 yield from delimiter_split(new_lines[1], py36)
2235 def is_import(leaf: Leaf) -> bool:
2236 """Return True if the given leaf starts an import statement."""
2243 (v == "import" and p and p.type == syms.import_name)
2244 or (v == "from" and p and p.type == syms.import_from)
2249 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
2250 """Leave existing extra newlines if not `inside_brackets`. Remove everything
2253 Note: don't use backslashes for formatting or you'll lose your voting rights.
2255 if not inside_brackets:
2256 spl = leaf.prefix.split("#")
2257 if "\\" not in spl[0]:
2258 nl_count = spl[-1].count("\n")
2261 leaf.prefix = "\n" * nl_count
2267 def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
2268 """Make all string prefixes lowercase.
2270 If remove_u_prefix is given, also removes any u prefix from the string.
2272 Note: Mutates its argument.
2274 match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
2275 assert match is not None, f"failed to match string {leaf.value!r}"
2276 orig_prefix = match.group(1)
2277 new_prefix = orig_prefix.lower()
2279 new_prefix = new_prefix.replace("u", "")
2280 leaf.value = f"{new_prefix}{match.group(2)}"
2283 def normalize_string_quotes(leaf: Leaf) -> None:
2284 """Prefer double quotes but only if it doesn't cause more escaping.
2286 Adds or removes backslashes as appropriate. Doesn't parse and fix
2287 strings nested in f-strings (yet).
2289 Note: Mutates its argument.
2291 value = leaf.value.lstrip("furbFURB")
2292 if value[:3] == '"""':
2295 elif value[:3] == "'''":
2298 elif value[0] == '"':
2304 first_quote_pos = leaf.value.find(orig_quote)
2305 if first_quote_pos == -1:
2306 return # There's an internal error
2308 prefix = leaf.value[:first_quote_pos]
2309 unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
2310 escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
2311 escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
2312 body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
2313 if "r" in prefix.casefold():
2314 if unescaped_new_quote.search(body):
2315 # There's at least one unescaped new_quote in this raw string
2316 # so converting is impossible
2319 # Do not introduce or remove backslashes in raw strings
2322 # remove unnecessary quotes
2323 new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
2324 if body != new_body:
2325 # Consider the string without unnecessary quotes as the original
2327 leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
2328 new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
2329 new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
2330 if new_quote == '"""' and new_body[-1] == '"':
2332 new_body = new_body[:-1] + '\\"'
2333 orig_escape_count = body.count("\\")
2334 new_escape_count = new_body.count("\\")
2335 if new_escape_count > orig_escape_count:
2336 return # Do not introduce more escaping
2338 if new_escape_count == orig_escape_count and orig_quote == '"':
2339 return # Prefer double quotes
2341 leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
2344 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
2345 """Make existing optional parentheses invisible or create new ones.
2347 `parens_after` is a set of string leaf values immeditely after which parens
2350 Standardizes on visible parentheses for single-element tuples, and keeps
2351 existing visible parentheses for other tuples and generator expressions.
2354 for child in list(node.children):
2356 if child.type == syms.atom:
2357 maybe_make_parens_invisible_in_atom(child)
2358 elif is_one_tuple(child):
2359 # wrap child in visible parentheses
2360 lpar = Leaf(token.LPAR, "(")
2361 rpar = Leaf(token.RPAR, ")")
2362 index = child.remove() or 0
2363 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2364 elif not (isinstance(child, Leaf) and is_multiline_string(child)):
2365 # wrap child in invisible parentheses
2366 lpar = Leaf(token.LPAR, "")
2367 rpar = Leaf(token.RPAR, "")
2368 index = child.remove() or 0
2369 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
2371 check_lpar = isinstance(child, Leaf) and child.value in parens_after
2374 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
2375 """If it's safe, make the parens in the atom `node` invisible, recusively."""
2377 node.type != syms.atom
2378 or is_empty_tuple(node)
2379 or is_one_tuple(node)
2381 or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
2385 first = node.children[0]
2386 last = node.children[-1]
2387 if first.type == token.LPAR and last.type == token.RPAR:
2388 # make parentheses invisible
2389 first.value = "" # type: ignore
2390 last.value = "" # type: ignore
2391 if len(node.children) > 1:
2392 maybe_make_parens_invisible_in_atom(node.children[1])
2398 def is_empty_tuple(node: LN) -> bool:
2399 """Return True if `node` holds an empty tuple."""
2401 node.type == syms.atom
2402 and len(node.children) == 2
2403 and node.children[0].type == token.LPAR
2404 and node.children[1].type == token.RPAR
2408 def is_one_tuple(node: LN) -> bool:
2409 """Return True if `node` holds a tuple with one element, with or without parens."""
2410 if node.type == syms.atom:
2411 if len(node.children) != 3:
2414 lpar, gexp, rpar = node.children
2416 lpar.type == token.LPAR
2417 and gexp.type == syms.testlist_gexp
2418 and rpar.type == token.RPAR
2422 return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
2425 node.type in IMPLICIT_TUPLE
2426 and len(node.children) == 2
2427 and node.children[1].type == token.COMMA
2431 def is_yield(node: LN) -> bool:
2432 """Return True if `node` holds a `yield` or `yield from` expression."""
2433 if node.type == syms.yield_expr:
2436 if node.type == token.NAME and node.value == "yield": # type: ignore
2439 if node.type != syms.atom:
2442 if len(node.children) != 3:
2445 lpar, expr, rpar = node.children
2446 if lpar.type == token.LPAR and rpar.type == token.RPAR:
2447 return is_yield(expr)
2452 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
2453 """Return True if `leaf` is a star or double star in a vararg or kwarg.
2455 If `within` includes VARARGS_PARENTS, this applies to function signatures.
2456 If `within` includes UNPACKING_PARENTS, it applies to right hand-side
2457 extended iterable unpacking (PEP 3132) and additional unpacking
2458 generalizations (PEP 448).
2460 if leaf.type not in STARS or not leaf.parent:
2464 if p.type == syms.star_expr:
2465 # Star expressions are also used as assignment targets in extended
2466 # iterable unpacking (PEP 3132). See what its parent is instead.
2472 return p.type in within
2475 def is_multiline_string(leaf: Leaf) -> bool:
2476 """Return True if `leaf` is a multiline string that actually spans many lines."""
2477 value = leaf.value.lstrip("furbFURB")
2478 return value[:3] in {'"""', "'''"} and "\n" in value
2481 def is_stub_suite(node: Node) -> bool:
2482 """Return True if `node` is a suite with a stub body."""
2484 len(node.children) != 4
2485 or node.children[0].type != token.NEWLINE
2486 or node.children[1].type != token.INDENT
2487 or node.children[3].type != token.DEDENT
2491 return is_stub_body(node.children[2])
2494 def is_stub_body(node: LN) -> bool:
2495 """Return True if `node` is a simple statement containing an ellipsis."""
2496 if not isinstance(node, Node) or node.type != syms.simple_stmt:
2499 if len(node.children) != 2:
2502 child = node.children[0]
2504 child.type == syms.atom
2505 and len(child.children) == 3
2506 and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
2510 def max_delimiter_priority_in_atom(node: LN) -> int:
2511 """Return maximum delimiter priority inside `node`.
2513 This is specific to atoms with contents contained in a pair of parentheses.
2514 If `node` isn't an atom or there are no enclosing parentheses, returns 0.
2516 if node.type != syms.atom:
2519 first = node.children[0]
2520 last = node.children[-1]
2521 if not (first.type == token.LPAR and last.type == token.RPAR):
2524 bt = BracketTracker()
2525 for c in node.children[1:-1]:
2526 if isinstance(c, Leaf):
2529 for leaf in c.leaves():
2532 return bt.max_delimiter_priority()
2538 def ensure_visible(leaf: Leaf) -> None:
2539 """Make sure parentheses are visible.
2541 They could be invisible as part of some statements (see
2542 :func:`normalize_invible_parens` and :func:`visit_import_from`).
2544 if leaf.type == token.LPAR:
2546 elif leaf.type == token.RPAR:
2550 def is_python36(node: Node) -> bool:
2551 """Return True if the current file is using Python 3.6+ features.
2553 Currently looking for:
2555 - trailing commas after * or ** in function signatures and calls.
2557 for n in node.pre_order():
2558 if n.type == token.STRING:
2559 value_head = n.value[:2] # type: ignore
2560 if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
2564 n.type in {syms.typedargslist, syms.arglist}
2566 and n.children[-1].type == token.COMMA
2568 for ch in n.children:
2569 if ch.type in STARS:
2572 if ch.type == syms.argument:
2573 for argch in ch.children:
2574 if argch.type in STARS:
2580 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
2581 """Generate sets of closing bracket IDs that should be omitted in a RHS.
2583 Brackets can be omitted if the entire trailer up to and including
2584 a preceding closing bracket fits in one line.
2586 Yielded sets are cumulative (contain results of previous yields, too). First
2590 omit: Set[LeafID] = set()
2593 length = 4 * line.depth
2594 opening_bracket = None
2595 closing_bracket = None
2596 optional_brackets: Set[LeafID] = set()
2597 inner_brackets: Set[LeafID] = set()
2598 for index, leaf in enumerate_reversed(line.leaves):
2599 length += len(leaf.prefix) + len(leaf.value)
2600 if length > line_length:
2603 comment: Optional[Leaf]
2604 for comment in line.comments_after(leaf, index):
2605 if "\n" in comment.prefix:
2606 break # Oops, standalone comment!
2608 length += len(comment.value)
2611 if comment is not None:
2612 break # There was a standalone comment, we can't continue.
2614 optional_brackets.discard(id(leaf))
2616 if leaf is opening_bracket:
2617 opening_bracket = None
2618 elif leaf.type in CLOSING_BRACKETS:
2619 inner_brackets.add(id(leaf))
2620 elif leaf.type in CLOSING_BRACKETS:
2622 optional_brackets.add(id(opening_bracket))
2625 if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
2626 # Empty brackets would fail a split so treat them as "inner"
2627 # brackets (e.g. only add them to the `omit` set if another
2628 # pair of brackets was good enough.
2629 inner_brackets.add(id(leaf))
2632 opening_bracket = leaf.opening_bracket
2634 omit.add(id(closing_bracket))
2635 omit.update(inner_brackets)
2636 inner_brackets.clear()
2638 closing_bracket = leaf
2641 def get_future_imports(node: Node) -> Set[str]:
2642 """Return a set of __future__ imports in the file."""
2644 for child in node.children:
2645 if child.type != syms.simple_stmt:
2647 first_child = child.children[0]
2648 if isinstance(first_child, Leaf):
2649 # Continue looking if we see a docstring; otherwise stop.
2651 len(child.children) == 2
2652 and first_child.type == token.STRING
2653 and child.children[1].type == token.NEWLINE
2658 elif first_child.type == syms.import_from:
2659 module_name = first_child.children[1]
2660 if not isinstance(module_name, Leaf) or module_name.value != "__future__":
2662 for import_from_child in first_child.children[3:]:
2663 if isinstance(import_from_child, Leaf):
2664 if import_from_child.type == token.NAME:
2665 imports.add(import_from_child.value)
2667 assert import_from_child.type == syms.import_as_names
2668 for leaf in import_from_child.children:
2669 if isinstance(leaf, Leaf) and leaf.type == token.NAME:
2670 imports.add(leaf.value)
2676 PYTHON_EXTENSIONS = {".py", ".pyi"}
2677 BLACKLISTED_DIRECTORIES = {
2678 "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
2682 def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
2683 """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
2684 and have one of the PYTHON_EXTENSIONS.
2686 for child in path.iterdir():
2688 if child.name in BLACKLISTED_DIRECTORIES:
2691 yield from gen_python_files_in_dir(child)
2693 elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
2699 """Provides a reformatting counter. Can be rendered with `str(report)`."""
2702 change_count: int = 0
2704 failure_count: int = 0
2706 def done(self, src: Path, changed: Changed) -> None:
2707 """Increment the counter for successful reformatting. Write out a message."""
2708 if changed is Changed.YES:
2709 reformatted = "would reformat" if self.check else "reformatted"
2711 out(f"{reformatted} {src}")
2712 self.change_count += 1
2715 if changed is Changed.NO:
2716 msg = f"{src} already well formatted, good job."
2718 msg = f"{src} wasn't modified on disk since last run."
2719 out(msg, bold=False)
2720 self.same_count += 1
2722 def failed(self, src: Path, message: str) -> None:
2723 """Increment the counter for failed reformatting. Write out a message."""
2724 err(f"error: cannot format {src}: {message}")
2725 self.failure_count += 1
2728 def return_code(self) -> int:
2729 """Return the exit code that the app should use.
2731 This considers the current state of changed files and failures:
2732 - if there were any failures, return 123;
2733 - if any files were changed and --check is being used, return 1;
2734 - otherwise return 0.
2736 # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
2737 # 126 we have special returncodes reserved by the shell.
2738 if self.failure_count:
2741 elif self.change_count and self.check:
2746 def __str__(self) -> str:
2747 """Render a color report of the current state.
2749 Use `click.unstyle` to remove colors.
2752 reformatted = "would be reformatted"
2753 unchanged = "would be left unchanged"
2754 failed = "would fail to reformat"
2756 reformatted = "reformatted"
2757 unchanged = "left unchanged"
2758 failed = "failed to reformat"
2760 if self.change_count:
2761 s = "s" if self.change_count > 1 else ""
2763 click.style(f"{self.change_count} file{s} {reformatted}", bold=True)
2766 s = "s" if self.same_count > 1 else ""
2767 report.append(f"{self.same_count} file{s} {unchanged}")
2768 if self.failure_count:
2769 s = "s" if self.failure_count > 1 else ""
2771 click.style(f"{self.failure_count} file{s} {failed}", fg="red")
2773 return ", ".join(report) + "."
2776 def assert_equivalent(src: str, dst: str) -> None:
2777 """Raise AssertionError if `src` and `dst` aren't equivalent."""
2782 def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
2783 """Simple visitor generating strings to compare ASTs by content."""
2784 yield f"{' ' * depth}{node.__class__.__name__}("
2786 for field in sorted(node._fields):
2788 value = getattr(node, field)
2789 except AttributeError:
2792 yield f"{' ' * (depth+1)}{field}="
2794 if isinstance(value, list):
2796 if isinstance(item, ast.AST):
2797 yield from _v(item, depth + 2)
2799 elif isinstance(value, ast.AST):
2800 yield from _v(value, depth + 2)
2803 yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}"
2805 yield f"{' ' * depth}) # /{node.__class__.__name__}"
2808 src_ast = ast.parse(src)
2809 except Exception as exc:
2810 major, minor = sys.version_info[:2]
2811 raise AssertionError(
2812 f"cannot use --safe with this file; failed to parse source file "
2813 f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
2814 f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
2818 dst_ast = ast.parse(dst)
2819 except Exception as exc:
2820 log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
2821 raise AssertionError(
2822 f"INTERNAL ERROR: Black produced invalid code: {exc}. "
2823 f"Please report a bug on https://github.com/ambv/black/issues. "
2824 f"This invalid output might be helpful: {log}"
2827 src_ast_str = "\n".join(_v(src_ast))
2828 dst_ast_str = "\n".join(_v(dst_ast))
2829 if src_ast_str != dst_ast_str:
2830 log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
2831 raise AssertionError(
2832 f"INTERNAL ERROR: Black produced code that is not equivalent to "
2834 f"Please report a bug on https://github.com/ambv/black/issues. "
2835 f"This diff might be helpful: {log}"
2839 def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
2840 """Raise AssertionError if `dst` reformats differently the second time."""
2841 newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
2844 diff(src, dst, "source", "first pass"),
2845 diff(dst, newdst, "first pass", "second pass"),
2847 raise AssertionError(
2848 f"INTERNAL ERROR: Black produced different code on the second pass "
2849 f"of the formatter. "
2850 f"Please report a bug on https://github.com/ambv/black/issues. "
2851 f"This diff might be helpful: {log}"
2855 def dump_to_file(*output: str) -> str:
2856 """Dump `output` to a temporary file. Return path to the file."""
2859 with tempfile.NamedTemporaryFile(
2860 mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
2862 for lines in output:
2864 if lines and lines[-1] != "\n":
2869 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
2870 """Return a unified diff string between strings `a` and `b`."""
2873 a_lines = [line + "\n" for line in a.split("\n")]
2874 b_lines = [line + "\n" for line in b.split("\n")]
2876 difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
2880 def cancel(tasks: Iterable[asyncio.Task]) -> None:
2881 """asyncio signal handler that cancels all `tasks` and reports to stderr."""
2887 def shutdown(loop: BaseEventLoop) -> None:
2888 """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
2890 # This part is borrowed from asyncio/runners.py in Python 3.7b2.
2891 to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
2895 for task in to_cancel:
2897 loop.run_until_complete(
2898 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
2901 # `concurrent.futures.Future` objects cannot be cancelled once they
2902 # are already running. There might be some when the `shutdown()` happened.
2903 # Silence their logger's spew about the event loop being closed.
2904 cf_logger = logging.getLogger("concurrent.futures")
2905 cf_logger.setLevel(logging.CRITICAL)
2909 def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
2910 """Replace `regex` with `replacement` twice on `original`.
2912 This is used by string normalization to perform replaces on
2913 overlapping matches.
2915 return regex.sub(replacement, regex.sub(replacement, original))
2918 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
2919 """Like `reversed(enumerate(sequence))` if that were possible."""
2920 index = len(sequence) - 1
2921 for element in reversed(sequence):
2922 yield (index, element)
2926 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
2927 """Return True if `line` is no longer than `line_length`.
2929 Uses the provided `line_str` rendering, if any, otherwise computes a new one.
2932 line_str = str(line).strip("\n")
2934 len(line_str) <= line_length
2935 and "\n" not in line_str # multiline strings
2936 and not line.contains_standalone_comments()
2940 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
2943 def get_cache_file(line_length: int) -> Path:
2944 return CACHE_DIR / f"cache.{line_length}.pickle"
2947 def read_cache(line_length: int) -> Cache:
2948 """Read the cache if it exists and is well formed.
2950 If it is not well formed, the call to write_cache later should resolve the issue.
2952 cache_file = get_cache_file(line_length)
2953 if not cache_file.exists():
2956 with cache_file.open("rb") as fobj:
2958 cache: Cache = pickle.load(fobj)
2959 except pickle.UnpicklingError:
2965 def get_cache_info(path: Path) -> CacheInfo:
2966 """Return the information used to check if a file is already formatted or not."""
2968 return stat.st_mtime, stat.st_size
2972 cache: Cache, sources: Iterable[Path]
2973 ) -> Tuple[List[Path], List[Path]]:
2974 """Split a list of paths into two.
2976 The first list contains paths of files that modified on disk or are not in the
2977 cache. The other list contains paths to non-modified files.
2982 if cache.get(src) != get_cache_info(src):
2989 def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
2990 """Update the cache file."""
2991 cache_file = get_cache_file(line_length)
2993 if not CACHE_DIR.exists():
2994 CACHE_DIR.mkdir(parents=True)
2995 new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
2996 with cache_file.open("wb") as fobj:
2997 pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
3002 if __name__ == "__main__":