From ba64fc757c12e59fb35f2306eb4fa75fdc566647 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sat, 16 Mar 2019 11:35:18 -0700 Subject: [PATCH] redo grammar selection, add test (#765) --- black.py | 18 +++++++----------- tests/test_black.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/black.py b/black.py index 680b1f4..2dee826 100644 --- a/black.py +++ b/black.py @@ -715,24 +715,20 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: return tiow.read(), encoding, newline -GRAMMARS = [ - pygram.python_grammar_no_print_statement_no_exec_statement, - pygram.python_grammar_no_print_statement, - pygram.python_grammar, -] - - def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: if not target_versions: - return GRAMMARS - elif all(not version.is_python2() for version in target_versions): - # Python 3-compatible code, so don't try Python 2 grammar + # No target_version specified, so try all grammars. return [ pygram.python_grammar_no_print_statement_no_exec_statement, pygram.python_grammar_no_print_statement, + pygram.python_grammar, ] - else: + elif all(version.is_python2() for version in target_versions): + # Python 2-only code, so try Python 2 grammars. return [pygram.python_grammar_no_print_statement, pygram.python_grammar] + else: + # Python 3-compatible code, so only try Python 3 grammar. + return [pygram.python_grammar_no_print_statement_no_exec_statement] def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: diff --git a/tests/test_black.py b/tests/test_black.py index 645eec7..a3e2ff8 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -28,7 +28,7 @@ from click import unstyle from click.testing import CliRunner import black -from black import Feature +from black import Feature, TargetVersion try: import blackd @@ -464,7 +464,7 @@ class BlackTestCase(unittest.TestCase): @patch("black.dump_to_file", dump_to_stderr) def test_python2_print_function(self) -> None: source, expected = read_data("python2_print_function") - mode = black.FileMode(target_versions={black.TargetVersion.PY27}) + mode = black.FileMode(target_versions={TargetVersion.PY27}) actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) black.assert_stable(source, actual, mode) @@ -828,6 +828,31 @@ class BlackTestCase(unittest.TestCase): "2 files would fail to reformat.", ) + def test_lib2to3_parse(self) -> None: + with self.assertRaises(black.InvalidInput): + black.lib2to3_parse("invalid syntax") + + straddling = "x + y" + black.lib2to3_parse(straddling) + black.lib2to3_parse(straddling, {TargetVersion.PY27}) + black.lib2to3_parse(straddling, {TargetVersion.PY36}) + black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36}) + + py2_only = "print x" + black.lib2to3_parse(py2_only) + black.lib2to3_parse(py2_only, {TargetVersion.PY27}) + with self.assertRaises(black.InvalidInput): + black.lib2to3_parse(py2_only, {TargetVersion.PY36}) + with self.assertRaises(black.InvalidInput): + black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36}) + + py3_only = "exec(x, end=y)" + black.lib2to3_parse(py3_only) + with self.assertRaises(black.InvalidInput): + black.lib2to3_parse(py3_only, {TargetVersion.PY27}) + black.lib2to3_parse(py3_only, {TargetVersion.PY36}) + black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36}) + def test_get_features_used(self) -> None: node = black.lib2to3_parse("def f(*, arg): ...\n") self.assertEqual(black.get_features_used(node), set()) -- 2.39.2