from attr import dataclass, evolve, Factory
import click
import toml
+from typed_ast import ast3, ast27
# lib2to3 fork
from blib2to3.pytree import Node, Leaf, type_repr
from blib2to3.pgen2.parse import ParseError
-__version__ = "18.9b0"
+__version__ = "19.3b0"
DEFAULT_LINE_LENGTH = 88
DEFAULT_EXCLUDES = (
r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
Priority = int
Index = int
LN = Union[Leaf, Node]
-SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
Timestamp = float
FileSize = int
CacheInfo = Tuple[Timestamp, FileSize]
class TargetVersion(Enum):
- PYPY35 = 1
- CPY27 = 2
- CPY33 = 3
- CPY34 = 4
- CPY35 = 5
- CPY36 = 6
- CPY37 = 7
- CPY38 = 8
+ PY27 = 2
+ PY33 = 3
+ PY34 = 4
+ PY35 = 5
+ PY36 = 6
+ PY37 = 7
+ PY38 = 8
def is_python2(self) -> bool:
- return self is TargetVersion.CPY27
+ return self is TargetVersion.PY27
-PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
+PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
class Feature(Enum):
UNICODE_LITERALS = 1
F_STRINGS = 2
NUMERIC_UNDERSCORES = 3
- TRAILING_COMMA = 4
+ TRAILING_COMMA_IN_CALL = 4
+ TRAILING_COMMA_IN_DEF = 5
VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
- TargetVersion.CPY27: set(),
- TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
- TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
- TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
- TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
- TargetVersion.CPY36: {
+ TargetVersion.PY27: set(),
+ TargetVersion.PY33: {Feature.UNICODE_LITERALS},
+ TargetVersion.PY34: {Feature.UNICODE_LITERALS},
+ TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA_IN_CALL},
+ TargetVersion.PY36: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
- Feature.TRAILING_COMMA,
+ Feature.TRAILING_COMMA_IN_CALL,
+ Feature.TRAILING_COMMA_IN_DEF,
},
- TargetVersion.CPY37: {
+ TargetVersion.PY37: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
- Feature.TRAILING_COMMA,
+ Feature.TRAILING_COMMA_IN_CALL,
+ Feature.TRAILING_COMMA_IN_DEF,
},
- TargetVersion.CPY38: {
+ TargetVersion.PY38: {
Feature.UNICODE_LITERALS,
Feature.F_STRINGS,
Feature.NUMERIC_UNDERSCORES,
- Feature.TRAILING_COMMA,
+ Feature.TRAILING_COMMA_IN_CALL,
+ Feature.TRAILING_COMMA_IN_DEF,
},
}
class FileMode:
target_versions: Set[TargetVersion] = Factory(set)
line_length: int = DEFAULT_LINE_LENGTH
- numeric_underscore_normalization: bool = True
string_normalization: bool = True
is_pyi: bool = False
parts = [
version_str,
str(self.line_length),
- str(int(self.numeric_underscore_normalization)),
str(int(self.string_normalization)),
str(int(self.is_pyi)),
]
@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
@click.option(
"-l",
"--line-length",
help=(
"Allow using Python 3.6-only syntax on all input files. This will put "
"trailing commas in function signatures and calls also after *args and "
- "**kwargs. [default: per-file auto-detection]"
+ "**kwargs. Deprecated; use --target-version instead. "
+ "[default: per-file auto-detection]"
),
)
@click.option(
is_flag=True,
help="Don't normalize string quotes or prefixes.",
)
-@click.option(
- "-N",
- "--skip-numeric-underscore-normalization",
- is_flag=True,
- help="Don't normalize underscores in numeric literals.",
-)
@click.option(
"--check",
is_flag=True,
@click.pass_context
def main(
ctx: click.Context,
+ code: Optional[str],
line_length: int,
target_version: List[TargetVersion],
check: bool,
pyi: bool,
py36: bool,
skip_string_normalization: bool,
- skip_numeric_underscore_normalization: bool,
quiet: bool,
verbose: bool,
include: str,
else:
versions = set(target_version)
elif py36:
+ err(
+ "--py36 is deprecated and will be removed in a future version. "
+ "Use --target-version py36 instead."
+ )
versions = PY36_VERSIONS
else:
# We'll autodetect later.
line_length=line_length,
is_pyi=pyi,
string_normalization=not skip_string_normalization,
- numeric_underscore_normalization=not skip_numeric_underscore_normalization,
)
if config and verbose:
out(f"Using configuration from {config}.", bold=False, fg="blue")
+ if code is not None:
+ print(format_str(code, mode=mode))
+ ctx.exit(0)
try:
include_regex = re_compile_maybe_verbose(include)
except re.error:
report=report,
)
else:
- loop = asyncio.get_event_loop()
- executor = ProcessPoolExecutor(max_workers=os.cpu_count())
- try:
- loop.run_until_complete(
- schedule_formatting(
- sources=sources,
- fast=fast,
- write_back=write_back,
- mode=mode,
- report=report,
- loop=loop,
- executor=executor,
- )
- )
- finally:
- shutdown(loop)
+ reformat_many(
+ sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+ )
+
if verbose or not quiet:
bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
out(f"All done! {bang}")
report.failed(src, str(exc))
+def reformat_many(
+ sources: Set[Path],
+ fast: bool,
+ write_back: WriteBack,
+ mode: FileMode,
+ report: "Report",
+) -> None:
+ """Reformat multiple files using a ProcessPoolExecutor."""
+ loop = asyncio.get_event_loop()
+ worker_count = os.cpu_count()
+ if sys.platform == "win32":
+ # Work around https://bugs.python.org/issue26903
+ worker_count = min(worker_count, 61)
+ executor = ProcessPoolExecutor(max_workers=worker_count)
+ try:
+ loop.run_until_complete(
+ schedule_formatting(
+ sources=sources,
+ fast=fast,
+ write_back=write_back,
+ mode=mode,
+ report=report,
+ loop=loop,
+ executor=executor,
+ )
+ )
+ finally:
+ shutdown(loop)
+
+
async def schedule_formatting(
sources: Set[Path],
fast: bool,
manager = Manager()
lock = manager.Lock()
tasks = {
- loop.run_in_executor(
- executor, format_file_in_place, src, fast, mode, write_back, lock
+ asyncio.ensure_future(
+ loop.run_in_executor(
+ executor, format_file_in_place, src, fast, mode, write_back, lock
+ )
): src
for src in sorted(sources)
}
- pending: Iterable[asyncio.Task] = tasks.keys()
+ pending: Iterable[asyncio.Future] = tasks.keys()
try:
loop.add_signal_handler(signal.SIGINT, cancel, pending)
loop.add_signal_handler(signal.SIGTERM, cancel, pending)
or supports_feature(versions, Feature.UNICODE_LITERALS),
is_pyi=mode.is_pyi,
normalize_strings=mode.string_normalization,
- allow_underscores=mode.numeric_underscore_normalization
- and supports_feature(versions, Feature.NUMERIC_UNDERSCORES),
)
elt = EmptyLineTracker(is_pyi=mode.is_pyi)
empty_line = Line()
after = 0
+ split_line_features = {
+ feature
+ for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
+ if supports_feature(versions, feature)
+ }
for current_line in lines.visit(src_node):
for _ in range(after):
dst_contents += str(empty_line)
for _ in range(before):
dst_contents += str(empty_line)
for line in split_line(
- current_line,
- line_length=mode.line_length,
- supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+ current_line, line_length=mode.line_length, features=split_line_features
):
dst_contents += str(line)
return dst_contents
return tiow.read(), encoding, newline
-GRAMMARS = [
- pygram.python_grammar_no_print_statement_no_exec_statement,
- pygram.python_grammar_no_print_statement,
- pygram.python_grammar,
-]
-
-
def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
if not target_versions:
- return GRAMMARS
- elif all(not version.is_python2() for version in target_versions):
- # Python 2-compatible code, so don't try Python 3 grammar.
+ # No target_version specified, so try all grammars.
return [
pygram.python_grammar_no_print_statement_no_exec_statement,
pygram.python_grammar_no_print_statement,
+ pygram.python_grammar,
]
+ elif all(version.is_python2() for version in target_versions):
+ # Python 2-only code, so try Python 2 grammars.
+ return [pygram.python_grammar_no_print_statement, pygram.python_grammar]
else:
- return [pygram.python_grammar]
+ # Python 3-compatible code, so only try Python 3 grammar.
+ return [pygram.python_grammar_no_print_statement_no_exec_statement]
def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
depth: int = 0
leaves: List[Leaf] = Factory(list)
- # The LeafID keys of comments must remain ordered by the corresponding leaf's index
- # in leaves
- comments: Dict[LeafID, List[Leaf]] = Factory(dict)
+ comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves`
bracket_tracker: BracketTracker = Factory(BracketTracker)
inside_brackets: bool = False
should_explode: bool = False
if leaf.type == STANDALONE_COMMENT:
if leaf.bracket_depth <= depth_limit:
return True
+ return False
+
+ def contains_inner_type_comments(self) -> bool:
+ ignored_ids = set()
+ try:
+ last_leaf = self.leaves[-1]
+ ignored_ids.add(id(last_leaf))
+ if last_leaf.type == token.COMMA:
+ # When trailing commas are inserted by Black for consistency, comments
+ # after the previous last element are not moved (they don't have to,
+ # rendering will still be correct). So we ignore trailing commas.
+ last_leaf = self.leaves[-2]
+ ignored_ids.add(id(last_leaf))
+ except IndexError:
+ return False
+
+ for leaf_id, comments in self.comments.items():
+ if leaf_id in ignored_ids:
+ continue
+
+ for comment in comments:
+ if is_type_comment(comment):
+ return True
return False
comment.prefix = ""
return False
- else:
- leaf_id = id(self.leaves[-1])
- if leaf_id not in self.comments:
- self.comments[leaf_id] = [comment]
- else:
- self.comments[leaf_id].append(comment)
- return True
+ self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
+ return True
def comments_after(self, leaf: Leaf) -> List[Leaf]:
"""Generate comments that should appear directly after `leaf`."""
def remove_trailing_comma(self) -> None:
"""Remove the trailing comma and moves the comments attached to it."""
- # Remember, the LeafID keys of self.comments are ordered by the
- # corresponding leaf's index in self.leaves
- # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
- # Otherwise, we insert it into self.comments, and it becomes the last entry.
- # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
- # is maintained
- self.comments.setdefault(id(self.leaves[-2]), []).extend(
- self.comments.get(id(self.leaves[-1]), [])
+ trailing_comma = self.leaves.pop()
+ trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
+ self.comments.setdefault(id(self.leaves[-1]), []).extend(
+ trailing_comma_comments
)
- self.comments.pop(id(self.leaves[-1]), None)
- self.leaves.pop()
def is_complex_subscript(self, leaf: Leaf) -> bool:
"""Return True iff `leaf` is part of a slice with non-trivial exprs."""
normalize_strings: bool = True
current_line: Line = Factory(Line)
remove_u_prefix: bool = False
- allow_underscores: bool = False
def line(self, indent: int = 0) -> Iterator[Line]:
"""Generate a line.
normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
normalize_string_quotes(node)
if node.type == token.NUMBER:
- normalize_numeric_literal(node, self.allow_underscores)
+ normalize_numeric_literal(node)
if node.type not in WHITESPACE:
self.current_line.append(node)
yield from super().visit_default(node)
self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
+ self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
self.visit_async_funcdef = self.visit_async_stmt
self.visit_decorated = self.visit_decorators
consumed = 0
nlines = 0
+ ignored_lines = 0
for index, line in enumerate(prefix.split("\n")):
consumed += len(line) + 1 # adding the length of the split '\n'
line = line.lstrip()
if not line:
nlines += 1
if not line.startswith("#"):
+ # Escaped newlines outside of a comment are not really newlines at
+ # all. We treat a single-line comment following an escaped newline
+ # as a simple trailing comment.
+ if line.endswith("\\"):
+ ignored_lines += 1
continue
- if index == 0 and not is_endmarker:
+ if index == ignored_lines and not is_endmarker:
comment_type = token.COMMENT # simple trailing comment
else:
comment_type = STANDALONE_COMMENT
line: Line,
line_length: int,
inner: bool = False,
- supports_trailing_commas: bool = False,
+ features: Collection[Feature] = (),
) -> Iterator[Line]:
"""Split a `line` into potentially many lines.
current `line`, possibly transitively. This means we can fallback to splitting
by delimiters if the LHS/RHS don't yield any results.
- If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
+ `features` are syntactical features that may be used in the output.
"""
if line.is_comment:
yield line
line_str = str(line).strip("\n")
- # we don't want to split special comments like type annotations
- # https://github.com/python/typing/issues/186
- has_special_comment = False
- for leaf in line.leaves:
- for comment in line.comments_after(leaf):
- if leaf.type == token.COMMA and is_special_comment(comment):
- has_special_comment = True
-
if (
- not has_special_comment
+ not line.contains_inner_type_comments()
and not line.should_explode
and is_line_short_enough(line, line_length=line_length, line_str=line_str)
):
split_funcs = [left_hand_split]
else:
- def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
+ def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
for omit in generate_trailers_to_omit(line, line_length):
- lines = list(
- right_hand_split(
- line, line_length, supports_trailing_commas, omit=omit
- )
- )
+ lines = list(right_hand_split(line, line_length, features, omit=omit))
if is_line_short_enough(lines[0], line_length=line_length):
yield from lines
return
# All splits failed, best effort split with no omits.
# This mostly happens to multiline strings that are by definition
# reported as not fitting a single line.
- yield from right_hand_split(line, supports_trailing_commas)
+ yield from right_hand_split(line, line_length, features=features)
if line.inside_brackets:
split_funcs = [delimiter_split, standalone_comment_split, rhs]
# split altogether.
result: List[Line] = []
try:
- for l in split_func(line, supports_trailing_commas):
+ for l in split_func(line, features):
if str(l).strip("\n") == line_str:
raise CannotSplit("Split function returned an unchanged result")
result.extend(
split_line(
- l,
- line_length=line_length,
- inner=True,
- supports_trailing_commas=supports_trailing_commas,
+ l, line_length=line_length, inner=True, features=features
)
)
except CannotSplit:
yield line
-def left_hand_split(
- line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split line into many lines, starting with the first matching bracket pair.
Note: this usually looks weird, only use this for function definitions.
def right_hand_split(
line: Line,
line_length: int,
- supports_trailing_commas: bool = False,
+ features: Collection[Feature] = (),
omit: Collection[LeafID] = (),
) -> Iterator[Line]:
"""Split line into many lines, starting with the last matching bracket pair.
):
omit = {id(closing_bracket), *omit}
try:
- yield from right_hand_split(
- line,
- line_length,
- supports_trailing_commas=supports_trailing_commas,
- omit=omit,
- )
+ yield from right_hand_split(line, line_length, features=features, omit=omit)
return
except CannotSplit:
if leaves:
# Since body is a new indent level, remove spurious leading whitespace.
normalize_prefix(leaves[0], inside_brackets=True)
- # Ensure a trailing comma when expected.
+ # Ensure a trailing comma for imports, but be careful not to add one after
+ # any comments.
if original.is_import:
- if leaves[-1].type != token.COMMA:
- leaves.append(Leaf(token.COMMA, ","))
+ for i in range(len(leaves) - 1, -1, -1):
+ if leaves[i].type == STANDALONE_COMMENT:
+ continue
+ elif leaves[i].type == token.COMMA:
+ break
+ else:
+ leaves.insert(i + 1, Leaf(token.COMMA, ","))
+ break
# Populate the line
for leaf in leaves:
result.append(leaf, preformatted=True)
"""
@wraps(split_func)
- def split_wrapper(
- line: Line, supports_trailing_commas: bool = False
- ) -> Iterator[Line]:
- for l in split_func(line, supports_trailing_commas):
+ def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
+ for l in split_func(line, features):
normalize_prefix(l.leaves[0], inside_brackets=True)
yield l
@dont_increase_indentation
-def delimiter_split(
- line: Line, supports_trailing_commas: bool = False
-) -> Iterator[Line]:
+def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
"""Split according to delimiters of the highest priority.
- If `py36` is True, the split will add trailing commas also in function
- signatures that contain `*` and `**`.
+ If the appropriate Features are given, the split will add trailing commas
+ also in function signatures and calls that contain `*` and `**`.
"""
try:
last_leaf = line.leaves[-1]
yield from append_to_line(comment_after)
lowest_depth = min(lowest_depth, leaf.bracket_depth)
- if leaf.bracket_depth == lowest_depth and is_vararg(
- leaf, within=VARARGS_PARENTS
- ):
- trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
+ if leaf.bracket_depth == lowest_depth:
+ if is_vararg(leaf, within={syms.typedargslist}):
+ trailing_comma_safe = (
+ trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
+ )
+ elif is_vararg(leaf, within={syms.arglist, syms.argument}):
+ trailing_comma_safe = (
+ trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
+ )
+
leaf_priority = bt.delimiters.get(id(leaf))
if leaf_priority == delimiter_priority:
yield current_line
@dont_increase_indentation
def standalone_comment_split(
- line: Line, supports_trailing_commas: bool = False
+ line: Line, features: Collection[Feature] = ()
) -> Iterator[Line]:
"""Split standalone comments from the rest of the line."""
if not line.contains_standalone_comments(0):
)
-def is_special_comment(leaf: Leaf) -> bool:
+def is_type_comment(leaf: Leaf) -> bool:
"""Return True if the given leaf is a special comment.
Only returns true for type comments for now."""
t = leaf.type
v = leaf.value
- return bool(
- (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
- )
+ return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
-def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+def normalize_numeric_literal(leaf: Leaf) -> None:
"""Normalizes numeric (float, int, and complex) literals.
All letters used in the representation are normalized to lowercase (except
- in Python 2 long literals), and long number literals are split using underscores.
+ in Python 2 long literals).
"""
text = leaf.value.lower()
if text.startswith(("0o", "0b")):
sign = "-"
elif after.startswith("+"):
after = after[1:]
- before = format_float_or_int_string(before, allow_underscores)
- after = format_int_string(after, allow_underscores)
+ before = format_float_or_int_string(before)
text = f"{before}e{sign}{after}"
elif text.endswith(("j", "l")):
number = text[:-1]
# Capitalize in "2L" because "l" looks too similar to "1".
if suffix == "l":
suffix = "L"
- text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+ text = f"{format_float_or_int_string(number)}{suffix}"
else:
- text = format_float_or_int_string(text, allow_underscores)
+ text = format_float_or_int_string(text)
leaf.value = text
-def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+def format_float_or_int_string(text: str) -> str:
"""Formats a float string like "1.0"."""
if "." not in text:
- return format_int_string(text, allow_underscores)
-
- before, after = text.split(".")
- before = format_int_string(before, allow_underscores) if before else "0"
- if after:
- after = format_int_string(after, allow_underscores, count_from_end=False)
- else:
- after = "0"
- return f"{before}.{after}"
-
-
-def format_int_string(
- text: str, allow_underscores: bool, count_from_end: bool = True
-) -> str:
- """Normalizes underscores in a string to e.g. 1_000_000.
-
- Input must be a string of digits and optional underscores.
- If count_from_end is False, we add underscores after groups of three digits
- counting from the beginning instead of the end of the strings. This is used
- for the fractional part of float literals.
- """
- if not allow_underscores:
- return text
-
- text = text.replace("_", "")
- if len(text) <= 5:
- # No underscores for numbers <= 5 digits long.
return text
- if count_from_end:
- # Avoid removing leading zeros, which are important if we're formatting
- # part of a number like "0.001".
- return format(int("1" + text), "3_")[1:].lstrip("_")
- else:
- return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
+ before, after = text.split(".")
+ return f"{before or 0}.{after or 0}"
def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
check_lpar = False
for index, child in enumerate(list(node.children)):
+ # Add parentheses around long tuple unpacking in assignments.
+ if (
+ index == 0
+ and isinstance(child, Node)
+ and child.type == syms.testlist_star_expr
+ ):
+ check_lpar = True
+
if check_lpar:
if child.type == syms.atom:
- if maybe_make_parens_invisible_in_atom(child):
+ if maybe_make_parens_invisible_in_atom(child, parent=node):
lpar = Leaf(token.LPAR, "")
rpar = Leaf(token.RPAR, "")
index = child.remove() or 0
lpar = Leaf(token.LPAR, "")
rpar = Leaf(token.RPAR, "")
index = child.remove() or 0
- node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+ prefix = child.prefix
+ child.prefix = ""
+ new_child = Node(syms.atom, [lpar, child, rpar])
+ new_child.prefix = prefix
+ node.insert_child(index, new_child)
check_lpar = isinstance(child, Leaf) and child.value in parens_after
container = container.next_sibling
-def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
+def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
"""If it's safe, make the parens in the atom `node` invisible, recursively.
Returns whether the node should itself be wrapped in invisible parentheses.
node.type != syms.atom
or is_empty_tuple(node)
or is_one_tuple(node)
- or is_yield(node)
+ or (is_yield(node) and parent.type != syms.expr_stmt)
or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
):
return False
first.value = "" # type: ignore
last.value = "" # type: ignore
if len(node.children) > 1:
- maybe_make_parens_invisible_in_atom(node.children[1])
+ maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
return False
return True
and n.children
and n.children[-1].type == token.COMMA
):
+ if n.type == syms.typedargslist:
+ feature = Feature.TRAILING_COMMA_IN_DEF
+ else:
+ feature = Feature.TRAILING_COMMA_IN_CALL
+
for ch in n.children:
if ch.type in STARS:
- features.add(Feature.TRAILING_COMMA)
+ features.add(feature)
if ch.type == syms.argument:
for argch in ch.children:
if argch.type in STARS:
- features.add(Feature.TRAILING_COMMA)
+ features.add(feature)
return features
elif child.type == syms.import_as_names:
yield from get_imports_from_children(child.children)
else:
- assert False, "Invalid syntax parsing imports"
+ raise AssertionError("Invalid syntax parsing imports")
for child in node.children:
if child.type != syms.simple_stmt:
return ", ".join(report) + "."
+def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
+ for feature_version in (7, 6):
+ try:
+ return ast3.parse(src, feature_version=feature_version)
+ except SyntaxError:
+ continue
+
+ return ast27.parse(src)
+
+
def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
- import ast
import traceback
- def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
+ def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
yield f"{' ' * depth}{node.__class__.__name__}("
for field in sorted(node._fields):
+ # TypeIgnore has only one field 'lineno' which breaks this comparison
+ if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+ break
+
+ # Ignore str kind which is case sensitive / and ignores unicode_literals
+ if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
+ continue
+
try:
value = getattr(node, field)
except AttributeError:
if isinstance(value, list):
for item in value:
- if isinstance(item, ast.AST):
+ # Ignore nested tuples within del statements, because we may insert
+ # parentheses and they change the AST.
+ if (
+ field == "targets"
+ and isinstance(node, (ast3.Delete, ast27.Delete))
+ and isinstance(item, (ast3.Tuple, ast27.Tuple))
+ ):
+ for item in item.elts:
+ yield from _v(item, depth + 2)
+ elif isinstance(item, (ast3.AST, ast27.AST)):
yield from _v(item, depth + 2)
- elif isinstance(value, ast.AST):
+ elif isinstance(value, (ast3.AST, ast27.AST)):
yield from _v(value, depth + 2)
else:
yield f"{' ' * depth}) # /{node.__class__.__name__}"
try:
- src_ast = ast.parse(src)
+ src_ast = parse_ast(src)
except Exception as exc:
- major, minor = sys.version_info[:2]
raise AssertionError(
- f"cannot use --safe with this file; failed to parse source file "
- f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
- f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+ f"cannot use --safe with this file; failed to parse source file. "
+ f"AST error message: {exc}"
)
try:
- dst_ast = ast.parse(dst)
+ dst_ast = parse_ast(dst)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
- f"Please report a bug on https://github.com/ambv/black/issues. "
+ f"Please report a bug on https://github.com/python/black/issues. "
f"This invalid output might be helpful: {log}"
) from None
raise AssertionError(
f"INTERNAL ERROR: Black produced code that is not equivalent to "
f"the source. "
- f"Please report a bug on https://github.com/ambv/black/issues. "
+ f"Please report a bug on https://github.com/python/black/issues. "
f"This diff might be helpful: {log}"
) from None
raise AssertionError(
f"INTERNAL ERROR: Black produced different code on the second pass "
f"of the formatter. "
- f"Please report a bug on https://github.com/ambv/black/issues. "
+ f"Please report a bug on https://github.com/python/black/issues. "
f"This diff might be helpful: {log}"
) from None
def shutdown(loop: BaseEventLoop) -> None:
"""Cancel all pending tasks on `loop`, wait for them, and close the loop."""
try:
+ if sys.version_info[:2] >= (3, 7):
+ all_tasks = asyncio.all_tasks
+ else:
+ all_tasks = asyncio.Task.all_tasks
# This part is borrowed from asyncio/runners.py in Python 3.7b2.
- to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
+ to_cancel = [task for task in all_tasks(loop) if not task.done()]
if not to_cancel:
return