]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update description for GitHub Action `options:` argument (GH-2858)
[etc/vim.git] / tests / test_black.py
index 63cd716c0bb80aeea9cf31c274ead6b1b930d4e6..cd38d9e2c0d63c1648f0003b58cedc77d865ea31 100644 (file)
@@ -40,7 +40,7 @@ import black
 import black.files
 from black import Feature, TargetVersion
 from black import re_compile_maybe_verbose as compile_pattern
-from black.cache import get_cache_file
+from black.cache import get_cache_dir, get_cache_file
 from black.debug import DebugVisitor
 from black.output import color_diff, diff
 from black.report import Report
@@ -63,6 +63,7 @@ from tests.util import (
 )
 
 THIS_FILE = Path(__file__)
+EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml"
 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
 DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
 DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
@@ -100,6 +101,8 @@ class FakeContext(click.Context):
 
     def __init__(self) -> None:
         self.default_map: Dict[str, Any] = {}
+        # Dummy root, since most of the tests don't care about it
+        self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
 
 
 class FakeParameter(click.Parameter):
@@ -148,11 +151,21 @@ class BlackTestCase(BlackBaseTestCase):
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
+    def test_experimental_string_processing_warns(self) -> None:
+        self.assertWarns(
+            black.mode.Deprecated, black.Mode, experimental_string_processing=True
+        )
+
     def test_piping(self) -> None:
         source, expected = read_data("src/black/__init__", data=False)
         result = BlackRunner().invoke(
             black.main,
-            ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
+            [
+                "-",
+                "--fast",
+                f"--line-length={black.DEFAULT_LINE_LENGTH}",
+                f"--config={EMPTY_CONFIG}",
+            ],
             input=BytesIO(source.encode("utf8")),
         )
         self.assertEqual(result.exit_code, 0)
@@ -168,13 +181,12 @@ class BlackTestCase(BlackBaseTestCase):
         )
         source, _ = read_data("expression.py")
         expected, _ = read_data("expression.diff")
-        config = THIS_DIR / "data" / "empty_pyproject.toml"
         args = [
             "-",
             "--fast",
             f"--line-length={black.DEFAULT_LINE_LENGTH}",
             "--diff",
-            f"--config={config}",
+            f"--config={EMPTY_CONFIG}",
         ]
         result = BlackRunner().invoke(
             black.main, args, input=BytesIO(source.encode("utf8"))
@@ -186,14 +198,13 @@ class BlackTestCase(BlackBaseTestCase):
 
     def test_piping_diff_with_color(self) -> None:
         source, _ = read_data("expression.py")
-        config = THIS_DIR / "data" / "empty_pyproject.toml"
         args = [
             "-",
             "--fast",
             f"--line-length={black.DEFAULT_LINE_LENGTH}",
             "--diff",
             "--color",
-            f"--config={config}",
+            f"--config={EMPTY_CONFIG}",
         ]
         result = BlackRunner().invoke(
             black.main, args, input=BytesIO(source.encode("utf8"))
@@ -221,45 +232,6 @@ class BlackTestCase(BlackBaseTestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, black.FileMode())
 
-    @unittest.expectedFailure
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability1(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens1")
-        actual = fs(source)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @unittest.expectedFailure
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens2")
-        actual = fs(source)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @unittest.expectedFailure
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability3(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens3")
-        actual = fs(source)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability1_pass2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens1")
-        actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability2_pass2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens2")
-        actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_trailing_comma_optional_parens_stability3_pass2(self) -> None:
-        source, _expected = read_data("trailing_comma_optional_parens3")
-        actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
     def test_pep_572_version_detection(self) -> None:
         source, _ = read_data("pep_572")
         root = black.lib2to3_parse(source)
@@ -284,7 +256,6 @@ class BlackTestCase(BlackBaseTestCase):
 
     def test_expression_diff(self) -> None:
         source, _ = read_data("expression.py")
-        config = THIS_DIR / "data" / "empty_pyproject.toml"
         expected, _ = read_data("expression.diff")
         tmp_file = Path(black.dump_to_file(source))
         diff_header = re.compile(
@@ -293,7 +264,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         try:
             result = BlackRunner().invoke(
-                black.main, ["--diff", str(tmp_file), f"--config={config}"]
+                black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
             )
             self.assertEqual(result.exit_code, 0)
         finally:
@@ -311,12 +282,12 @@ class BlackTestCase(BlackBaseTestCase):
 
     def test_expression_diff_with_color(self) -> None:
         source, _ = read_data("expression.py")
-        config = THIS_DIR / "data" / "empty_pyproject.toml"
         expected, _ = read_data("expression.diff")
         tmp_file = Path(black.dump_to_file(source))
         try:
             result = BlackRunner().invoke(
-                black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"]
+                black.main,
+                ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"],
             )
         finally:
             os.unlink(tmp_file)
@@ -340,7 +311,7 @@ class BlackTestCase(BlackBaseTestCase):
     @patch("black.dump_to_file", dump_to_stderr)
     def test_string_quotes(self) -> None:
         source, expected = read_data("string_quotes")
-        mode = black.Mode(experimental_string_processing=True)
+        mode = black.Mode(preview=True)
         assert_format(source, expected, mode)
         mode = replace(mode, string_normalization=False)
         not_normalized = fs(source, mode=mode)
@@ -357,7 +328,9 @@ class BlackTestCase(BlackBaseTestCase):
             r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
         )
         try:
-            result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)])
+            result = BlackRunner().invoke(
+                black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"]
+            )
             self.assertEqual(result.exit_code, 0)
         finally:
             os.unlink(tmp_file)
@@ -724,24 +697,15 @@ class BlackTestCase(BlackBaseTestCase):
 
         straddling = "x + y"
         black.lib2to3_parse(straddling)
-        black.lib2to3_parse(straddling, {TargetVersion.PY27})
         black.lib2to3_parse(straddling, {TargetVersion.PY36})
-        black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
 
         py2_only = "print x"
-        black.lib2to3_parse(py2_only)
-        black.lib2to3_parse(py2_only, {TargetVersion.PY27})
         with self.assertRaises(black.InvalidInput):
             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
-        with self.assertRaises(black.InvalidInput):
-            black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
 
         py3_only = "exec(x, end=y)"
         black.lib2to3_parse(py3_only)
-        with self.assertRaises(black.InvalidInput):
-            black.lib2to3_parse(py3_only, {TargetVersion.PY27})
         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
-        black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
 
     def test_get_features_used_decorator(self) -> None:
         # Test the feature detection of new decorator syntax
@@ -810,6 +774,26 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
         node = black.lib2to3_parse("def fn(a, /, b): ...")
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
+        node = black.lib2to3_parse("def fn(): yield a, b")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("def fn(): return a, b")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("def fn(): yield *b, c")
+        self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
+        node = black.lib2to3_parse("def fn(): return a, *b, c")
+        self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
+        node = black.lib2to3_parse("x = a, *b, c")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = regular")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = (regular, regular)")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
+        self.assertEqual(
+            black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
+        )
 
     def test_get_features_used_for_future_flags(self) -> None:
         for src, features in [
@@ -927,8 +911,8 @@ class BlackTestCase(BlackBaseTestCase):
                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
 
         out_str = "".join(out_lines)
-        self.assertTrue("Expected tree:" in out_str)
-        self.assertTrue("Actual tree:" in out_str)
+        self.assertIn("Expected tree:", out_str)
+        self.assertIn("Actual tree:", out_str)
         self.assertEqual("".join(err_lines), "")
 
     @event_loop()
@@ -954,10 +938,13 @@ class BlackTestCase(BlackBaseTestCase):
             # Multi file command.
             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
 
-    def test_no_files(self) -> None:
+    def test_no_src_fails(self) -> None:
+        with cache_dir():
+            self.invokeBlack([], exit_code=1)
+
+    def test_src_and_code_fails(self) -> None:
         with cache_dir():
-            # Without an argument, black exits with error code 0.
-            self.invokeBlack([])
+            self.invokeBlack([".", "-c", "0"], exit_code=1)
 
     def test_broken_symlink(self) -> None:
         with cache_dir() as workspace:
@@ -1211,14 +1198,33 @@ class BlackTestCase(BlackBaseTestCase):
 
     def test_required_version_matches_version(self) -> None:
         self.invokeBlack(
-            ["--required-version", black.__version__], exit_code=0, ignore_config=True
+            ["--required-version", black.__version__, "-c", "0"],
+            exit_code=0,
+            ignore_config=True,
         )
 
-    def test_required_version_does_not_match_version(self) -> None:
+    def test_required_version_matches_partial_version(self) -> None:
         self.invokeBlack(
-            ["--required-version", "20.99b"], exit_code=1, ignore_config=True
+            ["--required-version", black.__version__.split(".")[0], "-c", "0"],
+            exit_code=0,
+            ignore_config=True,
         )
 
+    def test_required_version_does_not_match_on_minor_version(self) -> None:
+        self.invokeBlack(
+            ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"],
+            exit_code=1,
+            ignore_config=True,
+        )
+
+    def test_required_version_does_not_match_version(self) -> None:
+        result = BlackRunner().invoke(
+            black.main,
+            ["--required-version", "20.99b", "-c", "0"],
+        )
+        self.assertEqual(result.exit_code, 1)
+        self.assertIn("required version", result.stderr)
+
     def test_preserves_line_endings(self) -> None:
         with TemporaryDirectory() as workspace:
             test_file = Path(workspace) / "test.py"
@@ -1304,6 +1310,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(config["color"], True)
         self.assertEqual(config["line_length"], 79)
         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
+        self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
@@ -1339,10 +1346,17 @@ class BlackTestCase(BlackBaseTestCase):
             src_python.touch()
 
             self.assertEqual(
-                black.find_project_root((src_dir, test_dir)), root.resolve()
+                black.find_project_root((src_dir, test_dir)),
+                (root.resolve(), "pyproject.toml"),
+            )
+            self.assertEqual(
+                black.find_project_root((src_dir,)),
+                (src_dir.resolve(), "pyproject.toml"),
+            )
+            self.assertEqual(
+                black.find_project_root((src_python,)),
+                (src_dir.resolve(), "pyproject.toml"),
             )
-            self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
-            self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
 
     @patch(
         "black.files.find_user_pyproject_toml",
@@ -1416,27 +1430,6 @@ class BlackTestCase(BlackBaseTestCase):
         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
         self.assertEqual(actual, expected)
 
-    @pytest.mark.python2
-    def test_docstring_reformat_for_py27(self) -> None:
-        """
-        Check that stripping trailing whitespace from Python 2 docstrings
-        doesn't trigger a "not equivalent to source" error
-        """
-        source = (
-            b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
-        )
-        expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
-
-        result = BlackRunner().invoke(
-            black.main,
-            ["-", "-q", "--target-version=py27"],
-            input=BytesIO(source),
-        )
-
-        self.assertEqual(result.exit_code, 0)
-        actual = result.stdout
-        self.assertFormatEqual(actual, expected)
-
     @staticmethod
     def compare_results(
         result: click.testing.Result, expected_value: str, expected_exit_code: int
@@ -1596,6 +1589,33 @@ class BlackTestCase(BlackBaseTestCase):
 
 
 class TestCaching:
+    def test_get_cache_dir(
+        self,
+        tmp_path: Path,
+        monkeypatch: pytest.MonkeyPatch,
+    ) -> None:
+        # Create multiple cache directories
+        workspace1 = tmp_path / "ws1"
+        workspace1.mkdir()
+        workspace2 = tmp_path / "ws2"
+        workspace2.mkdir()
+
+        # Force user_cache_dir to use the temporary directory for easier assertions
+        patch_user_cache_dir = patch(
+            target="black.cache.user_cache_dir",
+            autospec=True,
+            return_value=str(workspace1),
+        )
+
+        # If BLACK_CACHE_DIR is not set, use user_cache_dir
+        monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
+        with patch_user_cache_dir:
+            assert get_cache_dir() == workspace1
+
+        # If it is set, use the path provided in the env var.
+        monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
+        assert get_cache_dir() == workspace2
+
     def test_cache_broken_file(self) -> None:
         mode = DEFAULT_MODE
         with cache_dir() as workspace:
@@ -1766,6 +1786,7 @@ def assert_collected_sources(
     src: Sequence[Union[str, Path]],
     expected: Sequence[Union[str, Path]],
     *,
+    ctx: Optional[FakeContext] = None,
     exclude: Optional[str] = None,
     include: Optional[str] = None,
     extend_exclude: Optional[str] = None,
@@ -1781,7 +1802,7 @@ def assert_collected_sources(
     )
     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
     collected = black.get_sources(
-        ctx=FakeContext(),
+        ctx=ctx or FakeContext(),
         src=gs_src,
         quiet=False,
         verbose=False,
@@ -1817,9 +1838,11 @@ class TestFileCollection:
             base / "b/.definitely_exclude/a.pyi",
         ]
         src = [base / "b/"]
-        assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
+        ctx = FakeContext()
+        ctx.obj["root"] = base
+        assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_exclude_for_issue_1572(self) -> None:
         # Exclude shouldn't touch files that were explicitly given to Black through the
         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
@@ -2002,13 +2025,13 @@ class TestFileCollection:
         child.is_symlink.assert_called()
         assert child.is_symlink.call_count == 2
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin(self) -> None:
         src = ["-"]
         expected = ["-"]
         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename(self) -> None:
         src = ["-"]
         stdin_filename = str(THIS_DIR / "data/collections.py")
@@ -2020,7 +2043,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
         # Exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
@@ -2036,7 +2059,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
@@ -2052,7 +2075,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
         # Force exclude should exclude the file when passing it through
         # stdin_filename
@@ -2066,36 +2089,6 @@ class TestFileCollection:
         )
 
 
-@pytest.mark.python2
-@pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
-def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
-    args = [
-        "--config",
-        str(THIS_DIR / "empty.toml"),
-        str(DATA_DIR / "python2.py"),
-        "--check",
-    ]
-    if explicit:
-        args.append("--target-version=py27")
-    with cache_dir():
-        result = BlackRunner().invoke(black.main, args)
-    assert "DEPRECATION: Python 2 support will be removed" in result.stderr
-
-
-@pytest.mark.python2
-def test_python_2_deprecation_autodetection_extended() -> None:
-    # this test has a similar construction to test_get_features_used_decorator
-    python2, non_python2 = read_data("python2_detection")
-    for python2_case in python2.split("###"):
-        node = black.lib2to3_parse(python2_case)
-        assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
-    for non_python2_case in non_python2.split("###"):
-        node = black.lib2to3_parse(non_python2_case)
-        assert black.detect_target_versions(node) != {
-            TargetVersion.PY27
-        }, non_python2_case
-
-
 try:
     with open(black.__file__, "r", encoding="utf-8") as _bf:
         black_source_lines = _bf.readlines()