X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/2f3fa1f6d0cbc2a3f31c7440c422da173b068e7b..e0c572833a3e2b42cd45237c26a67c6f5be4b09d:/tests/test_black.py?ds=sidebyside

diff --git a/tests/test_black.py b/tests/test_black.py
index beb56cf1..fd01425a 100644
--- a/tests/test_black.py
+++ b/tests/test_black.py
@@ -31,7 +31,7 @@ from unittest.mock import MagicMock, patch
 
 import click
 import pytest
-import regex as re
+import re
 from click import unstyle
 from click.testing import CliRunner
 from pathspec import PathSpec
@@ -50,6 +50,7 @@ from tests.util import (
     DATA_DIR,
     DEFAULT_MODE,
     DETERMINISTIC_HEADER,
+    PROJECT_ROOT,
     PY36_VERSIONS,
     THIS_DIR,
     BlackBaseTestCase,
@@ -69,7 +70,7 @@ T = TypeVar("T")
 R = TypeVar("R")
 
 # Match the time output in a diff, but nothing else
-DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
+DIFF_TIME = re.compile(r"\t[\d\-:+\. ]+")
 
 
 @contextmanager
@@ -99,6 +100,8 @@ class FakeContext(click.Context):
 
     def __init__(self) -> None:
         self.default_map: Dict[str, Any] = {}
+        # Dummy root, since most of the tests don't care about it
+        self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
 
 
 class FakeParameter(click.Parameter):
@@ -121,7 +124,7 @@ def invokeBlack(
     runner = BlackRunner()
     if ignore_config:
         args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
-    result = runner.invoke(black.main, args)
+    result = runner.invoke(black.main, args, catch_exceptions=False)
     assert result.stdout_bytes is not None
     assert result.stderr_bytes is not None
     msg = (
@@ -147,6 +150,11 @@ class BlackTestCase(BlackBaseTestCase):
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
+    def test_experimental_string_processing_warns(self) -> None:
+        self.assertWarns(
+            black.mode.Deprecated, black.Mode, experimental_string_processing=True
+        )
+
     def test_piping(self) -> None:
         source, expected = read_data("src/black/__init__", data=False)
         result = BlackRunner().invoke(
@@ -199,7 +207,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         actual = result.output
         # Again, the contents are checked in a different test, so only look for colors.
-        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[1m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
@@ -322,7 +330,7 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         # We check the contents of the diff in `test_expression_diff`. All
         # we need to check here is that color codes exist in the result.
-        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[1m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
@@ -339,7 +347,7 @@ class BlackTestCase(BlackBaseTestCase):
     @patch("black.dump_to_file", dump_to_stderr)
     def test_string_quotes(self) -> None:
         source, expected = read_data("string_quotes")
-        mode = black.Mode(experimental_string_processing=True)
+        mode = black.Mode(preview=True)
         assert_format(source, expected, mode)
         mode = replace(mode, string_normalization=False)
         not_normalized = fs(source, mode=mode)
@@ -723,24 +731,15 @@ class BlackTestCase(BlackBaseTestCase):
 
         straddling = "x + y"
         black.lib2to3_parse(straddling)
-        black.lib2to3_parse(straddling, {TargetVersion.PY27})
         black.lib2to3_parse(straddling, {TargetVersion.PY36})
-        black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
 
         py2_only = "print x"
-        black.lib2to3_parse(py2_only)
-        black.lib2to3_parse(py2_only, {TargetVersion.PY27})
         with self.assertRaises(black.InvalidInput):
             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
-        with self.assertRaises(black.InvalidInput):
-            black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
 
         py3_only = "exec(x, end=y)"
         black.lib2to3_parse(py3_only)
-        with self.assertRaises(black.InvalidInput):
-            black.lib2to3_parse(py3_only, {TargetVersion.PY27})
         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
-        black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
 
     def test_get_features_used_decorator(self) -> None:
         # Test the feature detection of new decorator syntax
@@ -809,6 +808,44 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
         node = black.lib2to3_parse("def fn(a, /, b): ...")
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
+        node = black.lib2to3_parse("def fn(): yield a, b")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("def fn(): return a, b")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("def fn(): yield *b, c")
+        self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
+        node = black.lib2to3_parse("def fn(): return a, *b, c")
+        self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
+        node = black.lib2to3_parse("x = a, *b, c")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = regular")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = (regular, regular)")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
+        self.assertEqual(
+            black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
+        )
+
+    def test_get_features_used_for_future_flags(self) -> None:
+        for src, features in [
+            ("from __future__ import annotations", {Feature.FUTURE_ANNOTATIONS}),
+            (
+                "from __future__ import (other, annotations)",
+                {Feature.FUTURE_ANNOTATIONS},
+            ),
+            ("a = 1 + 2\nfrom something import annotations", set()),
+            ("from __future__ import x, y", set()),
+        ]:
+            with self.subTest(src=src, features=features):
+                node = black.lib2to3_parse(src)
+                future_imports = black.get_future_imports(node)
+                self.assertEqual(
+                    black.get_features_used(node, future_imports=future_imports),
+                    features,
+                )
 
     def test_get_future_imports(self) -> None:
         node = black.lib2to3_parse("\n")
@@ -840,6 +877,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
 
+    @pytest.mark.incompatible_with_mypyc
     def test_debug_visitor(self) -> None:
         source, _ = read_data("debug_visitor.py")
         expected, _ = read_data("debug_visitor.out")
@@ -890,6 +928,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(len(n.children), 1)
         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
 
+    @pytest.mark.incompatible_with_mypyc
     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
     def test_assertFormatEqual(self) -> None:
         out_lines = []
@@ -906,8 +945,8 @@ class BlackTestCase(BlackBaseTestCase):
                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
 
         out_str = "".join(out_lines)
-        self.assertTrue("Expected tree:" in out_str)
-        self.assertTrue("Actual tree:" in out_str)
+        self.assertIn("Expected tree:", out_str)
+        self.assertIn("Actual tree:", out_str)
         self.assertEqual("".join(err_lines), "")
 
     @event_loop()
@@ -943,7 +982,7 @@ class BlackTestCase(BlackBaseTestCase):
             symlink = workspace / "broken_link.py"
             try:
                 symlink.symlink_to("nonexistent.py")
-            except OSError as e:
+            except (OSError, NotImplementedError) as e:
                 self.skipTest(f"Can't create symlinks: {e}")
             self.invokeBlack([str(workspace.resolve())])
 
@@ -1054,6 +1093,7 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1071,6 +1111,7 @@ class BlackTestCase(BlackBaseTestCase):
             fsts.assert_called_once()
             report.done.assert_called_with(path, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1093,6 +1134,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename_pyi(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1117,6 +1159,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1141,6 +1184,7 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1278,6 +1322,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(config["color"], True)
         self.assertEqual(config["line_length"], 79)
         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
+        self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
@@ -1295,6 +1340,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
+    @pytest.mark.incompatible_with_mypyc
     def test_find_project_root(self) -> None:
         with TemporaryDirectory() as workspace:
             root = Path(workspace)
@@ -1312,10 +1358,17 @@ class BlackTestCase(BlackBaseTestCase):
             src_python.touch()
 
             self.assertEqual(
-                black.find_project_root((src_dir, test_dir)), root.resolve()
+                black.find_project_root((src_dir, test_dir)),
+                (root.resolve(), "pyproject.toml"),
+            )
+            self.assertEqual(
+                black.find_project_root((src_dir,)),
+                (src_dir.resolve(), "pyproject.toml"),
+            )
+            self.assertEqual(
+                black.find_project_root((src_python,)),
+                (src_dir.resolve(), "pyproject.toml"),
             )
-            self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
-            self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
 
     @patch(
         "black.files.find_user_pyproject_toml",
@@ -1389,27 +1442,6 @@ class BlackTestCase(BlackBaseTestCase):
         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
         self.assertEqual(actual, expected)
 
-    @pytest.mark.python2
-    def test_docstring_reformat_for_py27(self) -> None:
-        """
-        Check that stripping trailing whitespace from Python 2 docstrings
-        doesn't trigger a "not equivalent to source" error
-        """
-        source = (
-            b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
-        )
-        expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
-
-        result = CliRunner().invoke(
-            black.main,
-            ["-", "-q", "--target-version=py27"],
-            input=BytesIO(source),
-        )
-
-        self.assertEqual(result.exit_code, 0)
-        actual = result.output
-        self.assertFormatEqual(actual, expected)
-
     @staticmethod
     def compare_results(
         result: click.testing.Result, expected_value: str, expected_exit_code: int
@@ -1482,6 +1514,7 @@ class BlackTestCase(BlackBaseTestCase):
         assert output == result_diff, "The output did not match the expected value."
         assert result.exit_code == 0, "The exit code is incorrect."
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_safe(self) -> None:
         """Test that the code option throws an error when the sanity checks fail."""
         # Patch black.assert_equivalent to ensure the sanity checks fail
@@ -1506,15 +1539,18 @@ class BlackTestCase(BlackBaseTestCase):
 
             self.compare_results(result, formatted, 0)
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_config(self) -> None:
         """
         Test that the code option finds the pyproject.toml in the current directory.
         """
         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
             args = ["--code", "print"]
-            CliRunner().invoke(black.main, args)
+            # This is the only directory known to contain a pyproject.toml
+            with change_directory(PROJECT_ROOT):
+                CliRunner().invoke(black.main, args)
+                pyproject_path = Path(Path.cwd(), "pyproject.toml").resolve()
 
-            pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
             assert (
                 len(parse.mock_calls) >= 1
             ), "Expected config parse to be called with the current directory."
@@ -1524,12 +1560,13 @@ class BlackTestCase(BlackBaseTestCase):
                 call_args[0].lower() == str(pyproject_path).lower()
             ), "Incorrect config loaded."
 
+    @pytest.mark.incompatible_with_mypyc
     def test_code_option_parent_config(self) -> None:
         """
         Test that the code option finds the pyproject.toml in the parent directory.
         """
         with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
-            with change_directory(Path("tests")):
+            with change_directory(THIS_DIR):
                 args = ["--code", "print"]
                 CliRunner().invoke(black.main, args)
 
@@ -1543,6 +1580,25 @@ class BlackTestCase(BlackBaseTestCase):
                     call_args[0].lower() == str(pyproject_path).lower()
                 ), "Incorrect config loaded."
 
+    def test_for_handled_unexpected_eof_error(self) -> None:
+        """
+        Test that an unexpected EOF SyntaxError is nicely presented.
+        """
+        with pytest.raises(black.parsing.InvalidInput) as exc_info:
+            black.lib2to3_parse("print(", {})
+
+        exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
+
+    def test_equivalency_ast_parse_failure_includes_error(self) -> None:
+        with pytest.raises(AssertionError) as err:
+            black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
+
+        err.match("--safe")
+        # Unfortunately the SyntaxError message has changed in newer versions so we
+        # can't match it directly.
+        err.match("invalid character")
+        err.match(r"\(<unknown>, line 1\)")
+
 
 class TestCaching:
     def test_cache_broken_file(self) -> None:
@@ -1715,6 +1771,7 @@ def assert_collected_sources(
     src: Sequence[Union[str, Path]],
     expected: Sequence[Union[str, Path]],
     *,
+    ctx: Optional[FakeContext] = None,
     exclude: Optional[str] = None,
     include: Optional[str] = None,
     extend_exclude: Optional[str] = None,
@@ -1730,7 +1787,7 @@ def assert_collected_sources(
     )
     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
     collected = black.get_sources(
-        ctx=FakeContext(),
+        ctx=ctx or FakeContext(),
         src=gs_src,
         quiet=False,
         verbose=False,
@@ -1741,7 +1798,7 @@ def assert_collected_sources(
         report=black.Report(),
         stdin_filename=stdin_filename,
     )
-    assert sorted(list(collected)) == sorted(gs_expected)
+    assert sorted(collected) == sorted(gs_expected)
 
 
 class TestFileCollection:
@@ -1766,9 +1823,11 @@ class TestFileCollection:
             base / "b/.definitely_exclude/a.pyi",
         ]
         src = [base / "b/"]
-        assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
+        ctx = FakeContext()
+        ctx.obj["root"] = base
+        assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_exclude_for_issue_1572(self) -> None:
         # Exclude shouldn't touch files that were explicitly given to Black through the
         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
@@ -1891,6 +1950,7 @@ class TestFileCollection:
             src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
         )
 
+    @pytest.mark.incompatible_with_mypyc
     def test_symlink_out_of_root_directory(self) -> None:
         path = MagicMock()
         root = THIS_DIR.resolve()
@@ -1950,13 +2010,13 @@ class TestFileCollection:
         child.is_symlink.assert_called()
         assert child.is_symlink.call_count == 2
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin(self) -> None:
         src = ["-"]
         expected = ["-"]
         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename(self) -> None:
         src = ["-"]
         stdin_filename = str(THIS_DIR / "data/collections.py")
@@ -1968,7 +2028,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
         # Exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
@@ -1984,7 +2044,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
@@ -2000,7 +2060,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
         # Force exclude should exclude the file when passing it through
         # stdin_filename
@@ -2014,11 +2074,17 @@ class TestFileCollection:
         )
 
 
-with open(black.__file__, "r", encoding="utf-8") as _bf:
-    black_source_lines = _bf.readlines()
+try:
+    with open(black.__file__, "r", encoding="utf-8") as _bf:
+        black_source_lines = _bf.readlines()
+except UnicodeDecodeError:
+    if not black.COMPILED:
+        raise
 
 
-def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
+def tracefunc(
+    frame: types.FrameType, event: str, arg: Any
+) -> Callable[[types.FrameType, str, Any], Any]:
     """Show function calls `from black/__init__.py` as they happen.
 
     Register this with `sys.settrace()` in a test you're debugging.