madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

MNT: remove unnecessary test deps + some refactoring (GH-2510)
authorRichard Si <63936253+ichard26@users.noreply.github.com>
Sat, 2 Oct 2021 23:37:32 +0000 (19:37 -0400)
committerGitHub <noreply@github.com>
Sat, 2 Oct 2021 23:37:32 +0000 (19:37 -0400)
The main goals of this commit include:

* improving consistency on how strict the test suite is -- Jelle has
  seen cases where a test did not fail to an incomplete test setup
  even though it should've
* simplifying tests for both ease of creation and reading via
  parametrization and helpers
* reorganizing the test suite by grouping more tests
* dropping test suite dependencies that aren't strictly necessary

The test suite could definitely do with more refactoring, but this is a
good first pass. Anyway it would've gotten too big to review effectively
if I did continue on this PR.

Commit history before squash merge:

* Drop parameterized dep and refactor format tests

Since the test suite is already using pytest-only features we can drop
the parameterized test dependency in favour of pytest's own offering.

I also added an utility function called assert_format that makes it
even easier to verify Black formats some code correctly. We already
have great tooling if the case is very simple in test_format.py but
any sort of complication makes it hard to use. Also if you're writing
a non-standard test case, you have to be careful to include all of
the steps so issues don't go undetected. assert_format aims to
1) improve consistency, 2) avoid wasted CPU cycles, and 3) avoid
logical errors that hide issues.

Finally, quite a few tests were either moved and/or simplified with
the new setup.

* Move file collection tests
* Add assert_collected_sources helper function

Testing source collection involves a lot of repetitive boilerplate,
something that black.files.get_sources's signature does not help with.
So to cut down on boilerplate like `report=black.Report()` I added
a convenience function to tests/test_black.py which wraps
black.get_sources. Its signature is designed to be much more lax to
make it much easier to use. Somehow this leads to cutting 100 lines!

Also IMO the test cases are much easier to read since it's more
declarative than really procedural now.

* Run isort on some test files
* Move cache tests
* Use pytest-style asserts & add parametrization
* Drop now unnecessary test dependencies

*pytest-cases might be interesting for further refactoring but I
haven't been able to wrap my head around it for the time being. We
can always revisit anyway.

Pipfile
Pipfile.lock
test_requirements.txt
tests/test_black.py
tests/test_format.py
tests/util.py

diff --git a/Pipfile b/Pipfile
index c6cd8d41ef54aeafd53e0e492a83ff14f8fe532c..534ca50fa5dd118a1e8d54e1e069157fa7ab3bc1 100644 (file)
--- a/Pipfile
+++ b/Pipfile
@@ -7,11 +7,8 @@ verify_ssl = true
 # Testing related requirements.
 coverage = ">= 5.3"
 pytest = " >= 6.1.1"
-pytest-mock = ">= 3.3.1"
-pytest-cases = ">= 2.3.0"
 pytest-xdist = ">= 2.2.1"
 pytest-cov = ">= 2.11.1"
-parameterized = ">= 0.7.4"
 tox = "*"
 
 # Linting related requirements.
index 22b66ba8ca35be487aa38d2d030f8619084261fc..280e6498af14c236dbf180ee3914c92bfc87d2dd 100644 (file)
@@ -1,7 +1,7 @@
 {
     "_meta": {
         "hash": {
-            "sha256": "192f075f04e702887745a3f19056b0172d83e4bc494fff4e0bcd6cfcafedd512"
+            "sha256": "6dbdff058fac8e6492f9d64194e98e48e062946ec4c06f9bb7df517d1dd89ce8"
         },
         "pipfile-spec": 6,
         "requires": {},
                 "sha256:8479067f342acf957dc82ec415d355ab5edb7e7646b90dc6e2fd1d96ad084c97"
             ],
             "index": "pypi",
-            "python_version <": "3.7",
-            "version": "==0.8",
-            "version >": "0.1.3"
-        },
-        "decopatch": {
-            "hashes": [
-                "sha256:29a74d5d753423b188d5b537532da4f4b88e33ddccb95a8a20a5eff5b13265d4",
-                "sha256:c66b0815f15db04de7bb52b0b276432b76b7346fe7046f28033f48a14340d144"
-            ],
-            "version": "==1.4.8"
+            "markers": "python_version < '3.7'",
+            "version": "==0.8"
         },
         "decorator": {
             "hashes": [
             "markers": "python_version >= '3.6'",
             "version": "==23.1.0"
         },
-        "makefun": {
-            "hashes": [
-                "sha256:033eed65e2c1804fca84161a38d1fc8bb8650d32a89ac1c5dc7e54b2b2c2e88c",
-                "sha256:a19bddf07efb6bf92e3ccde5d593e49bc59001fd6c17cf7301d7a73a2647ae83"
-            ],
-            "version": "==1.11.3"
-        },
         "markdown-it-py": {
             "hashes": [
                 "sha256:36be6bb3ad987bfdb839f5ba78ddf094552ca38ccbd784ae4f74a4e1419fc6e3",
             "markers": "python_version >= '3.6'",
             "version": "==21.0"
         },
-        "parameterized": {
-            "hashes": [
-                "sha256:41bbff37d6186430f77f900d777e5bb6a24928a1c46fb1de692f8b52b8833b5c",
-                "sha256:9cbb0b69a03e8695d68b3399a8a5825200976536fe1cb79db60ed6a4c8c9efe9"
-            ],
-            "index": "pypi",
-            "version": "==0.8.1"
-        },
         "parso": {
             "hashes": [
                 "sha256:12b83492c6239ce32ff5eed6d3639d6a536170723c6f3f1506869f1ace413398",
             "index": "pypi",
             "version": "==6.2.4"
         },
-        "pytest-cases": {
-            "hashes": [
-                "sha256:13136269240615bc79041f8af8fc96e0e3e085da72dd22b18625451fda2443b8",
-                "sha256:a4abe0ec2b8acf8f8b5ab73060de72eac745c6ed9cfa317d59ae71b4a0bbbdf5"
-            ],
-            "index": "pypi",
-            "version": "==3.6.3"
-        },
         "pytest-cov": {
             "hashes": [
                 "sha256:261bb9e47e65bd099c89c3edf92972865210c36813f80ede5277dceb77a4a62a",
             "markers": "python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4'",
             "version": "==1.3.0"
         },
-        "pytest-mock": {
-            "hashes": [
-                "sha256:30c2f2cc9759e76eee674b81ea28c9f0b94f8f0445a1b87762cadf774f0df7e3",
-                "sha256:40217a058c52a63f1042f0784f62009e976ba824c418cced42e88d5f40ab0e62"
-            ],
-            "index": "pypi",
-            "version": "==3.6.1"
-        },
         "pytest-xdist": {
             "hashes": [
                 "sha256:e8ecde2f85d88fbcadb7d28cb33da0fa29bca5cf7d5967fa89fc0e97e5299ea5",
index 31ab2d05feac5e3b3335d4745d18cfe991dc154a..5bc494d599966e2631f37c79d180e260c7f746be 100644 (file)
@@ -1,9 +1,6 @@
 coverage >= 5.3
 pre-commit
 pytest >= 6.1.1
-pytest-mock >= 3.3.1
-pytest-cases >= 2.3.0
 pytest-xdist >= 2.2.1
 pytest-cov >= 2.11.1
-parameterized >= 0.7.4
 tox
index 398a528bee9a4684fd52355597fcbc1f3bf0d3c3..f25db1b73d1ef9b102cce37ae1c89fe3566efbb4 100644 (file)
@@ -1,69 +1,70 @@
 #!/usr/bin/env python3
-import multiprocessing
+
 import asyncio
+import inspect
+import io
 import logging
+import multiprocessing
+import os
+import sys
+import types
+import unittest
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import contextmanager
 from dataclasses import replace
-import inspect
-import io
 from io import BytesIO
-import os
 from pathlib import Path
 from platform import system
-import regex as re
-import sys
 from tempfile import TemporaryDirectory
-import types
 from typing import (
     Any,
     Callable,
     Dict,
-    List,
     Iterator,
+    List,
+    Optional,
+    Sequence,
     TypeVar,
+    Union,
 )
-import pytest
-import unittest
-from unittest.mock import patch, MagicMock
-from parameterized import parameterized
+from unittest.mock import MagicMock, patch
 
 import click
+import pytest
+import regex as re
 from click import unstyle
 from click.testing import CliRunner
+from pathspec import PathSpec
 
 import black
+import black.files
 from black import Feature, TargetVersion
+from black import re_compile_maybe_verbose as compile_pattern
 from black.cache import get_cache_file
 from black.debug import DebugVisitor
-from black.output import diff, color_diff
+from black.output import color_diff, diff
 from black.report import Report
-import black.files
-
-from pathspec import PathSpec
 
 # Import other test classes
 from tests.util import (
-    THIS_DIR,
-    change_directory,
-    read_data,
+    DATA_DIR,
+    DEFAULT_MODE,
     DETERMINISTIC_HEADER,
+    PY36_VERSIONS,
+    THIS_DIR,
     BlackBaseTestCase,
-    DEFAULT_MODE,
-    fs,
-    ff,
+    assert_format,
+    change_directory,
     dump_to_stderr,
+    ff,
+    fs,
+    read_data,
 )
 
-
 THIS_FILE = Path(__file__)
-PY36_VERSIONS = {
-    TargetVersion.PY36,
-    TargetVersion.PY37,
-    TargetVersion.PY38,
-    TargetVersion.PY39,
-}
 PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
+DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES)
+DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES)
 T = TypeVar("T")
 R = TypeVar("R")
 
@@ -114,34 +115,26 @@ class BlackRunner(CliRunner):
         super().__init__(mix_stderr=False)
 
 
-class BlackTestCase(BlackBaseTestCase):
-    def invokeBlack(
-        self, args: List[str], exit_code: int = 0, ignore_config: bool = True
-    ) -> None:
-        runner = BlackRunner()
-        if ignore_config:
-            args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
-        result = runner.invoke(black.main, args)
-        assert result.stdout_bytes is not None
-        assert result.stderr_bytes is not None
-        self.assertEqual(
-            result.exit_code,
-            exit_code,
-            msg=(
-                f"Failed with args: {args}\n"
-                f"stdout: {result.stdout_bytes.decode()!r}\n"
-                f"stderr: {result.stderr_bytes.decode()!r}\n"
-                f"exception: {result.exception}"
-            ),
-        )
+def invokeBlack(
+    args: List[str], exit_code: int = 0, ignore_config: bool = True
+) -> None:
+    runner = BlackRunner()
+    if ignore_config:
+        args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args]
+    result = runner.invoke(black.main, args)
+    assert result.stdout_bytes is not None
+    assert result.stderr_bytes is not None
+    msg = (
+        f"Failed with args: {args}\n"
+        f"stdout: {result.stdout_bytes.decode()!r}\n"
+        f"stderr: {result.stderr_bytes.decode()!r}\n"
+        f"exception: {result.exception}"
+    )
+    assert result.exit_code == exit_code, msg
 
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_empty(self) -> None:
-        source = expected = ""
-        actual = fs(source)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, DEFAULT_MODE)
+
+class BlackTestCase(BlackBaseTestCase):
+    invokeBlack = staticmethod(invokeBlack)
 
     def test_empty_ff(self) -> None:
         expected = ""
@@ -266,32 +259,6 @@ class BlackTestCase(BlackBaseTestCase):
         actual = fs(fs(source))  # this is what `format_file_contents` does with --safe
         black.assert_stable(source, actual, DEFAULT_MODE)
 
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_pep_572(self) -> None:
-        source, expected = read_data("pep_572")
-        actual = fs(source)
-        self.assertFormatEqual(expected, actual)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-        if sys.version_info >= (3, 8):
-            black.assert_equivalent(source, actual)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_pep_572_remove_parens(self) -> None:
-        source, expected = read_data("pep_572_remove_parens")
-        actual = fs(source)
-        self.assertFormatEqual(expected, actual)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-        if sys.version_info >= (3, 8):
-            black.assert_equivalent(source, actual)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_pep_572_do_not_remove_parens(self) -> None:
-        source, expected = read_data("pep_572_do_not_remove_parens")
-        # the AST safety checks will fail, but that's expected, just make sure no
-        # parentheses are touched
-        actual = black.format_str(source, mode=DEFAULT_MODE)
-        self.assertFormatEqual(expected, actual)
-
     def test_pep_572_version_detection(self) -> None:
         source, _ = read_data("pep_572")
         root = black.lib2to3_parse(source)
@@ -300,14 +267,6 @@ class BlackTestCase(BlackBaseTestCase):
         versions = black.detect_target_versions(root)
         self.assertIn(black.TargetVersion.PY38, versions)
 
-    @parameterized.expand([(3, 9), (3, 10)])
-    def test_pep_572_newer_syntax(self, major: int, minor: int) -> None:
-        source, expected = read_data(f"pep_572_py{major}{minor}")
-        actual = fs(source, mode=DEFAULT_MODE)
-        self.assertFormatEqual(expected, actual)
-        if sys.version_info >= (major, minor):
-            black.assert_equivalent(source, actual)
-
     def test_expression_ff(self) -> None:
         source, expected = read_data("expression")
         tmp_file = Path(black.dump_to_file(source))
@@ -369,15 +328,6 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertIn("\033[31m", actual)
         self.assertIn("\033[0m", actual)
 
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_pep_570(self) -> None:
-        source, expected = read_data("pep_570")
-        actual = fs(source)
-        self.assertFormatEqual(expected, actual)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-        if sys.version_info >= (3, 8):
-            black.assert_equivalent(source, actual)
-
     def test_detect_pos_only_arguments(self) -> None:
         source, _ = read_data("pep_570")
         root = black.lib2to3_parse(source)
@@ -390,52 +340,13 @@ class BlackTestCase(BlackBaseTestCase):
     def test_string_quotes(self) -> None:
         source, expected = read_data("string_quotes")
         mode = black.Mode(experimental_string_processing=True)
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, mode)
+        assert_format(source, expected, mode)
         mode = replace(mode, string_normalization=False)
         not_normalized = fs(source, mode=mode)
         self.assertFormatEqual(source.replace("\\\n", ""), not_normalized)
         black.assert_equivalent(source, not_normalized)
         black.assert_stable(source, not_normalized, mode=mode)
 
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_docstring_no_string_normalization(self) -> None:
-        """Like test_docstring but with string normalization off."""
-        source, expected = read_data("docstring_no_string_normalization")
-        mode = replace(DEFAULT_MODE, string_normalization=False)
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, mode)
-
-    def test_long_strings_flag_disabled(self) -> None:
-        """Tests for turning off the string processing logic."""
-        source, expected = read_data("long_strings_flag_disabled")
-        mode = replace(DEFAULT_MODE, experimental_string_processing=False)
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_stable(expected, actual, mode)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_numeric_literals(self) -> None:
-        source, expected = read_data("numeric_literals")
-        mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, mode)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_numeric_literals_ignoring_underscores(self) -> None:
-        source, expected = read_data("numeric_literals_skip_underscores")
-        mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, mode)
-
     def test_skip_magic_trailing_comma(self) -> None:
         source, _ = read_data("expression.py")
         expected, _ = read_data("expression_skip_magic_trailing_comma.diff")
@@ -461,24 +372,6 @@ class BlackTestCase(BlackBaseTestCase):
             )
             self.assertEqual(expected, actual, msg)
 
-    @pytest.mark.python2
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_python2_print_function(self) -> None:
-        source, expected = read_data("python2_print_function")
-        mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27})
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, mode)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_stub(self) -> None:
-        mode = replace(DEFAULT_MODE, is_pyi=True)
-        source, expected = read_data("stub.pyi")
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        black.assert_stable(source, actual, mode)
-
     @patch("black.dump_to_file", dump_to_stderr)
     def test_async_as_identifier(self) -> None:
         source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve()
@@ -509,26 +402,6 @@ class BlackTestCase(BlackBaseTestCase):
         # but not on 3.6, because we use async as a reserved keyword
         self.invokeBlack([str(source_path), "--target-version", "py36"], exit_code=123)
 
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_python38(self) -> None:
-        source, expected = read_data("python38")
-        actual = fs(source)
-        self.assertFormatEqual(expected, actual)
-        major, minor = sys.version_info[:2]
-        if major > 3 or (major == 3 and minor >= 8):
-            black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_python39(self) -> None:
-        source, expected = read_data("python39")
-        actual = fs(source)
-        self.assertFormatEqual(expected, actual)
-        major, minor = sys.version_info[:2]
-        if major > 3 or (major == 3 and minor >= 9):
-            black.assert_equivalent(source, actual)
-        black.assert_stable(source, actual, DEFAULT_MODE)
-
     def test_tab_comment_indentation(self) -> None:
         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
@@ -1033,256 +906,67 @@ class BlackTestCase(BlackBaseTestCase):
         self.assertTrue("Actual tree:" in out_str)
         self.assertEqual("".join(err_lines), "")
 
-    def test_cache_broken_file(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir() as workspace:
-            cache_file = get_cache_file(mode)
-            with cache_file.open("w") as fobj:
-                fobj.write("this is not a pickle")
-            self.assertEqual(black.read_cache(mode), {})
-            src = (workspace / "test.py").resolve()
-            with src.open("w") as fobj:
-                fobj.write("print('hello')")
-            self.invokeBlack([str(src)])
-            cache = black.read_cache(mode)
-            self.assertIn(str(src), cache)
-
-    def test_cache_single_file_already_cached(self) -> None:
-        mode = DEFAULT_MODE
+    @event_loop()
+    @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
+    def test_works_in_mono_process_only_environment(self) -> None:
         with cache_dir() as workspace:
-            src = (workspace / "test.py").resolve()
-            with src.open("w") as fobj:
-                fobj.write("print('hello')")
-            black.write_cache({}, [src], mode)
-            self.invokeBlack([str(src)])
-            with src.open("r") as fobj:
-                self.assertEqual(fobj.read(), "print('hello')")
+            for f in [
+                (workspace / "one.py").resolve(),
+                (workspace / "two.py").resolve(),
+            ]:
+                f.write_text('print("hello")\n')
+            self.invokeBlack([str(workspace)])
 
     @event_loop()
-    def test_cache_multiple_files(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir() as workspace, patch(
-            "black.ProcessPoolExecutor", new=ThreadPoolExecutor
-        ):
-            one = (workspace / "one.py").resolve()
-            with one.open("w") as fobj:
-                fobj.write("print('hello')")
-            two = (workspace / "two.py").resolve()
-            with two.open("w") as fobj:
-                fobj.write("print('hello')")
-            black.write_cache({}, [one], mode)
-            self.invokeBlack([str(workspace)])
-            with one.open("r") as fobj:
-                self.assertEqual(fobj.read(), "print('hello')")
-            with two.open("r") as fobj:
-                self.assertEqual(fobj.read(), 'print("hello")\n')
-            cache = black.read_cache(mode)
-            self.assertIn(str(one), cache)
-            self.assertIn(str(two), cache)
+    def test_check_diff_use_together(self) -> None:
+        with cache_dir():
+            # Files which will be reformatted.
+            src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
+            self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
+            # Files which will not be reformatted.
+            src2 = (THIS_DIR / "data" / "composition.py").resolve()
+            self.invokeBlack([str(src2), "--diff", "--check"])
+            # Multi file command.
+            self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
 
-    def test_no_cache_when_writeback_diff(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir() as workspace:
-            src = (workspace / "test.py").resolve()
-            with src.open("w") as fobj:
-                fobj.write("print('hello')")
-            with patch("black.read_cache") as read_cache, patch(
-                "black.write_cache"
-            ) as write_cache:
-                self.invokeBlack([str(src), "--diff"])
-                cache_file = get_cache_file(mode)
-                self.assertFalse(cache_file.exists())
-                write_cache.assert_not_called()
-                read_cache.assert_not_called()
+    def test_no_files(self) -> None:
+        with cache_dir():
+            # Without an argument, black exits with error code 0.
+            self.invokeBlack([])
 
-    def test_no_cache_when_writeback_color_diff(self) -> None:
-        mode = DEFAULT_MODE
+    def test_broken_symlink(self) -> None:
         with cache_dir() as workspace:
-            src = (workspace / "test.py").resolve()
-            with src.open("w") as fobj:
-                fobj.write("print('hello')")
-            with patch("black.read_cache") as read_cache, patch(
-                "black.write_cache"
-            ) as write_cache:
-                self.invokeBlack([str(src), "--diff", "--color"])
-                cache_file = get_cache_file(mode)
-                self.assertFalse(cache_file.exists())
-                write_cache.assert_not_called()
-                read_cache.assert_not_called()
+            symlink = workspace / "broken_link.py"
+            try:
+                symlink.symlink_to("nonexistent.py")
+            except OSError as e:
+                self.skipTest(f"Can't create symlinks: {e}")
+            self.invokeBlack([str(workspace.resolve())])
 
-    @event_loop()
-    def test_output_locking_when_writeback_diff(self) -> None:
+    def test_single_file_force_pyi(self) -> None:
+        pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
+        contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
-            for tag in range(0, 4):
-                src = (workspace / f"test{tag}.py").resolve()
-                with src.open("w") as fobj:
-                    fobj.write("print('hello')")
-            with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
-                self.invokeBlack(["--diff", str(workspace)], exit_code=0)
-                # this isn't quite doing what we want, but if it _isn't_
-                # called then we cannot be using the lock it provides
-                mgr.assert_called()
+            path = (workspace / "file.py").resolve()
+            with open(path, "w") as fh:
+                fh.write(contents)
+            self.invokeBlack([str(path), "--pyi"])
+            with open(path, "r") as fh:
+                actual = fh.read()
+            # verify cache with --pyi is separate
+            pyi_cache = black.read_cache(pyi_mode)
+            self.assertIn(str(path), pyi_cache)
+            normal_cache = black.read_cache(DEFAULT_MODE)
+            self.assertNotIn(str(path), normal_cache)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(contents, actual)
+        black.assert_stable(contents, actual, pyi_mode)
 
     @event_loop()
-    def test_output_locking_when_writeback_color_diff(self) -> None:
-        with cache_dir() as workspace:
-            for tag in range(0, 4):
-                src = (workspace / f"test{tag}.py").resolve()
-                with src.open("w") as fobj:
-                    fobj.write("print('hello')")
-            with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
-                self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
-                # this isn't quite doing what we want, but if it _isn't_
-                # called then we cannot be using the lock it provides
-                mgr.assert_called()
-
-    def test_no_cache_when_stdin(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir():
-            result = CliRunner().invoke(
-                black.main, ["-"], input=BytesIO(b"print('hello')")
-            )
-            self.assertEqual(result.exit_code, 0)
-            cache_file = get_cache_file(mode)
-            self.assertFalse(cache_file.exists())
-
-    def test_read_cache_no_cachefile(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir():
-            self.assertEqual(black.read_cache(mode), {})
-
-    def test_write_cache_read_cache(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir() as workspace:
-            src = (workspace / "test.py").resolve()
-            src.touch()
-            black.write_cache({}, [src], mode)
-            cache = black.read_cache(mode)
-            self.assertIn(str(src), cache)
-            self.assertEqual(cache[str(src)], black.get_cache_info(src))
-
-    def test_filter_cached(self) -> None:
-        with TemporaryDirectory() as workspace:
-            path = Path(workspace)
-            uncached = (path / "uncached").resolve()
-            cached = (path / "cached").resolve()
-            cached_but_changed = (path / "changed").resolve()
-            uncached.touch()
-            cached.touch()
-            cached_but_changed.touch()
-            cache = {
-                str(cached): black.get_cache_info(cached),
-                str(cached_but_changed): (0.0, 0),
-            }
-            todo, done = black.filter_cached(
-                cache, {uncached, cached, cached_but_changed}
-            )
-            self.assertEqual(todo, {uncached, cached_but_changed})
-            self.assertEqual(done, {cached})
-
-    def test_write_cache_creates_directory_if_needed(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir(exists=False) as workspace:
-            self.assertFalse(workspace.exists())
-            black.write_cache({}, [], mode)
-            self.assertTrue(workspace.exists())
-
-    @event_loop()
-    def test_failed_formatting_does_not_get_cached(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir() as workspace, patch(
-            "black.ProcessPoolExecutor", new=ThreadPoolExecutor
-        ):
-            failing = (workspace / "failing.py").resolve()
-            with failing.open("w") as fobj:
-                fobj.write("not actually python")
-            clean = (workspace / "clean.py").resolve()
-            with clean.open("w") as fobj:
-                fobj.write('print("hello")\n')
-            self.invokeBlack([str(workspace)], exit_code=123)
-            cache = black.read_cache(mode)
-            self.assertNotIn(str(failing), cache)
-            self.assertIn(str(clean), cache)
-
-    def test_write_cache_write_fail(self) -> None:
-        mode = DEFAULT_MODE
-        with cache_dir(), patch.object(Path, "open") as mock:
-            mock.side_effect = OSError
-            black.write_cache({}, [], mode)
-
-    @event_loop()
-    @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError))
-    def test_works_in_mono_process_only_environment(self) -> None:
-        with cache_dir() as workspace:
-            for f in [
-                (workspace / "one.py").resolve(),
-                (workspace / "two.py").resolve(),
-            ]:
-                f.write_text('print("hello")\n')
-            self.invokeBlack([str(workspace)])
-
-    @event_loop()
-    def test_check_diff_use_together(self) -> None:
-        with cache_dir():
-            # Files which will be reformatted.
-            src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
-            self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
-            # Files which will not be reformatted.
-            src2 = (THIS_DIR / "data" / "composition.py").resolve()
-            self.invokeBlack([str(src2), "--diff", "--check"])
-            # Multi file command.
-            self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
-
-    def test_no_files(self) -> None:
-        with cache_dir():
-            # Without an argument, black exits with error code 0.
-            self.invokeBlack([])
-
-    def test_broken_symlink(self) -> None:
-        with cache_dir() as workspace:
-            symlink = workspace / "broken_link.py"
-            try:
-                symlink.symlink_to("nonexistent.py")
-            except OSError as e:
-                self.skipTest(f"Can't create symlinks: {e}")
-            self.invokeBlack([str(workspace.resolve())])
-
-    def test_read_cache_line_lengths(self) -> None:
-        mode = DEFAULT_MODE
-        short_mode = replace(DEFAULT_MODE, line_length=1)
-        with cache_dir() as workspace:
-            path = (workspace / "file.py").resolve()
-            path.touch()
-            black.write_cache({}, [path], mode)
-            one = black.read_cache(mode)
-            self.assertIn(str(path), one)
-            two = black.read_cache(short_mode)
-            self.assertNotIn(str(path), two)
-
-    def test_single_file_force_pyi(self) -> None:
-        pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
-        contents, expected = read_data("force_pyi")
-        with cache_dir() as workspace:
-            path = (workspace / "file.py").resolve()
-            with open(path, "w") as fh:
-                fh.write(contents)
-            self.invokeBlack([str(path), "--pyi"])
-            with open(path, "r") as fh:
-                actual = fh.read()
-            # verify cache with --pyi is separate
-            pyi_cache = black.read_cache(pyi_mode)
-            self.assertIn(str(path), pyi_cache)
-            normal_cache = black.read_cache(DEFAULT_MODE)
-            self.assertNotIn(str(path), normal_cache)
-        self.assertFormatEqual(expected, actual)
-        black.assert_equivalent(contents, actual)
-        black.assert_stable(contents, actual, pyi_mode)
-
-    @event_loop()
-    def test_multi_file_force_pyi(self) -> None:
-        reg_mode = DEFAULT_MODE
-        pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
-        contents, expected = read_data("force_pyi")
+    def test_multi_file_force_pyi(self) -> None:
+        reg_mode = DEFAULT_MODE
+        pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
+        contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
             paths = [
                 (workspace / "file1.py").resolve(),
@@ -1366,216 +1050,6 @@ class BlackTestCase(BlackBaseTestCase):
         actual = result.output
         self.assertFormatEqual(actual, expected)
 
-    def test_include_exclude(self) -> None:
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        include = re.compile(r"\.pyi?$")
-        exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
-        report = black.Report()
-        gitignore = PathSpec.from_lines("gitwildmatch", [])
-        sources: List[Path] = []
-        expected = [
-            Path(path / "b/dont_exclude/a.py"),
-            Path(path / "b/dont_exclude/a.pyi"),
-        ]
-        this_abs = THIS_DIR.resolve()
-        sources.extend(
-            black.gen_python_files(
-                path.iterdir(),
-                this_abs,
-                include,
-                exclude,
-                None,
-                None,
-                report,
-                gitignore,
-                verbose=False,
-                quiet=False,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    def test_gitignore_used_as_default(self) -> None:
-        path = Path(THIS_DIR / "data" / "include_exclude_tests")
-        include = re.compile(r"\.pyi?$")
-        extend_exclude = re.compile(r"/exclude/")
-        src = str(path / "b/")
-        report = black.Report()
-        expected: List[Path] = [
-            path / "b/.definitely_exclude/a.py",
-            path / "b/.definitely_exclude/a.pyi",
-        ]
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=include,
-                exclude=None,
-                extend_exclude=extend_exclude,
-                force_exclude=None,
-                report=report,
-                stdin_filename=None,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
-    def test_exclude_for_issue_1572(self) -> None:
-        # Exclude shouldn't touch files that were explicitly given to Black through the
-        # CLI. Exclude is supposed to only apply to the recursive discovery of files.
-        # https://github.com/psf/black/issues/1572
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        include = ""
-        exclude = r"/exclude/|a\.py"
-        src = str(path / "b/exclude/a.py")
-        report = black.Report()
-        expected = [Path(path / "b/exclude/a.py")]
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=re.compile(include),
-                exclude=re.compile(exclude),
-                extend_exclude=None,
-                force_exclude=None,
-                report=report,
-                stdin_filename=None,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
-    def test_get_sources_with_stdin(self) -> None:
-        include = ""
-        exclude = r"/exclude/|a\.py"
-        src = "-"
-        report = black.Report()
-        expected = [Path("-")]
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=re.compile(include),
-                exclude=re.compile(exclude),
-                extend_exclude=None,
-                force_exclude=None,
-                report=report,
-                stdin_filename=None,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
-    def test_get_sources_with_stdin_filename(self) -> None:
-        include = ""
-        exclude = r"/exclude/|a\.py"
-        src = "-"
-        report = black.Report()
-        stdin_filename = str(THIS_DIR / "data/collections.py")
-        expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=re.compile(include),
-                exclude=re.compile(exclude),
-                extend_exclude=None,
-                force_exclude=None,
-                report=report,
-                stdin_filename=stdin_filename,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
-    def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
-        # Exclude shouldn't exclude stdin_filename since it is mimicking the
-        # file being passed directly. This is the same as
-        # test_exclude_for_issue_1572
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        include = ""
-        exclude = r"/exclude/|a\.py"
-        src = "-"
-        report = black.Report()
-        stdin_filename = str(path / "b/exclude/a.py")
-        expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=re.compile(include),
-                exclude=re.compile(exclude),
-                extend_exclude=None,
-                force_exclude=None,
-                report=report,
-                stdin_filename=stdin_filename,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
-    def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
-        # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
-        # file being passed directly. This is the same as
-        # test_exclude_for_issue_1572
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        include = ""
-        extend_exclude = r"/exclude/|a\.py"
-        src = "-"
-        report = black.Report()
-        stdin_filename = str(path / "b/exclude/a.py")
-        expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")]
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=re.compile(include),
-                exclude=re.compile(""),
-                extend_exclude=re.compile(extend_exclude),
-                force_exclude=None,
-                report=report,
-                stdin_filename=stdin_filename,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
-    def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
-        # Force exclude should exclude the file when passing it through
-        # stdin_filename
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        include = ""
-        force_exclude = r"/exclude/|a\.py"
-        src = "-"
-        report = black.Report()
-        stdin_filename = str(path / "b/exclude/a.py")
-        sources = list(
-            black.get_sources(
-                ctx=FakeContext(),
-                src=(src,),
-                quiet=True,
-                verbose=False,
-                include=re.compile(include),
-                exclude=re.compile(""),
-                extend_exclude=None,
-                force_exclude=re.compile(force_exclude),
-                report=report,
-                stdin_filename=stdin_filename,
-            )
-        )
-        self.assertEqual([], sorted(sources))
-
     def test_reformat_one_with_stdin(self) -> None:
         with patch(
             "black.format_stdin_to_stdout",
@@ -1701,158 +1175,13 @@ class BlackTestCase(BlackBaseTestCase):
                 pass  # StringIO does not support detach
             assert output.getvalue() == ""
 
-    def test_gitignore_exclude(self) -> None:
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        include = re.compile(r"\.pyi?$")
-        exclude = re.compile(r"")
-        report = black.Report()
-        gitignore = PathSpec.from_lines(
-            "gitwildmatch", ["exclude/", ".definitely_exclude"]
-        )
-        sources: List[Path] = []
-        expected = [
-            Path(path / "b/dont_exclude/a.py"),
-            Path(path / "b/dont_exclude/a.pyi"),
-        ]
-        this_abs = THIS_DIR.resolve()
-        sources.extend(
-            black.gen_python_files(
-                path.iterdir(),
-                this_abs,
-                include,
-                exclude,
-                None,
-                None,
-                report,
-                gitignore,
-                verbose=False,
-                quiet=False,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    def test_nested_gitignore(self) -> None:
-        path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
-        include = re.compile(r"\.pyi?$")
-        exclude = re.compile(r"")
-        root_gitignore = black.files.get_gitignore(path)
-        report = black.Report()
-        expected: List[Path] = [
-            Path(path / "x.py"),
-            Path(path / "root/b.py"),
-            Path(path / "root/c.py"),
-            Path(path / "root/child/c.py"),
-        ]
-        this_abs = THIS_DIR.resolve()
-        sources = list(
-            black.gen_python_files(
-                path.iterdir(),
-                this_abs,
-                include,
-                exclude,
-                None,
-                None,
-                report,
-                root_gitignore,
-                verbose=False,
-                quiet=False,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    def test_invalid_gitignore(self) -> None:
-        path = THIS_DIR / "data" / "invalid_gitignore_tests"
-        empty_config = path / "pyproject.toml"
-        result = BlackRunner().invoke(
-            black.main, ["--verbose", "--config", str(empty_config), str(path)]
-        )
-        assert result.exit_code == 1
-        assert result.stderr_bytes is not None
-
-        gitignore = path / ".gitignore"
-        assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
-
-    def test_invalid_nested_gitignore(self) -> None:
-        path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
-        empty_config = path / "pyproject.toml"
-        result = BlackRunner().invoke(
-            black.main, ["--verbose", "--config", str(empty_config), str(path)]
-        )
-        assert result.exit_code == 1
-        assert result.stderr_bytes is not None
-
-        gitignore = path / "a" / ".gitignore"
-        assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
-
-    def test_empty_include(self) -> None:
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        report = black.Report()
-        gitignore = PathSpec.from_lines("gitwildmatch", [])
-        empty = re.compile(r"")
-        sources: List[Path] = []
-        expected = [
-            Path(path / "b/exclude/a.pie"),
-            Path(path / "b/exclude/a.py"),
-            Path(path / "b/exclude/a.pyi"),
-            Path(path / "b/dont_exclude/a.pie"),
-            Path(path / "b/dont_exclude/a.py"),
-            Path(path / "b/dont_exclude/a.pyi"),
-            Path(path / "b/.definitely_exclude/a.pie"),
-            Path(path / "b/.definitely_exclude/a.py"),
-            Path(path / "b/.definitely_exclude/a.pyi"),
-            Path(path / ".gitignore"),
-            Path(path / "pyproject.toml"),
-        ]
-        this_abs = THIS_DIR.resolve()
-        sources.extend(
-            black.gen_python_files(
-                path.iterdir(),
-                this_abs,
-                empty,
-                re.compile(black.DEFAULT_EXCLUDES),
-                None,
-                None,
-                report,
-                gitignore,
-                verbose=False,
-                quiet=False,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    def test_extend_exclude(self) -> None:
-        path = THIS_DIR / "data" / "include_exclude_tests"
-        report = black.Report()
-        gitignore = PathSpec.from_lines("gitwildmatch", [])
-        sources: List[Path] = []
-        expected = [
-            Path(path / "b/exclude/a.py"),
-            Path(path / "b/dont_exclude/a.py"),
-        ]
-        this_abs = THIS_DIR.resolve()
-        sources.extend(
-            black.gen_python_files(
-                path.iterdir(),
-                this_abs,
-                re.compile(black.DEFAULT_INCLUDES),
-                re.compile(r"\.pyi$"),
-                re.compile(r"\.definitely_exclude"),
-                None,
-                report,
-                gitignore,
-                verbose=False,
-                quiet=False,
-            )
-        )
-        self.assertEqual(sorted(expected), sorted(sources))
-
-    def test_invalid_cli_regex(self) -> None:
-        for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
-            self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
-
-    def test_required_version_matches_version(self) -> None:
-        self.invokeBlack(
-            ["--required-version", black.__version__], exit_code=0, ignore_config=True
+    def test_invalid_cli_regex(self) -> None:
+        for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]:
+            self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
+
+    def test_required_version_matches_version(self) -> None:
+        self.invokeBlack(
+            ["--required-version", black.__version__], exit_code=0, ignore_config=True
         )
 
     def test_required_version_does_not_match_version(self) -> None:
@@ -1889,65 +1218,6 @@ class BlackTestCase(BlackBaseTestCase):
         with self.assertRaises(AssertionError):
             black.assert_equivalent("{}", "None")
 
-    def test_symlink_out_of_root_directory(self) -> None:
-        path = MagicMock()
-        root = THIS_DIR.resolve()
-        child = MagicMock()
-        include = re.compile(black.DEFAULT_INCLUDES)
-        exclude = re.compile(black.DEFAULT_EXCLUDES)
-        report = black.Report()
-        gitignore = PathSpec.from_lines("gitwildmatch", [])
-        # `child` should behave like a symlink which resolved path is clearly
-        # outside of the `root` directory.
-        path.iterdir.return_value = [child]
-        child.resolve.return_value = Path("/a/b/c")
-        child.as_posix.return_value = "/a/b/c"
-        child.is_symlink.return_value = True
-        try:
-            list(
-                black.gen_python_files(
-                    path.iterdir(),
-                    root,
-                    include,
-                    exclude,
-                    None,
-                    None,
-                    report,
-                    gitignore,
-                    verbose=False,
-                    quiet=False,
-                )
-            )
-        except ValueError as ve:
-            self.fail(f"`get_python_files_in_dir()` failed: {ve}")
-        path.iterdir.assert_called_once()
-        child.resolve.assert_called_once()
-        child.is_symlink.assert_called_once()
-        # `child` should behave like a strange file which resolved path is clearly
-        # outside of the `root` directory.
-        child.is_symlink.return_value = False
-        with self.assertRaises(ValueError):
-            list(
-                black.gen_python_files(
-                    path.iterdir(),
-                    root,
-                    include,
-                    exclude,
-                    None,
-                    None,
-                    report,
-                    gitignore,
-                    verbose=False,
-                    quiet=False,
-                )
-            )
-        path.iterdir.assert_called()
-        self.assertEqual(path.iterdir.call_count, 2)
-        child.resolve.assert_called()
-        self.assertEqual(child.resolve.call_count, 2)
-        child.is_symlink.assert_called()
-        self.assertEqual(child.is_symlink.call_count, 2)
-
     def test_shhh_click(self) -> None:
         try:
             from click import _unicodefun
@@ -2270,31 +1540,497 @@ class BlackTestCase(BlackBaseTestCase):
                 ), "Incorrect config loaded."
 
 
-with open(black.__file__, "r", encoding="utf-8") as _bf:
-    black_source_lines = _bf.readlines()
+class TestCaching:
+    def test_cache_broken_file(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace:
+            cache_file = get_cache_file(mode)
+            cache_file.write_text("this is not a pickle")
+            assert black.read_cache(mode) == {}
+            src = (workspace / "test.py").resolve()
+            src.write_text("print('hello')")
+            invokeBlack([str(src)])
+            cache = black.read_cache(mode)
+            assert str(src) in cache
 
+    def test_cache_single_file_already_cached(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            src.write_text("print('hello')")
+            black.write_cache({}, [src], mode)
+            invokeBlack([str(src)])
+            assert src.read_text() == "print('hello')"
 
-def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
-    """Show function calls `from black/__init__.py` as they happen.
+    @event_loop()
+    def test_cache_multiple_files(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace, patch(
+            "black.ProcessPoolExecutor", new=ThreadPoolExecutor
+        ):
+            one = (workspace / "one.py").resolve()
+            with one.open("w") as fobj:
+                fobj.write("print('hello')")
+            two = (workspace / "two.py").resolve()
+            with two.open("w") as fobj:
+                fobj.write("print('hello')")
+            black.write_cache({}, [one], mode)
+            invokeBlack([str(workspace)])
+            with one.open("r") as fobj:
+                assert fobj.read() == "print('hello')"
+            with two.open("r") as fobj:
+                assert fobj.read() == 'print("hello")\n'
+            cache = black.read_cache(mode)
+            assert str(one) in cache
+            assert str(two) in cache
 
-    Register this with `sys.settrace()` in a test you're debugging.
-    """
-    if event != "call":
-        return tracefunc
+    @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
+    def test_no_cache_when_writeback_diff(self, color: bool) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            with src.open("w") as fobj:
+                fobj.write("print('hello')")
+            with patch("black.read_cache") as read_cache, patch(
+                "black.write_cache"
+            ) as write_cache:
+                cmd = [str(src), "--diff"]
+                if color:
+                    cmd.append("--color")
+                invokeBlack(cmd)
+                cache_file = get_cache_file(mode)
+                assert cache_file.exists() is False
+                write_cache.assert_not_called()
+                read_cache.assert_not_called()
 
-    stack = len(inspect.stack()) - 19
-    stack *= 2
-    filename = frame.f_code.co_filename
-    lineno = frame.f_lineno
-    func_sig_lineno = lineno - 1
-    funcname = black_source_lines[func_sig_lineno].strip()
-    while funcname.startswith("@"):
-        func_sig_lineno += 1
-        funcname = black_source_lines[func_sig_lineno].strip()
-    if "black/__init__.py" in filename:
-        print(f"{' ' * stack}{lineno}:{funcname}")
-    return tracefunc
+    @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
+    @event_loop()
+    def test_output_locking_when_writeback_diff(self, color: bool) -> None:
+        with cache_dir() as workspace:
+            for tag in range(0, 4):
+                src = (workspace / f"test{tag}.py").resolve()
+                with src.open("w") as fobj:
+                    fobj.write("print('hello')")
+            with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
+                cmd = ["--diff", str(workspace)]
+                if color:
+                    cmd.append("--color")
+                invokeBlack(cmd, exit_code=0)
+                # this isn't quite doing what we want, but if it _isn't_
+                # called then we cannot be using the lock it provides
+                mgr.assert_called()
 
+    def test_no_cache_when_stdin(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir():
+            result = CliRunner().invoke(
+                black.main, ["-"], input=BytesIO(b"print('hello')")
+            )
+            assert not result.exit_code
+            cache_file = get_cache_file(mode)
+            assert not cache_file.exists()
 
-if __name__ == "__main__":
-    unittest.main(module="test_black")
+    def test_read_cache_no_cachefile(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir():
+            assert black.read_cache(mode) == {}
+
+    def test_write_cache_read_cache(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            src.touch()
+            black.write_cache({}, [src], mode)
+            cache = black.read_cache(mode)
+            assert str(src) in cache
+            assert cache[str(src)] == black.get_cache_info(src)
+
+    def test_filter_cached(self) -> None:
+        with TemporaryDirectory() as workspace:
+            path = Path(workspace)
+            uncached = (path / "uncached").resolve()
+            cached = (path / "cached").resolve()
+            cached_but_changed = (path / "changed").resolve()
+            uncached.touch()
+            cached.touch()
+            cached_but_changed.touch()
+            cache = {
+                str(cached): black.get_cache_info(cached),
+                str(cached_but_changed): (0.0, 0),
+            }
+            todo, done = black.filter_cached(
+                cache, {uncached, cached, cached_but_changed}
+            )
+            assert todo == {uncached, cached_but_changed}
+            assert done == {cached}
+
+    def test_write_cache_creates_directory_if_needed(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir(exists=False) as workspace:
+            assert not workspace.exists()
+            black.write_cache({}, [], mode)
+            assert workspace.exists()
+
+    @event_loop()
+    def test_failed_formatting_does_not_get_cached(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir() as workspace, patch(
+            "black.ProcessPoolExecutor", new=ThreadPoolExecutor
+        ):
+            failing = (workspace / "failing.py").resolve()
+            with failing.open("w") as fobj:
+                fobj.write("not actually python")
+            clean = (workspace / "clean.py").resolve()
+            with clean.open("w") as fobj:
+                fobj.write('print("hello")\n')
+            invokeBlack([str(workspace)], exit_code=123)
+            cache = black.read_cache(mode)
+            assert str(failing) not in cache
+            assert str(clean) in cache
+
+    def test_write_cache_write_fail(self) -> None:
+        mode = DEFAULT_MODE
+        with cache_dir(), patch.object(Path, "open") as mock:
+            mock.side_effect = OSError
+            black.write_cache({}, [], mode)
+
+    def test_read_cache_line_lengths(self) -> None:
+        mode = DEFAULT_MODE
+        short_mode = replace(DEFAULT_MODE, line_length=1)
+        with cache_dir() as workspace:
+            path = (workspace / "file.py").resolve()
+            path.touch()
+            black.write_cache({}, [path], mode)
+            one = black.read_cache(mode)
+            assert str(path) in one
+            two = black.read_cache(short_mode)
+            assert str(path) not in two
+
+
+def assert_collected_sources(
+    src: Sequence[Union[str, Path]],
+    expected: Sequence[Union[str, Path]],
+    *,
+    exclude: Optional[str] = None,
+    include: Optional[str] = None,
+    extend_exclude: Optional[str] = None,
+    force_exclude: Optional[str] = None,
+    stdin_filename: Optional[str] = None,
+) -> None:
+    gs_src = tuple(str(Path(s)) for s in src)
+    gs_expected = [Path(s) for s in expected]
+    gs_exclude = None if exclude is None else compile_pattern(exclude)
+    gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
+    gs_extend_exclude = (
+        None if extend_exclude is None else compile_pattern(extend_exclude)
+    )
+    gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
+    collected = black.get_sources(
+        ctx=FakeContext(),
+        src=gs_src,
+        quiet=False,
+        verbose=False,
+        include=gs_include,
+        exclude=gs_exclude,
+        extend_exclude=gs_extend_exclude,
+        force_exclude=gs_force_exclude,
+        report=black.Report(),
+        stdin_filename=stdin_filename,
+    )
+    assert sorted(list(collected)) == sorted(gs_expected)
+
+
+class TestFileCollection:
+    def test_include_exclude(self) -> None:
+        path = THIS_DIR / "data" / "include_exclude_tests"
+        src = [path]
+        expected = [
+            Path(path / "b/dont_exclude/a.py"),
+            Path(path / "b/dont_exclude/a.pyi"),
+        ]
+        assert_collected_sources(
+            src,
+            expected,
+            include=r"\.pyi?$",
+            exclude=r"/exclude/|/\.definitely_exclude/",
+        )
+
+    def test_gitignore_used_as_default(self) -> None:
+        base = Path(DATA_DIR / "include_exclude_tests")
+        expected = [
+            base / "b/.definitely_exclude/a.py",
+            base / "b/.definitely_exclude/a.pyi",
+        ]
+        src = [base / "b/"]
+        assert_collected_sources(src, expected, extend_exclude=r"/exclude/")
+
+    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    def test_exclude_for_issue_1572(self) -> None:
+        # Exclude shouldn't touch files that were explicitly given to Black through the
+        # CLI. Exclude is supposed to only apply to the recursive discovery of files.
+        # https://github.com/psf/black/issues/1572
+        path = DATA_DIR / "include_exclude_tests"
+        src = [path / "b/exclude/a.py"]
+        expected = [path / "b/exclude/a.py"]
+        assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
+
+    def test_gitignore_exclude(self) -> None:
+        path = THIS_DIR / "data" / "include_exclude_tests"
+        include = re.compile(r"\.pyi?$")
+        exclude = re.compile(r"")
+        report = black.Report()
+        gitignore = PathSpec.from_lines(
+            "gitwildmatch", ["exclude/", ".definitely_exclude"]
+        )
+        sources: List[Path] = []
+        expected = [
+            Path(path / "b/dont_exclude/a.py"),
+            Path(path / "b/dont_exclude/a.pyi"),
+        ]
+        this_abs = THIS_DIR.resolve()
+        sources.extend(
+            black.gen_python_files(
+                path.iterdir(),
+                this_abs,
+                include,
+                exclude,
+                None,
+                None,
+                report,
+                gitignore,
+                verbose=False,
+                quiet=False,
+            )
+        )
+        assert sorted(expected) == sorted(sources)
+
+    def test_nested_gitignore(self) -> None:
+        path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
+        include = re.compile(r"\.pyi?$")
+        exclude = re.compile(r"")
+        root_gitignore = black.files.get_gitignore(path)
+        report = black.Report()
+        expected: List[Path] = [
+            Path(path / "x.py"),
+            Path(path / "root/b.py"),
+            Path(path / "root/c.py"),
+            Path(path / "root/child/c.py"),
+        ]
+        this_abs = THIS_DIR.resolve()
+        sources = list(
+            black.gen_python_files(
+                path.iterdir(),
+                this_abs,
+                include,
+                exclude,
+                None,
+                None,
+                report,
+                root_gitignore,
+                verbose=False,
+                quiet=False,
+            )
+        )
+        assert sorted(expected) == sorted(sources)
+
+    def test_invalid_gitignore(self) -> None:
+        path = THIS_DIR / "data" / "invalid_gitignore_tests"
+        empty_config = path / "pyproject.toml"
+        result = BlackRunner().invoke(
+            black.main, ["--verbose", "--config", str(empty_config), str(path)]
+        )
+        assert result.exit_code == 1
+        assert result.stderr_bytes is not None
+
+        gitignore = path / ".gitignore"
+        assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
+
+    def test_invalid_nested_gitignore(self) -> None:
+        path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
+        empty_config = path / "pyproject.toml"
+        result = BlackRunner().invoke(
+            black.main, ["--verbose", "--config", str(empty_config), str(path)]
+        )
+        assert result.exit_code == 1
+        assert result.stderr_bytes is not None
+
+        gitignore = path / "a" / ".gitignore"
+        assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
+
+    def test_empty_include(self) -> None:
+        path = DATA_DIR / "include_exclude_tests"
+        src = [path]
+        expected = [
+            Path(path / "b/exclude/a.pie"),
+            Path(path / "b/exclude/a.py"),
+            Path(path / "b/exclude/a.pyi"),
+            Path(path / "b/dont_exclude/a.pie"),
+            Path(path / "b/dont_exclude/a.py"),
+            Path(path / "b/dont_exclude/a.pyi"),
+            Path(path / "b/.definitely_exclude/a.pie"),
+            Path(path / "b/.definitely_exclude/a.py"),
+            Path(path / "b/.definitely_exclude/a.pyi"),
+            Path(path / ".gitignore"),
+            Path(path / "pyproject.toml"),
+        ]
+        # Setting exclude explicitly to an empty string to block .gitignore usage.
+        assert_collected_sources(src, expected, include="", exclude="")
+
+    def test_extend_exclude(self) -> None:
+        path = DATA_DIR / "include_exclude_tests"
+        src = [path]
+        expected = [
+            Path(path / "b/exclude/a.py"),
+            Path(path / "b/dont_exclude/a.py"),
+        ]
+        assert_collected_sources(
+            src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
+        )
+
+    def test_symlink_out_of_root_directory(self) -> None:
+        path = MagicMock()
+        root = THIS_DIR.resolve()
+        child = MagicMock()
+        include = re.compile(black.DEFAULT_INCLUDES)
+        exclude = re.compile(black.DEFAULT_EXCLUDES)
+        report = black.Report()
+        gitignore = PathSpec.from_lines("gitwildmatch", [])
+        # `child` should behave like a symlink which resolved path is clearly
+        # outside of the `root` directory.
+        path.iterdir.return_value = [child]
+        child.resolve.return_value = Path("/a/b/c")
+        child.as_posix.return_value = "/a/b/c"
+        child.is_symlink.return_value = True
+        try:
+            list(
+                black.gen_python_files(
+                    path.iterdir(),
+                    root,
+                    include,
+                    exclude,
+                    None,
+                    None,
+                    report,
+                    gitignore,
+                    verbose=False,
+                    quiet=False,
+                )
+            )
+        except ValueError as ve:
+            pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
+        path.iterdir.assert_called_once()
+        child.resolve.assert_called_once()
+        child.is_symlink.assert_called_once()
+        # `child` should behave like a strange file which resolved path is clearly
+        # outside of the `root` directory.
+        child.is_symlink.return_value = False
+        with pytest.raises(ValueError):
+            list(
+                black.gen_python_files(
+                    path.iterdir(),
+                    root,
+                    include,
+                    exclude,
+                    None,
+                    None,
+                    report,
+                    gitignore,
+                    verbose=False,
+                    quiet=False,
+                )
+            )
+        path.iterdir.assert_called()
+        assert path.iterdir.call_count == 2
+        child.resolve.assert_called()
+        assert child.resolve.call_count == 2
+        child.is_symlink.assert_called()
+        assert child.is_symlink.call_count == 2
+
+    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    def test_get_sources_with_stdin(self) -> None:
+        src = ["-"]
+        expected = ["-"]
+        assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
+
+    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    def test_get_sources_with_stdin_filename(self) -> None:
+        src = ["-"]
+        stdin_filename = str(THIS_DIR / "data/collections.py")
+        expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
+        assert_collected_sources(
+            src,
+            expected,
+            exclude=r"/exclude/a\.py",
+            stdin_filename=stdin_filename,
+        )
+
+    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
+        # Exclude shouldn't exclude stdin_filename since it is mimicking the
+        # file being passed directly. This is the same as
+        # test_exclude_for_issue_1572
+        path = DATA_DIR / "include_exclude_tests"
+        src = ["-"]
+        stdin_filename = str(path / "b/exclude/a.py")
+        expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
+        assert_collected_sources(
+            src,
+            expected,
+            exclude=r"/exclude/|a\.py",
+            stdin_filename=stdin_filename,
+        )
+
+    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
+        # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
+        # file being passed directly. This is the same as
+        # test_exclude_for_issue_1572
+        src = ["-"]
+        path = THIS_DIR / "data" / "include_exclude_tests"
+        stdin_filename = str(path / "b/exclude/a.py")
+        expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
+        assert_collected_sources(
+            src,
+            expected,
+            extend_exclude=r"/exclude/|a\.py",
+            stdin_filename=stdin_filename,
+        )
+
+    @patch("black.find_project_root", lambda *args: THIS_DIR.resolve())
+    def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
+        # Force exclude should exclude the file when passing it through
+        # stdin_filename
+        path = THIS_DIR / "data" / "include_exclude_tests"
+        stdin_filename = str(path / "b/exclude/a.py")
+        assert_collected_sources(
+            src=["-"],
+            expected=[],
+            force_exclude=r"/exclude/|a\.py",
+            stdin_filename=stdin_filename,
+        )
+
+
+with open(black.__file__, "r", encoding="utf-8") as _bf:
+    black_source_lines = _bf.readlines()
+
+
+def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
+    """Show function calls `from black/__init__.py` as they happen.
+
+    Register this with `sys.settrace()` in a test you're debugging.
+    """
+    if event != "call":
+        return tracefunc
+
+    stack = len(inspect.stack()) - 19
+    stack *= 2
+    filename = frame.f_code.co_filename
+    lineno = frame.f_lineno
+    func_sig_lineno = lineno - 1
+    funcname = black_source_lines[func_sig_lineno].strip()
+    while funcname.startswith("@"):
+        func_sig_lineno += 1
+        funcname = black_source_lines[func_sig_lineno].strip()
+    if "black/__init__.py" in filename:
+        print(f"{' ' * stack}{lineno}:{funcname}")
+    return tracefunc
index fc9678ad27cda3c0e6791f54c0bcf6858c24e5c6..a659382092ac7251ce6d9be97d09b856adc0a810 100644 (file)
@@ -1,16 +1,17 @@
+from dataclasses import replace
+from typing import Any, Iterator
 from unittest.mock import patch
 
-import black
 import pytest
-from parameterized import parameterized
 
+import black
 from tests.util import (
-    BlackBaseTestCase,
-    fs,
     DEFAULT_MODE,
+    PY36_VERSIONS,
+    THIS_DIR,
+    assert_format,
     dump_to_stderr,
     read_data,
-    THIS_DIR,
 )
 
 SIMPLE_CASES = [
@@ -113,33 +114,121 @@ SOURCES = [
 ]
 
 
-class TestSimpleFormat(BlackBaseTestCase):
-    @parameterized.expand(SIMPLE_CASES_PY2)
-    @pytest.mark.python2
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_simple_format_py2(self, filename: str) -> None:
-        self.check_file(filename, DEFAULT_MODE)
-
-    @parameterized.expand(SIMPLE_CASES)
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_simple_format(self, filename: str) -> None:
-        self.check_file(filename, DEFAULT_MODE)
-
-    @parameterized.expand(EXPERIMENTAL_STRING_PROCESSING_CASES)
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_experimental_format(self, filename: str) -> None:
-        self.check_file(filename, black.Mode(experimental_string_processing=True))
-
-    @parameterized.expand(SOURCES)
-    @patch("black.dump_to_file", dump_to_stderr)
-    def test_source_is_formatted(self, filename: str) -> None:
-        path = THIS_DIR.parent / filename
-        self.check_file(str(path), DEFAULT_MODE, data=False)
-
-    def check_file(self, filename: str, mode: black.Mode, *, data: bool = True) -> None:
-        source, expected = read_data(filename, data=data)
-        actual = fs(source, mode=mode)
-        self.assertFormatEqual(expected, actual)
-        if source != actual:
-            black.assert_equivalent(source, actual)
-            black.assert_stable(source, actual, mode)
+@pytest.fixture(autouse=True)
+def patch_dump_to_file(request: Any) -> Iterator[None]:
+    with patch("black.dump_to_file", dump_to_stderr):
+        yield
+
+
+def check_file(filename: str, mode: black.Mode, *, data: bool = True) -> None:
+    source, expected = read_data(filename, data=data)
+    assert_format(source, expected, mode, fast=False)
+
+
+@pytest.mark.parametrize("filename", SIMPLE_CASES_PY2)
+@pytest.mark.python2
+def test_simple_format_py2(filename: str) -> None:
+    check_file(filename, DEFAULT_MODE)
+
+
+@pytest.mark.parametrize("filename", SIMPLE_CASES)
+def test_simple_format(filename: str) -> None:
+    check_file(filename, DEFAULT_MODE)
+
+
+@pytest.mark.parametrize("filename", EXPERIMENTAL_STRING_PROCESSING_CASES)
+def test_experimental_format(filename: str) -> None:
+    check_file(filename, black.Mode(experimental_string_processing=True))
+
+
+@pytest.mark.parametrize("filename", SOURCES)
+def test_source_is_formatted(filename: str) -> None:
+    path = THIS_DIR.parent / filename
+    check_file(str(path), DEFAULT_MODE, data=False)
+
+
+# =============== #
+# Complex cases
+# ============= #
+
+
+def test_empty() -> None:
+    source = expected = ""
+    assert_format(source, expected)
+
+
+def test_pep_572() -> None:
+    source, expected = read_data("pep_572")
+    assert_format(source, expected, minimum_version=(3, 8))
+
+
+def test_pep_572_remove_parens() -> None:
+    source, expected = read_data("pep_572_remove_parens")
+    assert_format(source, expected, minimum_version=(3, 8))
+
+
+def test_pep_572_do_not_remove_parens() -> None:
+    source, expected = read_data("pep_572_do_not_remove_parens")
+    # the AST safety checks will fail, but that's expected, just make sure no
+    # parentheses are touched
+    assert_format(source, expected, fast=True)
+
+
+@pytest.mark.parametrize("major, minor", [(3, 9), (3, 10)])
+def test_pep_572_newer_syntax(major: int, minor: int) -> None:
+    source, expected = read_data(f"pep_572_py{major}{minor}")
+    assert_format(source, expected, minimum_version=(major, minor))
+
+
+def test_pep_570() -> None:
+    source, expected = read_data("pep_570")
+    assert_format(source, expected, minimum_version=(3, 8))
+
+
+def test_docstring_no_string_normalization() -> None:
+    """Like test_docstring but with string normalization off."""
+    source, expected = read_data("docstring_no_string_normalization")
+    mode = replace(DEFAULT_MODE, string_normalization=False)
+    assert_format(source, expected, mode)
+
+
+def test_long_strings_flag_disabled() -> None:
+    """Tests for turning off the string processing logic."""
+    source, expected = read_data("long_strings_flag_disabled")
+    mode = replace(DEFAULT_MODE, experimental_string_processing=False)
+    assert_format(source, expected, mode)
+
+
+def test_numeric_literals() -> None:
+    source, expected = read_data("numeric_literals")
+    mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
+    assert_format(source, expected, mode)
+
+
+def test_numeric_literals_ignoring_underscores() -> None:
+    source, expected = read_data("numeric_literals_skip_underscores")
+    mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
+    assert_format(source, expected, mode)
+
+
+@pytest.mark.python2
+def test_python2_print_function() -> None:
+    source, expected = read_data("python2_print_function")
+    mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY27})
+    assert_format(source, expected, mode)
+
+
+def test_stub() -> None:
+    mode = replace(DEFAULT_MODE, is_pyi=True)
+    source, expected = read_data("stub.pyi")
+    assert_format(source, expected, mode)
+
+
+def test_python38() -> None:
+    source, expected = read_data("python38")
+    assert_format(source, expected, minimum_version=(3, 8))
+
+
+def test_python39() -> None:
+    source, expected = read_data("python39")
+    assert_format(source, expected, minimum_version=(3, 9))
index e83017f5ad3eba54a98ab78bf560f7e46318dc2d..84e98bb0fbde883da8d61545c935c7904641c232 100644 (file)
@@ -1,58 +1,97 @@
 import os
+import sys
 import unittest
-from pathlib import Path
-from typing import Iterator, List, Tuple, Any
 from contextlib import contextmanager
 from functools import partial
+from pathlib import Path
+from typing import Any, Iterator, List, Optional, Tuple
 
 import black
-from black.output import out, err
 from black.debug import DebugVisitor
+from black.mode import TargetVersion
+from black.output import err, out
 
 THIS_DIR = Path(__file__).parent
+DATA_DIR = THIS_DIR / "data"
 PROJECT_ROOT = THIS_DIR.parent
 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
 DETERMINISTIC_HEADER = "[Deterministic header]"
 
+PY36_VERSIONS = {
+    TargetVersion.PY36,
+    TargetVersion.PY37,
+    TargetVersion.PY38,
+    TargetVersion.PY39,
+}
 
 DEFAULT_MODE = black.Mode()
 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
 fs = partial(black.format_str, mode=DEFAULT_MODE)
 
 
+def _assert_format_equal(expected: str, actual: str) -> None:
+    if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
+        bdv: DebugVisitor[Any]
+        out("Expected tree:", fg="green")
+        try:
+            exp_node = black.lib2to3_parse(expected)
+            bdv = DebugVisitor()
+            list(bdv.visit(exp_node))
+        except Exception as ve:
+            err(str(ve))
+        out("Actual tree:", fg="red")
+        try:
+            exp_node = black.lib2to3_parse(actual)
+            bdv = DebugVisitor()
+            list(bdv.visit(exp_node))
+        except Exception as ve:
+            err(str(ve))
+
+    assert actual == expected
+
+
+def assert_format(
+    source: str,
+    expected: str,
+    mode: black.Mode = DEFAULT_MODE,
+    *,
+    fast: bool = False,
+    minimum_version: Optional[Tuple[int, int]] = None,
+) -> None:
+    """Convenience function to check that Black formats as expected.
+
+    You can pass @minimum_version if you're passing code with newer syntax to guard
+    safety guards so they don't just crash with a SyntaxError. Please note this is
+    separate from TargetVerson Mode configuration.
+    """
+    actual = black.format_str(source, mode=mode)
+    _assert_format_equal(expected, actual)
+    # It's not useful to run safety checks if we're expecting no changes anyway. The
+    # assertion right above will raise if reality does actually make changes. This just
+    # avoids wasted CPU cycles.
+    if not fast and source != expected:
+        # Unfortunately the AST equivalence check relies on the built-in ast module
+        # being able to parse the code being formatted. This doesn't always work out
+        # when checking modern code on older versions.
+        if minimum_version is None or sys.version_info >= minimum_version:
+            black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, mode=mode)
+
+
 def dump_to_stderr(*output: str) -> str:
     return "\n" + "\n".join(output) + "\n"
 
 
 class BlackBaseTestCase(unittest.TestCase):
-    maxDiff = None
-    _diffThreshold = 2 ** 20
-
     def assertFormatEqual(self, expected: str, actual: str) -> None:
-        if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
-            bdv: DebugVisitor[Any]
-            out("Expected tree:", fg="green")
-            try:
-                exp_node = black.lib2to3_parse(expected)
-                bdv = DebugVisitor()
-                list(bdv.visit(exp_node))
-            except Exception as ve:
-                err(str(ve))
-            out("Actual tree:", fg="red")
-            try:
-                exp_node = black.lib2to3_parse(actual)
-                bdv = DebugVisitor()
-                list(bdv.visit(exp_node))
-            except Exception as ve:
-                err(str(ve))
-        self.assertMultiLineEqual(expected, actual)
+        _assert_format_equal(expected, actual)
 
 
 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
     if not name.endswith((".py", ".pyi", ".out", ".diff")):
         name += ".py"
-    base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
+    base_dir = DATA_DIR if data else PROJECT_ROOT
     return read_data_from_file(base_dir / name)