import asyncio
-from asyncio.base_events import BaseEventLoop
from concurrent.futures import Executor, ProcessPoolExecutor
+from contextlib import contextmanager
from datetime import datetime
from enum import Enum
from functools import lru_cache, partial, wraps
import sys
import tempfile
import tokenize
+import traceback
from typing import (
Any,
Callable,
"--quiet",
is_flag=True,
help=(
- "Don't emit non-error messages to stderr. Errors are still emitted, "
+ "Don't emit non-error messages to stderr. Errors are still emitted; "
"silence those with 2>/dev/null."
),
)
) -> None:
"""Reformat a single file under `src` without spawning child processes.
- If `quiet` is True, non-error messages are not output. `line_length`,
- `write_back`, `fast` and `pyi` options are passed to
+ `fast`, `write_back`, and `mode` options are passed to
:func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
"""
try:
)
finally:
shutdown(loop)
+ executor.shutdown()
async def schedule_formatting(
write_back: WriteBack,
mode: FileMode,
report: "Report",
- loop: BaseEventLoop,
+ loop: asyncio.AbstractEventLoop,
executor: Executor,
) -> None:
"""Run formatting of `sources` in parallel using the provided `executor`.
(Use ProcessPoolExecutors for actual parallelism.)
- `line_length`, `write_back`, `fast`, and `pyi` options are passed to
+ `write_back`, `fast`, and `mode` options are passed to
:func:`format_file_in_place`.
"""
cache: Cache = {}
If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
code to the file.
- `line_length` and `fast` options are passed to :func:`format_file_contents`.
+ `mode` and `fast` options are passed to :func:`format_file_contents`.
"""
if src.suffix == ".pyi":
mode = evolve(mode, is_pyi=True)
src_name = f"{src}\t{then} +0000"
dst_name = f"{src}\t{now} +0000"
diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
- if lock:
- lock.acquire()
- try:
+
+ with lock or nullcontext():
f = io.TextIOWrapper(
sys.stdout.buffer,
encoding=encoding,
)
f.write(diff_contents)
f.detach()
- finally:
- if lock:
- lock.release()
+
return True
If `fast` is False, additionally confirm that the reformatted code is
valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
- `line_length` is passed to :func:`format_str`.
+ `mode` is passed to :func:`format_str`.
"""
if src_contents.strip() == "":
raise NothingChanged
def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
"""Reformat a string and return new contents.
- `line_length` determines how many characters per line are allowed.
+ `mode` determines formatting options, such as how many characters per line are
+ allowed.
"""
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
- dst_contents = ""
+ dst_contents = []
future_imports = get_future_imports(src_node)
if mode.target_versions:
versions = mode.target_versions
}
for current_line in lines.visit(src_node):
for _ in range(after):
- dst_contents += str(empty_line)
+ dst_contents.append(str(empty_line))
before, after = elt.maybe_empty_lines(current_line)
for _ in range(before):
- dst_contents += str(empty_line)
+ dst_contents.append(str(empty_line))
for line in split_line(
current_line, line_length=mode.line_length, features=split_line_features
):
- dst_contents += str(line)
- return dst_contents
+ dst_contents.append(str(line))
+ return "".join(dst_contents)
def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
"""Return True if there is an yet unmatched open bracket on the line."""
return bool(self.bracket_match)
- def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
+ def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
"""Return the highest priority of a delimiter found on the line.
Values are consistent with what `is_split_*_delimiter()` return.
"""
return max(v for k, v in self.delimiters.items() if k not in exclude)
- def delimiter_count_with_priority(self, priority: int = 0) -> int:
+ def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
"""Return the number of delimiters with the given `priority`.
If no `priority` is passed, defaults to max priority on the line.
try:
last_leaf = self.leaves[-1]
ignored_ids.add(id(last_leaf))
- if last_leaf.type == token.COMMA:
- # When trailing commas are inserted by Black for consistency, comments
- # after the previous last element are not moved (they don't have to,
- # rendering will still be correct). So we ignore trailing commas.
+ if last_leaf.type == token.COMMA or (
+ last_leaf.type == token.RPAR and not last_leaf.value
+ ):
+ # When trailing commas or optional parens are inserted by Black for
+ # consistency, comments after the previous last element are not moved
+ # (they don't have to, rendering will still be correct). So we ignore
+ # trailing commas and invisible.
last_leaf = self.leaves[-2]
ignored_ids.add(id(last_leaf))
except IndexError:
bracket_depth = leaf.bracket_depth
if bracket_depth == depth and leaf.type == token.COMMA:
commas += 1
- if leaf.parent and leaf.parent.type == syms.arglist:
+ if leaf.parent and leaf.parent.type in {
+ syms.arglist,
+ syms.typedargslist,
+ }:
commas += 1
break
comment.prefix = ""
return False
- self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
+ last_leaf = self.leaves[-1]
+ if (
+ last_leaf.type == token.RPAR
+ and not last_leaf.value
+ and last_leaf.parent
+ and len(list(last_leaf.parent.leaves())) <= 3
+ and not is_type_comment(comment)
+ ):
+ # Comments on an optional parens wrapping a single leaf should belong to
+ # the wrapped node except if it's a type comment. Pinning the comment like
+ # this avoids unstable formatting caused by comment migration.
+ if len(self.leaves) < 2:
+ comment.type = STANDALONE_COMMENT
+ comment.prefix = ""
+ return False
+ last_leaf = self.leaves[-2]
+ self.comments.setdefault(id(last_leaf), []).append(comment)
return True
def comments_after(self, leaf: Leaf) -> List[Leaf]:
node.children[2].value = ""
yield from super().visit_default(node)
+ def visit_factor(self, node: Node) -> Iterator[Line]:
+ """Force parentheses between a unary op and a binary power:
+
+ -2 ** 8 -> -(2 ** 8)
+ """
+ child = node.children[1]
+ if child.type == syms.power and len(child.children) == 3:
+ lpar = Leaf(token.LPAR, "(")
+ rpar = Leaf(token.RPAR, ")")
+ index = child.remove() or 0
+ node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+ yield from self.visit_default(node)
+
def visit_INDENT(self, node: Node) -> Iterator[Line]:
"""Increase indentation level, maybe yield a line."""
# In blib2to3 INDENT never holds comments.
return container
-def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
+def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
"""Return the priority of the `leaf` delimiter, given a line break after it.
The delimiter priorities returned here are from those delimiters that would
return 0
-def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
+def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
"""Return the priority of the `leaf` delimiter, given a line break before it.
The delimiter priorities returned here are from those delimiters that would
if leaves:
# Since body is a new indent level, remove spurious leading whitespace.
normalize_prefix(leaves[0], inside_brackets=True)
- # Ensure a trailing comma for imports, but be careful not to add one after
- # any comments.
- if original.is_import:
+ # Ensure a trailing comma for imports and standalone function arguments, but
+ # be careful not to add one after any comments.
+ no_commas = original.is_def and not any(
+ l.type == token.COMMA for l in leaves
+ )
+
+ if original.is_import or no_commas:
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue
new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
if "f" in prefix.casefold():
- matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
+ matches = re.findall(
+ r"""
+ (?:[^{]|^)\{ # start of the string or a non-{ followed by a single {
+ ([^{].*?) # contents of the brackets except if begins with {{
+ \}(?:[^}]|$) # A } followed by end of the string or a non-}
+ """,
+ new_body,
+ re.VERBOSE,
+ )
for m in matches:
if "\\" in str(m):
# Do not introduce backslashes in interpolated expressions
def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
"""Make existing optional parentheses invisible or create new ones.
- `parens_after` is a set of string leaf values immeditely after which parens
+ `parens_after` is a set of string leaf values immediately after which parens
should be put.
Standardizes on visible parentheses for single-element tuples, and keeps
)
-def max_delimiter_priority_in_atom(node: LN) -> int:
+def max_delimiter_priority_in_atom(node: LN) -> Priority:
"""Return maximum delimiter priority inside `node`.
This is specific to atoms with contents contained in a pair of parentheses.
"""Make sure parentheses are visible.
They could be invisible as part of some statements (see
- :func:`normalize_invible_parens` and :func:`visit_import_from`).
+ :func:`normalize_invisible_parens` and :func:`visit_import_from`).
"""
if leaf.type == token.LPAR:
leaf.value = "("
def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
- import traceback
-
def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
yield f"{' ' * depth}{node.__class__.__name__}("
def dump_to_file(*output: str) -> str:
"""Dump `output` to a temporary file. Return path to the file."""
- import tempfile
-
with tempfile.NamedTemporaryFile(
mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
) as f:
return f.name
+@contextmanager
+def nullcontext() -> Iterator[None]:
+ """Return context manager that does nothing.
+ Similar to `nullcontext` from python 3.7"""
+ yield
+
+
def diff(a: str, b: str, a_name: str, b_name: str) -> str:
"""Return a unified diff string between strings `a` and `b`."""
import difflib
task.cancel()
-def shutdown(loop: BaseEventLoop) -> None:
+def shutdown(loop: asyncio.AbstractEventLoop) -> None:
"""Cancel all pending tasks on `loop`, wait for them, and close the loop."""
try:
if sys.version_info[:2] >= (3, 7):
if "\n" in leaf.value:
return # Multiline strings, we can't continue.
- comment: Optional[Leaf]
for comment in line.comments_after(leaf):
length += len(comment.value)