From: Jason Fried Date: Wed, 8 May 2019 14:45:59 +0000 (-0400) Subject: Make --safe work for Python2.7 syntax, by using typed_ast for safe validation (#840) X-Git-Url: https://git.madduck.net/etc/vim.git/commitdiff_plain/866be066463fc8fd01c16559596641f6ead1e797 Make --safe work for Python2.7 syntax, by using typed_ast for safe validation (#840) --- diff --git a/Pipfile b/Pipfile index 04531cd..a8ef07a 100644 --- a/Pipfile +++ b/Pipfile @@ -11,6 +11,7 @@ appdirs = "*" toml = ">=0.9.4" black = {path = ".",extras = ["d"],editable = true} aiohttp-cors = "*" +typed-ast = ">=1.3.1" [dev-packages] pre-commit = "*" diff --git a/black.py b/black.py index 1978fd5..9494c4c 100644 --- a/black.py +++ b/black.py @@ -40,6 +40,7 @@ from appdirs import user_cache_dir from attr import dataclass, evolve, Factory import click import toml +from typed_ast import ast3, ast27 # lib2to3 fork from blib2to3.pytree import Node, Leaf, type_repr @@ -3380,17 +3381,31 @@ class Report: return ", ".join(report) + "." +def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]: + try: + return ast3.parse(src) + except SyntaxError: + return ast27.parse(src) + + def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" - import ast import traceback - def _v(node: ast.AST, depth: int = 0) -> Iterator[str]: + def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: """Simple visitor generating strings to compare ASTs by content.""" yield f"{' ' * depth}{node.__class__.__name__}(" for field in sorted(node._fields): + # TypeIgnore has only one field 'lineno' which breaks this comparison + if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)): + break + + # Ignore str kind which is case sensitive / and ignores unicode_literals + if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind": + continue + try: value = getattr(node, field) except AttributeError: @@ -3404,15 +3419,15 @@ def assert_equivalent(src: str, dst: str) -> None: # parentheses and they change the AST. if ( field == "targets" - and isinstance(node, ast.Delete) - and isinstance(item, ast.Tuple) + and isinstance(node, (ast3.Delete, ast27.Delete)) + and isinstance(item, (ast3.Tuple, ast27.Tuple)) ): for item in item.elts: yield from _v(item, depth + 2) - elif isinstance(item, ast.AST): + elif isinstance(item, (ast3.AST, ast27.AST)): yield from _v(item, depth + 2) - elif isinstance(value, ast.AST): + elif isinstance(value, (ast3.AST, ast27.AST)): yield from _v(value, depth + 2) else: @@ -3421,7 +3436,7 @@ def assert_equivalent(src: str, dst: str) -> None: yield f"{' ' * depth}) # /{node.__class__.__name__}" try: - src_ast = ast.parse(src) + src_ast = parse_ast(src) except Exception as exc: major, minor = sys.version_info[:2] raise AssertionError( @@ -3431,7 +3446,7 @@ def assert_equivalent(src: str, dst: str) -> None: ) try: - dst_ast = ast.parse(dst) + dst_ast = parse_ast(dst) except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( diff --git a/setup.py b/setup.py index b12d4ac..c4ea969 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,13 @@ setup( package_data={"blib2to3": ["*.txt"]}, python_requires=">=3.6", zip_safe=False, - install_requires=["click>=6.5", "attrs>=18.1.0", "appdirs", "toml>=0.9.4"], + install_requires=[ + "click>=6.5", + "attrs>=18.1.0", + "appdirs", + "toml>=0.9.4", + "typed-ast>=1.3.1", + ], extras_require={"d": ["aiohttp>=3.3.2", "aiohttp-cors"]}, test_suite="tests.test_black", classifiers=[ diff --git a/tests/data/comments6.py b/tests/data/comments6.py index 0a0bf46..ce17382 100644 --- a/tests/data/comments6.py +++ b/tests/data/comments6.py @@ -55,8 +55,8 @@ def f( an_element_with_a_long_value = calls() or more_calls() and more() # type: bool tup = ( - another_element, # type: int - another_really_really_long_element_with_a_unnecessarily_long_name_to_describe_what_it_does_enterprise_style, # type: int + another_element, + another_really_really_long_element_with_a_unnecessarily_long_name_to_describe_what_it_does_enterprise_style, ) # type: Tuple[int, int] a = ( @@ -83,4 +83,4 @@ def func( 0.0456, 0.0789, a[-1], # type: ignore - ) \ No newline at end of file + ) diff --git a/tests/test_black.py b/tests/test_black.py index 53d1750..59343ef 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -2,7 +2,7 @@ import asyncio import logging from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager, redirect_stderr +from contextlib import contextmanager from functools import partial, wraps from io import BytesIO, TextIOWrapper import os @@ -474,7 +474,7 @@ class BlackTestCase(unittest.TestCase): source, expected = read_data("python2") actual = fs(source) self.assertFormatEqual(expected, actual) - # black.assert_equivalent(source, actual) + black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) @patch("black.dump_to_file", dump_to_stderr) @@ -483,6 +483,7 @@ class BlackTestCase(unittest.TestCase): mode = black.FileMode(target_versions={TargetVersion.PY27}) actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) black.assert_stable(source, actual, mode) @patch("black.dump_to_file", dump_to_stderr) @@ -490,6 +491,7 @@ class BlackTestCase(unittest.TestCase): source, expected = read_data("python2_unicode_literals") actual = fs(source) self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) @patch("black.dump_to_file", dump_to_stderr) @@ -1562,20 +1564,6 @@ class BlackTestCase(unittest.TestCase): await check("3.4", 204) await check("py3.4", 204) - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @async_test - async def test_blackd_fast(self) -> None: - with open(os.devnull, "w") as dn, redirect_stderr(dn): - app = blackd.make_app() - async with TestClient(TestServer(app)) as client: - response = await client.post("/", data=b"ur'hello'") - self.assertEqual(response.status, 500) - self.assertIn("failed to parse source file", await response.text()) - response = await client.post( - "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"} - ) - self.assertEqual(response.status, 200) - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") @async_test async def test_blackd_line_length(self) -> None: