X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/e401b6bb1e1c0ed534bba59d9dc908caf7ba898c..7af77d1cf1fdeb54a45ddae422e1ebc3329129fa:/tests/test_black.py diff --git a/tests/test_black.py b/tests/test_black.py index 5be4ae8..8adcaed 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -10,7 +10,7 @@ import sys import types import unittest from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager +from contextlib import contextmanager, redirect_stderr from dataclasses import replace from io import BytesIO from pathlib import Path @@ -40,7 +40,7 @@ import black import black.files from black import Feature, TargetVersion from black import re_compile_maybe_verbose as compile_pattern -from black.cache import get_cache_file +from black.cache import get_cache_dir, get_cache_file from black.debug import DebugVisitor from black.output import color_diff, diff from black.report import Report @@ -60,9 +60,12 @@ from tests.util import ( ff, fs, read_data, + get_case_path, + read_data_from_file, ) THIS_FILE = Path(__file__) +EMPTY_CONFIG = THIS_DIR / "data" / "empty_pyproject.toml" PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS] DEFAULT_EXCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_EXCLUDES) DEFAULT_INCLUDE = black.re_compile_maybe_verbose(black.const.DEFAULT_INCLUDES) @@ -100,6 +103,8 @@ class FakeContext(click.Context): def __init__(self) -> None: self.default_map: Dict[str, Any] = {} + # Dummy root, since most of the tests don't care about it + self.obj: Dict[str, Any] = {"root": PROJECT_ROOT} class FakeParameter(click.Parameter): @@ -148,11 +153,21 @@ class BlackTestCase(BlackBaseTestCase): os.unlink(tmp_file) self.assertFormatEqual(expected, actual) + def test_experimental_string_processing_warns(self) -> None: + self.assertWarns( + black.mode.Deprecated, black.Mode, experimental_string_processing=True + ) + def test_piping(self) -> None: - source, expected = read_data("src/black/__init__", data=False) + source, expected = read_data_from_file(PROJECT_ROOT / "src/black/__init__.py") result = BlackRunner().invoke( black.main, - ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"], + [ + "-", + "--fast", + f"--line-length={black.DEFAULT_LINE_LENGTH}", + f"--config={EMPTY_CONFIG}", + ], input=BytesIO(source.encode("utf8")), ) self.assertEqual(result.exit_code, 0) @@ -166,15 +181,14 @@ class BlackTestCase(BlackBaseTestCase): r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d " r"\+\d\d\d\d" ) - source, _ = read_data("expression.py") - expected, _ = read_data("expression.diff") - config = THIS_DIR / "data" / "empty_pyproject.toml" + source, _ = read_data("simple_cases", "expression.py") + expected, _ = read_data("simple_cases", "expression.diff") args = [ "-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}", "--diff", - f"--config={config}", + f"--config={EMPTY_CONFIG}", ] result = BlackRunner().invoke( black.main, args, input=BytesIO(source.encode("utf8")) @@ -185,15 +199,14 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(expected, actual) def test_piping_diff_with_color(self) -> None: - source, _ = read_data("expression.py") - config = THIS_DIR / "data" / "empty_pyproject.toml" + source, _ = read_data("simple_cases", "expression.py") args = [ "-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}", "--diff", "--color", - f"--config={config}", + f"--config={EMPTY_CONFIG}", ] result = BlackRunner().invoke( black.main, args, input=BytesIO(source.encode("utf8")) @@ -208,7 +221,7 @@ class BlackTestCase(BlackBaseTestCase): @patch("black.dump_to_file", dump_to_stderr) def _test_wip(self) -> None: - source, expected = read_data("wip") + source, expected = read_data("miscellaneous", "wip") sys.settrace(tracefunc) mode = replace( DEFAULT_MODE, @@ -221,47 +234,8 @@ class BlackTestCase(BlackBaseTestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) - @unittest.expectedFailure - @patch("black.dump_to_file", dump_to_stderr) - def test_trailing_comma_optional_parens_stability1(self) -> None: - source, _expected = read_data("trailing_comma_optional_parens1") - actual = fs(source) - black.assert_stable(source, actual, DEFAULT_MODE) - - @unittest.expectedFailure - @patch("black.dump_to_file", dump_to_stderr) - def test_trailing_comma_optional_parens_stability2(self) -> None: - source, _expected = read_data("trailing_comma_optional_parens2") - actual = fs(source) - black.assert_stable(source, actual, DEFAULT_MODE) - - @unittest.expectedFailure - @patch("black.dump_to_file", dump_to_stderr) - def test_trailing_comma_optional_parens_stability3(self) -> None: - source, _expected = read_data("trailing_comma_optional_parens3") - actual = fs(source) - black.assert_stable(source, actual, DEFAULT_MODE) - - @patch("black.dump_to_file", dump_to_stderr) - def test_trailing_comma_optional_parens_stability1_pass2(self) -> None: - source, _expected = read_data("trailing_comma_optional_parens1") - actual = fs(fs(source)) # this is what `format_file_contents` does with --safe - black.assert_stable(source, actual, DEFAULT_MODE) - - @patch("black.dump_to_file", dump_to_stderr) - def test_trailing_comma_optional_parens_stability2_pass2(self) -> None: - source, _expected = read_data("trailing_comma_optional_parens2") - actual = fs(fs(source)) # this is what `format_file_contents` does with --safe - black.assert_stable(source, actual, DEFAULT_MODE) - - @patch("black.dump_to_file", dump_to_stderr) - def test_trailing_comma_optional_parens_stability3_pass2(self) -> None: - source, _expected = read_data("trailing_comma_optional_parens3") - actual = fs(fs(source)) # this is what `format_file_contents` does with --safe - black.assert_stable(source, actual, DEFAULT_MODE) - def test_pep_572_version_detection(self) -> None: - source, _ = read_data("pep_572") + source, _ = read_data("py_38", "pep_572") root = black.lib2to3_parse(source) features = black.get_features_used(root) self.assertIn(black.Feature.ASSIGNMENT_EXPRESSIONS, features) @@ -269,7 +243,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertIn(black.TargetVersion.PY38, versions) def test_expression_ff(self) -> None: - source, expected = read_data("expression") + source, expected = read_data("simple_cases", "expression.py") tmp_file = Path(black.dump_to_file(source)) try: self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES)) @@ -283,9 +257,8 @@ class BlackTestCase(BlackBaseTestCase): black.assert_stable(source, actual, DEFAULT_MODE) def test_expression_diff(self) -> None: - source, _ = read_data("expression.py") - config = THIS_DIR / "data" / "empty_pyproject.toml" - expected, _ = read_data("expression.diff") + source, _ = read_data("simple_cases", "expression.py") + expected, _ = read_data("simple_cases", "expression.diff") tmp_file = Path(black.dump_to_file(source)) diff_header = re.compile( rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d " @@ -293,7 +266,7 @@ class BlackTestCase(BlackBaseTestCase): ) try: result = BlackRunner().invoke( - black.main, ["--diff", str(tmp_file), f"--config={config}"] + black.main, ["--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"] ) self.assertEqual(result.exit_code, 0) finally: @@ -310,13 +283,13 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(expected, actual, msg) def test_expression_diff_with_color(self) -> None: - source, _ = read_data("expression.py") - config = THIS_DIR / "data" / "empty_pyproject.toml" - expected, _ = read_data("expression.diff") + source, _ = read_data("simple_cases", "expression.py") + expected, _ = read_data("simple_cases", "expression.diff") tmp_file = Path(black.dump_to_file(source)) try: result = BlackRunner().invoke( - black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"] + black.main, + ["--diff", "--color", str(tmp_file), f"--config={EMPTY_CONFIG}"], ) finally: os.unlink(tmp_file) @@ -330,7 +303,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertIn("\033[0m", actual) def test_detect_pos_only_arguments(self) -> None: - source, _ = read_data("pep_570") + source, _ = read_data("py_38", "pep_570") root = black.lib2to3_parse(source) features = black.get_features_used(root) self.assertIn(black.Feature.POS_ONLY_ARGUMENTS, features) @@ -339,8 +312,8 @@ class BlackTestCase(BlackBaseTestCase): @patch("black.dump_to_file", dump_to_stderr) def test_string_quotes(self) -> None: - source, expected = read_data("string_quotes") - mode = black.Mode(experimental_string_processing=True) + source, expected = read_data("miscellaneous", "string_quotes") + mode = black.Mode(preview=True) assert_format(source, expected, mode) mode = replace(mode, string_normalization=False) not_normalized = fs(source, mode=mode) @@ -349,15 +322,19 @@ class BlackTestCase(BlackBaseTestCase): black.assert_stable(source, not_normalized, mode=mode) def test_skip_magic_trailing_comma(self) -> None: - source, _ = read_data("expression.py") - expected, _ = read_data("expression_skip_magic_trailing_comma.diff") + source, _ = read_data("simple_cases", "expression") + expected, _ = read_data( + "miscellaneous", "expression_skip_magic_trailing_comma.diff" + ) tmp_file = Path(black.dump_to_file(source)) diff_header = re.compile( rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d " r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d" ) try: - result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)]) + result = BlackRunner().invoke( + black.main, ["-C", "--diff", str(tmp_file), f"--config={EMPTY_CONFIG}"] + ) self.assertEqual(result.exit_code, 0) finally: os.unlink(tmp_file) @@ -375,8 +352,8 @@ class BlackTestCase(BlackBaseTestCase): @patch("black.dump_to_file", dump_to_stderr) def test_async_as_identifier(self) -> None: - source_path = (THIS_DIR / "data" / "async_as_identifier.py").resolve() - source, expected = read_data("async_as_identifier") + source_path = get_case_path("miscellaneous", "async_as_identifier") + source, expected = read_data_from_file(source_path) actual = fs(source) self.assertFormatEqual(expected, actual) major, minor = sys.version_info[:2] @@ -390,8 +367,8 @@ class BlackTestCase(BlackBaseTestCase): @patch("black.dump_to_file", dump_to_stderr) def test_python37(self) -> None: - source_path = (THIS_DIR / "data" / "python37.py").resolve() - source, expected = read_data("python37") + source_path = get_case_path("py_37", "python37") + source, expected = read_data_from_file(source_path) actual = fs(source) self.assertFormatEqual(expected, actual) major, minor = sys.version_info[:2] @@ -739,7 +716,7 @@ class BlackTestCase(BlackBaseTestCase): # since this makes some test cases of test_get_features_used() # fails if it fails, this is tested first so that a useful case # is identified - simples, relaxed = read_data("decorators") + simples, relaxed = read_data("miscellaneous", "decorators") # skip explanation comments at the top of the file for simple_test in simples.split("##")[1:]: node = black.lib2to3_parse(simple_test) @@ -782,7 +759,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES}) node = black.lib2to3_parse("123456\n") self.assertEqual(black.get_features_used(node), set()) - source, expected = read_data("function") + source, expected = read_data("simple_cases", "function") node = black.lib2to3_parse(source) expected_features = { Feature.TRAILING_COMMA_IN_CALL, @@ -792,7 +769,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(black.get_features_used(node), expected_features) node = black.lib2to3_parse(expected) self.assertEqual(black.get_features_used(node), expected_features) - source, expected = read_data("expression") + source, expected = read_data("simple_cases", "expression") node = black.lib2to3_parse(source) self.assertEqual(black.get_features_used(node), set()) node = black.lib2to3_parse(expected) @@ -821,6 +798,18 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual( black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS} ) + node = black.lib2to3_parse("try: pass\nexcept Something: pass") + self.assertEqual(black.get_features_used(node), set()) + node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass") + self.assertEqual(black.get_features_used(node), set()) + node = black.lib2to3_parse("try: pass\nexcept *Group: pass") + self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR}) + node = black.lib2to3_parse("a[*b]") + self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS}) + node = black.lib2to3_parse("a[x, *y(), z] = t") + self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS}) + node = black.lib2to3_parse("def fn(*args: *T): pass") + self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS}) def test_get_features_used_for_future_flags(self) -> None: for src, features in [ @@ -872,8 +861,8 @@ class BlackTestCase(BlackBaseTestCase): @pytest.mark.incompatible_with_mypyc def test_debug_visitor(self) -> None: - source, _ = read_data("debug_visitor.py") - expected, _ = read_data("debug_visitor.out") + source, _ = read_data("miscellaneous", "debug_visitor") + expected, _ = read_data("miscellaneous", "debug_visitor.out") out_lines = [] err_lines = [] @@ -938,12 +927,12 @@ class BlackTestCase(BlackBaseTestCase): self.assertFormatEqual("j = [1, 2, 3]", "j = [1, 2, 3,]") out_str = "".join(out_lines) - self.assertTrue("Expected tree:" in out_str) - self.assertTrue("Actual tree:" in out_str) + self.assertIn("Expected tree:", out_str) + self.assertIn("Actual tree:", out_str) self.assertEqual("".join(err_lines), "") @event_loop() - @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError)) + @patch("concurrent.futures.ProcessPoolExecutor", MagicMock(side_effect=OSError)) def test_works_in_mono_process_only_environment(self) -> None: with cache_dir() as workspace: for f in [ @@ -957,18 +946,21 @@ class BlackTestCase(BlackBaseTestCase): def test_check_diff_use_together(self) -> None: with cache_dir(): # Files which will be reformatted. - src1 = (THIS_DIR / "data" / "string_quotes.py").resolve() + src1 = get_case_path("miscellaneous", "string_quotes") self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1) # Files which will not be reformatted. - src2 = (THIS_DIR / "data" / "composition.py").resolve() + src2 = get_case_path("simple_cases", "composition") self.invokeBlack([str(src2), "--diff", "--check"]) # Multi file command. self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1) - def test_no_files(self) -> None: + def test_no_src_fails(self) -> None: + with cache_dir(): + self.invokeBlack([], exit_code=1) + + def test_src_and_code_fails(self) -> None: with cache_dir(): - # Without an argument, black exits with error code 0. - self.invokeBlack([]) + self.invokeBlack([".", "-c", "0"], exit_code=1) def test_broken_symlink(self) -> None: with cache_dir() as workspace: @@ -981,7 +973,7 @@ class BlackTestCase(BlackBaseTestCase): def test_single_file_force_pyi(self) -> None: pyi_mode = replace(DEFAULT_MODE, is_pyi=True) - contents, expected = read_data("force_pyi") + contents, expected = read_data("miscellaneous", "force_pyi") with cache_dir() as workspace: path = (workspace / "file.py").resolve() with open(path, "w") as fh: @@ -1002,7 +994,7 @@ class BlackTestCase(BlackBaseTestCase): def test_multi_file_force_pyi(self) -> None: reg_mode = DEFAULT_MODE pyi_mode = replace(DEFAULT_MODE, is_pyi=True) - contents, expected = read_data("force_pyi") + contents, expected = read_data("miscellaneous", "force_pyi") with cache_dir() as workspace: paths = [ (workspace / "file1.py").resolve(), @@ -1024,7 +1016,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertNotIn(str(path), normal_cache) def test_pipe_force_pyi(self) -> None: - source, expected = read_data("force_pyi") + source, expected = read_data("miscellaneous", "force_pyi") result = CliRunner().invoke( black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8")) ) @@ -1035,7 +1027,7 @@ class BlackTestCase(BlackBaseTestCase): def test_single_file_force_py36(self) -> None: reg_mode = DEFAULT_MODE py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS) - source, expected = read_data("force_py36") + source, expected = read_data("miscellaneous", "force_py36") with cache_dir() as workspace: path = (workspace / "file.py").resolve() with open(path, "w") as fh: @@ -1054,7 +1046,7 @@ class BlackTestCase(BlackBaseTestCase): def test_multi_file_force_py36(self) -> None: reg_mode = DEFAULT_MODE py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS) - source, expected = read_data("force_py36") + source, expected = read_data("miscellaneous", "force_py36") with cache_dir() as workspace: paths = [ (workspace / "file1.py").resolve(), @@ -1076,7 +1068,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertNotIn(str(path), normal_cache) def test_pipe_force_py36(self) -> None: - source, expected = read_data("force_py36") + source, expected = read_data("miscellaneous", "force_py36") result = CliRunner().invoke( black.main, ["-", "-q", "--target-version=py36"], @@ -1186,7 +1178,7 @@ class BlackTestCase(BlackBaseTestCase): report = MagicMock() # Even with an existing file, since we are forcing stdin, black # should output to stdout and not modify the file inplace - p = Path(str(THIS_DIR / "data/collections.py")) + p = THIS_DIR / "data" / "simple_cases" / "collections.py" # Make sure is_file actually returns True self.assertTrue(p.is_file()) path = Path(f"__BLACK_STDIN_FILENAME__{p}") @@ -1222,14 +1214,33 @@ class BlackTestCase(BlackBaseTestCase): def test_required_version_matches_version(self) -> None: self.invokeBlack( - ["--required-version", black.__version__], exit_code=0, ignore_config=True + ["--required-version", black.__version__, "-c", "0"], + exit_code=0, + ignore_config=True, ) - def test_required_version_does_not_match_version(self) -> None: + def test_required_version_matches_partial_version(self) -> None: self.invokeBlack( - ["--required-version", "20.99b"], exit_code=1, ignore_config=True + ["--required-version", black.__version__.split(".")[0], "-c", "0"], + exit_code=0, + ignore_config=True, ) + def test_required_version_does_not_match_on_minor_version(self) -> None: + self.invokeBlack( + ["--required-version", black.__version__.split(".")[0] + ".999", "-c", "0"], + exit_code=1, + ignore_config=True, + ) + + def test_required_version_does_not_match_version(self) -> None: + result = BlackRunner().invoke( + black.main, + ["--required-version", "20.99b", "-c", "0"], + ) + self.assertEqual(result.exit_code, 1) + self.assertIn("required version", result.stderr) + def test_preserves_line_endings(self) -> None: with TemporaryDirectory() as workspace: test_file = Path(workspace) / "test.py" @@ -1261,23 +1272,25 @@ class BlackTestCase(BlackBaseTestCase): def test_shhh_click(self) -> None: try: - from click import _unicodefun - except ModuleNotFoundError: + from click import _unicodefun # type: ignore + except ImportError: self.skipTest("Incompatible Click version") - if not hasattr(_unicodefun, "_verify_python3_env"): + + if not hasattr(_unicodefun, "_verify_python_env"): self.skipTest("Incompatible Click version") + # First, let's see if Click is crashing with a preferred ASCII charset. with patch("locale.getpreferredencoding") as gpe: gpe.return_value = "ASCII" with self.assertRaises(RuntimeError): - _unicodefun._verify_python3_env() # type: ignore + _unicodefun._verify_python_env() # Now, let's silence Click... black.patch_click() # ...and confirm it's silent. with patch("locale.getpreferredencoding") as gpe: gpe.return_value = "ASCII" try: - _unicodefun._verify_python3_env() # type: ignore + _unicodefun._verify_python_env() except RuntimeError as re: self.fail(f"`patch_click()` failed, exception still raised: {re}") @@ -1315,6 +1328,7 @@ class BlackTestCase(BlackBaseTestCase): self.assertEqual(config["color"], True) self.assertEqual(config["line_length"], 79) self.assertEqual(config["target_version"], ["py36", "py37", "py38"]) + self.assertEqual(config["python_cell_magics"], ["custom1", "custom2"]) self.assertEqual(config["exclude"], r"\.pyi?$") self.assertEqual(config["include"], r"\.py?$") @@ -1350,10 +1364,32 @@ class BlackTestCase(BlackBaseTestCase): src_python.touch() self.assertEqual( - black.find_project_root((src_dir, test_dir)), root.resolve() + black.find_project_root((src_dir, test_dir)), + (root.resolve(), "pyproject.toml"), + ) + self.assertEqual( + black.find_project_root((src_dir,)), + (src_dir.resolve(), "pyproject.toml"), + ) + self.assertEqual( + black.find_project_root((src_python,)), + (src_dir.resolve(), "pyproject.toml"), ) - self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve()) - self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve()) + + @patch( + "black.files.find_user_pyproject_toml", + ) + def test_find_pyproject_toml(self, find_user_pyproject_toml: MagicMock) -> None: + find_user_pyproject_toml.side_effect = RuntimeError() + + with redirect_stderr(io.StringIO()) as stderr: + result = black.files.find_pyproject_toml( + path_search_start=(str(Path.cwd().root),) + ) + + assert result is None + err = stderr.getvalue() + assert "Ignoring user configuration" in err @patch( "black.files.find_user_pyproject_toml", @@ -1400,19 +1436,37 @@ class BlackTestCase(BlackBaseTestCase): normalized_path = black.normalize_path_maybe_ignore(path, root, report) self.assertEqual(normalized_path, "workspace/project") + def test_normalize_path_ignore_windows_junctions_outside_of_root(self) -> None: + if system() != "Windows": + return + + with TemporaryDirectory() as workspace: + root = Path(workspace) + junction_dir = root / "junction" + junction_target_outside_of_root = root / ".." + os.system(f"mklink /J {junction_dir} {junction_target_outside_of_root}") + + report = black.Report(verbose=True) + normalized_path = black.normalize_path_maybe_ignore( + junction_dir, root, report + ) + # Manually delete for Python < 3.8 + os.system(f"rmdir {junction_dir}") + + self.assertEqual(normalized_path, None) + def test_newline_comment_interaction(self) -> None: source = "class A:\\\r\n# type: ignore\n pass\n" output = black.format_str(source, mode=DEFAULT_MODE) black.assert_stable(source, output, mode=DEFAULT_MODE) def test_bpo_2142_workaround(self) -> None: - # https://bugs.python.org/issue2142 - source, _ = read_data("missing_final_newline.py") + source, _ = read_data("miscellaneous", "missing_final_newline") # read_data adds a trailing newline source = source.rstrip() - expected, _ = read_data("missing_final_newline.diff") + expected, _ = read_data("miscellaneous", "missing_final_newline.diff") tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False)) diff_header = re.compile( rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d " @@ -1586,6 +1640,33 @@ class BlackTestCase(BlackBaseTestCase): class TestCaching: + def test_get_cache_dir( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + # Create multiple cache directories + workspace1 = tmp_path / "ws1" + workspace1.mkdir() + workspace2 = tmp_path / "ws2" + workspace2.mkdir() + + # Force user_cache_dir to use the temporary directory for easier assertions + patch_user_cache_dir = patch( + target="black.cache.user_cache_dir", + autospec=True, + return_value=str(workspace1), + ) + + # If BLACK_CACHE_DIR is not set, use user_cache_dir + monkeypatch.delenv("BLACK_CACHE_DIR", raising=False) + with patch_user_cache_dir: + assert get_cache_dir() == workspace1 + + # If it is set, use the path provided in the env var. + monkeypatch.setenv("BLACK_CACHE_DIR", str(workspace2)) + assert get_cache_dir() == workspace2 + def test_cache_broken_file(self) -> None: mode = DEFAULT_MODE with cache_dir() as workspace: @@ -1611,7 +1692,7 @@ class TestCaching: def test_cache_multiple_files(self) -> None: mode = DEFAULT_MODE with cache_dir() as workspace, patch( - "black.ProcessPoolExecutor", new=ThreadPoolExecutor + "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor ): one = (workspace / "one.py").resolve() with one.open("w") as fobj: @@ -1720,7 +1801,7 @@ class TestCaching: def test_failed_formatting_does_not_get_cached(self) -> None: mode = DEFAULT_MODE with cache_dir() as workspace, patch( - "black.ProcessPoolExecutor", new=ThreadPoolExecutor + "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor ): failing = (workspace / "failing.py").resolve() with failing.open("w") as fobj: @@ -1756,6 +1837,7 @@ def assert_collected_sources( src: Sequence[Union[str, Path]], expected: Sequence[Union[str, Path]], *, + ctx: Optional[FakeContext] = None, exclude: Optional[str] = None, include: Optional[str] = None, extend_exclude: Optional[str] = None, @@ -1771,7 +1853,7 @@ def assert_collected_sources( ) gs_force_exclude = None if force_exclude is None else compile_pattern(force_exclude) collected = black.get_sources( - ctx=FakeContext(), + ctx=ctx or FakeContext(), src=gs_src, quiet=False, verbose=False, @@ -1807,9 +1889,11 @@ class TestFileCollection: base / "b/.definitely_exclude/a.pyi", ] src = [base / "b/"] - assert_collected_sources(src, expected, extend_exclude=r"/exclude/") + ctx = FakeContext() + ctx.obj["root"] = base + assert_collected_sources(src, expected, ctx=ctx, extend_exclude=r"/exclude/") - @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None)) def test_exclude_for_issue_1572(self) -> None: # Exclude shouldn't touch files that were explicitly given to Black through the # CLI. Exclude is supposed to only apply to the recursive discovery of files. @@ -1946,7 +2030,6 @@ class TestFileCollection: path.iterdir.return_value = [child] child.resolve.return_value = Path("/a/b/c") child.as_posix.return_value = "/a/b/c" - child.is_symlink.return_value = True try: list( black.gen_python_files( @@ -1966,39 +2049,14 @@ class TestFileCollection: pytest.fail(f"`get_python_files_in_dir()` failed: {ve}") path.iterdir.assert_called_once() child.resolve.assert_called_once() - child.is_symlink.assert_called_once() - # `child` should behave like a strange file which resolved path is clearly - # outside of the `root` directory. - child.is_symlink.return_value = False - with pytest.raises(ValueError): - list( - black.gen_python_files( - path.iterdir(), - root, - include, - exclude, - None, - None, - report, - gitignore, - verbose=False, - quiet=False, - ) - ) - path.iterdir.assert_called() - assert path.iterdir.call_count == 2 - child.resolve.assert_called() - assert child.resolve.call_count == 2 - child.is_symlink.assert_called() - assert child.is_symlink.call_count == 2 - - @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + + @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None)) def test_get_sources_with_stdin(self) -> None: src = ["-"] expected = ["-"] assert_collected_sources(src, expected, include="", exclude=r"/exclude/|a\.py") - @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None)) def test_get_sources_with_stdin_filename(self) -> None: src = ["-"] stdin_filename = str(THIS_DIR / "data/collections.py") @@ -2010,7 +2068,7 @@ class TestFileCollection: stdin_filename=stdin_filename, ) - @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None)) def test_get_sources_with_stdin_filename_and_exclude(self) -> None: # Exclude shouldn't exclude stdin_filename since it is mimicking the # file being passed directly. This is the same as @@ -2026,7 +2084,7 @@ class TestFileCollection: stdin_filename=stdin_filename, ) - @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None)) def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None: # Extend exclude shouldn't exclude stdin_filename since it is mimicking the # file being passed directly. This is the same as @@ -2042,7 +2100,7 @@ class TestFileCollection: stdin_filename=stdin_filename, ) - @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + @patch("black.find_project_root", lambda *args: (THIS_DIR.resolve(), None)) def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None: # Force exclude should exclude the file when passing it through # stdin_filename