X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/4ea75cd49521ed7fd8384e7a739e1abb1b6de46a..fda2561f79e10826dbdeb900b6124d642766229f:/tests/test_black.py

diff --git a/tests/test_black.py b/tests/test_black.py
index fd01425..2dd284f 100644
--- a/tests/test_black.py
+++ b/tests/test_black.py
@@ -40,7 +40,7 @@ import black
 import black.files
 from black import Feature, TargetVersion
 from black import re_compile_maybe_verbose as compile_pattern
-from black.cache import get_cache_file
+from black.cache import get_cache_dir, get_cache_file
 from black.debug import DebugVisitor
 from black.output import color_diff, diff
 from black.report import Report
@@ -228,45 +228,6 @@ class BlackTestCase(BlackBaseTestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, black.FileMode())
 
-    @unittest.expectedFailure
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability1(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens1")
-        actual = fs(source)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @unittest.expectedFailure
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens2")
-        actual = fs(source)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @unittest.expectedFailure
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability3(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens3")
-        actual = fs(source)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens1")
-        actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens2")
-        actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens3")
-        actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
     def test_pep_572_version_detection(self) -> None:
         source, _ = read_data("pep_572")
         root = black.lib2to3_parse(source)
@@ -972,10 +933,13 @@ class BlackTestCase(BlackBaseTestCase):
             # Multi file command.
             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
 
-    def test_no_files(self) -> None:
+    def test_no_src_fails(self) -> None:
+        with cache_dir():
+            self.invokeBlack([], exit_code=1)
+
+    def test_src_and_code_fails(self) -> None:
         with cache_dir():
-            # Without an argument, black exits with error code 0.
-            self.invokeBlack([])
+            self.invokeBlack([".", "-c", "0"], exit_code=1)
 
     def test_broken_symlink(self) -> None:
         with cache_dir() as workspace:
@@ -1229,13 +1193,18 @@ class BlackTestCase(BlackBaseTestCase):
 
     def test_required_version_matches_version(self) -> None:
         self.invokeBlack(
-            ["--required-version", black.__version__], exit_code=0, ignore_config=True
+            ["--required-version", black.__version__, "-c", "0"],
+            exit_code=0,
+            ignore_config=True,
         )
 
     def test_required_version_does_not_match_version(self) -> None:
-        self.invokeBlack(
-            ["--required-version", "20.99b"], exit_code=1, ignore_config=True
+        result = BlackRunner().invoke(
+            black.main,
+            ["--required-version", "20.99b", "-c", "0"],
         )
+        self.assertEqual(result.exit_code, 1)
+        self.assertIn("required version", result.stderr)
 
     def test_preserves_line_endings(self) -> None:
         with TemporaryDirectory() as workspace:
@@ -1601,6 +1570,33 @@ class BlackTestCase(BlackBaseTestCase):
 
 
 class TestCaching:
+    def test_get_cache_dir(
+        self,
+        tmp_path: Path,
+        monkeypatch: pytest.MonkeyPatch,
+    ) -> None:
+        # Create multiple cache directories
+        workspace1 = tmp_path / "ws1"
+        workspace1.mkdir()
+        workspace2 = tmp_path / "ws2"
+        workspace2.mkdir()
+
+        # Force user_cache_dir to use the temporary directory for easier assertions
+        patch_user_cache_dir = patch(
+            target="black.cache.user_cache_dir",
+            autospec=True,
+            return_value=str(workspace1),
+        )
+
+        # If BLACK_CACHE_DIR is not set, use user_cache_dir
+        monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
+        with patch_user_cache_dir:
+            assert get_cache_dir() == workspace1
+
+        # If it is set, use the path provided in the env var.
+        monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
+        assert get_cache_dir() == workspace2
+
     def test_cache_broken_file(self) -> None:
         mode = DEFAULT_MODE
         with cache_dir() as workspace: