+++ /dev/null
-"""
-Parse Python code and perform AST validation.
-"""
-import ast
-import sys
-from typing import Iterable, Iterator, List, Set, Union, Tuple
-
-# lib2to3 fork
-from blib2to3.pytree import Node, Leaf
-from blib2to3 import pygram, pytree
-from blib2to3.pgen2 import driver
-from blib2to3.pgen2.grammar import Grammar
-from blib2to3.pgen2.parse import ParseError
-
-from black.mode import TargetVersion, Feature, supports_feature
-from black.nodes import syms
-
-try:
- from typed_ast import ast3, ast27
-except ImportError:
- if sys.version_info < (3, 8):
- print(
- "The typed_ast package is required but not installed.\n"
- "You can upgrade to Python 3.8+ or install typed_ast with\n"
- "`python3 -m pip install typed-ast`.",
- file=sys.stderr,
- )
- sys.exit(1)
- else:
- ast3 = ast27 = ast
-
-
-class InvalidInput(ValueError):
- """Raised when input source code fails all parse attempts."""
-
-
-def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
- if not target_versions:
- # No target_version specified, so try all grammars.
- return [
- # Python 3.7+
- pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
- # Python 3.0-3.6
- pygram.python_grammar_no_print_statement_no_exec_statement,
- # Python 2.7 with future print_function import
- pygram.python_grammar_no_print_statement,
- # Python 2.7
- pygram.python_grammar,
- ]
-
- if all(version.is_python2() for version in target_versions):
- # Python 2-only code, so try Python 2 grammars.
- return [
- # Python 2.7 with future print_function import
- pygram.python_grammar_no_print_statement,
- # Python 2.7
- pygram.python_grammar,
- ]
-
- # Python 3-compatible code, so only try Python 3 grammar.
- grammars = []
- # If we have to parse both, try to parse async as a keyword first
- if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
- # Python 3.7+
- grammars.append(
- pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
- )
- if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
- # Python 3.0-3.6
- grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
- # At least one of the above branches must have been taken, because every Python
- # version has exactly one of the two 'ASYNC_*' flags
- return grammars
-
-
-def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
- """Given a string with source, return the lib2to3 Node."""
- if not src_txt.endswith("\n"):
- src_txt += "\n"
-
- for grammar in get_grammars(set(target_versions)):
- drv = driver.Driver(grammar, pytree.convert)
- try:
- result = drv.parse_string(src_txt, True)
- break
-
- except ParseError as pe:
- lineno, column = pe.context[1]
- lines = src_txt.splitlines()
- try:
- faulty_line = lines[lineno - 1]
- except IndexError:
- faulty_line = "<line number missing in source>"
- exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
- else:
- raise exc from None
-
- if isinstance(result, Leaf):
- result = Node(syms.file_input, [result])
- return result
-
-
-def lib2to3_unparse(node: Node) -> str:
- """Given a lib2to3 node, return its string representation."""
- code = str(node)
- return code
-
-
-def parse_single_version(
- src: str, version: Tuple[int, int]
-) -> Union[ast.AST, ast3.AST, ast27.AST]:
- filename = "<unknown>"
- # typed_ast is needed because of feature version limitations in the builtin ast
- if sys.version_info >= (3, 8) and version >= (3,):
- return ast.parse(src, filename, feature_version=version)
- elif version >= (3,):
- return ast3.parse(src, filename, feature_version=version[1])
- elif version == (2, 7):
- return ast27.parse(src)
- raise AssertionError("INTERNAL ERROR: Tried parsing unsupported Python version!")
-
-
-def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
- # TODO: support Python 4+ ;)
- versions = [(3, minor) for minor in range(3, sys.version_info[1] + 1)]
-
- if ast27.__name__ != "ast":
- versions.append((2, 7))
-
- first_error = ""
- for version in sorted(versions, reverse=True):
- try:
- return parse_single_version(src, version)
- except SyntaxError as e:
- if not first_error:
- first_error = str(e)
-
- raise SyntaxError(first_error)
-
-
-def stringify_ast(
- node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
-) -> Iterator[str]:
- """Simple visitor generating strings to compare ASTs by content."""
-
- node = fixup_ast_constants(node)
-
- yield f"{' ' * depth}{node.__class__.__name__}("
-
- for field in sorted(node._fields): # noqa: F402
- # TypeIgnore has only one field 'lineno' which breaks this comparison
- type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
- if sys.version_info >= (3, 8):
- type_ignore_classes += (ast.TypeIgnore,)
- if isinstance(node, type_ignore_classes):
- break
-
- try:
- value = getattr(node, field)
- except AttributeError:
- continue
-
- yield f"{' ' * (depth+1)}{field}="
-
- if isinstance(value, list):
- for item in value:
- # Ignore nested tuples within del statements, because we may insert
- # parentheses and they change the AST.
- if (
- field == "targets"
- and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
- and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
- ):
- for item in item.elts:
- yield from stringify_ast(item, depth + 2)
-
- elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
- yield from stringify_ast(item, depth + 2)
-
- elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
- yield from stringify_ast(value, depth + 2)
-
- else:
- # Constant strings may be indented across newlines, if they are
- # docstrings; fold spaces after newlines when comparing. Similarly,
- # trailing and leading space may be removed.
- # Note that when formatting Python 2 code, at least with Windows
- # line-endings, docstrings can end up here as bytes instead of
- # str so make sure that we handle both cases.
- if (
- isinstance(node, ast.Constant)
- and field == "value"
- and isinstance(value, (str, bytes))
- ):
- lineend = "\n" if isinstance(value, str) else b"\n"
- # To normalize, we strip any leading and trailing space from
- # each line...
- stripped = [line.strip() for line in value.splitlines()]
- normalized = lineend.join(stripped) # type: ignore[attr-defined]
- # ...and remove any blank lines at the beginning and end of
- # the whole string
- normalized = normalized.strip()
- else:
- normalized = value
- yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}"
-
- yield f"{' ' * depth}) # /{node.__class__.__name__}"
-
-
-def fixup_ast_constants(
- node: Union[ast.AST, ast3.AST, ast27.AST]
-) -> Union[ast.AST, ast3.AST, ast27.AST]:
- """Map ast nodes deprecated in 3.8 to Constant."""
- if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
- return ast.Constant(value=node.s)
-
- if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
- return ast.Constant(value=node.n)
-
- if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
- return ast.Constant(value=node.value)
-
- return node