]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Switch to pytest and tox (#1763)
[etc/vim.git] / tests / test_black.py
index 6705490ea13624aab062e12af969346f24eb4f2a..7927b368b5d0fc7d7141875db1933175b261b240 100644 (file)
@@ -1,4 +1,5 @@
 #!/usr/bin/env python3
+import multiprocessing
 import asyncio
 import logging
 from concurrent.futures import ThreadPoolExecutor
@@ -9,6 +10,7 @@ import inspect
 from io import BytesIO, TextIOWrapper
 import os
 from pathlib import Path
+from platform import system
 import regex as re
 import sys
 from tempfile import TemporaryDirectory
@@ -20,7 +22,6 @@ from typing import (
     Dict,
     Generator,
     List,
-    Tuple,
     Iterator,
     TypeVar,
 )
@@ -34,18 +35,10 @@ from click.testing import CliRunner
 import black
 from black import Feature, TargetVersion
 
-try:
-    import blackd
-    from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
-    from aiohttp import web
-except ImportError:
-    has_blackd_deps = False
-else:
-    has_blackd_deps = True
-
 from pathspec import PathSpec
 
 # Import other test classes
+from tests.util import THIS_DIR, read_data, DETERMINISTIC_HEADER
 from .test_primer import PrimerCLITests  # noqa: F401
 
 
@@ -53,13 +46,13 @@ DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
 fs = partial(black.format_str, mode=DEFAULT_MODE)
 THIS_FILE = Path(__file__)
-THIS_DIR = THIS_FILE.parent
-PROJECT_ROOT = THIS_DIR.parent
-DETERMINISTIC_HEADER = "[Deterministic header]"
-EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
-PY36_ARGS = [
-    f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
-]
+PY36_VERSIONS = {
+    TargetVersion.PY36,
+    TargetVersion.PY37,
+    TargetVersion.PY38,
+    TargetVersion.PY39,
+}
+PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
 T = TypeVar("T")
 R = TypeVar("R")
 
@@ -68,29 +61,6 @@ def dump_to_stderr(*output: str) -> str:
     return "\n" + "\n".join(output) + "\n"
 
 
-def read_data(name: str, data: bool = True) -> Tuple[str, str]:
-    """read_data('test_name') -> 'input', 'output'"""
-    if not name.endswith((".py", ".pyi", ".out", ".diff")):
-        name += ".py"
-    _input: List[str] = []
-    _output: List[str] = []
-    base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
-    with open(base_dir / name, "r", encoding="utf8") as test:
-        lines = test.readlines()
-    result = _input
-    for line in lines:
-        line = line.replace(EMPTY_LINE, "")
-        if line.rstrip() == "# output":
-            result = _output
-            continue
-
-        result.append(line)
-    if _input and not _output:
-        # If there's no output marker, treat the entire file as already pre-formatted.
-        _output = _input[:]
-    return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
-
-
 @contextmanager
 def cache_dir(exists: bool = True) -> Iterator[Path]:
     with TemporaryDirectory() as workspace:
@@ -113,17 +83,6 @@ def event_loop() -> Iterator[None]:
         loop.close()
 
 
-@contextmanager
-def skip_if_exception(e: str) -> Iterator[None]:
-    try:
-        yield
-    except Exception as exc:
-        if exc.__class__.__name__ == e:
-            unittest.skip(f"Encountered expected exception {exc}, skipping")
-        else:
-            raise
-
-
 class FakeContext(click.Context):
     """A fake click Context for when calling functions that need it."""
 
@@ -233,9 +192,12 @@ class BlackTestCase(unittest.TestCase):
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
-    def test_self(self) -> None:
+    def test_run_on_test_black(self) -> None:
         self.checkSourceFile("tests/test_black.py")
 
+    def test_run_on_test_blackd(self) -> None:
+        self.checkSourceFile("tests/test_blackd.py")
+
     def test_black(self) -> None:
         self.checkSourceFile("src/black/__init__.py")
 
@@ -368,6 +330,27 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
+    @unittest.expectedFailure
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_trailing_comma_optional_parens_stability1(self) -> None:
+        source, _expected = read_data("trailing_comma_optional_parens1")
+        actual = fs(source)
+        black.assert_stable(source, actual, DEFAULT_MODE)
+
+    @unittest.expectedFailure
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_trailing_comma_optional_parens_stability2(self) -> None:
+        source, _expected = read_data("trailing_comma_optional_parens2")
+        actual = fs(source)
+        black.assert_stable(source, actual, DEFAULT_MODE)
+
+    @unittest.expectedFailure
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_trailing_comma_optional_parens_stability3(self) -> None:
+        source, _expected = read_data("trailing_comma_optional_parens3")
+        actual = fs(source)
+        black.assert_stable(source, actual, DEFAULT_MODE)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_expression(self) -> None:
         source, expected = read_data("expression")
@@ -497,6 +480,16 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_docstring_no_string_normalization(self) -> None:
+        """Like test_docstring but with string normalization off."""
+        source, expected = read_data("docstring_no_string_normalization")
+        mode = replace(DEFAULT_MODE, string_normalization=False)
+        actual = fs(source, mode=mode)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, mode)
+
     def test_long_strings(self) -> None:
         """Tests for splitting long strings."""
         source, expected = read_data("long_strings")
@@ -672,7 +665,7 @@ class BlackTestCase(unittest.TestCase):
     @patch("black.dump_to_file", dump_to_stderr)
     def test_numeric_literals(self) -> None:
         source, expected = read_data("numeric_literals")
-        mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+        mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
         actual = fs(source, mode=mode)
         self.assertFormatEqual(expected, actual)
         black.assert_equivalent(source, actual)
@@ -681,7 +674,7 @@ class BlackTestCase(unittest.TestCase):
     @patch("black.dump_to_file", dump_to_stderr)
     def test_numeric_literals_ignoring_underscores(self) -> None:
         source, expected = read_data("numeric_literals_skip_underscores")
-        mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+        mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
         actual = fs(source, mode=mode)
         self.assertFormatEqual(expected, actual)
         black.assert_equivalent(source, actual)
@@ -767,6 +760,16 @@ class BlackTestCase(unittest.TestCase):
             black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_python39(self) -> None:
+        source, expected = read_data("python39")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        major, minor = sys.version_info[:2]
+        if major > 3 or (major == 3 and minor >= 9):
+            black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, DEFAULT_MODE)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fmtonoff(self) -> None:
         source, expected = read_data("fmtonoff")
@@ -1179,6 +1182,39 @@ class BlackTestCase(unittest.TestCase):
         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
 
+    def test_get_features_used_decorator(self) -> None:
+        # Test the feature detection of new decorator syntax
+        # since this makes some test cases of test_get_features_used()
+        # fails if it fails, this is tested first so that a useful case
+        # is identified
+        simples, relaxed = read_data("decorators")
+        # skip explanation comments at the top of the file
+        for simple_test in simples.split("##")[1:]:
+            node = black.lib2to3_parse(simple_test)
+            decorator = str(node.children[0].children[0]).strip()
+            self.assertNotIn(
+                Feature.RELAXED_DECORATORS,
+                black.get_features_used(node),
+                msg=(
+                    f"decorator '{decorator}' follows python<=3.8 syntax"
+                    "but is detected as 3.9+"
+                    # f"The full node is\n{node!r}"
+                ),
+            )
+        # skip the '# output' comment at the top of the output part
+        for relaxed_test in relaxed.split("##")[1:]:
+            node = black.lib2to3_parse(relaxed_test)
+            decorator = str(node.children[0].children[0]).strip()
+            self.assertIn(
+                Feature.RELAXED_DECORATORS,
+                black.get_features_used(node),
+                msg=(
+                    f"decorator '{decorator}' uses python3.9+ syntax"
+                    "but is detected as python<=3.8"
+                    # f"The full node is\n{node!r}"
+                ),
+            )
+
     def test_get_features_used(self) -> None:
         node = black.lib2to3_parse("def f(*, arg): ...\n")
         self.assertEqual(black.get_features_used(node), set())
@@ -1363,9 +1399,55 @@ class BlackTestCase(unittest.TestCase):
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
-            self.invokeBlack([str(src), "--diff"])
-            cache_file = black.get_cache_file(mode)
-            self.assertFalse(cache_file.exists())
+            with patch("black.read_cache") as read_cache, patch(
+                "black.write_cache"
+            ) as write_cache:
+                self.invokeBlack([str(src), "--diff"])
+                cache_file = black.get_cache_file(mode)
+                self.assertFalse(cache_file.exists())
+                write_cache.assert_not_called()
+                read_cache.assert_not_called()
+
+    def test_no_cache_when_writeback_color_diff(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            with src.open("w") as fobj:
+                fobj.write("print('hello')")
+            with patch("black.read_cache") as read_cache, patch(
+                "black.write_cache"
+            ) as write_cache:
+                self.invokeBlack([str(src), "--diff", "--color"])
+                cache_file = black.get_cache_file(mode)
+                self.assertFalse(cache_file.exists())
+                write_cache.assert_not_called()
+                read_cache.assert_not_called()
+
+    @event_loop()
+    def test_output_locking_when_writeback_diff(self) -> None:
+        with cache_dir() as workspace:
+            for tag in range(0, 4):
+                src = (workspace / f"test{tag}.py").resolve()
+                with src.open("w") as fobj:
+                    fobj.write("print('hello')")
+            with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
+                self.invokeBlack(["--diff", str(workspace)], exit_code=0)
+                # this isn't quite doing what we want, but if it _isn't_
+                # called then we cannot be using the lock it provides
+                mgr.assert_called()
+
+    @event_loop()
+    def test_output_locking_when_writeback_color_diff(self) -> None:
+        with cache_dir() as workspace:
+            for tag in range(0, 4):
+                src = (workspace / f"test{tag}.py").resolve()
+                with src.open("w") as fobj:
+                    fobj.write("print('hello')")
+            with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
+                self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
+                # this isn't quite doing what we want, but if it _isn't_
+                # called then we cannot be using the lock it provides
+                mgr.assert_called()
 
     def test_no_cache_when_stdin(self) -> None:
         mode = DEFAULT_MODE
@@ -1495,7 +1577,6 @@ class BlackTestCase(unittest.TestCase):
         black.assert_stable(source, actual, DEFAULT_MODE)
 
     def test_single_file_force_pyi(self) -> None:
-        reg_mode = DEFAULT_MODE
         pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
         contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
@@ -1508,9 +1589,11 @@ class BlackTestCase(unittest.TestCase):
             # verify cache with --pyi is separate
             pyi_cache = black.read_cache(pyi_mode)
             self.assertIn(path, pyi_cache)
-            normal_cache = black.read_cache(reg_mode)
+            normal_cache = black.read_cache(DEFAULT_MODE)
             self.assertNotIn(path, normal_cache)
-        self.assertEqual(actual, expected)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(contents, actual)
+        black.assert_stable(contents, actual, pyi_mode)
 
     @event_loop()
     def test_multi_file_force_pyi(self) -> None:
@@ -1548,7 +1631,7 @@ class BlackTestCase(unittest.TestCase):
 
     def test_single_file_force_py36(self) -> None:
         reg_mode = DEFAULT_MODE
-        py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+        py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
         source, expected = read_data("force_py36")
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
@@ -1567,7 +1650,7 @@ class BlackTestCase(unittest.TestCase):
     @event_loop()
     def test_multi_file_force_py36(self) -> None:
         reg_mode = DEFAULT_MODE
-        py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+        py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
         source, expected = read_data("force_py36")
         with cache_dir() as workspace:
             paths = [
@@ -1842,14 +1925,6 @@ class BlackTestCase(unittest.TestCase):
         ):
             ff(THIS_FILE)
 
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    def test_blackd_main(self) -> None:
-        with patch("blackd.web.run_app"):
-            result = CliRunner().invoke(blackd.main, [])
-            if result.exception is not None:
-                raise result.exception
-            self.assertEqual(result.exit_code, 0)
-
     def test_invalid_config_return_code(self) -> None:
         tmp_file = Path(black.dump_to_file())
         try:
@@ -1908,173 +1983,22 @@ class BlackTestCase(unittest.TestCase):
             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
 
+    def test_bpo_33660_workaround(self) -> None:
+        if system() == "Windows":
+            return
 
-class BlackDTestCase(AioHTTPTestCase):
-    async def get_application(self) -> web.Application:
-        return blackd.make_app()
-
-    # TODO: remove these decorators once the below is released
-    # https://github.com/aio-libs/aiohttp/pull/3727
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_request_needs_formatting(self) -> None:
-        response = await self.client.post("/", data=b"print('hello world')")
-        self.assertEqual(response.status, 200)
-        self.assertEqual(response.charset, "utf8")
-        self.assertEqual(await response.read(), b'print("hello world")\n')
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_request_no_change(self) -> None:
-        response = await self.client.post("/", data=b'print("hello world")\n')
-        self.assertEqual(response.status, 204)
-        self.assertEqual(await response.read(), b"")
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_request_syntax_error(self) -> None:
-        response = await self.client.post("/", data=b"what even ( is")
-        self.assertEqual(response.status, 400)
-        content = await response.text()
-        self.assertTrue(
-            content.startswith("Cannot parse"),
-            msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
-        )
+        # https://bugs.python.org/issue33660
 
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_unsupported_version(self) -> None:
-        response = await self.client.post(
-            "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
-        )
-        self.assertEqual(response.status, 501)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_supported_version(self) -> None:
-        response = await self.client.post(
-            "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
-        )
-        self.assertEqual(response.status, 200)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_invalid_python_variant(self) -> None:
-        async def check(header_value: str, expected_status: int = 400) -> None:
-            response = await self.client.post(
-                "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value}
-            )
-            self.assertEqual(response.status, expected_status)
-
-        await check("lol")
-        await check("ruby3.5")
-        await check("pyi3.6")
-        await check("py1.5")
-        await check("2.8")
-        await check("py2.8")
-        await check("3.0")
-        await check("pypy3.0")
-        await check("jython3.4")
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_pyi(self) -> None:
-        source, expected = read_data("stub.pyi")
-        response = await self.client.post(
-            "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
-        )
-        self.assertEqual(response.status, 200)
-        self.assertEqual(await response.text(), expected)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_diff(self) -> None:
-        diff_header = re.compile(
-            r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
-        )
-
-        source, _ = read_data("blackd_diff.py")
-        expected, _ = read_data("blackd_diff.diff")
-
-        response = await self.client.post(
-            "/", data=source, headers={blackd.DIFF_HEADER: "true"}
-        )
-        self.assertEqual(response.status, 200)
-
-        actual = await response.text()
-        actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
-        self.assertEqual(actual, expected)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_python_variant(self) -> None:
-        code = (
-            "def f(\n"
-            "    and_has_a_bunch_of,\n"
-            "    very_long_arguments_too,\n"
-            "    and_lots_of_them_as_well_lol,\n"
-            "    **and_very_long_keyword_arguments\n"
-            "):\n"
-            "    pass\n"
-        )
-
-        async def check(header_value: str, expected_status: int) -> None:
-            response = await self.client.post(
-                "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
-            )
-            self.assertEqual(
-                response.status, expected_status, msg=await response.text()
-            )
-
-        await check("3.6", 200)
-        await check("py3.6", 200)
-        await check("3.6,3.7", 200)
-        await check("3.6,py3.7", 200)
-        await check("py36,py37", 200)
-        await check("36", 200)
-        await check("3.6.4", 200)
-
-        await check("2", 204)
-        await check("2.7", 204)
-        await check("py2.7", 204)
-        await check("3.4", 204)
-        await check("py3.4", 204)
-        await check("py34,py36", 204)
-        await check("34", 204)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_line_length(self) -> None:
-        response = await self.client.post(
-            "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
-        )
-        self.assertEqual(response.status, 200)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_invalid_line_length(self) -> None:
-        response = await self.client.post(
-            "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"}
-        )
-        self.assertEqual(response.status, 400)
-
-    @skip_if_exception("ClientOSError")
-    @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
-    @unittest_run_loop
-    async def test_blackd_response_black_version_header(self) -> None:
-        response = await self.client.post("/")
-        self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
+        old_cwd = Path.cwd()
+        try:
+            root = Path("/")
+            os.chdir(str(root))
+            path = Path("workspace") / "project"
+            report = black.Report(verbose=True)
+            normalized_path = black.normalize_path_maybe_ignore(path, root, report)
+            self.assertEqual(normalized_path, "workspace/project")
+        finally:
+            os.chdir(str(old_cwd))
 
 
 with open(black.__file__, "r", encoding="utf-8") as _bf: