]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Mark Felix and Batuhan as maintainers (#2794)
[etc/vim.git] / tests / test_black.py
index 2d0a7dfd4e233ebecc906115c637e15dfcc145ba..559690938a8bc36962b4fe143b0b8adca945dd09 100644 (file)
@@ -40,7 +40,7 @@ import black
 import black.files
 from black import Feature, TargetVersion
 from black import re_compile_maybe_verbose as compile_pattern
 import black.files
 from black import Feature, TargetVersion
 from black import re_compile_maybe_verbose as compile_pattern
-from black.cache import get_cache_file
+from black.cache import get_cache_dir, get_cache_file
 from black.debug import DebugVisitor
 from black.output import color_diff, diff
 from black.report import Report
 from black.debug import DebugVisitor
 from black.output import color_diff, diff
 from black.report import Report
@@ -100,6 +100,8 @@ class FakeContext(click.Context):
 
     def __init__(self) -> None:
         self.default_map: Dict[str, Any] = {}
 
     def __init__(self) -> None:
         self.default_map: Dict[str, Any] = {}
+        # Dummy root, since most of the tests don't care about it
+        self.obj: Dict[str, Any] = {"root": PROJECT_ROOT}
 
 
 class FakeParameter(click.Parameter):
 
 
 class FakeParameter(click.Parameter):
@@ -148,6 +150,11 @@ class BlackTestCase(BlackBaseTestCase):
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
+    def test_experimental_string_processing_warns(self) -> None:
+        self.assertWarns(
+            black.mode.Deprecated, black.Mode, experimental_string_processing=True
+        )
+
     def test_piping(self) -> None:
         source, expected = read_data("src/black/__init__", data=False)
         result = BlackRunner().invoke(
     def test_piping(self) -> None:
         source, expected = read_data("src/black/__init__", data=False)
         result = BlackRunner().invoke(
@@ -200,7 +207,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         actual = result.output
         # Again, the contents are checked in a different test, so only look for colors.
         )
         actual = result.output
         # Again, the contents are checked in a different test, so only look for colors.
-        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[1m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
@@ -323,7 +330,7 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         # We check the contents of the diff in `test_expression_diff`. All
         # we need to check here is that color codes exist in the result.
         actual = result.output
         # We check the contents of the diff in `test_expression_diff`. All
         # we need to check here is that color codes exist in the result.
-        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[1m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
         self.assertIn("\033[36m", actual)
         self.assertIn("\033[32m", actual)
         self.assertIn("\033[31m", actual)
@@ -340,7 +347,7 @@ class BlackTestCase(BlackBaseTestCase):
     @patch("black.dump_to_file", dump_to_stderr)
     def test_string_quotes(self) -> None:
         source, expected = read_data("string_quotes")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_string_quotes(self) -> None:
         source, expected = read_data("string_quotes")
-        mode = black.Mode(experimental_string_processing=True)
+        mode = black.Mode(preview=True)
         assert_format(source, expected, mode)
         mode = replace(mode, string_normalization=False)
         not_normalized = fs(source, mode=mode)
         assert_format(source, expected, mode)
         mode = replace(mode, string_normalization=False)
         not_normalized = fs(source, mode=mode)
@@ -724,24 +731,15 @@ class BlackTestCase(BlackBaseTestCase):
 
         straddling = "x + y"
         black.lib2to3_parse(straddling)
 
         straddling = "x + y"
         black.lib2to3_parse(straddling)
-        black.lib2to3_parse(straddling, {TargetVersion.PY27})
         black.lib2to3_parse(straddling, {TargetVersion.PY36})
         black.lib2to3_parse(straddling, {TargetVersion.PY36})
-        black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
 
         py2_only = "print x"
 
         py2_only = "print x"
-        black.lib2to3_parse(py2_only)
-        black.lib2to3_parse(py2_only, {TargetVersion.PY27})
         with self.assertRaises(black.InvalidInput):
             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
         with self.assertRaises(black.InvalidInput):
             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
-        with self.assertRaises(black.InvalidInput):
-            black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
 
         py3_only = "exec(x, end=y)"
         black.lib2to3_parse(py3_only)
 
         py3_only = "exec(x, end=y)"
         black.lib2to3_parse(py3_only)
-        with self.assertRaises(black.InvalidInput):
-            black.lib2to3_parse(py3_only, {TargetVersion.PY27})
         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
-        black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
 
     def test_get_features_used_decorator(self) -> None:
         # Test the feature detection of new decorator syntax
 
     def test_get_features_used_decorator(self) -> None:
         # Test the feature detection of new decorator syntax
@@ -810,6 +808,26 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
         node = black.lib2to3_parse("def fn(a, /, b): ...")
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
         node = black.lib2to3_parse("def fn(a, /, b): ...")
         self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS})
+        node = black.lib2to3_parse("def fn(): yield a, b")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("def fn(): return a, b")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("def fn(): yield *b, c")
+        self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
+        node = black.lib2to3_parse("def fn(): return a, *b, c")
+        self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW})
+        node = black.lib2to3_parse("x = a, *b, c")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = regular")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = (regular, regular)")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]")
+        self.assertEqual(black.get_features_used(node), set())
+        node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c")
+        self.assertEqual(
+            black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS}
+        )
 
     def test_get_features_used_for_future_flags(self) -> None:
         for src, features in [
 
     def test_get_features_used_for_future_flags(self) -> None:
         for src, features in [
@@ -927,8 +945,8 @@ class BlackTestCase(BlackBaseTestCase):
                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
 
         out_str = "".join(out_lines)
                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
 
         out_str = "".join(out_lines)
-        self.assertTrue("Expected tree:" in out_str)
-        self.assertTrue("Actual tree:" in out_str)
+        self.assertIn("Expected tree:", out_str)
+        self.assertIn("Actual tree:", out_str)
         self.assertEqual("".join(err_lines), "")
 
     @event_loop()
         self.assertEqual("".join(err_lines), "")
 
     @event_loop()
@@ -1304,6 +1322,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertEqual(config["color"], True)
         self.assertEqual(config["line_length"], 79)
         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
         self.assertEqual(config["color"], True)
         self.assertEqual(config["line_length"], 79)
         self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
+        self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"])
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
         self.assertEqual(config["exclude"], r"\.pyi?$")
         self.assertEqual(config["include"], r"\.py?$")
 
@@ -1339,10 +1358,17 @@ class BlackTestCase(BlackBaseTestCase):
             src_python.touch()
 
             self.assertEqual(
             src_python.touch()
 
             self.assertEqual(
-                black.find_project_root((src_dir, test_dir)), root.resolve()
+                black.find_project_root((src_dir, test_dir)),
+                (root.resolve(), "pyproject.toml"),
+            )
+            self.assertEqual(
+                black.find_project_root((src_dir,)),
+                (src_dir.resolve(), "pyproject.toml"),
+            )
+            self.assertEqual(
+                black.find_project_root((src_python,)),
+                (src_dir.resolve(), "pyproject.toml"),
             )
             )
-            self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
-            self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
 
     @patch(
         "black.files.find_user_pyproject_toml",
 
     @patch(
         "black.files.find_user_pyproject_toml",
@@ -1416,27 +1442,6 @@ class BlackTestCase(BlackBaseTestCase):
         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
         self.assertEqual(actual, expected)
 
         actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
         self.assertEqual(actual, expected)
 
-    @pytest.mark.python2
-    def test_docstring_reformat_for_py27(self) -> None:
-        """
-        Check that stripping trailing whitespace from Python 2 docstrings
-        doesn't trigger a "not equivalent to source" error
-        """
-        source = (
-            b'def foo():\r\n    """Testing\r\n    Testing """\r\n    print "Foo"\r\n'
-        )
-        expected = 'def foo():\n    """Testing\n    Testing"""\n    print "Foo"\n'
-
-        result = BlackRunner().invoke(
-            black.main,
-            ["-", "-q", "--target-version=py27"],
-            input=BytesIO(source),
-        )
-
-        self.assertEqual(result.exit_code, 0)
-        actual = result.stdout
-        self.assertFormatEqual(actual, expected)
-
     @staticmethod
     def compare_results(
         result: click.testing.Result, expected_value: str, expected_exit_code: int
     @staticmethod
     def compare_results(
         result: click.testing.Result, expected_value: str, expected_exit_code: int
@@ -1584,8 +1589,45 @@ class BlackTestCase(BlackBaseTestCase):
 
         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
 
 
         exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")
 
+    def test_equivalency_ast_parse_failure_includes_error(self) -> None:
+        with pytest.raises(AssertionError) as err:
+            black.assert_equivalent("a«»a  = 1", "a«»a  = 1")
+
+        err.match("--safe")
+        # Unfortunately the SyntaxError message has changed in newer versions so we
+        # can't match it directly.
+        err.match("invalid character")
+        err.match(r"\(<unknown>, line 1\)")
+
 
 class TestCaching:
 
 class TestCaching:
+    def test_get_cache_dir(
+        self,
+        tmp_path: Path,
+        monkeypatch: pytest.MonkeyPatch,
+    ) -> None:
+        # Create multiple cache directories
+        workspace1 = tmp_path / "ws1"
+        workspace1.mkdir()
+        workspace2 = tmp_path / "ws2"
+        workspace2.mkdir()
+
+        # Force user_cache_dir to use the temporary directory for easier assertions
+        patch_user_cache_dir = patch(
+            target="black.cache.user_cache_dir",
+            autospec=True,
+            return_value=str(workspace1),
+        )
+
+        # If BLACK_CACHE_DIR is not set, use user_cache_dir
+        monkeypatch.delenv("BLACK_CACHE_DIR", raising=False)
+        with patch_user_cache_dir:
+            assert get_cache_dir() == workspace1
+
+        # If it is set, use the path provided in the env var.
+        monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
+        assert get_cache_dir() == workspace2
+
     def test_cache_broken_file(self) -> None:
         mode = DEFAULT_MODE
         with cache_dir() as workspace:
     def test_cache_broken_file(self) -> None:
         mode = DEFAULT_MODE
         with cache_dir() as workspace:
@@ -1756,6 +1798,7 @@ def assert_collected_sources(
     src: Sequence[Union[str, Path]],
     expected: Sequence[Union[str, Path]],
     *,
     src: Sequence[Union[str, Path]],
     expected: Sequence[Union[str, Path]],
     *,
+    ctx: Optional[FakeContext] = None,
     exclude: Optional[str] = None,
     include: Optional[str] = None,
     extend_exclude: Optional[str] = None,
     exclude: Optional[str] = None,
     include: Optional[str] = None,
     extend_exclude: Optional[str] = None,
@@ -1771,7 +1814,7 @@ def assert_collected_sources(
     )
     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
     collected = black.get_sources(
     )
     gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
     collected = black.get_sources(
-        ctx=FakeContext(),
+        ctx=ctx or FakeContext(),
         src=gs_src,
         quiet=False,
         verbose=False,
         src=gs_src,
         quiet=False,
         verbose=False,
@@ -1807,9 +1850,11 @@ class TestFileCollection:
             base / "b/.definitely_exclude/a.pyi",
         ]
         src = [base / "b/"]
             base / "b/.definitely_exclude/a.pyi",
         ]
         src = [base / "b/"]
-        assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
+        ctx = FakeContext()
+        ctx.obj["root"] = base
+        assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
 
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_exclude_for_issue_1572(self) -> None:
         # Exclude shouldn't touch files that were explicitly given to Black through the
         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
     def test_exclude_for_issue_1572(self) -> None:
         # Exclude shouldn't touch files that were explicitly given to Black through the
         # CLI. Exclude is supposed to only apply to the recursive discovery of files.
@@ -1992,13 +2037,13 @@ class TestFileCollection:
         child.is_symlink.assert_called()
         assert child.is_symlink.call_count == 2
 
         child.is_symlink.assert_called()
         assert child.is_symlink.call_count == 2
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin(self) -> None:
         src = ["-"]
         expected = ["-"]
         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
 
     def test_get_sources_with_stdin(self) -> None:
         src = ["-"]
         expected = ["-"]
         assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename(self) -> None:
         src = ["-"]
         stdin_filename = str(THIS_DIR / "data/collections.py")
     def test_get_sources_with_stdin_filename(self) -> None:
         src = ["-"]
         stdin_filename = str(THIS_DIR / "data/collections.py")
@@ -2010,7 +2055,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
         # Exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
         # Exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
@@ -2026,7 +2071,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
     def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
         # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
@@ -2042,7 +2087,7 @@ class TestFileCollection:
             stdin_filename=stdin_filename,
         )
 
             stdin_filename=stdin_filename,
         )
 
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
         # Force exclude should exclude the file when passing it through
         # stdin_filename
     def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
         # Force exclude should exclude the file when passing it through
         # stdin_filename
@@ -2056,36 +2101,6 @@ class TestFileCollection:
         )
 
 
         )
 
 
-@pytest.mark.python2
-@pytest.mark.parametrize("explicit", [True, False], ids=["explicit", "autodetection"])
-def test_python_2_deprecation_with_target_version(explicit: bool) -> None:
-    args = [
-        "--config",
-        str(THIS_DIR / "empty.toml"),
-        str(DATA_DIR / "python2.py"),
-        "--check",
-    ]
-    if explicit:
-        args.append("--target-version=py27")
-    with cache_dir():
-        result = BlackRunner().invoke(black.main, args)
-    assert "DEPRECATION: Python 2 support will be removed" in result.stderr
-
-
-@pytest.mark.python2
-def test_python_2_deprecation_autodetection_extended() -> None:
-    # this test has a similar construction to test_get_features_used_decorator
-    python2, non_python2 = read_data("python2_detection")
-    for python2_case in python2.split("###"):
-        node = black.lib2to3_parse(python2_case)
-        assert black.detect_target_versions(node) == {TargetVersion.PY27}, python2_case
-    for non_python2_case in non_python2.split("###"):
-        node = black.lib2to3_parse(non_python2_case)
-        assert black.detect_target_versions(node) != {
-            TargetVersion.PY27
-        }, non_python2_case
-
-
 try:
     with open(black.__file__, "r", encoding="utf-8") as _bf:
         black_source_lines = _bf.readlines()
 try:
     with open(black.__file__, "r", encoding="utf-8") as _bf:
         black_source_lines = _bf.readlines()