]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump mypy[c] from 0.971 to 0.991 (#3380)
[etc/vim.git] / tests / test_black.py
index 089e043d639a82f9614018dba2b300580544dc30..dda10555c970e3ebcb305bb454dc43d755313e3c 100644 (file)
@@ -25,6 +25,7 @@ from typing import (
     List,
     Optional,
     Sequence,
     List,
     Optional,
     Sequence,
+    Type,
     TypeVar,
     Union,
 )
     TypeVar,
     Union,
 )
@@ -153,6 +154,34 @@ class BlackTestCase(BlackBaseTestCase):
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
             os.unlink(tmp_file)
         self.assertFormatEqual(expected, actual)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_one_empty_line(self) -> None:
+        mode = black.Mode(preview=True)
+        for nl in ["\n", "\r\n"]:
+            source = expected = nl
+            assert_format(source, expected, mode=mode)
+
+    def test_one_empty_line_ff(self) -> None:
+        mode = black.Mode(preview=True)
+        for nl in ["\n", "\r\n"]:
+            expected = nl
+            tmp_file = Path(black.dump_to_file(nl))
+            if system() == "Windows":
+                # Writing files in text mode automatically uses the system newline,
+                # but in this case we don't want this for testing reasons. See:
+                # https://github.com/psf/black/pull/3348
+                with open(tmp_file, "wb") as f:
+                    f.write(nl.encode("utf-8"))
+            try:
+                self.assertFalse(
+                    ff(tmp_file, mode=mode, write_back=black.WriteBack.YES)
+                )
+                with open(tmp_file, "rb") as f:
+                    actual = f.read().decode("utf8")
+            finally:
+                os.unlink(tmp_file)
+            self.assertFormatEqual(expected, actual)
+
     def test_experimental_string_processing_warns(self) -> None:
         self.assertWarns(
             black.mode.Deprecated, black.Mode, experimental_string_processing=True
     def test_experimental_string_processing_warns(self) -> None:
         self.assertWarns(
             black.mode.Deprecated, black.Mode, experimental_string_processing=True
@@ -341,6 +370,30 @@ class BlackTestCase(BlackBaseTestCase):
         black.assert_equivalent(source, not_normalized)
         black.assert_stable(source, not_normalized, mode=mode)
 
         black.assert_equivalent(source, not_normalized)
         black.assert_stable(source, not_normalized, mode=mode)
 
+    def test_skip_source_first_line(self) -> None:
+        source, _ = read_data("miscellaneous", "invalid_header")
+        tmp_file = Path(black.dump_to_file(source))
+        # Full source should fail (invalid syntax at header)
+        self.invokeBlack([str(tmp_file), "--diff", "--check"], exit_code=123)
+        # So, skipping the first line should work
+        result = BlackRunner().invoke(
+            black.main, [str(tmp_file), "-x", f"--config={EMPTY_CONFIG}"]
+        )
+        self.assertEqual(result.exit_code, 0)
+        with open(tmp_file, encoding="utf8") as f:
+            actual = f.read()
+        self.assertFormatEqual(source, actual)
+
+    def test_skip_source_first_line_when_mixing_newlines(self) -> None:
+        code_mixing_newlines = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
+        expected = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
+        with TemporaryDirectory() as workspace:
+            test_file = Path(workspace) / "skip_header.py"
+            test_file.write_bytes(code_mixing_newlines)
+            mode = replace(DEFAULT_MODE, skip_source_first_line=True)
+            ff(test_file, mode=mode, write_back=black.WriteBack.YES)
+            self.assertEqual(test_file.read_bytes(), expected)
+
     def test_skip_magic_trailing_comma(self) -> None:
         source, _ = read_data("simple_cases", "expression")
         expected, _ = read_data(
     def test_skip_magic_trailing_comma(self) -> None:
         source, _ = read_data("simple_cases", "expression")
         expected, _ = read_data(
@@ -466,8 +519,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "1 file reformatted, 2 files left unchanged, 1 file failed to"
-                " reformat.",
+                (
+                    "1 file reformatted, 2 files left unchanged, 1 file failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f3"), black.Changed.YES)
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f3"), black.Changed.YES)
@@ -476,8 +531,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(out_lines[-1], "reformatted f3")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(out_lines[-1], "reformatted f3")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 1 file failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 1 file failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
@@ -486,8 +543,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.path_ignored(Path("wat"), "no match")
             )
             self.assertEqual(report.return_code, 123)
             report.path_ignored(Path("wat"), "no match")
@@ -496,8 +555,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(out_lines[-1], "wat ignored: no match")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(out_lines[-1], "wat ignored: no match")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f4"), black.Changed.NO)
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f4"), black.Changed.NO)
@@ -506,22 +567,28 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 3 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 3 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 3 files would be left unchanged, 2"
-                " files would fail to reformat.",
+                (
+                    "2 files would be reformatted, 3 files would be left unchanged, 2"
+                    " files would fail to reformat."
+                ),
             )
             report.check = False
             report.diff = True
             self.assertEqual(
                 unstyle(str(report)),
             )
             report.check = False
             report.diff = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 3 files would be left unchanged, 2"
-                " files would fail to reformat.",
+                (
+                    "2 files would be reformatted, 3 files would be left unchanged, 2"
+                    " files would fail to reformat."
+                ),
             )
 
     def test_report_quiet(self) -> None:
             )
 
     def test_report_quiet(self) -> None:
@@ -563,8 +630,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "1 file reformatted, 2 files left unchanged, 1 file failed to"
-                " reformat.",
+                (
+                    "1 file reformatted, 2 files left unchanged, 1 file failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f3"), black.Changed.YES)
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f3"), black.Changed.YES)
@@ -572,8 +641,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(len(err_lines), 1)
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(len(err_lines), 1)
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 1 file failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 1 file failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
@@ -582,8 +653,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.path_ignored(Path("wat"), "no match")
             )
             self.assertEqual(report.return_code, 123)
             report.path_ignored(Path("wat"), "no match")
@@ -591,8 +664,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f4"), black.Changed.NO)
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f4"), black.Changed.NO)
@@ -600,22 +675,28 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 3 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 3 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 3 files would be left unchanged, 2"
-                " files would fail to reformat.",
+                (
+                    "2 files would be reformatted, 3 files would be left unchanged, 2"
+                    " files would fail to reformat."
+                ),
             )
             report.check = False
             report.diff = True
             self.assertEqual(
                 unstyle(str(report)),
             )
             report.check = False
             report.diff = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 3 files would be left unchanged, 2"
-                " files would fail to reformat.",
+                (
+                    "2 files would be reformatted, 3 files would be left unchanged, 2"
+                    " files would fail to reformat."
+                ),
             )
 
     def test_report_normal(self) -> None:
             )
 
     def test_report_normal(self) -> None:
@@ -659,8 +740,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "1 file reformatted, 2 files left unchanged, 1 file failed to"
-                " reformat.",
+                (
+                    "1 file reformatted, 2 files left unchanged, 1 file failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f3"), black.Changed.YES)
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f3"), black.Changed.YES)
@@ -669,8 +752,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(out_lines[-1], "reformatted f3")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(out_lines[-1], "reformatted f3")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 1 file failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 1 file failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
@@ -679,8 +764,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.path_ignored(Path("wat"), "no match")
             )
             self.assertEqual(report.return_code, 123)
             report.path_ignored(Path("wat"), "no match")
@@ -688,8 +775,10 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 2 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f4"), black.Changed.NO)
             )
             self.assertEqual(report.return_code, 123)
             report.done(Path("f4"), black.Changed.NO)
@@ -697,22 +786,28 @@ class BlackTestCase(BlackBaseTestCase):
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 3 files left unchanged, 2 files failed to"
-                " reformat.",
+                (
+                    "2 files reformatted, 3 files left unchanged, 2 files failed to"
+                    " reformat."
+                ),
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 3 files would be left unchanged, 2"
-                " files would fail to reformat.",
+                (
+                    "2 files would be reformatted, 3 files would be left unchanged, 2"
+                    " files would fail to reformat."
+                ),
             )
             report.check = False
             report.diff = True
             self.assertEqual(
                 unstyle(str(report)),
             )
             report.check = False
             report.diff = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 3 files would be left unchanged, 2"
-                " files would fail to reformat.",
+                (
+                    "2 files would be reformatted, 3 files would be left unchanged, 2"
+                    " files would fail to reformat."
+                ),
             )
 
     def test_lib2to3_parse(self) -> None:
             )
 
     def test_lib2to3_parse(self) -> None:
@@ -905,8 +1000,8 @@ class BlackTestCase(BlackBaseTestCase):
         )
 
     def test_format_file_contents(self) -> None:
         )
 
     def test_format_file_contents(self) -> None:
-        empty = ""
         mode = DEFAULT_MODE
         mode = DEFAULT_MODE
+        empty = ""
         with self.assertRaises(black.NothingChanged):
             black.format_file_contents(empty, mode=mode, fast=False)
         just_nl = "\n"
         with self.assertRaises(black.NothingChanged):
             black.format_file_contents(empty, mode=mode, fast=False)
         just_nl = "\n"
@@ -924,6 +1019,17 @@ class BlackTestCase(BlackBaseTestCase):
             black.format_file_contents(invalid, mode=mode, fast=False)
         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
 
             black.format_file_contents(invalid, mode=mode, fast=False)
         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
 
+        mode = black.Mode(preview=True)
+        just_crlf = "\r\n"
+        with self.assertRaises(black.NothingChanged):
+            black.format_file_contents(just_crlf, mode=mode, fast=False)
+        just_whitespace_nl = "\n\t\n \n\t \n \t\n\n"
+        actual = black.format_file_contents(just_whitespace_nl, mode=mode, fast=False)
+        self.assertEqual("\n", actual)
+        just_whitespace_crlf = "\r\n\t\r\n \r\n\t \r\n \t\r\n\r\n"
+        actual = black.format_file_contents(just_whitespace_crlf, mode=mode, fast=False)
+        self.assertEqual("\r\n", actual)
+
     def test_endmarker(self) -> None:
         n = black.lib2to3_parse("\n")
         self.assertEqual(n.type, black.syms.file_input)
     def test_endmarker(self) -> None:
         n = black.lib2to3_parse("\n")
         self.assertEqual(n.type, black.syms.file_input)
@@ -1215,8 +1321,51 @@ class BlackTestCase(BlackBaseTestCase):
             report.done.assert_called_with(expected, black.Changed.YES)
 
     def test_reformat_one_with_stdin_empty(self) -> None:
             report.done.assert_called_with(expected, black.Changed.YES)
 
     def test_reformat_one_with_stdin_empty(self) -> None:
+        cases = [
+            ("", ""),
+            ("\n", "\n"),
+            ("\r\n", "\r\n"),
+            (" \t", ""),
+            (" \t\n\t ", "\n"),
+            (" \t\r\n\t ", "\r\n"),
+        ]
+
+        def _new_wrapper(
+            output: io.StringIO, io_TextIOWrapper: Type[io.TextIOWrapper]
+        ) -> Callable[[Any, Any], io.TextIOWrapper]:
+            def get_output(*args: Any, **kwargs: Any) -> io.TextIOWrapper:
+                if args == (sys.stdout.buffer,):
+                    # It's `format_stdin_to_stdout()` calling `io.TextIOWrapper()`,
+                    # return our mock object.
+                    return output
+                # It's something else (i.e. `decode_bytes()`) calling
+                # `io.TextIOWrapper()`, pass through to the original implementation.
+                # See discussion in https://github.com/psf/black/pull/2489
+                return io_TextIOWrapper(*args, **kwargs)
+
+            return get_output
+
+        mode = black.Mode(preview=True)
+        for content, expected in cases:
+            output = io.StringIO()
+            io_TextIOWrapper = io.TextIOWrapper
+
+            with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
+                try:
+                    black.format_stdin_to_stdout(
+                        fast=True,
+                        content=content,
+                        write_back=black.WriteBack.YES,
+                        mode=mode,
+                    )
+                except io.UnsupportedOperation:
+                    pass  # StringIO does not support detach
+                assert output.getvalue() == expected
+
+        # An empty string is the only test case for `preview=False`
         output = io.StringIO()
         output = io.StringIO()
-        with patch("io.TextIOWrapper", lambda *args, **kwargs: output):
+        io_TextIOWrapper = io.TextIOWrapper
+        with patch("io.TextIOWrapper", _new_wrapper(output, io_TextIOWrapper)):
             try:
                 black.format_stdin_to_stdout(
                     fast=True,
             try:
                 black.format_stdin_to_stdout(
                     fast=True,
@@ -1286,6 +1435,17 @@ class BlackTestCase(BlackBaseTestCase):
             if nl == "\n":
                 self.assertNotIn(b"\r\n", output)
 
             if nl == "\n":
                 self.assertNotIn(b"\r\n", output)
 
+    def test_normalize_line_endings(self) -> None:
+        with TemporaryDirectory() as workspace:
+            test_file = Path(workspace) / "test.py"
+            for data, expected in (
+                (b"c\r\nc\n ", b"c\r\nc\r\n"),
+                (b"l\nl\r\n ", b"l\nl\n"),
+            ):
+                test_file.write_bytes(data)
+                ff(test_file, write_back=black.WriteBack.YES)
+                self.assertEqual(test_file.read_bytes(), expected)
+
     def test_assert_equivalent_different_asts(self) -> None:
         with self.assertRaises(AssertionError):
             black.assert_equivalent("{}", "None")
     def test_assert_equivalent_different_asts(self) -> None:
         with self.assertRaises(AssertionError):
             black.assert_equivalent("{}", "None")
@@ -1921,6 +2081,17 @@ class TestFileCollection:
         ctx.obj["root"] = base
         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
 
         ctx.obj["root"] = base
         assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
 
+    def test_gitignore_used_on_multiple_sources(self) -> None:
+        root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
+        expected = [
+            root / "dir1" / "b.py",
+            root / "dir2" / "b.py",
+        ]
+        ctx = FakeContext()
+        ctx.obj["root"] = root
+        src = [root / "dir1", root / "dir2"]
+        assert_collected_sources(src, expected, ctx=ctx)
+
     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_exclude_for_issue_1572(self) -> None:
         # Exclude shouldn't touch files that were explicitly given to Black through the
     @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
     def test_exclude_for_issue_1572(self) -> None:
         # Exclude shouldn't touch files that were explicitly given to Black through the
@@ -1954,7 +2125,7 @@ class TestFileCollection:
                 None,
                 None,
                 report,
                 None,
                 None,
                 report,
-                gitignore,
+                {path: gitignore},
                 verbose=False,
                 quiet=False,
             )
                 verbose=False,
                 quiet=False,
             )
@@ -1983,13 +2154,20 @@ class TestFileCollection:
                 None,
                 None,
                 report,
                 None,
                 None,
                 report,
-                root_gitignore,
+                {path: root_gitignore},
                 verbose=False,
                 quiet=False,
             )
         )
         assert sorted(expected) == sorted(sources)
 
                 verbose=False,
                 quiet=False,
             )
         )
         assert sorted(expected) == sorted(sources)
 
+    def test_nested_gitignore_directly_in_source_directory(self) -> None:
+        # https://github.com/psf/black/issues/2598
+        path = Path(DATA_DIR / "nested_gitignore_tests")
+        src = Path(path / "root" / "child")
+        expected = [src / "a.py", src / "c.py"]
+        assert_collected_sources([src], expected)
+
     def test_invalid_gitignore(self) -> None:
         path = THIS_DIR / "data" / "invalid_gitignore_tests"
         empty_config = path / "pyproject.toml"
     def test_invalid_gitignore(self) -> None:
         path = THIS_DIR / "data" / "invalid_gitignore_tests"
         empty_config = path / "pyproject.toml"
@@ -2014,6 +2192,32 @@ class TestFileCollection:
         gitignore = path / "a" / ".gitignore"
         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
 
         gitignore = path / "a" / ".gitignore"
         assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
 
+    def test_gitignore_that_ignores_subfolders(self) -> None:
+        # If gitignore with */* is in root
+        root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
+        expected = [root / "b.py"]
+        ctx = FakeContext()
+        ctx.obj["root"] = root
+        assert_collected_sources([root], expected, ctx=ctx)
+
+        # If .gitignore with */* is nested
+        root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
+        expected = [
+            root / "a.py",
+            root / "subdir" / "b.py",
+        ]
+        ctx = FakeContext()
+        ctx.obj["root"] = root
+        assert_collected_sources([root], expected, ctx=ctx)
+
+        # If command is executed from outer dir
+        root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
+        target = root / "subdir"
+        expected = [target / "b.py"]
+        ctx = FakeContext()
+        ctx.obj["root"] = root
+        assert_collected_sources([target], expected, ctx=ctx)
+
     def test_empty_include(self) -> None:
         path = DATA_DIR / "include_exclude_tests"
         src = [path]
     def test_empty_include(self) -> None:
         path = DATA_DIR / "include_exclude_tests"
         src = [path]
@@ -2068,7 +2272,7 @@ class TestFileCollection:
                     None,
                     None,
                     report,
                     None,
                     None,
                     report,
-                    gitignore,
+                    {path: gitignore},
                     verbose=False,
                     quiet=False,
                 )
                     verbose=False,
                     quiet=False,
                 )