X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/e4b4fb02b91e0f5a60a9678604653aecedff513b..a9eab85f226df3b3070aca122d089dbd62b42b9c:/tests/test_black.py

diff --git a/tests/test_black.py b/tests/test_black.py
index b8e526a..455cb33 100644
--- a/tests/test_black.py
+++ b/tests/test_black.py
@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import contextmanager
 from dataclasses import replace
 import inspect
-from io import BytesIO, TextIOWrapper
+from io import BytesIO
 import os
 from pathlib import Path
 from platform import system
@@ -16,10 +16,8 @@ from tempfile import TemporaryDirectory
 import types
 from typing import (
     Any,
-    BinaryIO,
     Callable,
     Dict,
-    Generator,
     List,
     Iterator,
     TypeVar,
@@ -34,6 +32,11 @@ from click.testing import CliRunner
 
 import black
 from black import Feature, TargetVersion
+from black.cache import get_cache_file
+from black.debug import DebugVisitor
+from black.output import diff, color_diff
+from black.report import Report
+import black.files
 
 from pathspec import PathSpec
 
@@ -48,7 +51,6 @@ from tests.util import (
     ff,
     dump_to_stderr,
 )
-from .test_primer import PrimerCLITests  # noqa: F401
 
 
 THIS_FILE = Path(__file__)
@@ -62,6 +64,9 @@ PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERS
 T = TypeVar("T")
 R = TypeVar("R")
 
+# Match the time output in a diff, but nothing else
+DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
+
 
 @contextmanager
 def cache_dir(exists: bool = True) -> Iterator[Path]:
@@ -69,7 +74,7 @@ def cache_dir(exists: bool = True) -> Iterator[Path]:
         cache_dir = Path(workspace)
         if not exists:
             cache_dir = cache_dir / "new"
-        with patch("black.CACHE_DIR", cache_dir):
+        with patch("black.cache.CACHE_DIR", cache_dir):
             yield cache_dir
 
 
@@ -100,28 +105,10 @@ class FakeParameter(click.Parameter):
 
 
 class BlackRunner(CliRunner):
-    """Modify CliRunner so that stderr is not merged with stdout.
-
-    This is a hack that can be removed once we depend on Click 7.x"""
+    """Make sure STDOUT and STDERR are kept separate when testing Black via its CLI."""
 
     def __init__(self) -> None:
-        self.stderrbuf = BytesIO()
-        self.stdoutbuf = BytesIO()
-        self.stdout_bytes = b""
-        self.stderr_bytes = b""
-        super().__init__()
-
-    @contextmanager
-    def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
-        with super().isolation(*args, **kwargs) as output:
-            try:
-                hold_stderr = sys.stderr
-                sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
-                yield output
-            finally:
-                self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
-                self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
-                sys.stderr = hold_stderr
+        super().__init__(mix_stderr=False)
 
 
 class BlackTestCase(BlackBaseTestCase):
@@ -137,8 +124,8 @@ class BlackTestCase(BlackBaseTestCase):
             exit_code,
             msg=(
                 f"Failed with args: {args}\n"
-                f"stdout: {runner.stdout_bytes.decode()!r}\n"
-                f"stderr: {runner.stderr_bytes.decode()!r}\n"
+                f"stdout: {result.stdout_bytes.decode()!r}\n"
+                f"stderr: {result.stderr_bytes.decode()!r}\n"
                 f"exception: {result.exception}"
             ),
         )
@@ -171,8 +158,9 @@ class BlackTestCase(BlackBaseTestCase):
         )
         self.assertEqual(result.exit_code, 0)
         self.assertFormatEqual(expected, result.output)
-        black.assert_equivalent(source, result.output)
-        black.assert_stable(source, result.output, DEFAULT_MODE)
+        if source != result.output:
+            black.assert_equivalent(source, result.output)
+            black.assert_stable(source, result.output, DEFAULT_MODE)
 
     def test_piping_diff(self) -> None:
         diff_header = re.compile(
@@ -478,7 +466,7 @@ class BlackTestCase(BlackBaseTestCase):
         finally:
             os.unlink(tmp_file)
         actual = (
-            runner.stderr_bytes.decode()
+            result.stderr_bytes.decode()
             .replace("\n", "")
             .replace("\\n", "")
             .replace("\\r", "")
@@ -582,7 +570,7 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertFormatEqual(contents_spc, fs(contents_tab))
 
     def test_report_verbose(self) -> None:
-        report = black.Report(verbose=True)
+        report = Report(verbose=True)
         out_lines = []
         err_lines = []
 
@@ -592,7 +580,7 @@ class BlackTestCase(BlackBaseTestCase):
         def err(msg: str, **kwargs: Any) -> None:
             err_lines.append(msg)
 
-        with patch("black.out", out), patch("black.err", err):
+        with patch("black.output._out", out), patch("black.output._err", err):
             report.done(Path("f1"), black.Changed.NO)
             self.assertEqual(len(out_lines), 1)
             self.assertEqual(len(err_lines), 0)
@@ -684,7 +672,7 @@ class BlackTestCase(BlackBaseTestCase):
             )
 
     def test_report_quiet(self) -> None:
-        report = black.Report(quiet=True)
+        report = Report(quiet=True)
         out_lines = []
         err_lines = []
 
@@ -694,7 +682,7 @@ class BlackTestCase(BlackBaseTestCase):
         def err(msg: str, **kwargs: Any) -> None:
             err_lines.append(msg)
 
-        with patch("black.out", out), patch("black.err", err):
+        with patch("black.output._out", out), patch("black.output._err", err):
             report.done(Path("f1"), black.Changed.NO)
             self.assertEqual(len(out_lines), 0)
             self.assertEqual(len(err_lines), 0)
@@ -788,7 +776,7 @@ class BlackTestCase(BlackBaseTestCase):
         def err(msg: str, **kwargs: Any) -> None:
             err_lines.append(msg)
 
-        with patch("black.out", out), patch("black.err", err):
+        with patch("black.output._out", out), patch("black.output._err", err):
             report.done(Path("f1"), black.Changed.NO)
             self.assertEqual(len(out_lines), 0)
             self.assertEqual(len(err_lines), 0)
@@ -1005,8 +993,8 @@ class BlackTestCase(BlackBaseTestCase):
         def err(msg: str, **kwargs: Any) -> None:
             err_lines.append(msg)
 
-        with patch("black.out", out), patch("black.err", err):
-            black.DebugVisitor.show(source)
+        with patch("black.debug.out", out):
+            DebugVisitor.show(source)
         actual = "\n".join(out_lines) + "\n"
         log_name = ""
         if expected != actual:
@@ -1054,7 +1042,7 @@ class BlackTestCase(BlackBaseTestCase):
         def err(msg: str, **kwargs: Any) -> None:
             err_lines.append(msg)
 
-        with patch("black.out", out), patch("black.err", err):
+        with patch("black.output._out", out), patch("black.output._err", err):
             with self.assertRaises(AssertionError):
                 self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]")
 
@@ -1066,7 +1054,7 @@ class BlackTestCase(BlackBaseTestCase):
     def test_cache_broken_file(self) -> None:
         mode = DEFAULT_MODE
         with cache_dir() as workspace:
-            cache_file = black.get_cache_file(mode)
+            cache_file = get_cache_file(mode)
             with cache_file.open("w") as fobj:
                 fobj.write("this is not a pickle")
             self.assertEqual(black.read_cache(mode), {})
@@ -1120,7 +1108,7 @@ class BlackTestCase(BlackBaseTestCase):
                 "black.write_cache"
             ) as write_cache:
                 self.invokeBlack([str(src), "--diff"])
-                cache_file = black.get_cache_file(mode)
+                cache_file = get_cache_file(mode)
                 self.assertFalse(cache_file.exists())
                 write_cache.assert_not_called()
                 read_cache.assert_not_called()
@@ -1135,7 +1123,7 @@ class BlackTestCase(BlackBaseTestCase):
                 "black.write_cache"
             ) as write_cache:
                 self.invokeBlack([str(src), "--diff", "--color"])
-                cache_file = black.get_cache_file(mode)
+                cache_file = get_cache_file(mode)
                 self.assertFalse(cache_file.exists())
                 write_cache.assert_not_called()
                 read_cache.assert_not_called()
@@ -1173,7 +1161,7 @@ class BlackTestCase(BlackBaseTestCase):
                 black.main, ["-"], input=BytesIO(b"print('hello')")
             )
             self.assertEqual(result.exit_code, 0)
-            cache_file = black.get_cache_file(mode)
+            cache_file = get_cache_file(mode)
             self.assertFalse(cache_file.exists())
 
     def test_read_cache_no_cachefile(self) -> None:
@@ -1422,7 +1410,7 @@ class BlackTestCase(BlackBaseTestCase):
         )
         self.assertEqual(sorted(expected), sorted(sources))
 
-    def test_gitingore_used_as_default(self) -> None:
+    def test_gitignore_used_as_default(self) -> None:
         path = Path(THIS_DIR / "data" / "include_exclude_tests")
         include = re.compile(r"\.pyi?$")
         extend_exclude = re.compile(r"/exclude/")
@@ -1524,7 +1512,7 @@ class BlackTestCase(BlackBaseTestCase):
 
     @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
     def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
-        # Exclude shouldn't exclude stdin_filename since it is mimicing the
+        # Exclude shouldn't exclude stdin_filename since it is mimicking the
         # file being passed directly. This is the same as
         # test_exclude_for_issue_1572
         path = THIS_DIR / "data" / "include_exclude_tests"
@@ -1719,6 +1707,33 @@ class BlackTestCase(BlackBaseTestCase):
         )
         self.assertEqual(sorted(expected), sorted(sources))
 
+    def test_nested_gitignore(self) -> None:
+        path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
+        include = re.compile(r"\.pyi?$")
+        exclude = re.compile(r"")
+        root_gitignore = black.files.get_gitignore(path)
+        report = black.Report()
+        expected: List[Path] = [
+            Path(path / "x.py"),
+            Path(path / "root/b.py"),
+            Path(path / "root/c.py"),
+            Path(path / "root/child/c.py"),
+        ]
+        this_abs = THIS_DIR.resolve()
+        sources = list(
+            black.gen_python_files(
+                path.iterdir(),
+                this_abs,
+                include,
+                exclude,
+                None,
+                None,
+                report,
+                root_gitignore,
+            )
+        )
+        self.assertEqual(sorted(expected), sorted(sources))
+
     def test_empty_include(self) -> None:
         path = THIS_DIR / "data" / "include_exclude_tests"
         report = black.Report()
@@ -1781,6 +1796,16 @@ class BlackTestCase(BlackBaseTestCase):
         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
 
+    def test_required_version_matches_version(self) -> None:
+        self.invokeBlack(
+            ["--required-version", black.__version__], exit_code=0, ignore_config=True
+        )
+
+    def test_required_version_does_not_match_version(self) -> None:
+        self.invokeBlack(
+            ["--required-version", "20.99b"], exit_code=1, ignore_config=True
+        )
+
     def test_preserves_line_endings(self) -> None:
         with TemporaryDirectory() as workspace:
             test_file = Path(workspace) / "test.py"
@@ -1801,7 +1826,7 @@ class BlackTestCase(BlackBaseTestCase):
                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
             )
             self.assertEqual(result.exit_code, 0)
-            output = runner.stdout_bytes
+            output = result.stdout_bytes
             self.assertIn(nl.encode("utf8"), output)
             if nl == "\n":
                 self.assertNotIn(b"\r\n", output)
@@ -1900,7 +1925,7 @@ class BlackTestCase(BlackBaseTestCase):
             critical=fail,
             log=fail,
         ):
-            ff(THIS_FILE)
+            ff(THIS_DIR / "util.py")
 
     def test_invalid_config_return_code(self) -> None:
         tmp_file = Path(black.dump_to_file())
@@ -1960,7 +1985,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
             self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
 
-    @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__)
+    @patch(
+        "black.files.find_user_pyproject_toml",
+        black.files.find_user_pyproject_toml.__wrapped__,
+    )
     def test_find_user_pyproject_toml_linux(self) -> None:
         if system() == "Windows":
             return
@@ -1970,7 +1998,7 @@ class BlackTestCase(BlackBaseTestCase):
             tmp_user_config = Path(workspace) / "black"
             with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}):
                 self.assertEqual(
-                    black.find_user_pyproject_toml(), tmp_user_config.resolve()
+                    black.files.find_user_pyproject_toml(), tmp_user_config.resolve()
                 )
 
         # Test fallback for XDG_CONFIG_HOME
@@ -1978,7 +2006,7 @@ class BlackTestCase(BlackBaseTestCase):
             os.environ.pop("XDG_CONFIG_HOME", None)
             fallback_user_config = Path("~/.config").expanduser() / "black"
             self.assertEqual(
-                black.find_user_pyproject_toml(), fallback_user_config.resolve()
+                black.files.find_user_pyproject_toml(), fallback_user_config.resolve()
             )
 
     def test_find_user_pyproject_toml_windows(self) -> None:
@@ -1986,7 +2014,9 @@ class BlackTestCase(BlackBaseTestCase):
             return
 
         user_config_path = Path.home() / ".black"
-        self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve())
+        self.assertEqual(
+            black.files.find_user_pyproject_toml(), user_config_path.resolve()
+        )
 
     def test_bpo_33660_workaround(self) -> None:
         if system() == "Windows":
@@ -2053,6 +2083,146 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
+    @staticmethod
+    def compare_results(
+        result: click.testing.Result, expected_value: str, expected_exit_code: int
+    ) -> None:
+        """Helper method to test the value and exit code of a click Result."""
+        assert (
+            result.output == expected_value
+        ), "The output did not match the expected value."
+        assert result.exit_code == expected_exit_code, "The exit code is incorrect."
+
+    def test_code_option(self) -> None:
+        """Test the code option with no changes."""
+        code = 'print("Hello world")\n'
+        args = ["--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        self.compare_results(result, code, 0)
+
+    def test_code_option_changed(self) -> None:
+        """Test the code option when changes are required."""
+        code = "print('hello world')"
+        formatted = black.format_str(code, mode=DEFAULT_MODE)
+
+        args = ["--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        self.compare_results(result, formatted, 0)
+
+    def test_code_option_check(self) -> None:
+        """Test the code option when check is passed."""
+        args = ["--check", "--code", 'print("Hello world")\n']
+        result = CliRunner().invoke(black.main, args)
+        self.compare_results(result, "", 0)
+
+    def test_code_option_check_changed(self) -> None:
+        """Test the code option when changes are required, and check is passed."""
+        args = ["--check", "--code", "print('hello world')"]
+        result = CliRunner().invoke(black.main, args)
+        self.compare_results(result, "", 1)
+
+    def test_code_option_diff(self) -> None:
+        """Test the code option when diff is passed."""
+        code = "print('hello world')"
+        formatted = black.format_str(code, mode=DEFAULT_MODE)
+        result_diff = diff(code, formatted, "STDIN", "STDOUT")
+
+        args = ["--diff", "--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        # Remove time from diff
+        output = DIFF_TIME.sub("", result.output)
+
+        assert output == result_diff, "The output did not match the expected value."
+        assert result.exit_code == 0, "The exit code is incorrect."
+
+    def test_code_option_color_diff(self) -> None:
+        """Test the code option when color and diff are passed."""
+        code = "print('hello world')"
+        formatted = black.format_str(code, mode=DEFAULT_MODE)
+
+        result_diff = diff(code, formatted, "STDIN", "STDOUT")
+        result_diff = color_diff(result_diff)
+
+        args = ["--diff", "--color", "--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        # Remove time from diff
+        output = DIFF_TIME.sub("", result.output)
+
+        assert output == result_diff, "The output did not match the expected value."
+        assert result.exit_code == 0, "The exit code is incorrect."
+
+    def test_code_option_safe(self) -> None:
+        """Test that the code option throws an error when the sanity checks fail."""
+        # Patch black.assert_equivalent to ensure the sanity checks fail
+        with patch.object(black, "assert_equivalent", side_effect=AssertionError):
+            code = 'print("Hello world")'
+            error_msg = f"{code}\nerror: cannot format <string>: \n"
+
+            args = ["--safe", "--code", code]
+            result = CliRunner().invoke(black.main, args)
+
+            self.compare_results(result, error_msg, 123)
+
+    def test_code_option_fast(self) -> None:
+        """Test that the code option ignores errors when the sanity checks fail."""
+        # Patch black.assert_equivalent to ensure the sanity checks fail
+        with patch.object(black, "assert_equivalent", side_effect=AssertionError):
+            code = 'print("Hello world")'
+            formatted = black.format_str(code, mode=DEFAULT_MODE)
+
+            args = ["--fast", "--code", code]
+            result = CliRunner().invoke(black.main, args)
+
+            self.compare_results(result, formatted, 0)
+
+    def test_code_option_config(self) -> None:
+        """
+        Test that the code option finds the pyproject.toml in the current directory.
+        """
+        with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
+            # Make sure we are in the project root with the pyproject file
+            if not Path("tests").exists():
+                os.chdir("..")
+
+            args = ["--code", "print"]
+            CliRunner().invoke(black.main, args)
+
+            pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
+            assert (
+                len(parse.mock_calls) >= 1
+            ), "Expected config parse to be called with the current directory."
+
+            _, call_args, _ = parse.mock_calls[0]
+            assert (
+                call_args[0].lower() == str(pyproject_path).lower()
+            ), "Incorrect config loaded."
+
+    def test_code_option_parent_config(self) -> None:
+        """
+        Test that the code option finds the pyproject.toml in the parent directory.
+        """
+        with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
+            # Make sure we are in the tests directory
+            if Path("tests").exists():
+                os.chdir("tests")
+
+            args = ["--code", "print"]
+            CliRunner().invoke(black.main, args)
+
+            pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
+            assert (
+                len(parse.mock_calls) >= 1
+            ), "Expected config parse to be called with the current directory."
+
+            _, call_args, _ = parse.mock_calls[0]
+            assert (
+                call_args[0].lower() == str(pyproject_path).lower()
+            ), "Incorrect config loaded."
+
 
 with open(black.__file__, "r", encoding="utf-8") as _bf:
     black_source_lines = _bf.readlines()