]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Include underlying error when AST safety check parsing fails (#2693)
[etc/vim.git] / tests / test_black.py
index 1fc63c942e9089aed7e34320d8f23a94910c5ba2..63cd716c0bb80aeea9cf31c274ead6b1b930d4e6 100644 (file)
@@ -31,7 +31,7 @@ from unittest.mock import MagicMock, patch
 
 import click
 import pytest
-import regex as re
+import re
 from click import unstyle
 from click.testing import CliRunner
 from pathspec import PathSpec
@@ -50,6 +50,7 @@ from tests.util import (
     DATA_DIR,
     DEFAULT_MODE,
     DETERMINISTIC_HEADER,
+    PROJECT_ROOT,
     PY36_VERSIONS,
     THIS_DIR,
     BlackBaseTestCase,
@@ -69,7 +70,7 @@ T = TypeVar("T")
 R = TypeVar("R")
 
 # Match the time output in a diff, but nothing else
-DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
+DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
 
 
 @contextmanager
@@ -121,7 +122,7 @@ def invokeBlack(
     runner = BlackRunner()
     if ignore_config:
         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
-    result = runner.invoke(black.main, args)
+    result = runner.invoke(black.main, args, catch_exceptions=False)
     assert result.stdout_bytes is not None
     assert result.stderr_bytes is not None
     msg = (
@@ -199,7 +200,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         actual = result.output
         # Again, the contents are checked in a different test, so only look for colors.
-        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[1m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
@@ -322,7 +323,7 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         # We check the contents of the diff in `test_expression_diff`. All
         # we need to check here is that color codes exist in the result.
-        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[1m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
@@ -810,6 +811,24 @@ class BlackTestCase(BlackBaseTestCase):
         node = black.lib2to3_parse("def fn(a, /, b): ...")
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
 
+    def test_get_features_used_for_future_flags(self) -> None:
+        for src, features in [
+            ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
+            (
+                "from __future__ import (other, annotations)",
+                {Feature.FUTURE_ANNOTATIONS},
+            ),
+            ("a = 1 + 2\nfrom something import annotations", set()),
+            ("from __future__ import x, y", set()),
+        ]:
+            with self.subTest(src=src, features=features):
+                node = black.lib2to3_parse(src)
+                future_imports = black.get_future_imports(node)
+                self.assertEqual(
+                    black.get_features_used(node, future_imports=future_imports),
+                    features,
+                )
+
     def test_get_future_imports(self) -> None:
         node = black.lib2to3_parse("\n")
         self.assertEqual(set(), black.get_future_imports(node))
@@ -840,6 +859,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
 
+    @pytest.mark.incompatible_with_mypyc
     def test_debug_visitor(self) -> None:
         source, _ = read_data("debug_visitor.py")
         expected, _ = read_data("debug_visitor.out")
@@ -890,6 +910,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(len(n.children), 1)
         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
 
+    @pytest.mark.incompatible_with_mypyc
     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
     def test_assertFormatEqual(self) -> None:
         out_lines = []
@@ -943,7 +964,7 @@ class BlackTestCase(BlackBaseTestCase):
             symlink = workspace / "broken_link.py"
             try:
                 symlink.symlink_to("nonexistent.py")
-            except OSError as e:
+            except (OSError, NotImplementedError) as e:
                 self.skipTest(f"Can't create symlinks: {e}")
             self.invokeBlack([str(workspace.resolve())])
 
@@ -1054,6 +1075,7 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1071,6 +1093,7 @@ class BlackTestCase(BlackBaseTestCase):
             fsts.assert_called_once()
             report.done.assert_called_with(path, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1093,6 +1116,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1117,6 +1141,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1141,6 +1166,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1295,6 +1321,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
+    @pytest.mark.incompatible_with_mypyc
     def test_find_project_root(self) -> None:
         with TemporaryDirectory() as workspace:
             root = Path(workspace)
@@ -1400,14 +1427,14 @@ class BlackTestCase(BlackBaseTestCase):
         )
         expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
 
-        result = CliRunner().invoke(
+        result = BlackRunner().invoke(
             black.main,
             ["-", "-q", "--target-version=py27"],
             input=BytesIO(source),
         )
 
         self.assertEqual(result.exit_code, 0)
-        actual = result.output
+        actual = result.stdout
         self.assertFormatEqual(actual, expected)
 
     @staticmethod
@@ -1482,6 +1509,7 @@ class BlackTestCase(BlackBaseTestCase):
         assert output == result_diff, "The output did not match the expected value."
         assert result.exit_code == 0, "The exit code is incorrect."
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_safe(self) -> None:
         """Test that the code option throws an error when the sanity checks fail."""
         # Patch black.assert_equivalent to ensure the sanity checks fail
@@ -1506,15 +1534,18 @@ class BlackTestCase(BlackBaseTestCase):
 
             self.compare_results(result, formatted, 0)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_config(self) -> None:
         """
         Test that the code option finds the pyproject.toml in the current directory.
         """
         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
             args = ["--code", "print"]
-            CliRunner().invoke(black.main, args)
+            # This is the only directory known to contain a pyproject.toml
+            with change_directory(PROJECT_ROOT):
+                CliRunner().invoke(black.main, args)
+                pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
 
-            pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
             assert (
                 len(parse.mock_calls) >= 1
             ), "Expected config parse to be called with the current directory."
@@ -1524,12 +1555,13 @@ class BlackTestCase(BlackBaseTestCase):
                 call_args[0].lower() == str(pyproject_path).lower()
             ), "Incorrect config loaded."
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_parent_config(self) -> None:
         """
         Test that the code option finds the pyproject.toml in the parent directory.
         """
         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
-            with change_directory(Path("tests")):
+            with change_directory(THIS_DIR):
                 args = ["--code", "print"]
                 CliRunner().invoke(black.main, args)
 
@@ -1543,6 +1575,25 @@ class BlackTestCase(BlackBaseTestCase):
                     call_args[0].lower() == str(pyproject_path).lower()
                 ), "Incorrect config loaded."
 
+    def test_for_handled_unexpected_eof_error(self) -> None:
+        """
+        Test that an unexpected EOF SyntaxError is nicely presented.
+        """
+        with pytest.raises(black.parsing.InvalidInput) as exc_info:
+            black.lib2to3_parse("print(", {})
+
+        exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
+
+    def test_equivalency_ast_parse_failure_includes_error(self) -> None:
+        with pytest.raises(AssertionError) as err:
+            black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
+
+        err.match("--safe")
+        # Unfortunately the SyntaxError message has changed in newer versions so we
+        # can't match it directly.
+        err.match("invalid character")
+        err.match(r"\(<unknown>, line 1\)")
+
 
 class TestCaching:
     def test_cache_broken_file(self) -> None:
@@ -1741,7 +1792,7 @@ def assert_collected_sources(
         report=black.Report(),
         stdin_filename=stdin_filename,
     )
-    assert sorted(list(collected)) == sorted(gs_expected)
+    assert sorted(collected) == sorted(gs_expected)
 
 
 class TestFileCollection:
@@ -1891,6 +1942,7 @@ class TestFileCollection:
             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
         )
 
+    @pytest.mark.incompatible_with_mypyc
     def test_symlink_out_of_root_directory(self) -> None:
         path = MagicMock()
         root = THIS_DIR.resolve()
@@ -2014,8 +2066,42 @@ class TestFileCollection:
         )
 
 
-with open(black.__file__, "r", encoding="utf-8") as _bf:
-    black_source_lines = _bf.readlines()
+@pytest.mark.python2
+@pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
+def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
+    args = [
+        "--config",
+        str(THIS_DIR / "empty.toml"),
+        str(DATA_DIR / "python2.py"),
+        "--check",
+    ]
+    if explicit:
+        args.append("--target-version=py27")
+    with cache_dir():
+        result = BlackRunner().invoke(black.main, args)
+    assert "DEPRECATION: Python 2 support will be removed" in result.stderr
+
+
+@pytest.mark.python2
+def test_python_2_deprecation_autodetection_extended() -> None:
+    # this test has a similar construction to test_get_features_used_decorator
+    python2, non_python2 = read_data("python2_detection")
+    for python2_case in python2.split("###"):
+        node = black.lib2to3_parse(python2_case)
+        assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
+    for non_python2_case in non_python2.split("###"):
+        node = black.lib2to3_parse(non_python2_case)
+        assert black.detect_target_versions(node) != {
+            TargetVersion.PY27
+        }, non_python2_case
+
+
+try:
+    with open(black.__file__, "r", encoding="utf-8") as _bf:
+        black_source_lines = _bf.readlines()
+except UnicodeDecodeError:
+    if not black.COMPILED:
+        raise
 
 
 def tracefunc(