]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Prepare CHANGES.md for release 21.8b0 (#2458)
[etc/vim.git] / tests / test_black.py
index 5ab25cd160158d094b0122bdfd6beba60a9555fc..398a528bee9a4684fd52355597fcbc1f3bf0d3c3 100644 (file)
@@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import contextmanager
 from dataclasses import replace
 import inspect
 from contextlib import contextmanager
 from dataclasses import replace
 import inspect
+import io
 from io import BytesIO
 import os
 from pathlib import Path
 from io import BytesIO
 import os
 from pathlib import Path
@@ -25,6 +26,7 @@ from typing import (
 import pytest
 import unittest
 from unittest.mock import patch, MagicMock
 import pytest
 import unittest
 from unittest.mock import patch, MagicMock
+from parameterized import parameterized
 
 import click
 from click import unstyle
 
 import click
 from click import unstyle
@@ -34,6 +36,7 @@ import black
 from black import Feature, TargetVersion
 from black.cache import get_cache_file
 from black.debug import DebugVisitor
 from black import Feature, TargetVersion
 from black.cache import get_cache_file
 from black.debug import DebugVisitor
+from black.output import diff, color_diff
 from black.report import Report
 import black.files
 
 from black.report import Report
 import black.files
 
@@ -42,6 +45,7 @@ from pathspec import PathSpec
 # Import other test classes
 from tests.util import (
     THIS_DIR,
 # Import other test classes
 from tests.util import (
     THIS_DIR,
+    change_directory,
     read_data,
     DETERMINISTIC_HEADER,
     BlackBaseTestCase,
     read_data,
     DETERMINISTIC_HEADER,
     BlackBaseTestCase,
@@ -63,6 +67,9 @@ PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERS
 T = TypeVar("T")
 R = TypeVar("R")
 
 T = TypeVar("T")
 R = TypeVar("R")
 
+# Match the time output in a diff, but nothing else
+DIFF_TIME = re.compile(r"\t[\d-:+\. ]+")
+
 
 @contextmanager
 def cache_dir(exists: bool = True) -> Iterator[Path]:
 
 @contextmanager
 def cache_dir(exists: bool = True) -> Iterator[Path]:
@@ -115,6 +122,8 @@ class BlackTestCase(BlackBaseTestCase):
         if ignore_config:
             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
         result = runner.invoke(black.main, args)
         if ignore_config:
             args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
         result = runner.invoke(black.main, args)
+        assert result.stdout_bytes is not None
+        assert result.stderr_bytes is not None
         self.assertEqual(
             result.exit_code,
             exit_code,
         self.assertEqual(
             result.exit_code,
             exit_code,
@@ -291,6 +300,14 @@ class BlackTestCase(BlackBaseTestCase):
         versions = black.detect_target_versions(root)
         self.assertIn(black.TargetVersion.PY38, versions)
 
         versions = black.detect_target_versions(root)
         self.assertIn(black.TargetVersion.PY38, versions)
 
+    @parameterized.expand([(3, 9), (3, 10)])
+    def test_pep_572_newer_syntax(self, major: int, minor: int) -> None:
+        source, expected = read_data(f"pep_572_py{major}{minor}")
+        actual = fs(source, mode=DEFAULT_MODE)
+        self.assertFormatEqual(expected, actual)
+        if sys.version_info >= (major, minor):
+            black.assert_equivalent(source, actual)
+
     def test_expression_ff(self) -> None:
         source, expected = read_data("expression")
         tmp_file = Path(black.dump_to_file(source))
     def test_expression_ff(self) -> None:
         source, expected = read_data("expression")
         tmp_file = Path(black.dump_to_file(source))
@@ -444,37 +461,6 @@ class BlackTestCase(BlackBaseTestCase):
             )
             self.assertEqual(expected, actual, msg)
 
             )
             self.assertEqual(expected, actual, msg)
 
-    @pytest.mark.no_python2
-    def test_python2_should_fail_without_optional_install(self) -> None:
-        if sys.version_info < (3, 8):
-            self.skipTest(
-                "Python 3.6 and 3.7 will install typed-ast to work and as such will be"
-                " able to parse Python 2 syntax without explicitly specifying the"
-                " python2 extra"
-            )
-
-        source = "x = 1234l"
-        tmp_file = Path(black.dump_to_file(source))
-        try:
-            runner = BlackRunner()
-            result = runner.invoke(black.main, [str(tmp_file)])
-            self.assertEqual(result.exit_code, 123)
-        finally:
-            os.unlink(tmp_file)
-        actual = (
-            result.stderr_bytes.decode()
-            .replace("\n", "")
-            .replace("\\n", "")
-            .replace("\\r", "")
-            .replace("\r", "")
-        )
-        msg = (
-            "The requested source code has invalid Python 3 syntax."
-            "If you are trying to format Python 2 files please reinstall Black"
-            " with the 'python2' extra: `python3 -m pip install black[python2]`."
-        )
-        self.assertIn(msg, actual)
-
     @pytest.mark.python2
     @patch("black.dump_to_file", dump_to_stderr)
     def test_python2_print_function(self) -> None:
     @pytest.mark.python2
     @patch("black.dump_to_file", dump_to_stderr)
     def test_python2_print_function(self) -> None:
@@ -1402,11 +1388,13 @@ class BlackTestCase(BlackBaseTestCase):
                 None,
                 report,
                 gitignore,
                 None,
                 report,
                 gitignore,
+                verbose=False,
+                quiet=False,
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
 
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
 
-    def test_gitingore_used_as_default(self) -> None:
+    def test_gitignore_used_as_default(self) -> None:
         path = Path(THIS_DIR / "data" / "include_exclude_tests")
         include = re.compile(r"\.pyi?$")
         extend_exclude = re.compile(r"/exclude/")
         path = Path(THIS_DIR / "data" / "include_exclude_tests")
         include = re.compile(r"\.pyi?$")
         extend_exclude = re.compile(r"/exclude/")
@@ -1651,6 +1639,30 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    def test_reformat_one_with_stdin_filename_ipynb(self) -> None:
+        with patch(
+            "black.format_stdin_to_stdout",
+            return_value=lambda *args, **kwargs: black.Changed.YES,
+        ) as fsts:
+            report = MagicMock()
+            p = "foo.ipynb"
+            path = Path(f"__BLACK_STDIN_FILENAME__{p}")
+            expected = Path(p)
+            black.reformat_one(
+                path,
+                fast=True,
+                write_back=black.WriteBack.YES,
+                mode=DEFAULT_MODE,
+                report=report,
+            )
+            fsts.assert_called_once_with(
+                fast=True,
+                write_back=black.WriteBack.YES,
+                mode=replace(DEFAULT_MODE, is_ipynb=True),
+            )
+            # __BLACK_STDIN_FILENAME__ should have been stripped
+            report.done.assert_called_with(expected, black.Changed.YES)
+
     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
     def test_reformat_one_with_stdin_and_existing_path(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1675,6 +1687,20 @@ class BlackTestCase(BlackBaseTestCase):
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
             # __BLACK_STDIN_FILENAME__ should have been stripped
             report.done.assert_called_with(expected, black.Changed.YES)
 
+    def test_reformat_one_with_stdin_empty(self) -> None:
+        output = io.StringIO()
+        with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
+            try:
+                black.format_stdin_to_stdout(
+                    fast=True,
+                    content="",
+                    write_back=black.WriteBack.YES,
+                    mode=DEFAULT_MODE,
+                )
+            except io.UnsupportedOperation:
+                pass  # StringIO does not support detach
+            assert output.getvalue() == ""
+
     def test_gitignore_exclude(self) -> None:
         path = THIS_DIR / "data" / "include_exclude_tests"
         include = re.compile(r"\.pyi?$")
     def test_gitignore_exclude(self) -> None:
         path = THIS_DIR / "data" / "include_exclude_tests"
         include = re.compile(r"\.pyi?$")
@@ -1699,10 +1725,65 @@ class BlackTestCase(BlackBaseTestCase):
                 None,
                 report,
                 gitignore,
                 None,
                 report,
                 gitignore,
+                verbose=False,
+                quiet=False,
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
 
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
 
+    def test_nested_gitignore(self) -> None:
+        path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
+        include = re.compile(r"\.pyi?$")
+        exclude = re.compile(r"")
+        root_gitignore = black.files.get_gitignore(path)
+        report = black.Report()
+        expected: List[Path] = [
+            Path(path / "x.py"),
+            Path(path / "root/b.py"),
+            Path(path / "root/c.py"),
+            Path(path / "root/child/c.py"),
+        ]
+        this_abs = THIS_DIR.resolve()
+        sources = list(
+            black.gen_python_files(
+                path.iterdir(),
+                this_abs,
+                include,
+                exclude,
+                None,
+                None,
+                report,
+                root_gitignore,
+                verbose=False,
+                quiet=False,
+            )
+        )
+        self.assertEqual(sorted(expected), sorted(sources))
+
+    def test_invalid_gitignore(self) -> None:
+        path = THIS_DIR / "data" / "invalid_gitignore_tests"
+        empty_config = path / "pyproject.toml"
+        result = BlackRunner().invoke(
+            black.main, ["--verbose", "--config", str(empty_config), str(path)]
+        )
+        assert result.exit_code == 1
+        assert result.stderr_bytes is not None
+
+        gitignore = path / ".gitignore"
+        assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
+
+    def test_invalid_nested_gitignore(self) -> None:
+        path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
+        empty_config = path / "pyproject.toml"
+        result = BlackRunner().invoke(
+            black.main, ["--verbose", "--config", str(empty_config), str(path)]
+        )
+        assert result.exit_code == 1
+        assert result.stderr_bytes is not None
+
+        gitignore = path / "a" / ".gitignore"
+        assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
+
     def test_empty_include(self) -> None:
         path = THIS_DIR / "data" / "include_exclude_tests"
         report = black.Report()
     def test_empty_include(self) -> None:
         path = THIS_DIR / "data" / "include_exclude_tests"
         report = black.Report()
@@ -1733,6 +1814,8 @@ class BlackTestCase(BlackBaseTestCase):
                 None,
                 report,
                 gitignore,
                 None,
                 report,
                 gitignore,
+                verbose=False,
+                quiet=False,
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
@@ -1757,6 +1840,8 @@ class BlackTestCase(BlackBaseTestCase):
                 None,
                 report,
                 gitignore,
                 None,
                 report,
                 gitignore,
+                verbose=False,
+                quiet=False,
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
             )
         )
         self.assertEqual(sorted(expected), sorted(sources))
@@ -1765,6 +1850,16 @@ class BlackTestCase(BlackBaseTestCase):
         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
 
         for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
 
+    def test_required_version_matches_version(self) -> None:
+        self.invokeBlack(
+            ["--required-version", black.__version__], exit_code=0, ignore_config=True
+        )
+
+    def test_required_version_does_not_match_version(self) -> None:
+        self.invokeBlack(
+            ["--required-version", "20.99b"], exit_code=1, ignore_config=True
+        )
+
     def test_preserves_line_endings(self) -> None:
         with TemporaryDirectory() as workspace:
             test_file = Path(workspace) / "test.py"
     def test_preserves_line_endings(self) -> None:
         with TemporaryDirectory() as workspace:
             test_file = Path(workspace) / "test.py"
@@ -1819,6 +1914,8 @@ class BlackTestCase(BlackBaseTestCase):
                     None,
                     report,
                     gitignore,
                     None,
                     report,
                     gitignore,
+                    verbose=False,
+                    quiet=False,
                 )
             )
         except ValueError as ve:
                 )
             )
         except ValueError as ve:
@@ -1840,6 +1937,8 @@ class BlackTestCase(BlackBaseTestCase):
                     None,
                     report,
                     gitignore,
                     None,
                     report,
                     gitignore,
+                    verbose=False,
+                    quiet=False,
                 )
             )
         path.iterdir.assert_called()
                 )
             )
         path.iterdir.assert_called()
@@ -1851,7 +1950,7 @@ class BlackTestCase(BlackBaseTestCase):
 
     def test_shhh_click(self) -> None:
         try:
 
     def test_shhh_click(self) -> None:
         try:
-            from click import _unicodefun  # type: ignore
+            from click import _unicodefun
         except ModuleNotFoundError:
             self.skipTest("Incompatible Click version")
         if not hasattr(_unicodefun, "_verify_python3_env"):
         except ModuleNotFoundError:
             self.skipTest("Incompatible Click version")
         if not hasattr(_unicodefun, "_verify_python3_env"):
@@ -1860,14 +1959,14 @@ class BlackTestCase(BlackBaseTestCase):
         with patch("locale.getpreferredencoding") as gpe:
             gpe.return_value = "ASCII"
             with self.assertRaises(RuntimeError):
         with patch("locale.getpreferredencoding") as gpe:
             gpe.return_value = "ASCII"
             with self.assertRaises(RuntimeError):
-                _unicodefun._verify_python3_env()
+                _unicodefun._verify_python3_env()  # type: ignore
         # Now, let's silence Click...
         black.patch_click()
         # ...and confirm it's silent.
         with patch("locale.getpreferredencoding") as gpe:
             gpe.return_value = "ASCII"
             try:
         # Now, let's silence Click...
         black.patch_click()
         # ...and confirm it's silent.
         with patch("locale.getpreferredencoding") as gpe:
             gpe.return_value = "ASCII"
             try:
-                _unicodefun._verify_python3_env()
+                _unicodefun._verify_python3_env()  # type: ignore
             except RuntimeError as re:
                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
 
             except RuntimeError as re:
                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
 
@@ -1982,17 +2081,12 @@ class BlackTestCase(BlackBaseTestCase):
             return
 
         # https://bugs.python.org/issue33660
             return
 
         # https://bugs.python.org/issue33660
-
-        old_cwd = Path.cwd()
-        try:
-            root = Path("/")
-            os.chdir(str(root))
+        root = Path("/")
+        with change_directory(root):
             path = Path("workspace") / "project"
             report = black.Report(verbose=True)
             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
             self.assertEqual(normalized_path, "workspace/project")
             path = Path("workspace") / "project"
             report = black.Report(verbose=True)
             normalized_path = black.normalize_path_maybe_ignore(path, root, report)
             self.assertEqual(normalized_path, "workspace/project")
-        finally:
-            os.chdir(str(old_cwd))
 
     def test_newline_comment_interaction(self) -> None:
         source = "class A:\\\r\n# type: ignore\n pass\n"
 
     def test_newline_comment_interaction(self) -> None:
         source = "class A:\\\r\n# type: ignore\n pass\n"
@@ -2042,6 +2136,139 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
+    @staticmethod
+    def compare_results(
+        result: click.testing.Result, expected_value: str, expected_exit_code: int
+    ) -> None:
+        """Helper method to test the value and exit code of a click Result."""
+        assert (
+            result.output == expected_value
+        ), "The output did not match the expected value."
+        assert result.exit_code == expected_exit_code, "The exit code is incorrect."
+
+    def test_code_option(self) -> None:
+        """Test the code option with no changes."""
+        code = 'print("Hello world")\n'
+        args = ["--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        self.compare_results(result, code, 0)
+
+    def test_code_option_changed(self) -> None:
+        """Test the code option when changes are required."""
+        code = "print('hello world')"
+        formatted = black.format_str(code, mode=DEFAULT_MODE)
+
+        args = ["--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        self.compare_results(result, formatted, 0)
+
+    def test_code_option_check(self) -> None:
+        """Test the code option when check is passed."""
+        args = ["--check", "--code", 'print("Hello world")\n']
+        result = CliRunner().invoke(black.main, args)
+        self.compare_results(result, "", 0)
+
+    def test_code_option_check_changed(self) -> None:
+        """Test the code option when changes are required, and check is passed."""
+        args = ["--check", "--code", "print('hello world')"]
+        result = CliRunner().invoke(black.main, args)
+        self.compare_results(result, "", 1)
+
+    def test_code_option_diff(self) -> None:
+        """Test the code option when diff is passed."""
+        code = "print('hello world')"
+        formatted = black.format_str(code, mode=DEFAULT_MODE)
+        result_diff = diff(code, formatted, "STDIN", "STDOUT")
+
+        args = ["--diff", "--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        # Remove time from diff
+        output = DIFF_TIME.sub("", result.output)
+
+        assert output == result_diff, "The output did not match the expected value."
+        assert result.exit_code == 0, "The exit code is incorrect."
+
+    def test_code_option_color_diff(self) -> None:
+        """Test the code option when color and diff are passed."""
+        code = "print('hello world')"
+        formatted = black.format_str(code, mode=DEFAULT_MODE)
+
+        result_diff = diff(code, formatted, "STDIN", "STDOUT")
+        result_diff = color_diff(result_diff)
+
+        args = ["--diff", "--color", "--code", code]
+        result = CliRunner().invoke(black.main, args)
+
+        # Remove time from diff
+        output = DIFF_TIME.sub("", result.output)
+
+        assert output == result_diff, "The output did not match the expected value."
+        assert result.exit_code == 0, "The exit code is incorrect."
+
+    def test_code_option_safe(self) -> None:
+        """Test that the code option throws an error when the sanity checks fail."""
+        # Patch black.assert_equivalent to ensure the sanity checks fail
+        with patch.object(black, "assert_equivalent", side_effect=AssertionError):
+            code = 'print("Hello world")'
+            error_msg = f"{code}\nerror: cannot format <string>: \n"
+
+            args = ["--safe", "--code", code]
+            result = CliRunner().invoke(black.main, args)
+
+            self.compare_results(result, error_msg, 123)
+
+    def test_code_option_fast(self) -> None:
+        """Test that the code option ignores errors when the sanity checks fail."""
+        # Patch black.assert_equivalent to ensure the sanity checks fail
+        with patch.object(black, "assert_equivalent", side_effect=AssertionError):
+            code = 'print("Hello world")'
+            formatted = black.format_str(code, mode=DEFAULT_MODE)
+
+            args = ["--fast", "--code", code]
+            result = CliRunner().invoke(black.main, args)
+
+            self.compare_results(result, formatted, 0)
+
+    def test_code_option_config(self) -> None:
+        """
+        Test that the code option finds the pyproject.toml in the current directory.
+        """
+        with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
+            args = ["--code", "print"]
+            CliRunner().invoke(black.main, args)
+
+            pyproject_path = Path(Path().cwd(), "pyproject.toml").resolve()
+            assert (
+                len(parse.mock_calls) >= 1
+            ), "Expected config parse to be called with the current directory."
+
+            _, call_args, _ = parse.mock_calls[0]
+            assert (
+                call_args[0].lower() == str(pyproject_path).lower()
+            ), "Incorrect config loaded."
+
+    def test_code_option_parent_config(self) -> None:
+        """
+        Test that the code option finds the pyproject.toml in the parent directory.
+        """
+        with patch.object(black, "parse_pyproject_toml", return_value={}) as parse:
+            with change_directory(Path("tests")):
+                args = ["--code", "print"]
+                CliRunner().invoke(black.main, args)
+
+                pyproject_path = Path(Path().cwd().parent, "pyproject.toml").resolve()
+                assert (
+                    len(parse.mock_calls) >= 1
+                ), "Expected config parse to be called with the current directory."
+
+                _, call_args, _ = parse.mock_calls[0]
+                assert (
+                    call_args[0].lower() == str(pyproject_path).lower()
+                ), "Incorrect config loaded."
+
 
 with open(black.__file__, "r", encoding="utf-8") as _bf:
     black_source_lines = _bf.readlines()
 
 with open(black.__file__, "r", encoding="utf-8") as _bf:
     black_source_lines = _bf.readlines()