from attr import dataclass, evolve, Factory
import click
import toml
+from typed_ast import ast3, ast27
# lib2to3 fork
from blib2to3.pytree import Node, Leaf, type_repr
report=report,
)
else:
- loop = asyncio.get_event_loop()
- executor = ProcessPoolExecutor(max_workers=os.cpu_count())
- try:
- loop.run_until_complete(
- schedule_formatting(
- sources=sources,
- fast=fast,
- write_back=write_back,
- mode=mode,
- report=report,
- loop=loop,
- executor=executor,
- )
- )
- finally:
- shutdown(loop)
+ reformat_many(
+ sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+ )
+
if verbose or not quiet:
bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
out(f"All done! {bang}")
report.failed(src, str(exc))
+def reformat_many(
+ sources: Set[Path],
+ fast: bool,
+ write_back: WriteBack,
+ mode: FileMode,
+ report: "Report",
+) -> None:
+ """Reformat multiple files using a ProcessPoolExecutor."""
+ loop = asyncio.get_event_loop()
+ worker_count = os.cpu_count()
+ if sys.platform == "win32":
+ # Work around https://bugs.python.org/issue26903
+ worker_count = min(worker_count, 61)
+ executor = ProcessPoolExecutor(max_workers=worker_count)
+ try:
+ loop.run_until_complete(
+ schedule_formatting(
+ sources=sources,
+ fast=fast,
+ write_back=write_back,
+ mode=mode,
+ report=report,
+ loop=loop,
+ executor=executor,
+ )
+ )
+ finally:
+ shutdown(loop)
+
+
async def schedule_formatting(
sources: Set[Path],
fast: bool,
consumed = 0
nlines = 0
+ ignored_lines = 0
for index, line in enumerate(prefix.split("\n")):
consumed += len(line) + 1 # adding the length of the split '\n'
line = line.lstrip()
if not line:
nlines += 1
if not line.startswith("#"):
+ # Escaped newlines outside of a comment are not really newlines at
+ # all. We treat a single-line comment following an escaped newline
+ # as a simple trailing comment.
+ if line.endswith("\\"):
+ ignored_lines += 1
continue
- if index == 0 and not is_endmarker:
+ if index == ignored_lines and not is_endmarker:
comment_type = token.COMMENT # simple trailing comment
else:
comment_type = STANDALONE_COMMENT
return ", ".join(report) + "."
+def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
+ for feature_version in (7, 6):
+ try:
+ return ast3.parse(src, feature_version=feature_version)
+ except SyntaxError:
+ continue
+
+ return ast27.parse(src)
+
+
def assert_equivalent(src: str, dst: str) -> None:
"""Raise AssertionError if `src` and `dst` aren't equivalent."""
- import ast
import traceback
- def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
+ def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
yield f"{' ' * depth}{node.__class__.__name__}("
for field in sorted(node._fields):
+ # TypeIgnore has only one field 'lineno' which breaks this comparison
+ if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+ break
+
+ # Ignore str kind which is case sensitive / and ignores unicode_literals
+ if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
+ continue
+
try:
value = getattr(node, field)
except AttributeError:
# parentheses and they change the AST.
if (
field == "targets"
- and isinstance(node, ast.Delete)
- and isinstance(item, ast.Tuple)
+ and isinstance(node, (ast3.Delete, ast27.Delete))
+ and isinstance(item, (ast3.Tuple, ast27.Tuple))
):
for item in item.elts:
yield from _v(item, depth + 2)
- elif isinstance(item, ast.AST):
+ elif isinstance(item, (ast3.AST, ast27.AST)):
yield from _v(item, depth + 2)
- elif isinstance(value, ast.AST):
+ elif isinstance(value, (ast3.AST, ast27.AST)):
yield from _v(value, depth + 2)
else:
yield f"{' ' * depth}) # /{node.__class__.__name__}"
try:
- src_ast = ast.parse(src)
+ src_ast = parse_ast(src)
except Exception as exc:
- major, minor = sys.version_info[:2]
raise AssertionError(
- f"cannot use --safe with this file; failed to parse source file "
- f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
- f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+ f"cannot use --safe with this file; failed to parse source file. "
+ f"AST error message: {exc}"
)
try:
- dst_ast = ast.parse(dst)
+ dst_ast = parse_ast(dst)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(