]> git.madduck.net Git - etc/vim.git/blob - src/black/parsing.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

ee6aae1e7ff9db26be2f5f4c439a25d31baf2b7d
[etc/vim.git] / src / black / parsing.py
1 """
2 Parse Python code and perform AST validation.
3 """
4 import ast
5 import platform
6 import sys
7 from typing import Iterable, Iterator, List, Set, Union, Tuple
8
9 # lib2to3 fork
10 from blib2to3.pytree import Node, Leaf
11 from blib2to3 import pygram, pytree
12 from blib2to3.pgen2 import driver
13 from blib2to3.pgen2.grammar import Grammar
14 from blib2to3.pgen2.parse import ParseError
15
16 from black.mode import TargetVersion, Feature, supports_feature
17 from black.nodes import syms
18
19 _IS_PYPY = platform.python_implementation() == "PyPy"
20
21 try:
22     from typed_ast import ast3, ast27
23 except ImportError:
24     # Either our python version is too low, or we're on pypy
25     if sys.version_info < (3, 7) or (sys.version_info < (3, 8) and not _IS_PYPY):
26         print(
27             "The typed_ast package is required but not installed.\n"
28             "You can upgrade to Python 3.8+ or install typed_ast with\n"
29             "`python3 -m pip install typed-ast`.",
30             file=sys.stderr,
31         )
32         sys.exit(1)
33     else:
34         ast3 = ast27 = ast
35
36
37 class InvalidInput(ValueError):
38     """Raised when input source code fails all parse attempts."""
39
40
41 def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
42     if not target_versions:
43         # No target_version specified, so try all grammars.
44         return [
45             # Python 3.7+
46             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
47             # Python 3.0-3.6
48             pygram.python_grammar_no_print_statement_no_exec_statement,
49             # Python 2.7 with future print_function import
50             pygram.python_grammar_no_print_statement,
51             # Python 2.7
52             pygram.python_grammar,
53         ]
54
55     if all(version.is_python2() for version in target_versions):
56         # Python 2-only code, so try Python 2 grammars.
57         return [
58             # Python 2.7 with future print_function import
59             pygram.python_grammar_no_print_statement,
60             # Python 2.7
61             pygram.python_grammar,
62         ]
63
64     # Python 3-compatible code, so only try Python 3 grammar.
65     grammars = []
66     if supports_feature(target_versions, Feature.PATTERN_MATCHING):
67         # Python 3.10+
68         grammars.append(pygram.python_grammar_soft_keywords)
69     # If we have to parse both, try to parse async as a keyword first
70     if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
71         # Python 3.7+
72         grammars.append(
73             pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords
74         )
75     if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
76         # Python 3.0-3.6
77         grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
78     # At least one of the above branches must have been taken, because every Python
79     # version has exactly one of the two 'ASYNC_*' flags
80     return grammars
81
82
83 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
84     """Given a string with source, return the lib2to3 Node."""
85     if not src_txt.endswith("\n"):
86         src_txt += "\n"
87
88     for grammar in get_grammars(set(target_versions)):
89         drv = driver.Driver(grammar, pytree.convert)
90         try:
91             result = drv.parse_string(src_txt, True)
92             break
93
94         except ParseError as pe:
95             lineno, column = pe.context[1]
96             lines = src_txt.splitlines()
97             try:
98                 faulty_line = lines[lineno - 1]
99             except IndexError:
100                 faulty_line = "<line number missing in source>"
101             exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
102     else:
103         raise exc from None
104
105     if isinstance(result, Leaf):
106         result = Node(syms.file_input, [result])
107     return result
108
109
110 def lib2to3_unparse(node: Node) -> str:
111     """Given a lib2to3 node, return its string representation."""
112     code = str(node)
113     return code
114
115
116 def parse_single_version(
117     src: str, version: Tuple[int, int]
118 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
119     filename = "<unknown>"
120     # typed_ast is needed because of feature version limitations in the builtin ast
121     if sys.version_info >= (3, 8) and version >= (3,):
122         return ast.parse(src, filename, feature_version=version)
123     elif version >= (3,):
124         if _IS_PYPY:
125             return ast3.parse(src, filename)
126         else:
127             return ast3.parse(src, filename, feature_version=version[1])
128     elif version == (2, 7):
129         return ast27.parse(src)
130     raise AssertionError("INTERNAL ERROR: Tried parsing unsupported Python version!")
131
132
133 def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
134     # TODO: support Python 4+ ;)
135     versions = [(3, minor) for minor in range(3, sys.version_info[1] + 1)]
136
137     if ast27.__name__ != "ast":
138         versions.append((2, 7))
139
140     first_error = ""
141     for version in sorted(versions, reverse=True):
142         try:
143             return parse_single_version(src, version)
144         except SyntaxError as e:
145             if not first_error:
146                 first_error = str(e)
147
148     raise SyntaxError(first_error)
149
150
151 def stringify_ast(
152     node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0
153 ) -> Iterator[str]:
154     """Simple visitor generating strings to compare ASTs by content."""
155
156     node = fixup_ast_constants(node)
157
158     yield f"{'  ' * depth}{node.__class__.__name__}("
159
160     for field in sorted(node._fields):  # noqa: F402
161         # TypeIgnore will not be present using pypy < 3.8, so need for this
162         if not (_IS_PYPY and sys.version_info < (3, 8)):
163             # TypeIgnore has only one field 'lineno' which breaks this comparison
164             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
165             if sys.version_info >= (3, 8):
166                 type_ignore_classes += (ast.TypeIgnore,)
167             if isinstance(node, type_ignore_classes):
168                 break
169
170         try:
171             value = getattr(node, field)
172         except AttributeError:
173             continue
174
175         yield f"{'  ' * (depth+1)}{field}="
176
177         if isinstance(value, list):
178             for item in value:
179                 # Ignore nested tuples within del statements, because we may insert
180                 # parentheses and they change the AST.
181                 if (
182                     field == "targets"
183                     and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
184                     and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
185                 ):
186                     for item in item.elts:
187                         yield from stringify_ast(item, depth + 2)
188
189                 elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
190                     yield from stringify_ast(item, depth + 2)
191
192         elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
193             yield from stringify_ast(value, depth + 2)
194
195         else:
196             # Constant strings may be indented across newlines, if they are
197             # docstrings; fold spaces after newlines when comparing. Similarly,
198             # trailing and leading space may be removed.
199             # Note that when formatting Python 2 code, at least with Windows
200             # line-endings, docstrings can end up here as bytes instead of
201             # str so make sure that we handle both cases.
202             if (
203                 isinstance(node, ast.Constant)
204                 and field == "value"
205                 and isinstance(value, (str, bytes))
206             ):
207                 lineend = "\n" if isinstance(value, str) else b"\n"
208                 # To normalize, we strip any leading and trailing space from
209                 # each line...
210                 stripped = [line.strip() for line in value.splitlines()]
211                 normalized = lineend.join(stripped)  # type: ignore[attr-defined]
212                 # ...and remove any blank lines at the beginning and end of
213                 # the whole string
214                 normalized = normalized.strip()
215             else:
216                 normalized = value
217             yield f"{'  ' * (depth+2)}{normalized!r},  # {value.__class__.__name__}"
218
219     yield f"{'  ' * depth})  # /{node.__class__.__name__}"
220
221
222 def fixup_ast_constants(
223     node: Union[ast.AST, ast3.AST, ast27.AST]
224 ) -> Union[ast.AST, ast3.AST, ast27.AST]:
225     """Map ast nodes deprecated in 3.8 to Constant."""
226     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
227         return ast.Constant(value=node.s)
228
229     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
230         return ast.Constant(value=node.n)
231
232     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
233         return ast.Constant(value=node.value)
234
235     return node