+ # If it is set, use the path provided in the env var.
+ monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2))
+ assert get_cache_dir() == workspace2
+
+ def test_cache_broken_file(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace:
+ cache_file = get_cache_file(mode)
+ cache_file.write_text("this is not a pickle", encoding="utf-8")
+ assert black.read_cache(mode) == {}
+ src = (workspace / "test.py").resolve()
+ src.write_text("print('hello')", encoding="utf-8")
+ invokeBlack([str(src)])
+ cache = black.read_cache(mode)
+ assert str(src) in cache
+
+ def test_cache_single_file_already_cached(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace:
+ src = (workspace / "test.py").resolve()
+ src.write_text("print('hello')", encoding="utf-8")
+ black.write_cache({}, [src], mode)
+ invokeBlack([str(src)])
+ assert src.read_text(encoding="utf-8") == "print('hello')"
+
+ @event_loop()
+ def test_cache_multiple_files(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace, patch(
+ "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
+ ):
+ one = (workspace / "one.py").resolve()
+ one.write_text("print('hello')", encoding="utf-8")
+ two = (workspace / "two.py").resolve()
+ two.write_text("print('hello')", encoding="utf-8")
+ black.write_cache({}, [one], mode)
+ invokeBlack([str(workspace)])
+ assert one.read_text(encoding="utf-8") == "print('hello')"
+ assert two.read_text(encoding="utf-8") == 'print("hello")\n'
+ cache = black.read_cache(mode)
+ assert str(one) in cache
+ assert str(two) in cache
+
+ @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
+ def test_no_cache_when_writeback_diff(self, color: bool) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace:
+ src = (workspace / "test.py").resolve()
+ src.write_text("print('hello')", encoding="utf-8")
+ with patch("black.read_cache") as read_cache, patch(
+ "black.write_cache"
+ ) as write_cache:
+ cmd = [str(src), "--diff"]
+ if color:
+ cmd.append("--color")
+ invokeBlack(cmd)
+ cache_file = get_cache_file(mode)
+ assert cache_file.exists() is False
+ write_cache.assert_not_called()
+ read_cache.assert_not_called()
+
+ @pytest.mark.parametrize("color", [False, True], ids=["no-color", "with-color"])
+ @event_loop()
+ def test_output_locking_when_writeback_diff(self, color: bool) -> None:
+ with cache_dir() as workspace:
+ for tag in range(0, 4):
+ src = (workspace / f"test{tag}.py").resolve()
+ src.write_text("print('hello')", encoding="utf-8")
+ with patch(
+ "black.concurrency.Manager", wraps=multiprocessing.Manager
+ ) as mgr:
+ cmd = ["--diff", str(workspace)]
+ if color:
+ cmd.append("--color")
+ invokeBlack(cmd, exit_code=0)
+ # this isn't quite doing what we want, but if it _isn't_
+ # called then we cannot be using the lock it provides
+ mgr.assert_called()
+
+ def test_no_cache_when_stdin(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir():
+ result = CliRunner().invoke(
+ black.main, ["-"], input=BytesIO(b"print('hello')")
+ )
+ assert not result.exit_code
+ cache_file = get_cache_file(mode)
+ assert not cache_file.exists()
+
+ def test_read_cache_no_cachefile(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir():
+ assert black.read_cache(mode) == {}
+
+ def test_write_cache_read_cache(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace:
+ src = (workspace / "test.py").resolve()
+ src.touch()
+ black.write_cache({}, [src], mode)
+ cache = black.read_cache(mode)
+ assert str(src) in cache
+ assert cache[str(src)] == black.get_cache_info(src)
+
+ def test_filter_cached(self) -> None:
+ with TemporaryDirectory() as workspace:
+ path = Path(workspace)
+ uncached = (path / "uncached").resolve()
+ cached = (path / "cached").resolve()
+ cached_but_changed = (path / "changed").resolve()
+ uncached.touch()
+ cached.touch()
+ cached_but_changed.touch()
+ cache = {
+ str(cached): black.get_cache_info(cached),
+ str(cached_but_changed): (0.0, 0),
+ }
+ todo, done = black.cache.filter_cached(
+ cache, {uncached, cached, cached_but_changed}
+ )
+ assert todo == {uncached, cached_but_changed}
+ assert done == {cached}
+
+ def test_write_cache_creates_directory_if_needed(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir(exists=False) as workspace:
+ assert not workspace.exists()
+ black.write_cache({}, [], mode)
+ assert workspace.exists()
+
+ @event_loop()
+ def test_failed_formatting_does_not_get_cached(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace, patch(
+ "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor
+ ):
+ failing = (workspace / "failing.py").resolve()
+ failing.write_text("not actually python", encoding="utf-8")
+ clean = (workspace / "clean.py").resolve()
+ clean.write_text('print("hello")\n', encoding="utf-8")
+ invokeBlack([str(workspace)], exit_code=123)
+ cache = black.read_cache(mode)
+ assert str(failing) not in cache
+ assert str(clean) in cache
+
+ def test_write_cache_write_fail(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir(), patch.object(Path, "open") as mock:
+ mock.side_effect = OSError
+ black.write_cache({}, [], mode)
+
+ def test_read_cache_line_lengths(self) -> None:
+ mode = DEFAULT_MODE
+ short_mode = replace(DEFAULT_MODE, line_length=1)
+ with cache_dir() as workspace:
+ path = (workspace / "file.py").resolve()
+ path.touch()
+ black.write_cache({}, [path], mode)
+ one = black.read_cache(mode)
+ assert str(path) in one
+ two = black.read_cache(short_mode)
+ assert str(path) not in two
+
+
+def assert_collected_sources(
+ src: Sequence[Union[str, Path]],
+ expected: Sequence[Union[str, Path]],
+ *,
+ ctx: Optional[FakeContext] = None,
+ exclude: Optional[str] = None,
+ include: Optional[str] = None,
+ extend_exclude: Optional[str] = None,
+ force_exclude: Optional[str] = None,
+ stdin_filename: Optional[str] = None,
+) -> None:
+ gs_src = tuple(str(Path(s)) for s in src)
+ gs_expected = [Path(s) for s in expected]
+ gs_exclude = None if exclude is None else compile_pattern(exclude)
+ gs_include = DEFAULT_INCLUDE if include is None else compile_pattern(include)
+ gs_extend_exclude = (
+ None if extend_exclude is None else compile_pattern(extend_exclude)
+ )
+ gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude)
+ collected = black.get_sources(
+ ctx=ctx or FakeContext(),
+ src=gs_src,
+ quiet=False,
+ verbose=False,
+ include=gs_include,
+ exclude=gs_exclude,
+ extend_exclude=gs_extend_exclude,
+ force_exclude=gs_force_exclude,
+ report=black.Report(),
+ stdin_filename=stdin_filename,
+ )
+ assert sorted(collected) == sorted(gs_expected)
+
+
+class TestFileCollection:
+ def test_include_exclude(self) -> None:
+ path = THIS_DIR / "data" / "include_exclude_tests"
+ src = [path]
+ expected = [
+ Path(path / "b/dont_exclude/a.py"),
+ Path(path / "b/dont_exclude/a.pyi"),
+ ]
+ assert_collected_sources(
+ src,
+ expected,
+ include=r"\.pyi?$",
+ exclude=r"/exclude/|/\.definitely_exclude/",
+ )
+
+ def test_gitignore_used_as_default(self) -> None:
+ base = Path(DATA_DIR / "include_exclude_tests")
+ expected = [
+ base / "b/.definitely_exclude/a.py",
+ base / "b/.definitely_exclude/a.pyi",
+ ]
+ src = [base / "b/"]
+ ctx = FakeContext()
+ ctx.obj["root"] = base
+ assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/")
+
+ def test_gitignore_used_on_multiple_sources(self) -> None:
+ root = Path(DATA_DIR / "gitignore_used_on_multiple_sources")
+ expected = [
+ root / "dir1" / "b.py",
+ root / "dir2" / "b.py",
+ ]
+ ctx = FakeContext()
+ ctx.obj["root"] = root
+ src = [root / "dir1", root / "dir2"]
+ assert_collected_sources(src, expected, ctx=ctx)
+
+ @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
+ def test_exclude_for_issue_1572(self) -> None:
+ # Exclude shouldn't touch files that were explicitly given to Black through the
+ # CLI. Exclude is supposed to only apply to the recursive discovery of files.
+ # https://github.com/psf/black/issues/1572
+ path = DATA_DIR / "include_exclude_tests"
+ src = [path / "b/exclude/a.py"]
+ expected = [path / "b/exclude/a.py"]
+ assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
+
+ def test_gitignore_exclude(self) -> None:
+ path = THIS_DIR / "data" / "include_exclude_tests"
+ include = re.compile(r"\.pyi?$")
+ exclude = re.compile(r"")
+ report = black.Report()
+ gitignore = PathSpec.from_lines(
+ "gitwildmatch", ["exclude/", ".definitely_exclude"]
+ )
+ sources: List[Path] = []
+ expected = [
+ Path(path / "b/dont_exclude/a.py"),
+ Path(path / "b/dont_exclude/a.pyi"),
+ ]
+ this_abs = THIS_DIR.resolve()
+ sources.extend(
+ black.gen_python_files(
+ path.iterdir(),
+ this_abs,
+ include,
+ exclude,
+ None,
+ None,
+ report,
+ {path: gitignore},
+ verbose=False,
+ quiet=False,
+ )
+ )
+ assert sorted(expected) == sorted(sources)
+
+ def test_nested_gitignore(self) -> None:
+ path = Path(THIS_DIR / "data" / "nested_gitignore_tests")
+ include = re.compile(r"\.pyi?$")
+ exclude = re.compile(r"")
+ root_gitignore = black.files.get_gitignore(path)
+ report = black.Report()
+ expected: List[Path] = [
+ Path(path / "x.py"),
+ Path(path / "root/b.py"),
+ Path(path / "root/c.py"),
+ Path(path / "root/child/c.py"),
+ ]
+ this_abs = THIS_DIR.resolve()
+ sources = list(
+ black.gen_python_files(
+ path.iterdir(),
+ this_abs,
+ include,
+ exclude,
+ None,
+ None,
+ report,
+ {path: root_gitignore},
+ verbose=False,
+ quiet=False,
+ )
+ )
+ assert sorted(expected) == sorted(sources)
+
+ def test_nested_gitignore_directly_in_source_directory(self) -> None:
+ # https://github.com/psf/black/issues/2598
+ path = Path(DATA_DIR / "nested_gitignore_tests")
+ src = Path(path / "root" / "child")
+ expected = [src / "a.py", src / "c.py"]
+ assert_collected_sources([src], expected)
+
+ def test_invalid_gitignore(self) -> None:
+ path = THIS_DIR / "data" / "invalid_gitignore_tests"
+ empty_config = path / "pyproject.toml"
+ result = BlackRunner().invoke(
+ black.main, ["--verbose", "--config", str(empty_config), str(path)]
+ )
+ assert result.exit_code == 1
+ assert result.stderr_bytes is not None
+
+ gitignore = path / ".gitignore"
+ assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
+
+ def test_invalid_nested_gitignore(self) -> None:
+ path = THIS_DIR / "data" / "invalid_nested_gitignore_tests"
+ empty_config = path / "pyproject.toml"
+ result = BlackRunner().invoke(
+ black.main, ["--verbose", "--config", str(empty_config), str(path)]
+ )
+ assert result.exit_code == 1
+ assert result.stderr_bytes is not None
+
+ gitignore = path / "a" / ".gitignore"
+ assert f"Could not parse {gitignore}" in result.stderr_bytes.decode()
+
+ def test_gitignore_that_ignores_subfolders(self) -> None:
+ # If gitignore with */* is in root
+ root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests" / "subdir")
+ expected = [root / "b.py"]
+ ctx = FakeContext()
+ ctx.obj["root"] = root
+ assert_collected_sources([root], expected, ctx=ctx)
+
+ # If .gitignore with */* is nested
+ root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
+ expected = [
+ root / "a.py",
+ root / "subdir" / "b.py",
+ ]
+ ctx = FakeContext()
+ ctx.obj["root"] = root
+ assert_collected_sources([root], expected, ctx=ctx)
+
+ # If command is executed from outer dir
+ root = Path(DATA_DIR / "ignore_subfolders_gitignore_tests")
+ target = root / "subdir"
+ expected = [target / "b.py"]
+ ctx = FakeContext()
+ ctx.obj["root"] = root
+ assert_collected_sources([target], expected, ctx=ctx)
+
+ def test_empty_include(self) -> None:
+ path = DATA_DIR / "include_exclude_tests"
+ src = [path]
+ expected = [
+ Path(path / "b/exclude/a.pie"),
+ Path(path / "b/exclude/a.py"),
+ Path(path / "b/exclude/a.pyi"),
+ Path(path / "b/dont_exclude/a.pie"),
+ Path(path / "b/dont_exclude/a.py"),
+ Path(path / "b/dont_exclude/a.pyi"),
+ Path(path / "b/.definitely_exclude/a.pie"),
+ Path(path / "b/.definitely_exclude/a.py"),
+ Path(path / "b/.definitely_exclude/a.pyi"),
+ Path(path / ".gitignore"),
+ Path(path / "pyproject.toml"),
+ ]
+ # Setting exclude explicitly to an empty string to block .gitignore usage.
+ assert_collected_sources(src, expected, include="", exclude="")
+
+ def test_extend_exclude(self) -> None:
+ path = DATA_DIR / "include_exclude_tests"
+ src = [path]
+ expected = [
+ Path(path / "b/exclude/a.py"),
+ Path(path / "b/dont_exclude/a.py"),
+ ]
+ assert_collected_sources(
+ src, expected, exclude=r"\.pyi$", extend_exclude=r"\.definitely_exclude"
+ )
+
+ @pytest.mark.incompatible_with_mypyc
+ def test_symlink_out_of_root_directory(self) -> None:
+ path = MagicMock()
+ root = THIS_DIR.resolve()
+ child = MagicMock()
+ include = re.compile(black.DEFAULT_INCLUDES)
+ exclude = re.compile(black.DEFAULT_EXCLUDES)
+ report = black.Report()
+ gitignore = PathSpec.from_lines("gitwildmatch", [])
+ # `child` should behave like a symlink which resolved path is clearly
+ # outside of the `root` directory.
+ path.iterdir.return_value = [child]
+ child.resolve.return_value = Path("/a/b/c")
+ child.as_posix.return_value = "/a/b/c"
+ try:
+ list(
+ black.gen_python_files(
+ path.iterdir(),
+ root,
+ include,
+ exclude,
+ None,
+ None,
+ report,
+ {path: gitignore},
+ verbose=False,
+ quiet=False,
+ )
+ )
+ except ValueError as ve:
+ pytest.fail(f"`get_python_files_in_dir()` failed: {ve}")
+ path.iterdir.assert_called_once()
+ child.resolve.assert_called_once()
+
+ @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
+ def test_get_sources_with_stdin(self) -> None:
+ src = ["-"]
+ expected = ["-"]
+ assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py")
+
+ @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
+ def test_get_sources_with_stdin_filename(self) -> None:
+ src = ["-"]
+ stdin_filename = str(THIS_DIR / "data/collections.py")
+ expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
+ assert_collected_sources(
+ src,
+ expected,
+ exclude=r"/exclude/a\.py",
+ stdin_filename=stdin_filename,
+ )
+
+ @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
+ def test_get_sources_with_stdin_filename_and_exclude(self) -> None:
+ # Exclude shouldn't exclude stdin_filename since it is mimicking the
+ # file being passed directly. This is the same as
+ # test_exclude_for_issue_1572
+ path = DATA_DIR / "include_exclude_tests"
+ src = ["-"]
+ stdin_filename = str(path / "b/exclude/a.py")
+ expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
+ assert_collected_sources(
+ src,
+ expected,
+ exclude=r"/exclude/|a\.py",
+ stdin_filename=stdin_filename,
+ )
+
+ @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
+ def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None:
+ # Extend exclude shouldn't exclude stdin_filename since it is mimicking the
+ # file being passed directly. This is the same as
+ # test_exclude_for_issue_1572
+ src = ["-"]
+ path = THIS_DIR / "data" / "include_exclude_tests"
+ stdin_filename = str(path / "b/exclude/a.py")
+ expected = [f"__BLACK_STDIN_FILENAME__{stdin_filename}"]
+ assert_collected_sources(
+ src,
+ expected,
+ extend_exclude=r"/exclude/|a\.py",
+ stdin_filename=stdin_filename,
+ )
+
+ @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None))
+ def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None:
+ # Force exclude should exclude the file when passing it through
+ # stdin_filename
+ path = THIS_DIR / "data" / "include_exclude_tests"
+ stdin_filename = str(path / "b/exclude/a.py")
+ assert_collected_sources(
+ src=["-"],
+ expected=[],
+ force_exclude=r"/exclude/|a\.py",
+ stdin_filename=stdin_filename,
+ )
+
+
+try:
+ with open(black.__file__, "r", encoding="utf-8") as _bf:
+ black_source_lines = _bf.readlines()
+except UnicodeDecodeError:
+ if not black.COMPILED:
+ raise
+
+
+def tracefunc(
+ frame: types.FrameType, event: str, arg: Any
+) -> Callable[[types.FrameType, str, Any], Any]:
+ """Show function calls `from black/__init__.py` as they happen.
+
+ Register this with `sys.settrace()` in a test you're debugging.
+ """
+ if event != "call":
+ return tracefunc
+
+ stack = len(inspect.stack()) - 19
+ stack *= 2
+ filename = frame.f_code.co_filename
+ lineno = frame.f_lineno
+ func_sig_lineno = lineno - 1
+ funcname = black_source_lines[func_sig_lineno].strip()
+ while funcname.startswith("@"):
+ func_sig_lineno += 1
+ funcname = black_source_lines[func_sig_lineno].strip()
+ if "black/__init__.py" in filename:
+ print(f"{' ' * stack}{lineno}:{funcname}")
+ return tracefunc