]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

docs: Improve pre-commit use (#1551)
[etc/vim.git] / tests / test_black.py
index fdd19947bcb7970b04ea0ab6349189f5f860ff55..3ed5daa4b494afb5b8a0918d8f60e7426f623201 100644 (file)
@@ -10,10 +10,11 @@ from pathlib import Path
 import regex as re
 import sys
 from tempfile import TemporaryDirectory
 import regex as re
 import sys
 from tempfile import TemporaryDirectory
-from typing import Any, BinaryIO, Generator, List, Tuple, Iterator, TypeVar
+from typing import Any, BinaryIO, Dict, Generator, List, Tuple, Iterator, TypeVar
 import unittest
 from unittest.mock import patch, MagicMock
 
 import unittest
 from unittest.mock import patch, MagicMock
 
+import click
 from click import unstyle
 from click.testing import CliRunner
 
 from click import unstyle
 from click.testing import CliRunner
 
@@ -31,10 +32,15 @@ else:
 
 from pathspec import PathSpec
 
 
 from pathspec import PathSpec
 
+# Import other test classes
+from .test_primer import PrimerCLITests  # noqa: F401
+
+
 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
 fs = partial(black.format_str, mode=black.FileMode())
 THIS_FILE = Path(__file__)
 THIS_DIR = THIS_FILE.parent
 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
 fs = partial(black.format_str, mode=black.FileMode())
 THIS_FILE = Path(__file__)
 THIS_DIR = THIS_FILE.parent
+PROJECT_ROOT = THIS_DIR.parent
 DETERMINISTIC_HEADER = "[Deterministic header]"
 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
 PY36_ARGS = [
 DETERMINISTIC_HEADER = "[Deterministic header]"
 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
 PY36_ARGS = [
@@ -54,7 +60,7 @@ def read_data(name: str, data: bool = True) -> Tuple[str, str]:
         name += ".py"
     _input: List[str] = []
     _output: List[str] = []
         name += ".py"
     _input: List[str] = []
     _output: List[str] = []
-    base_dir = THIS_DIR / "data" if data else THIS_DIR
+    base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
     with open(base_dir / name, "r", encoding="utf8") as test:
         lines = test.readlines()
     result = _input
     with open(base_dir / name, "r", encoding="utf8") as test:
         lines = test.readlines()
     result = _input
@@ -82,7 +88,7 @@ def cache_dir(exists: bool = True) -> Iterator[Path]:
 
 
 @contextmanager
 
 
 @contextmanager
-def event_loop(close: bool) -> Iterator[None]:
+def event_loop() -> Iterator[None]:
     policy = asyncio.get_event_loop_policy()
     loop = policy.new_event_loop()
     asyncio.set_event_loop(loop)
     policy = asyncio.get_event_loop_policy()
     loop = policy.new_event_loop()
     asyncio.set_event_loop(loop)
@@ -90,8 +96,7 @@ def event_loop(close: bool) -> Iterator[None]:
         yield
 
     finally:
         yield
 
     finally:
-        if close:
-            loop.close()
+        loop.close()
 
 
 @contextmanager
 
 
 @contextmanager
@@ -157,9 +162,18 @@ class BlackTestCase(unittest.TestCase):
     ) -> None:
         runner = BlackRunner()
         if ignore_config:
     ) -> None:
         runner = BlackRunner()
         if ignore_config:
-            args = ["--config", str(THIS_DIR / "empty.toml"), *args]
+            args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
         result = runner.invoke(black.main, args)
         result = runner.invoke(black.main, args)
-        self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
+        self.assertEqual(
+            result.exit_code,
+            exit_code,
+            msg=(
+                f"Failed with args: {args}\n"
+                f"stdout: {runner.stdout_bytes.decode()!r}\n"
+                f"stderr: {runner.stderr_bytes.decode()!r}\n"
+                f"exception: {result.exception}"
+            ),
+        )
 
     @patch("black.dump_to_file", dump_to_stderr)
     def checkSourceFile(self, name: str) -> None:
 
     @patch("black.dump_to_file", dump_to_stderr)
     def checkSourceFile(self, name: str) -> None:
@@ -194,43 +208,43 @@ class BlackTestCase(unittest.TestCase):
         self.checkSourceFile("tests/test_black.py")
 
     def test_black(self) -> None:
         self.checkSourceFile("tests/test_black.py")
 
     def test_black(self) -> None:
-        self.checkSourceFile("black.py")
+        self.checkSourceFile("src/black/__init__.py")
 
     def test_pygram(self) -> None:
 
     def test_pygram(self) -> None:
-        self.checkSourceFile("blib2to3/pygram.py")
+        self.checkSourceFile("src/blib2to3/pygram.py")
 
     def test_pytree(self) -> None:
 
     def test_pytree(self) -> None:
-        self.checkSourceFile("blib2to3/pytree.py")
+        self.checkSourceFile("src/blib2to3/pytree.py")
 
     def test_conv(self) -> None:
 
     def test_conv(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/conv.py")
+        self.checkSourceFile("src/blib2to3/pgen2/conv.py")
 
     def test_driver(self) -> None:
 
     def test_driver(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/driver.py")
+        self.checkSourceFile("src/blib2to3/pgen2/driver.py")
 
     def test_grammar(self) -> None:
 
     def test_grammar(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/grammar.py")
+        self.checkSourceFile("src/blib2to3/pgen2/grammar.py")
 
     def test_literals(self) -> None:
 
     def test_literals(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/literals.py")
+        self.checkSourceFile("src/blib2to3/pgen2/literals.py")
 
     def test_parse(self) -> None:
 
     def test_parse(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/parse.py")
+        self.checkSourceFile("src/blib2to3/pgen2/parse.py")
 
     def test_pgen(self) -> None:
 
     def test_pgen(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/pgen.py")
+        self.checkSourceFile("src/blib2to3/pgen2/pgen.py")
 
     def test_tokenize(self) -> None:
 
     def test_tokenize(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/tokenize.py")
+        self.checkSourceFile("src/blib2to3/pgen2/tokenize.py")
 
     def test_token(self) -> None:
 
     def test_token(self) -> None:
-        self.checkSourceFile("blib2to3/pgen2/token.py")
+        self.checkSourceFile("src/blib2to3/pgen2/token.py")
 
     def test_setup(self) -> None:
         self.checkSourceFile("setup.py")
 
     def test_piping(self) -> None:
 
     def test_setup(self) -> None:
         self.checkSourceFile("setup.py")
 
     def test_piping(self) -> None:
-        source, expected = read_data("../black", data=False)
+        source, expected = read_data("src/black/__init__", data=False)
         result = BlackRunner().invoke(
             black.main,
             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
         result = BlackRunner().invoke(
             black.main,
             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
@@ -243,8 +257,8 @@ class BlackTestCase(unittest.TestCase):
 
     def test_piping_diff(self) -> None:
         diff_header = re.compile(
 
     def test_piping_diff(self) -> None:
         diff_header = re.compile(
-            rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
-            rf"\+\d\d\d\d"
+            r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d "
+            r"\+\d\d\d\d"
         )
         source, _ = read_data("expression.py")
         expected, _ = read_data("expression.diff")
         )
         source, _ = read_data("expression.py")
         expected, _ = read_data("expression.diff")
@@ -264,6 +278,28 @@ class BlackTestCase(unittest.TestCase):
         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
         self.assertEqual(expected, actual)
 
         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
         self.assertEqual(expected, actual)
 
+    def test_piping_diff_with_color(self) -> None:
+        source, _ = read_data("expression.py")
+        config = THIS_DIR / "data" / "empty_pyproject.toml"
+        args = [
+            "-",
+            "--fast",
+            f"--line-length={black.DEFAULT_LINE_LENGTH}",
+            "--diff",
+            "--color",
+            f"--config={config}",
+        ]
+        result = BlackRunner().invoke(
+            black.main, args, input=BytesIO(source.encode("utf8"))
+        )
+        actual = result.output
+        # Again, the contents are checked in a different test, so only look for colors.
+        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[36m", actual)
+        self.assertIn("\033[32m", actual)
+        self.assertIn("\033[31m", actual)
+        self.assertIn("\033[0m", actual)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_function(self) -> None:
         source, expected = read_data("function")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_function(self) -> None:
         source, expected = read_data("function")
@@ -352,6 +388,25 @@ class BlackTestCase(unittest.TestCase):
             )
             self.assertEqual(expected, actual, msg)
 
             )
             self.assertEqual(expected, actual, msg)
 
+    def test_expression_diff_with_color(self) -> None:
+        source, _ = read_data("expression.py")
+        expected, _ = read_data("expression.diff")
+        tmp_file = Path(black.dump_to_file(source))
+        try:
+            result = BlackRunner().invoke(
+                black.main, ["--diff", "--color", str(tmp_file)]
+            )
+        finally:
+            os.unlink(tmp_file)
+        actual = result.output
+        # We check the contents of the diff in `test_expression_diff`. All
+        # we need to check here is that color codes exist in the result.
+        self.assertIn("\033[1;37m", actual)
+        self.assertIn("\033[36m", actual)
+        self.assertIn("\033[32m", actual)
+        self.assertIn("\033[31m", actual)
+        self.assertIn("\033[0m", actual)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fstring(self) -> None:
         source, expected = read_data("fstring")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fstring(self) -> None:
         source, expected = read_data("fstring")
@@ -1210,7 +1265,7 @@ class BlackTestCase(unittest.TestCase):
             with src.open("r") as fobj:
                 self.assertEqual(fobj.read(), "print('hello')")
 
             with src.open("r") as fobj:
                 self.assertEqual(fobj.read(), "print('hello')")
 
-    @event_loop(close=False)
+    @event_loop()
     def test_cache_multiple_files(self) -> None:
         mode = black.FileMode()
         with cache_dir() as workspace, patch(
     def test_cache_multiple_files(self) -> None:
         mode = black.FileMode()
         with cache_dir() as workspace, patch(
@@ -1290,7 +1345,7 @@ class BlackTestCase(unittest.TestCase):
             black.write_cache({}, [], mode)
             self.assertTrue(workspace.exists())
 
             black.write_cache({}, [], mode)
             self.assertTrue(workspace.exists())
 
-    @event_loop(close=False)
+    @event_loop()
     def test_failed_formatting_does_not_get_cached(self) -> None:
         mode = black.FileMode()
         with cache_dir() as workspace, patch(
     def test_failed_formatting_does_not_get_cached(self) -> None:
         mode = black.FileMode()
         with cache_dir() as workspace, patch(
@@ -1313,7 +1368,18 @@ class BlackTestCase(unittest.TestCase):
             mock.side_effect = OSError
             black.write_cache({}, [], mode)
 
             mock.side_effect = OSError
             black.write_cache({}, [], mode)
 
-    @event_loop(close=False)
+    @event_loop()
+    @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
+    def test_works_in_mono_process_only_environment(self) -> None:
+        with cache_dir() as workspace:
+            for f in [
+                (workspace / "one.py").resolve(),
+                (workspace / "two.py").resolve(),
+            ]:
+                f.write_text('print("hello")\n')
+            self.invokeBlack([str(workspace)])
+
+    @event_loop()
     def test_check_diff_use_together(self) -> None:
         with cache_dir():
             # Files which will be reformatted.
     def test_check_diff_use_together(self) -> None:
         with cache_dir():
             # Files which will be reformatted.
@@ -1376,7 +1442,7 @@ class BlackTestCase(unittest.TestCase):
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
-    @event_loop(close=False)
+    @event_loop()
     def test_multi_file_force_pyi(self) -> None:
         reg_mode = black.FileMode()
         pyi_mode = black.FileMode(is_pyi=True)
     def test_multi_file_force_pyi(self) -> None:
         reg_mode = black.FileMode()
         pyi_mode = black.FileMode(is_pyi=True)
@@ -1428,7 +1494,7 @@ class BlackTestCase(unittest.TestCase):
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
-    @event_loop(close=False)
+    @event_loop()
     def test_multi_file_force_py36(self) -> None:
         reg_mode = black.FileMode()
         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
     def test_multi_file_force_py36(self) -> None:
         reg_mode = black.FileMode()
         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
@@ -1484,8 +1550,8 @@ class BlackTestCase(unittest.TestCase):
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
-            black.gen_python_files_in_dir(
-                path, this_abs, include, exclude, report, gitignore
+            black.gen_python_files(
+                path.iterdir(), this_abs, include, [exclude], report, gitignore
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
@@ -1505,8 +1571,8 @@ class BlackTestCase(unittest.TestCase):
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
-            black.gen_python_files_in_dir(
-                path, this_abs, include, exclude, report, gitignore
+            black.gen_python_files(
+                path.iterdir(), this_abs, include, [exclude], report, gitignore
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
@@ -1530,11 +1596,11 @@ class BlackTestCase(unittest.TestCase):
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
-            black.gen_python_files_in_dir(
-                path,
+            black.gen_python_files(
+                path.iterdir(),
                 this_abs,
                 empty,
                 this_abs,
                 empty,
-                re.compile(black.DEFAULT_EXCLUDES),
+                [re.compile(black.DEFAULT_EXCLUDES)],
                 report,
                 gitignore,
             )
                 report,
                 gitignore,
             )
@@ -1557,11 +1623,11 @@ class BlackTestCase(unittest.TestCase):
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
         ]
         this_abs = THIS_DIR.resolve()
         sources.extend(
-            black.gen_python_files_in_dir(
-                path,
+            black.gen_python_files(
+                path.iterdir(),
                 this_abs,
                 re.compile(black.DEFAULT_INCLUDES),
                 this_abs,
                 re.compile(black.DEFAULT_INCLUDES),
-                empty,
+                [empty],
                 report,
                 gitignore,
             )
                 report,
                 gitignore,
             )
@@ -1603,7 +1669,7 @@ class BlackTestCase(unittest.TestCase):
 
     def test_symlink_out_of_root_directory(self) -> None:
         path = MagicMock()
 
     def test_symlink_out_of_root_directory(self) -> None:
         path = MagicMock()
-        root = THIS_DIR
+        root = THIS_DIR.resolve()
         child = MagicMock()
         include = re.compile(black.DEFAULT_INCLUDES)
         exclude = re.compile(black.DEFAULT_EXCLUDES)
         child = MagicMock()
         include = re.compile(black.DEFAULT_INCLUDES)
         exclude = re.compile(black.DEFAULT_EXCLUDES)
@@ -1617,8 +1683,8 @@ class BlackTestCase(unittest.TestCase):
         child.is_symlink.return_value = True
         try:
             list(
         child.is_symlink.return_value = True
         try:
             list(
-                black.gen_python_files_in_dir(
-                    path, root, include, exclude, report, gitignore
+                black.gen_python_files(
+                    path.iterdir(), root, include, exclude, report, gitignore
                 )
             )
         except ValueError as ve:
                 )
             )
         except ValueError as ve:
@@ -1631,8 +1697,8 @@ class BlackTestCase(unittest.TestCase):
         child.is_symlink.return_value = False
         with self.assertRaises(ValueError):
             list(
         child.is_symlink.return_value = False
         with self.assertRaises(ValueError):
             list(
-                black.gen_python_files_in_dir(
-                    path, root, include, exclude, report, gitignore
+                black.gen_python_files(
+                    path.iterdir(), root, include, exclude, report, gitignore
                 )
             )
         path.iterdir.assert_called()
                 )
             )
         path.iterdir.assert_called()
@@ -1697,6 +1763,66 @@ class BlackTestCase(unittest.TestCase):
         finally:
             tmp_file.unlink()
 
         finally:
             tmp_file.unlink()
 
+    def test_parse_pyproject_toml(self) -> None:
+        test_toml_file = THIS_DIR / "test.toml"
+        config = black.parse_pyproject_toml(str(test_toml_file))
+        self.assertEqual(config["verbose"], 1)
+        self.assertEqual(config["check"], "no")
+        self.assertEqual(config["diff"], "y")
+        self.assertEqual(config["color"], True)
+        self.assertEqual(config["line_length"], 79)
+        self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
+        self.assertEqual(config["exclude"], r"\.pyi?$")
+        self.assertEqual(config["include"], r"\.py?$")
+
+    def test_read_pyproject_toml(self) -> None:
+        test_toml_file = THIS_DIR / "test.toml"
+
+        # Fake a click context and parameter so mypy stays happy
+        class FakeContext(click.Context):
+            def __init__(self) -> None:
+                self.default_map: Dict[str, Any] = {}
+
+        class FakeParameter(click.Parameter):
+            def __init__(self) -> None:
+                pass
+
+        fake_ctx = FakeContext()
+        black.read_pyproject_toml(
+            fake_ctx, FakeParameter(), str(test_toml_file),
+        )
+        config = fake_ctx.default_map
+        self.assertEqual(config["verbose"], "1")
+        self.assertEqual(config["check"], "no")
+        self.assertEqual(config["diff"], "y")
+        self.assertEqual(config["color"], "True")
+        self.assertEqual(config["line_length"], "79")
+        self.assertEqual(config["target_version"], ["py36", "py37", "py38"])
+        self.assertEqual(config["exclude"], r"\.pyi?$")
+        self.assertEqual(config["include"], r"\.py?$")
+
+    def test_find_project_root(self) -> None:
+        with TemporaryDirectory() as workspace:
+            root = Path(workspace)
+            test_dir = root / "test"
+            test_dir.mkdir()
+
+            src_dir = root / "src"
+            src_dir.mkdir()
+
+            root_pyproject = root / "pyproject.toml"
+            root_pyproject.touch()
+            src_pyproject = src_dir / "pyproject.toml"
+            src_pyproject.touch()
+            src_python = src_dir / "foo.py"
+            src_python.touch()
+
+            self.assertEqual(
+                black.find_project_root((src_dir, test_dir)), root.resolve()
+            )
+            self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
+            self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
+
 
 class BlackDTestCase(AioHTTPTestCase):
     async def get_application(self) -> web.Application:
 
 class BlackDTestCase(AioHTTPTestCase):
     async def get_application(self) -> web.Application:
@@ -1787,7 +1913,7 @@ class BlackDTestCase(AioHTTPTestCase):
     @unittest_run_loop
     async def test_blackd_diff(self) -> None:
         diff_header = re.compile(
     @unittest_run_loop
     async def test_blackd_diff(self) -> None:
         diff_header = re.compile(
-            rf"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
+            r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
         )
 
         source, _ = read_data("blackd_diff.py")
         )
 
         source, _ = read_data("blackd_diff.py")