-#!/usr/bin/env python3
-
import asyncio
+import pickle
from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor
from enum import Enum
Iterator,
List,
Optional,
+ Pattern,
+ Sequence,
Set,
Tuple,
Type,
Union,
)
+from appdirs import user_cache_dir
from attr import dataclass, Factory
import click
from blib2to3.pgen2 import driver, token
from blib2to3.pgen2.parse import ParseError
-__version__ = "18.4a0"
+
+__version__ = "18.4a6"
DEFAULT_LINE_LENGTH = 88
+
# types
syms = pygram.python_symbols
FileContent = str
Index = int
LN = Union[Leaf, Node]
SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+Timestamp = float
+FileSize = int
+CacheInfo = Tuple[Timestamp, FileSize]
+Cache = Dict[Path, CacheInfo]
out = partial(click.secho, bold=True, err=True)
err = partial(click.secho, fg="red", err=True)
self.consumed = consumed
def trim_prefix(self, leaf: Leaf) -> None:
- leaf.prefix = leaf.prefix[self.consumed:]
+ leaf.prefix = leaf.prefix[self.consumed :]
def leaf_from_consumed(self, leaf: Leaf) -> Leaf:
"""Returns a new Leaf from the consumed part of the prefix."""
- unformatted_prefix = leaf.prefix[:self.consumed]
+ unformatted_prefix = leaf.prefix[: self.consumed]
return Leaf(token.NEWLINE, unformatted_prefix)
DIFF = 2
+class Changed(Enum):
+ NO = 0
+ CACHED = 1
+ YES = 2
+
+
@click.command()
@click.option(
"-l",
sources.append(Path("-"))
else:
err(f"invalid path: {s}")
- if check and diff:
- exc = click.ClickException("Options --check and --diff are mutually exclusive")
- exc.exit_code = 2
- raise exc
- if check:
+ if check and not diff:
write_back = WriteBack.NO
elif diff:
write_back = WriteBack.DIFF
else:
write_back = WriteBack.YES
+ report = Report(check=check, quiet=quiet)
if len(sources) == 0:
+ out("No paths given. Nothing to do 😴")
ctx.exit(0)
+ return
+
elif len(sources) == 1:
- p = sources[0]
- report = Report(check=check, quiet=quiet)
- try:
- if not p.is_file() and str(p) == "-":
- changed = format_stdin_to_stdout(
- line_length=line_length, fast=fast, write_back=write_back
- )
- else:
- changed = format_file_in_place(
- p, line_length=line_length, fast=fast, write_back=write_back
- )
- report.done(p, changed)
- except Exception as exc:
- report.failed(p, str(exc))
- ctx.exit(report.return_code)
+ reformat_one(sources[0], line_length, fast, write_back, report)
else:
loop = asyncio.get_event_loop()
executor = ProcessPoolExecutor(max_workers=os.cpu_count())
- return_code = 1
try:
- return_code = loop.run_until_complete(
+ loop.run_until_complete(
schedule_formatting(
- sources, line_length, write_back, fast, quiet, loop, executor
+ sources, line_length, fast, write_back, report, loop, executor
)
)
finally:
shutdown(loop)
- ctx.exit(return_code)
+ if not quiet:
+ out("All done! ✨ 🍰 ✨")
+ click.echo(str(report))
+ ctx.exit(report.return_code)
+
+
+def reformat_one(
+ src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
+) -> None:
+ """Reformat a single file under `src` without spawning child processes.
+
+ If `quiet` is True, non-error messages are not output. `line_length`,
+ `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
+ """
+ try:
+ changed = Changed.NO
+ if not src.is_file() and str(src) == "-":
+ if format_stdin_to_stdout(
+ line_length=line_length, fast=fast, write_back=write_back
+ ):
+ changed = Changed.YES
+ else:
+ cache: Cache = {}
+ if write_back != WriteBack.DIFF:
+ cache = read_cache(line_length)
+ src = src.resolve()
+ if src in cache and cache[src] == get_cache_info(src):
+ changed = Changed.CACHED
+ if (
+ changed is not Changed.CACHED
+ and format_file_in_place(
+ src, line_length=line_length, fast=fast, write_back=write_back
+ )
+ ):
+ changed = Changed.YES
+ if write_back == WriteBack.YES and changed is not Changed.NO:
+ write_cache(cache, [src], line_length)
+ report.done(src, changed)
+ except Exception as exc:
+ report.failed(src, str(exc))
async def schedule_formatting(
sources: List[Path],
line_length: int,
- write_back: WriteBack,
fast: bool,
- quiet: bool,
+ write_back: WriteBack,
+ report: "Report",
loop: BaseEventLoop,
executor: Executor,
-) -> int:
+) -> None:
"""Run formatting of `sources` in parallel using the provided `executor`.
(Use ProcessPoolExecutors for actual parallelism.)
`line_length`, `write_back`, and `fast` options are passed to
:func:`format_file_in_place`.
"""
- lock = None
- if write_back == WriteBack.DIFF:
- # For diff output, we need locks to ensure we don't interleave output
- # from different processes.
- manager = Manager()
- lock = manager.Lock()
- tasks = {
- src: loop.run_in_executor(
- executor, format_file_in_place, src, line_length, fast, write_back, lock
- )
- for src in sources
- }
- _task_values = list(tasks.values())
- loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
- loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
- await asyncio.wait(tasks.values())
+ cache: Cache = {}
+ if write_back != WriteBack.DIFF:
+ cache = read_cache(line_length)
+ sources, cached = filter_cached(cache, sources)
+ for src in cached:
+ report.done(src, Changed.CACHED)
cancelled = []
- report = Report(check=write_back is WriteBack.NO, quiet=quiet)
- for src, task in tasks.items():
- if not task.done():
- report.failed(src, "timed out, cancelling")
- task.cancel()
- cancelled.append(task)
- elif task.cancelled():
- cancelled.append(task)
- elif task.exception():
- report.failed(src, str(task.exception()))
- else:
- report.done(src, task.result())
+ formatted = []
+ if sources:
+ lock = None
+ if write_back == WriteBack.DIFF:
+ # For diff output, we need locks to ensure we don't interleave output
+ # from different processes.
+ manager = Manager()
+ lock = manager.Lock()
+ tasks = {
+ src: loop.run_in_executor(
+ executor, format_file_in_place, src, line_length, fast, write_back, lock
+ )
+ for src in sources
+ }
+ _task_values = list(tasks.values())
+ try:
+ loop.add_signal_handler(signal.SIGINT, cancel, _task_values)
+ loop.add_signal_handler(signal.SIGTERM, cancel, _task_values)
+ except NotImplementedError:
+ # There are no good alternatives for these on Windows
+ pass
+ await asyncio.wait(_task_values)
+ for src, task in tasks.items():
+ if not task.done():
+ report.failed(src, "timed out, cancelling")
+ task.cancel()
+ cancelled.append(task)
+ elif task.cancelled():
+ cancelled.append(task)
+ elif task.exception():
+ report.failed(src, str(task.exception()))
+ else:
+ formatted.append(src)
+ report.done(src, Changed.YES if task.result() else Changed.NO)
+
if cancelled:
await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
- elif not quiet:
- out("All done! ✨ 🍰 ✨")
- if not quiet:
- click.echo(str(report))
- return report.return_code
+ if write_back == WriteBack.YES and formatted:
+ write_cache(cache, formatted, line_length)
def format_file_in_place(
If `write_back` is True, write reformatted code back to stdout.
`line_length` and `fast` options are passed to :func:`format_file_contents`.
"""
+ is_pyi = src.suffix == ".pyi"
+
with tokenize.open(src) as src_buffer:
src_contents = src_buffer.read()
try:
dst_contents = format_file_contents(
- src_contents, line_length=line_length, fast=fast
+ src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
)
except NothingChanged:
return False
with open(src, "w", encoding=src_buffer.encoding) as f:
f.write(dst_contents)
elif write_back == write_back.DIFF:
- src_name = f"{src.name} (original)"
- dst_name = f"{src.name} (formatted)"
+ src_name = f"{src} (original)"
+ dst_name = f"{src} (formatted)"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
if lock:
lock.acquire()
def format_file_contents(
- src_contents: str, line_length: int, fast: bool
+ src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
) -> FileContent:
"""Reformat contents a file and return new contents.
if src_contents.strip() == "":
raise NothingChanged
- dst_contents = format_str(src_contents, line_length=line_length)
+ dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
if src_contents == dst_contents:
raise NothingChanged
if not fast:
assert_equivalent(src_contents, dst_contents)
- assert_stable(src_contents, dst_contents, line_length=line_length)
+ assert_stable(
+ src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
+ )
return dst_contents
-def format_str(src_contents: str, line_length: int) -> FileContent:
+def format_str(
+ src_contents: str, line_length: int, *, is_pyi: bool = False
+) -> FileContent:
"""Reformat a string and return new contents.
`line_length` determines how many characters per line are allowed.
"""
src_node = lib2to3_parse(src_contents)
dst_contents = ""
- lines = LineGenerator()
- elt = EmptyLineTracker()
+ future_imports = get_future_imports(src_node)
+ elt = EmptyLineTracker(is_pyi=is_pyi)
py36 = is_python36(src_node)
+ lines = LineGenerator(
+ remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+ )
empty_line = Line()
after = 0
for current_line in lines.visit(src_node):
GRAMMARS = [
pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement,
- pygram.python_grammar_no_exec_statement,
pygram.python_grammar,
]
token.GREATEREQUAL,
}
MATH_OPERATORS = {
+ token.VBAR,
+ token.CIRCUMFLEX,
+ token.AMPER,
+ token.LEFTSHIFT,
+ token.RIGHTSHIFT,
token.PLUS,
token.MINUS,
token.STAR,
token.SLASH,
- token.VBAR,
- token.AMPER,
+ token.DOUBLESLASH,
token.PERCENT,
- token.CIRCUMFLEX,
+ token.AT,
token.TILDE,
- token.LEFTSHIFT,
- token.RIGHTSHIFT,
token.DOUBLESTAR,
- token.DOUBLESLASH,
}
-VARARGS = {token.STAR, token.DOUBLESTAR}
+STARS = {token.STAR, token.DOUBLESTAR}
+VARARGS_PARENTS = {
+ syms.arglist,
+ syms.argument, # double star in arglist
+ syms.trailer, # single argument to call
+ syms.typedargslist,
+ syms.varargslist, # lambdas
+}
+UNPACKING_PARENTS = {
+ syms.atom, # single element of a list or set literal
+ syms.dictsetmaker,
+ syms.listmaker,
+ syms.testlist_gexp,
+}
+TEST_DESCENDANTS = {
+ syms.test,
+ syms.lambdef,
+ syms.or_test,
+ syms.and_test,
+ syms.not_test,
+ syms.comparison,
+ syms.star_expr,
+ syms.expr,
+ syms.xor_expr,
+ syms.and_expr,
+ syms.shift_expr,
+ syms.arith_expr,
+ syms.trailer,
+ syms.term,
+ syms.power,
+}
+ASSIGNMENTS = {
+ "=",
+ "+=",
+ "-=",
+ "*=",
+ "@=",
+ "/=",
+ "%=",
+ "&=",
+ "|=",
+ "^=",
+ "<<=",
+ ">>=",
+ "**=",
+ "//=",
+}
COMPREHENSION_PRIORITY = 20
-COMMA_PRIORITY = 10
-LOGIC_PRIORITY = 5
-STRING_PRIORITY = 4
-COMPARATOR_PRIORITY = 3
-MATH_PRIORITY = 1
+COMMA_PRIORITY = 18
+TERNARY_PRIORITY = 16
+LOGIC_PRIORITY = 14
+STRING_PRIORITY = 12
+COMPARATOR_PRIORITY = 10
+MATH_PRIORITIES = {
+ token.VBAR: 8,
+ token.CIRCUMFLEX: 7,
+ token.AMPER: 6,
+ token.LEFTSHIFT: 5,
+ token.RIGHTSHIFT: 5,
+ token.PLUS: 4,
+ token.MINUS: 4,
+ token.STAR: 3,
+ token.SLASH: 3,
+ token.DOUBLESLASH: 3,
+ token.PERCENT: 3,
+ token.AT: 3,
+ token.TILDE: 2,
+ token.DOUBLESTAR: 1,
+}
@dataclass
bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
delimiters: Dict[LeafID, Priority] = Factory(dict)
previous: Optional[Leaf] = None
+ _for_loop_variable: int = 0
+ _lambda_arguments: int = 0
def mark(self, leaf: Leaf) -> None:
"""Mark `leaf` with bracket-related metadata. Keep track of delimiters.
if leaf.type == token.COMMENT:
return
+ self.maybe_decrement_after_for_loop_variable(leaf)
+ self.maybe_decrement_after_lambda_arguments(leaf)
if leaf.type in CLOSING_BRACKETS:
self.depth -= 1
opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf
self.depth += 1
self.previous = leaf
+ self.maybe_increment_lambda_arguments(leaf)
+ self.maybe_increment_for_loop_variable(leaf)
def any_open_brackets(self) -> bool:
"""Return True if there is an yet unmatched open bracket on the line."""
def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
"""Return the highest priority of a delimiter found on the line.
- Values are consistent with what `is_delimiter()` returns.
+ Values are consistent with what `is_split_*_delimiter()` return.
Raises ValueError on no delimiters.
"""
return max(v for k, v in self.delimiters.items() if k not in exclude)
+ def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
+ """In a for loop, or comprehension, the variables are often unpacks.
+
+ To avoid splitting on the comma in this situation, increase the depth of
+ tokens between `for` and `in`.
+ """
+ if leaf.type == token.NAME and leaf.value == "for":
+ self.depth += 1
+ self._for_loop_variable += 1
+ return True
+
+ return False
+
+ def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
+ """See `maybe_increment_for_loop_variable` above for explanation."""
+ if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
+ self.depth -= 1
+ self._for_loop_variable -= 1
+ return True
+
+ return False
+
+ def maybe_increment_lambda_arguments(self, leaf: Leaf) -> bool:
+ """In a lambda expression, there might be more than one argument.
+
+ To avoid splitting on the comma in this situation, increase the depth of
+ tokens between `lambda` and `:`.
+ """
+ if leaf.type == token.NAME and leaf.value == "lambda":
+ self.depth += 1
+ self._lambda_arguments += 1
+ return True
+
+ return False
+
+ def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
+ """See `maybe_increment_lambda_arguments` above for explanation."""
+ if self._lambda_arguments and leaf.type == token.COLON:
+ self.depth -= 1
+ self._lambda_arguments -= 1
+ return True
+
+ return False
+
+ def get_open_lsqb(self) -> Optional[Leaf]:
+ """Return the most recent opening square bracket (if any)."""
+ return self.bracket_match.get((self.depth - 1, token.RSQB))
+
@dataclass
class Line:
comments: List[Tuple[Index, Leaf]] = Factory(list)
bracket_tracker: BracketTracker = Factory(BracketTracker)
inside_brackets: bool = False
- has_for: bool = False
- _for_loop_variable: bool = False
def append(self, leaf: Leaf, preformatted: bool = False) -> None:
"""Add a new `leaf` to the end of the line.
if not has_value:
return
+ if token.COLON == leaf.type and self.is_class_paren_empty:
+ del self.leaves[-2:]
if self.leaves and not preformatted:
# Note: at this point leaf.prefix should be empty except for
# imports, for which we only preserve newlines.
- leaf.prefix += whitespace(leaf)
+ leaf.prefix += whitespace(
+ leaf, complex_subscript=self.is_complex_subscript(leaf)
+ )
if self.inside_brackets or not preformatted:
- self.maybe_decrement_after_for_loop_variable(leaf)
self.bracket_tracker.mark(leaf)
self.maybe_remove_trailing_comma(leaf)
- self.maybe_increment_for_loop_variable(leaf)
-
if not self.append_comment(leaf):
self.leaves.append(leaf)
and self.leaves[0].value == "class"
)
+ @property
+ def is_trivial_class(self) -> bool:
+ """Is this line a class definition with a body consisting only of "..."?"""
+ return (
+ self.is_class
+ and self.leaves[-3:] == [Leaf(token.DOT, ".") for _ in range(3)]
+ )
+
@property
def is_def(self) -> bool:
"""Is this a function definition? (Also returns True for async defs.)"""
and self.leaves[0].value == "yield"
)
+ @property
+ def is_class_paren_empty(self) -> bool:
+ """Is this a class with no base classes but using parentheses?
+
+ Those are unnecessary and should be removed.
+ """
+ return (
+ bool(self)
+ and len(self.leaves) == 4
+ and self.is_class
+ and self.leaves[2].type == token.LPAR
+ and self.leaves[2].value == "("
+ and self.leaves[3].type == token.RPAR
+ and self.leaves[3].value == ")"
+ )
+
def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
"""If so, needs to be split before emitting."""
for leaf in self.leaves:
self.remove_trailing_comma()
return True
- # For parens let's check if it's safe to remove the comma. If the
- # trailing one is the only one, we might mistakenly change a tuple
- # into a different type by removing the comma.
+ # For parens let's check if it's safe to remove the comma.
+ # Imports are always safe.
+ if self.is_import:
+ self.remove_trailing_comma()
+ return True
+
+ # Otheriwsse, if the trailing one is the only one, we might mistakenly
+ # change a tuple into a different type by removing the comma.
depth = closing.bracket_depth + 1
commas = 0
opening = closing.opening_bracket
else:
return False
- for leaf in self.leaves[_opening_index + 1:]:
+ for leaf in self.leaves[_opening_index + 1 :]:
if leaf is closing:
break
return False
- def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
- """In a for loop, or comprehension, the variables are often unpacks.
-
- To avoid splitting on the comma in this situation, increase the depth of
- tokens between `for` and `in`.
- """
- if leaf.type == token.NAME and leaf.value == "for":
- self.has_for = True
- self.bracket_tracker.depth += 1
- self._for_loop_variable = True
- return True
-
- return False
-
- def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
- """See `maybe_increment_for_loop_variable` above for explanation."""
- if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
- self.bracket_tracker.depth -= 1
- self._for_loop_variable = False
- return True
-
- return False
-
def append_comment(self, comment: Leaf) -> bool:
"""Add an inline or standalone comment to the line."""
if (
self.comments.append((after, comment))
return True
- def comments_after(self, leaf: Leaf) -> Iterator[Leaf]:
- """Generate comments that should appear directly after `leaf`."""
- for _leaf_index, _leaf in enumerate(self.leaves):
- if leaf is _leaf:
- break
+ def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
+ """Generate comments that should appear directly after `leaf`.
- else:
- return
+ Provide a non-negative leaf `_index` to speed up the function.
+ """
+ if _index == -1:
+ for _index, _leaf in enumerate(self.leaves):
+ if leaf is _leaf:
+ break
+
+ else:
+ return
for index, comment_after in self.comments:
- if _leaf_index == index:
+ if _index == index:
yield comment_after
def remove_trailing_comma(self) -> None:
self.comments[i] = (comma_index - 1, comment)
self.leaves.pop()
+ def is_complex_subscript(self, leaf: Leaf) -> bool:
+ """Return True iff `leaf` is part of a slice with non-trivial exprs."""
+ open_lsqb = (
+ leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
+ )
+ if open_lsqb is None:
+ return False
+
+ subscript_start = open_lsqb.next_sibling
+ if (
+ isinstance(subscript_start, Node)
+ and subscript_start.type == syms.subscriptlist
+ ):
+ subscript_start = child_towards(subscript_start, leaf)
+ return (
+ subscript_start is not None
+ and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
+ )
+
def __str__(self) -> str:
"""Render the line."""
if not self:
the prefix of the first leaf consists of optional newlines. Those newlines
are consumed by `maybe_empty_lines()` and included in the computation.
"""
+ is_pyi: bool = False
previous_line: Optional[Line] = None
previous_after: int = 0
previous_defs: List[int] = Factory(list)
def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
max_allowed = 1
if current_line.depth == 0:
- max_allowed = 2
+ max_allowed = 1 if self.is_pyi else 2
if current_line.leaves:
# Consume the first leaf's extra newlines.
first_leaf = current_line.leaves[0]
depth = current_line.depth
while self.previous_defs and self.previous_defs[-1] >= depth:
self.previous_defs.pop()
- before = 1 if depth else 2
+ if self.is_pyi:
+ before = 0 if depth else 1
+ else:
+ before = 1 if depth else 2
is_decorator = current_line.is_decorator
if is_decorator or current_line.is_def or current_line.is_class:
if not is_decorator:
# Don't insert empty lines before the first line in the file.
return 0, 0
- if self.previous_line and self.previous_line.is_decorator:
- # Don't insert empty lines between decorators.
+ if self.previous_line.is_decorator:
+ return 0, 0
+
+ if (
+ self.previous_line.is_comment
+ and self.previous_line.depth == current_line.depth
+ and before == 0
+ ):
return 0, 0
- newlines = 2
- if current_line.depth:
+ if self.is_pyi:
+ if self.previous_line.depth > current_line.depth:
+ newlines = 1
+ elif current_line.is_class or self.previous_line.is_class:
+ if (
+ current_line.is_trivial_class
+ and self.previous_line.is_trivial_class
+ ):
+ newlines = 0
+ else:
+ newlines = 1
+ else:
+ newlines = 0
+ else:
+ newlines = 2
+ if current_line.depth and newlines:
newlines -= 1
return newlines, 0
- if current_line.is_flow_control:
- return before, 1
-
if (
self.previous_line
and self.previous_line.is_import
):
return (before or 1), 0
- if (
- self.previous_line
- and self.previous_line.is_yield
- and (not current_line.is_yield or depth != self.previous_line.depth)
- ):
- return (before or 1), 0
-
return before, 0
Note: destroys the tree it's visiting by mutating prefixes of its leaves
in ways that will no longer stringify to valid Python code on the tree.
"""
+ is_pyi: bool = False
current_line: Line = Factory(Line)
+ remove_u_prefix: bool = False
def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
"""Generate a line.
else:
normalize_prefix(node, inside_brackets=any_open_brackets)
if node.type == token.STRING:
+ normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
normalize_string_quotes(node)
if node.type not in WHITESPACE:
self.current_line.append(node)
def visit_DEDENT(self, node: Node) -> Iterator[Line]:
"""Decrease indentation level, maybe yield a line."""
- # DEDENT has no value. Additionally, in blib2to3 it never holds comments.
+ # The current line might still wait for trailing comments. At DEDENT time
+ # there won't be any (they would be prefixes on the preceding NEWLINE).
+ # Emit the line then.
+ yield from self.line()
+
+ # While DEDENT has no value, its prefix may contain standalone comments
+ # that belong to the current indentation level. Get 'em.
+ yield from self.visit_default(node)
+
+ # Finally, emit the dedent.
yield from self.line(-1)
def visit_stmt(
"""Visit a statement.
This implementation is shared for `if`, `while`, `for`, `try`, `except`,
- `def`, `with`, `class`, and `assert`.
+ `def`, `with`, `class`, `assert` and assignments.
The relevant Python language `keywords` for a given statement will be
NAME leaves within it. This methods puts those on a separate line.
- `parens` holds pairs of nodes where invisible parentheses should be put.
- Keys hold nodes after which opening parentheses should be put, values
- hold nodes before which closing parentheses should be put.
+ `parens` holds a set of string leaf values immeditely after which
+ invisible parens should be put.
"""
normalize_invisible_parens(node, parens_after=parens)
for child in node.children:
yield from self.visit(child)
+ def visit_suite(self, node: Node) -> Iterator[Line]:
+ """Visit a suite."""
+ if self.is_pyi and self.is_trivial_suite(node):
+ yield from self.visit(node.children[2])
+ else:
+ yield from self.visit_default(node)
+
+ def is_trivial_suite(self, node: Node) -> bool:
+ if len(node.children) != 4:
+ return False
+ if (
+ not isinstance(node.children[0], Leaf)
+ or node.children[0].type != token.NEWLINE
+ ):
+ return False
+ if (
+ not isinstance(node.children[1], Leaf)
+ or node.children[1].type != token.INDENT
+ ):
+ return False
+ if (
+ not isinstance(node.children[3], Leaf)
+ or node.children[3].type != token.DEDENT
+ ):
+ return False
+ stmt = node.children[2]
+ if not isinstance(stmt, Node):
+ return False
+ return self.is_trivial_body(stmt)
+
+ def is_trivial_body(self, stmt: Node) -> bool:
+ if not isinstance(stmt, Node) or stmt.type != syms.simple_stmt:
+ return False
+ if len(stmt.children) != 2:
+ return False
+ child = stmt.children[0]
+ return (
+ child.type == syms.atom
+ and len(child.children) == 3
+ and all(leaf == Leaf(token.DOT, ".") for leaf in child.children)
+ )
+
def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
"""Visit a statement without nested statements."""
is_suite_like = node.parent and node.parent.type in STATEMENT
if is_suite_like:
- yield from self.line(+1)
- yield from self.visit_default(node)
- yield from self.line(-1)
+ if self.is_pyi and self.is_trivial_body(node):
+ yield from self.visit_default(node)
+ else:
+ yield from self.line(+1)
+ yield from self.visit_default(node)
+ yield from self.line(-1)
else:
- yield from self.line()
+ if (
+ not self.is_pyi
+ or not node.parent
+ or not self.is_trivial_suite(node.parent)
+ ):
+ yield from self.line()
yield from self.visit_default(node)
def visit_async_stmt(self, node: Node) -> Iterator[Line]:
v = self.visit_stmt
Ø: Set[str] = set()
self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
- self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
+ self.visit_if_stmt = partial(
+ v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
+ )
self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
self.visit_try_stmt = partial(
self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
+ self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
+ self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
self.visit_async_funcdef = self.visit_async_stmt
self.visit_decorated = self.visit_decorators
ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
-def whitespace(leaf: Leaf) -> str: # noqa C901
- """Return whitespace prefix if needed for the given `leaf`."""
+def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa C901
+ """Return whitespace prefix if needed for the given `leaf`.
+
+ `complex_subscript` signals whether the given leaf is part of a subscription
+ which has non-trivial arguments, like arithmetic expressions or function calls.
+ """
NO = ""
SPACE = " "
DOUBLESPACE = " "
return DOUBLESPACE
assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
- if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
+ if (
+ t == token.COLON
+ and p.type not in {syms.subscript, syms.subscriptlist, syms.sliceop}
+ ):
return NO
prev = leaf.prev_sibling
return NO
if t == token.COLON:
- return SPACE if prevp.type == token.COMMA else NO
+ if prevp.type == token.COLON:
+ return NO
+
+ elif prevp.type != token.COMMA and not complex_subscript:
+ return NO
+
+ return SPACE
if prevp.type == token.EQUAL:
if prevp.parent:
# that, too.
return prevp.prefix
- elif prevp.type == token.DOUBLESTAR:
- if (
- prevp.parent
- and prevp.parent.type in {
- syms.arglist,
- syms.argument,
- syms.dictsetmaker,
- syms.parameters,
- syms.typedargslist,
- syms.varargslist,
- }
- ):
+ elif prevp.type in STARS:
+ if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
return NO
elif prevp.type == token.COLON:
if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
- return NO
+ return SPACE if complex_subscript else NO
elif (
prevp.parent
- and prevp.parent.type in {syms.factor, syms.star_expr}
+ and prevp.parent.type == syms.factor
and prevp.type in MATH_OPERATORS
):
return NO
if p.type in {syms.parameters, syms.arglist}:
# untyped function signatures or calls
- if t == token.RPAR:
- return NO
-
if not prev or prev.type != token.COMMA:
return NO
elif p.type == syms.varargslist:
# lambdas
- if t == token.RPAR:
- return NO
-
if prev and prev.type != token.COMMA:
return NO
if not prevp or prevp.type == token.LPAR:
return NO
- elif prev.type == token.EQUAL or prev.type == token.DOUBLESTAR:
+ elif prev.type in {token.EQUAL} | STARS:
return NO
elif p.type == syms.decorator:
if prev and prev.type == token.LPAR:
return NO
- elif p.type == syms.subscript:
+ elif p.type in {syms.subscript, syms.sliceop}:
# indexing
if not prev:
assert p.parent is not None, "subscripts are always parented"
return NO
- else:
+ elif not complex_subscript:
return NO
elif p.type == syms.atom:
# dots, but not the first one.
return NO
- elif (
- p.type == syms.listmaker
- or p.type == syms.testlist_gexp
- or p.type == syms.subscriptlist
- ):
- # list interior, including unpacking
- if not prev:
- return NO
-
elif p.type == syms.dictsetmaker:
- # dict and set interior, including unpacking
- if not prev:
- return NO
-
- if prev.type == token.DOUBLESTAR:
+ # dict unpacking
+ if prev and prev.type == token.DOUBLESTAR:
return NO
elif p.type in {syms.factor, syms.star_expr}:
return None
+def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]:
+ """Return the child of `ancestor` that contains `descendant`."""
+ node: Optional[LN] = descendant
+ while node and node.parent != ancestor:
+ node = node.parent
+ return node
+
+
def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
"""Return the priority of the `leaf` delimiter, given a line break after it.
Higher numbers are higher priority.
"""
- if (
- leaf.type in VARARGS
- and leaf.parent
- and leaf.parent.type in {syms.argument, syms.typedargslist}
- ):
+ if is_vararg(leaf, within=VARARGS_PARENTS | UNPACKING_PARENTS):
# * and ** might also be MATH_OPERATORS but in this case they are not.
# Don't treat them as a delimiter.
return 0
and leaf.parent
and leaf.parent.type not in {syms.factor, syms.star_expr}
):
- return MATH_PRIORITY
+ return MATH_PRIORITIES[leaf.type]
if leaf.type in COMPARATORS:
return COMPARATOR_PRIORITY
):
return COMPREHENSION_PRIORITY
+ if (
+ leaf.type == token.NAME
+ and leaf.value in {"if", "else"}
+ and leaf.parent
+ and leaf.parent.type == syms.test
+ ):
+ return TERNARY_PRIORITY
+
if leaf.type == token.NAME and leaf.value in LOGIC_OPERATORS and leaf.parent:
return LOGIC_PRIORITY
return 0
-def is_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
- """Return the priority of the `leaf` delimiter. Return 0 if not delimiter.
-
- Higher numbers are higher priority.
- """
- return max(
- is_split_before_delimiter(leaf, previous),
- is_split_after_delimiter(leaf, previous),
- )
-
-
def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
"""Clean the prefix of the `leaf` and generate comments from it, if any.
return
line_str = str(line).strip("\n")
- if (
- len(line_str) <= line_length
- and "\n" not in line_str # multiline strings
- and not line.contains_standalone_comments()
- ):
+ if is_line_short_enough(line, line_length=line_length, line_str=line_str):
yield line
return
split_funcs: List[SplitFunc]
if line.is_def:
split_funcs = [left_hand_split]
- elif line.inside_brackets:
- split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
+ elif line.is_import:
+ split_funcs = [explode_split]
else:
- split_funcs = [right_hand_split]
+
+ def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+ for omit in generate_trailers_to_omit(line, line_length):
+ lines = list(right_hand_split(line, py36, omit=omit))
+ if is_line_short_enough(lines[0], line_length=line_length):
+ yield from lines
+ return
+
+ # All splits failed, best effort split with no omits.
+ yield from right_hand_split(line, py36)
+
+ if line.inside_brackets:
+ split_funcs = [delimiter_split, standalone_comment_split, rhs]
+ else:
+ split_funcs = [rhs]
for split_func in split_funcs:
# We are accumulating lines in `result` because we might want to abort
# mission and return the original line in the end, or attempt a different
"""Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions.
- Prefer RHS otherwise.
+ Prefer RHS otherwise. This is why this function is not symmetrical with
+ :func:`right_hand_split` which also handles optional parentheses.
"""
head = Line(depth=line.depth)
body = Line(depth=line.depth + 1, inside_brackets=True)
def right_hand_split(
line: Line, py36: bool = False, omit: Collection[LeafID] = ()
) -> Iterator[Line]:
- """Split line into many lines, starting with the last matching bracket pair."""
+ """Split line into many lines, starting with the last matching bracket pair.
+
+ If the split was by optional parentheses, attempt splitting without them, too.
+ `omit` is a collection of closing bracket IDs that shouldn't be considered for
+ this split.
+ """
head = Line(depth=line.depth)
body = Line(depth=line.depth + 1, inside_brackets=True)
tail = Line(depth=line.depth)
bracket_split_succeeded_or_raise(head, body, tail)
assert opening_bracket and closing_bracket
if (
+ # the opening bracket is an optional paren
opening_bracket.type == token.LPAR
and not opening_bracket.value
+ # the closing bracket is an optional paren
and closing_bracket.type == token.RPAR
and not closing_bracket.value
+ # there are no delimiters or standalone comments in the body
+ and not body.bracket_tracker.delimiters
+ and not line.contains_standalone_comments(0)
+ # and it's not an import (optional parens are the only thing we can split
+ # on in this case; attempting a split without them is a waste of time)
+ and not line.is_import
):
- # These parens were optional. If there aren't any delimiters or standalone
- # comments in the body, they were unnecessary and another split without
- # them should be attempted.
- if not (
- body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
- ):
- omit = {id(closing_bracket), *omit}
+ omit = {id(closing_bracket), *omit}
+ try:
yield from right_hand_split(line, py36=py36, omit=omit)
return
+ except CannotSplit:
+ pass
ensure_visible(opening_bracket)
ensure_visible(closing_bracket)
current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
current_line.append(leaf)
- for leaf in line.leaves:
+ for index, leaf in enumerate(line.leaves):
yield from append_to_line(leaf)
- for comment_after in line.comments_after(leaf):
+ for comment_after in line.comments_after(leaf, index):
yield from append_to_line(comment_after)
lowest_depth = min(lowest_depth, leaf.bracket_depth)
if (
leaf.bracket_depth == lowest_depth
- and leaf.type == token.STAR
- or leaf.type == token.DOUBLESTAR
+ and is_vararg(leaf, within=VARARGS_PARENTS)
):
trailing_comma_safe = trailing_comma_safe and py36
leaf_priority = delimiters.get(id(leaf))
current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
current_line.append(leaf)
- for leaf in line.leaves:
+ for index, leaf in enumerate(line.leaves):
yield from append_to_line(leaf)
- for comment_after in line.comments_after(leaf):
+ for comment_after in line.comments_after(leaf, index):
yield from append_to_line(comment_after)
if current_line:
yield current_line
+def explode_split(
+ line: Line, py36: bool = False, omit: Collection[LeafID] = ()
+) -> Iterator[Line]:
+ """Split by rightmost bracket and immediately split contents by a delimiter."""
+ new_lines = list(right_hand_split(line, py36, omit))
+ if len(new_lines) != 3:
+ yield from new_lines
+ return
+
+ yield new_lines[0]
+
+ try:
+ yield from delimiter_split(new_lines[1], py36)
+
+ except CannotSplit:
+ yield new_lines[1]
+
+ yield new_lines[2]
+
+
def is_import(leaf: Leaf) -> bool:
"""Return True if the given leaf starts an import statement."""
p = leaf.parent
leaf.prefix = ""
+def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
+ """Make all string prefixes lowercase.
+
+ If remove_u_prefix is given, also removes any u prefix from the string.
+
+ Note: Mutates its argument.
+ """
+ match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
+ assert match is not None, f"failed to match string {leaf.value!r}"
+ orig_prefix = match.group(1)
+ new_prefix = orig_prefix.lower()
+ if remove_u_prefix:
+ new_prefix = new_prefix.replace("u", "")
+ leaf.value = f"{new_prefix}{match.group(2)}"
+
+
def normalize_string_quotes(leaf: Leaf) -> None:
"""Prefer double quotes but only if it doesn't cause more escaping.
return # There's an internal error
prefix = leaf.value[:first_quote_pos]
- body = leaf.value[first_quote_pos + len(orig_quote):-len(orig_quote)]
unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
- escaped_orig_quote = re.compile(rf"\\(\\\\)*{orig_quote}")
+ escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
+ escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
+ body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
if "r" in prefix.casefold():
if unescaped_new_quote.search(body):
# There's at least one unescaped new_quote in this raw string
# Do not introduce or remove backslashes in raw strings
new_body = body
else:
- new_body = escaped_orig_quote.sub(rf"\1{orig_quote}", body)
- new_body = unescaped_new_quote.sub(rf"\1\\{new_quote}", new_body)
+ # remove unnecessary quotes
+ new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
+ if body != new_body:
+ # Consider the string without unnecessary quotes as the original
+ body = new_body
+ leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
+ new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
+ new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
if new_quote == '"""' and new_body[-1] == '"':
# edge case:
new_body = new_body[:-1] + '\\"'
def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
"""Make existing optional parentheses invisible or create new ones.
+ `parens_after` is a set of string leaf values immeditely after which parens
+ should be put.
+
Standardizes on visible parentheses for single-element tuples, and keeps
existing visible parentheses for other tuples and generator expressions.
"""
for child in list(node.children):
if check_lpar:
if child.type == syms.atom:
- if not (
- is_empty_tuple(child)
- or is_one_tuple(child)
- or max_delimiter_priority_in_atom(child) >= COMMA_PRIORITY
- ):
- first = child.children[0]
- last = child.children[-1]
- if first.type == token.LPAR and last.type == token.RPAR:
- # make parentheses invisible
- first.value = "" # type: ignore
- last.value = "" # type: ignore
+ maybe_make_parens_invisible_in_atom(child)
elif is_one_tuple(child):
# wrap child in visible parentheses
lpar = Leaf(token.LPAR, "(")
check_lpar = isinstance(child, Leaf) and child.value in parens_after
+def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
+ """If it's safe, make the parens in the atom `node` invisible, recusively."""
+ if (
+ node.type != syms.atom
+ or is_empty_tuple(node)
+ or is_one_tuple(node)
+ or is_yield(node)
+ or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
+ ):
+ return False
+
+ first = node.children[0]
+ last = node.children[-1]
+ if first.type == token.LPAR and last.type == token.RPAR:
+ # make parentheses invisible
+ first.value = "" # type: ignore
+ last.value = "" # type: ignore
+ if len(node.children) > 1:
+ maybe_make_parens_invisible_in_atom(node.children[1])
+ return True
+
+ return False
+
+
def is_empty_tuple(node: LN) -> bool:
"""Return True if `node` holds an empty tuple."""
return (
)
+def is_yield(node: LN) -> bool:
+ """Return True if `node` holds a `yield` or `yield from` expression."""
+ if node.type == syms.yield_expr:
+ return True
+
+ if node.type == token.NAME and node.value == "yield": # type: ignore
+ return True
+
+ if node.type != syms.atom:
+ return False
+
+ if len(node.children) != 3:
+ return False
+
+ lpar, expr, rpar = node.children
+ if lpar.type == token.LPAR and rpar.type == token.RPAR:
+ return is_yield(expr)
+
+ return False
+
+
+def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
+ """Return True if `leaf` is a star or double star in a vararg or kwarg.
+
+ If `within` includes VARARGS_PARENTS, this applies to function signatures.
+ If `within` includes UNPACKING_PARENTS, it applies to right hand-side
+ extended iterable unpacking (PEP 3132) and additional unpacking
+ generalizations (PEP 448).
+ """
+ if leaf.type not in STARS or not leaf.parent:
+ return False
+
+ p = leaf.parent
+ if p.type == syms.star_expr:
+ # Star expressions are also used as assignment targets in extended
+ # iterable unpacking (PEP 3132). See what its parent is instead.
+ if not p.parent:
+ return False
+
+ p = p.parent
+
+ return p.type in within
+
+
def max_delimiter_priority_in_atom(node: LN) -> int:
+ """Return maximum delimiter priority inside `node`.
+
+ This is specific to atoms with contents contained in a pair of parentheses.
+ If `node` isn't an atom or there are no enclosing parentheses, returns 0.
+ """
if node.type != syms.atom:
return 0
Currently looking for:
- f-strings; and
- - trailing commas after * or ** in function signatures.
+ - trailing commas after * or ** in function signatures and calls.
"""
for n in node.pre_order():
if n.type == token.STRING:
return True
elif (
- n.type == syms.typedargslist
+ n.type in {syms.typedargslist, syms.arglist}
and n.children
and n.children[-1].type == token.COMMA
):
for ch in n.children:
- if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
+ if ch.type in STARS:
return True
+ if ch.type == syms.argument:
+ for argch in ch.children:
+ if argch.type in STARS:
+ return True
+
return False
-PYTHON_EXTENSIONS = {".py"}
+def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
+ """Generate sets of closing bracket IDs that should be omitted in a RHS.
+
+ Brackets can be omitted if the entire trailer up to and including
+ a preceding closing bracket fits in one line.
+
+ Yielded sets are cumulative (contain results of previous yields, too). First
+ set is empty.
+ """
+
+ omit: Set[LeafID] = set()
+ yield omit
+
+ length = 4 * line.depth
+ opening_bracket = None
+ closing_bracket = None
+ optional_brackets: Set[LeafID] = set()
+ inner_brackets: Set[LeafID] = set()
+ for index, leaf in enumerate_reversed(line.leaves):
+ length += len(leaf.prefix) + len(leaf.value)
+ if length > line_length:
+ break
+
+ comment: Optional[Leaf]
+ for comment in line.comments_after(leaf, index):
+ if "\n" in comment.prefix:
+ break # Oops, standalone comment!
+
+ length += len(comment.value)
+ else:
+ comment = None
+ if comment is not None:
+ break # There was a standalone comment, we can't continue.
+
+ optional_brackets.discard(id(leaf))
+ if opening_bracket:
+ if leaf is opening_bracket:
+ opening_bracket = None
+ elif leaf.type in CLOSING_BRACKETS:
+ inner_brackets.add(id(leaf))
+ elif leaf.type in CLOSING_BRACKETS:
+ if not leaf.value:
+ optional_brackets.add(id(opening_bracket))
+ continue
+
+ if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
+ # Empty brackets would fail a split so treat them as "inner"
+ # brackets (e.g. only add them to the `omit` set if another
+ # pair of brackets was good enough.
+ inner_brackets.add(id(leaf))
+ continue
+
+ opening_bracket = leaf.opening_bracket
+ if closing_bracket:
+ omit.add(id(closing_bracket))
+ omit.update(inner_brackets)
+ inner_brackets.clear()
+ yield omit
+ closing_bracket = leaf
+
+
+def get_future_imports(node: Node) -> Set[str]:
+ """Return a set of __future__ imports in the file."""
+ imports = set()
+ for child in node.children:
+ if child.type != syms.simple_stmt:
+ break
+ first_child = child.children[0]
+ if isinstance(first_child, Leaf):
+ # Continue looking if we see a docstring; otherwise stop.
+ if (
+ len(child.children) == 2
+ and first_child.type == token.STRING
+ and child.children[1].type == token.NEWLINE
+ ):
+ continue
+ else:
+ break
+ elif first_child.type == syms.import_from:
+ module_name = first_child.children[1]
+ if not isinstance(module_name, Leaf) or module_name.value != "__future__":
+ break
+ for import_from_child in first_child.children[3:]:
+ if isinstance(import_from_child, Leaf):
+ if import_from_child.type == token.NAME:
+ imports.add(import_from_child.value)
+ else:
+ assert import_from_child.type == syms.import_as_names
+ for leaf in import_from_child.children:
+ if isinstance(leaf, Leaf) and leaf.type == token.NAME:
+ imports.add(leaf.value)
+ else:
+ break
+ return imports
+
+
+PYTHON_EXTENSIONS = {".py", ".pyi"}
BLACKLISTED_DIRECTORIES = {
"build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
}
yield from gen_python_files_in_dir(child)
- elif child.suffix in PYTHON_EXTENSIONS:
+ elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
yield child
same_count: int = 0
failure_count: int = 0
- def done(self, src: Path, changed: bool) -> None:
+ def done(self, src: Path, changed: Changed) -> None:
"""Increment the counter for successful reformatting. Write out a message."""
- if changed:
+ if changed is Changed.YES:
reformatted = "would reformat" if self.check else "reformatted"
if not self.quiet:
out(f"{reformatted} {src}")
self.change_count += 1
else:
if not self.quiet:
- out(f"{src} already well formatted, good job.", bold=False)
+ if changed is Changed.NO:
+ msg = f"{src} already well formatted, good job."
+ else:
+ msg = f"{src} wasn't modified on disk since last run."
+ out(msg, bold=False)
self.same_count += 1
def failed(self, src: Path, message: str) -> None:
) from None
-def assert_stable(src: str, dst: str, line_length: int) -> None:
+def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
- newdst = format_str(dst, line_length=line_length)
+ newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
if dst != newdst:
log = dump_to_file(
diff(src, dst, "source", "first pass"),
import tempfile
with tempfile.NamedTemporaryFile(
- mode="w", prefix="blk_", suffix=".log", delete=False
+ mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
) as f:
for lines in output:
f.write(lines)
loop.close()
+def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
+ """Replace `regex` with `replacement` twice on `original`.
+
+ This is used by string normalization to perform replaces on
+ overlapping matches.
+ """
+ return regex.sub(replacement, regex.sub(replacement, original))
+
+
+def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
+ """Like `reversed(enumerate(sequence))` if that were possible."""
+ index = len(sequence) - 1
+ for element in reversed(sequence):
+ yield (index, element)
+ index -= 1
+
+
+def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
+ """Return True if `line` is no longer than `line_length`.
+
+ Uses the provided `line_str` rendering, if any, otherwise computes a new one.
+ """
+ if not line_str:
+ line_str = str(line).strip("\n")
+ return (
+ len(line_str) <= line_length
+ and "\n" not in line_str # multiline strings
+ and not line.contains_standalone_comments()
+ )
+
+
+CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+
+
+def get_cache_file(line_length: int) -> Path:
+ return CACHE_DIR / f"cache.{line_length}.pickle"
+
+
+def read_cache(line_length: int) -> Cache:
+ """Read the cache if it exists and is well formed.
+
+ If it is not well formed, the call to write_cache later should resolve the issue.
+ """
+ cache_file = get_cache_file(line_length)
+ if not cache_file.exists():
+ return {}
+
+ with cache_file.open("rb") as fobj:
+ try:
+ cache: Cache = pickle.load(fobj)
+ except pickle.UnpicklingError:
+ return {}
+
+ return cache
+
+
+def get_cache_info(path: Path) -> CacheInfo:
+ """Return the information used to check if a file is already formatted or not."""
+ stat = path.stat()
+ return stat.st_mtime, stat.st_size
+
+
+def filter_cached(
+ cache: Cache, sources: Iterable[Path]
+) -> Tuple[List[Path], List[Path]]:
+ """Split a list of paths into two.
+
+ The first list contains paths of files that modified on disk or are not in the
+ cache. The other list contains paths to non-modified files.
+ """
+ todo, done = [], []
+ for src in sources:
+ src = src.resolve()
+ if cache.get(src) != get_cache_info(src):
+ todo.append(src)
+ else:
+ done.append(src)
+ return todo, done
+
+
+def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
+ """Update the cache file."""
+ cache_file = get_cache_file(line_length)
+ try:
+ if not CACHE_DIR.exists():
+ CACHE_DIR.mkdir(parents=True)
+ new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
+ with cache_file.open("wb") as fobj:
+ pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
+ except OSError:
+ pass
+
+
if __name__ == "__main__":
main()