X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/c0a7582e3d4cc8bec3b7f5a6c52b36880dcb57d7..48dfda084a8c854a73c11b6c91c67deae5f86ca3:/tests/test_black.py diff --git a/tests/test_black.py b/tests/test_black.py index a0e57dc..c603233 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -1,76 +1,67 @@ #!/usr/bin/env python3 +import multiprocessing import asyncio import logging from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from functools import partial +from dataclasses import replace +import inspect from io import BytesIO, TextIOWrapper import os from pathlib import Path +from platform import system import regex as re import sys from tempfile import TemporaryDirectory -from typing import Any, BinaryIO, Generator, List, Tuple, Iterator, TypeVar +import types +from typing import ( + Any, + BinaryIO, + Callable, + Dict, + Generator, + List, + Iterator, + TypeVar, +) import unittest from unittest.mock import patch, MagicMock +import click from click import unstyle from click.testing import CliRunner import black from black import Feature, TargetVersion -try: - import blackd - from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop - from aiohttp import web -except ImportError: - has_blackd_deps = False -else: - has_blackd_deps = True - from pathspec import PathSpec -ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True) -fs = partial(black.format_str, mode=black.FileMode()) +# Import other test classes +from tests.util import ( + THIS_DIR, + read_data, + DETERMINISTIC_HEADER, + BlackBaseTestCase, + DEFAULT_MODE, + fs, + ff, + dump_to_stderr, +) +from .test_primer import PrimerCLITests # noqa: F401 + + THIS_FILE = Path(__file__) -THIS_DIR = THIS_FILE.parent -DETERMINISTIC_HEADER = "[Deterministic header]" -EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" -PY36_ARGS = [ - f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS -] +PY36_VERSIONS = { + TargetVersion.PY36, + TargetVersion.PY37, + TargetVersion.PY38, + TargetVersion.PY39, +} +PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS] T = TypeVar("T") R = TypeVar("R") -def dump_to_stderr(*output: str) -> str: - return "\n" + "\n".join(output) + "\n" - - -def read_data(name: str, data: bool = True) -> Tuple[str, str]: - """read_data('test_name') -> 'input', 'output'""" - if not name.endswith((".py", ".pyi", ".out", ".diff")): - name += ".py" - _input: List[str] = [] - _output: List[str] = [] - base_dir = THIS_DIR / "data" if data else THIS_DIR - with open(base_dir / name, "r", encoding="utf8") as test: - lines = test.readlines() - result = _input - for line in lines: - line = line.replace(EMPTY_LINE, "") - if line.rstrip() == "# output": - result = _output - continue - - result.append(line) - if _input and not _output: - # If there's no output marker, treat the entire file as already pre-formatted. - _output = _input[:] - return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n" - - @contextmanager def cache_dir(exists: bool = True) -> Iterator[Path]: with TemporaryDirectory() as workspace: @@ -82,7 +73,7 @@ def cache_dir(exists: bool = True) -> Iterator[Path]: @contextmanager -def event_loop(close: bool) -> Iterator[None]: +def event_loop() -> Iterator[None]: policy = asyncio.get_event_loop_policy() loop = policy.new_event_loop() asyncio.set_event_loop(loop) @@ -90,19 +81,21 @@ def event_loop(close: bool) -> Iterator[None]: yield finally: - if close: - loop.close() + loop.close() -@contextmanager -def skip_if_exception(e: str) -> Iterator[None]: - try: - yield - except Exception as exc: - if exc.__class__.__name__ == e: - unittest.skip(f"Encountered expected exception {exc}, skipping") - else: - raise +class FakeContext(click.Context): + """A fake click Context for when calling functions that need it.""" + + def __init__(self) -> None: + self.default_map: Dict[str, Any] = {} + + +class FakeParameter(click.Parameter): + """A fake click Parameter for when calling functions that need it.""" + + def __init__(self) -> None: + pass class BlackRunner(CliRunner): @@ -130,46 +123,24 @@ class BlackRunner(CliRunner): sys.stderr = hold_stderr -class BlackTestCase(unittest.TestCase): - maxDiff = None - - def assertFormatEqual(self, expected: str, actual: str) -> None: - if actual != expected and not os.environ.get("SKIP_AST_PRINT"): - bdv: black.DebugVisitor[Any] - black.out("Expected tree:", fg="green") - try: - exp_node = black.lib2to3_parse(expected) - bdv = black.DebugVisitor() - list(bdv.visit(exp_node)) - except Exception as ve: - black.err(str(ve)) - black.out("Actual tree:", fg="red") - try: - exp_node = black.lib2to3_parse(actual) - bdv = black.DebugVisitor() - list(bdv.visit(exp_node)) - except Exception as ve: - black.err(str(ve)) - self.assertEqual(expected, actual) - +class BlackTestCase(BlackBaseTestCase): def invokeBlack( self, args: List[str], exit_code: int = 0, ignore_config: bool = True ) -> None: runner = BlackRunner() if ignore_config: - args = ["--config", str(THIS_DIR / "empty.toml"), *args] + args = ["--verbose", "--config", str(THIS_DIR / "empty.toml"), *args] result = runner.invoke(black.main, args) - self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode()) - - @patch("black.dump_to_file", dump_to_stderr) - def checkSourceFile(self, name: str) -> None: - path = THIS_DIR.parent / name - source, expected = read_data(str(path), data=False) - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - self.assertFalse(ff(path)) + self.assertEqual( + result.exit_code, + exit_code, + msg=( + f"Failed with args: {args}\n" + f"stdout: {runner.stdout_bytes.decode()!r}\n" + f"stderr: {runner.stderr_bytes.decode()!r}\n" + f"exception: {result.exception}" + ), + ) @patch("black.dump_to_file", dump_to_stderr) def test_empty(self) -> None: @@ -177,7 +148,7 @@ class BlackTestCase(unittest.TestCase): actual = fs(source) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) def test_empty_ff(self) -> None: expected = "" @@ -190,47 +161,8 @@ class BlackTestCase(unittest.TestCase): os.unlink(tmp_file) self.assertFormatEqual(expected, actual) - def test_self(self) -> None: - self.checkSourceFile("tests/test_black.py") - - def test_black(self) -> None: - self.checkSourceFile("black.py") - - def test_pygram(self) -> None: - self.checkSourceFile("blib2to3/pygram.py") - - def test_pytree(self) -> None: - self.checkSourceFile("blib2to3/pytree.py") - - def test_conv(self) -> None: - self.checkSourceFile("blib2to3/pgen2/conv.py") - - def test_driver(self) -> None: - self.checkSourceFile("blib2to3/pgen2/driver.py") - - def test_grammar(self) -> None: - self.checkSourceFile("blib2to3/pgen2/grammar.py") - - def test_literals(self) -> None: - self.checkSourceFile("blib2to3/pgen2/literals.py") - - def test_parse(self) -> None: - self.checkSourceFile("blib2to3/pgen2/parse.py") - - def test_pgen(self) -> None: - self.checkSourceFile("blib2to3/pgen2/pgen.py") - - def test_tokenize(self) -> None: - self.checkSourceFile("blib2to3/pgen2/tokenize.py") - - def test_token(self) -> None: - self.checkSourceFile("blib2to3/pgen2/token.py") - - def test_setup(self) -> None: - self.checkSourceFile("setup.py") - def test_piping(self) -> None: - source, expected = read_data("../black", data=False) + source, expected = read_data("src/black/__init__", data=False) result = BlackRunner().invoke( black.main, ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"], @@ -239,12 +171,12 @@ class BlackTestCase(unittest.TestCase): self.assertEqual(result.exit_code, 0) self.assertFormatEqual(expected, result.output) black.assert_equivalent(source, result.output) - black.assert_stable(source, result.output, black.FileMode()) + black.assert_stable(source, result.output, DEFAULT_MODE) def test_piping_diff(self) -> None: diff_header = re.compile( - rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d " - rf"\+\d\d\d\d" + r"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d " + r"\+\d\d\d\d" ) source, _ = read_data("expression.py") expected, _ = read_data("expression.diff") @@ -287,43 +219,56 @@ class BlackTestCase(unittest.TestCase): self.assertIn("\033[0m", actual) @patch("black.dump_to_file", dump_to_stderr) - def test_function(self) -> None: - source, expected = read_data("function") - actual = fs(source) + def _test_wip(self) -> None: + source, expected = read_data("wip") + sys.settrace(tracefunc) + mode = replace( + DEFAULT_MODE, + experimental_string_processing=False, + target_versions={black.TargetVersion.PY38}, + ) + actual = fs(source, mode=mode) + sys.settrace(None) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) black.assert_stable(source, actual, black.FileMode()) + @unittest.expectedFailure @patch("black.dump_to_file", dump_to_stderr) - def test_function2(self) -> None: - source, expected = read_data("function2") + def test_trailing_comma_optional_parens_stability1(self) -> None: + source, _expected = read_data("trailing_comma_optional_parens1") actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) + @unittest.expectedFailure @patch("black.dump_to_file", dump_to_stderr) - def test_function_trailing_comma(self) -> None: - source, expected = read_data("function_trailing_comma") + def test_trailing_comma_optional_parens_stability2(self) -> None: + source, _expected = read_data("trailing_comma_optional_parens2") actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) + @unittest.expectedFailure @patch("black.dump_to_file", dump_to_stderr) - def test_expression(self) -> None: - source, expected = read_data("expression") + def test_trailing_comma_optional_parens_stability3(self) -> None: + source, _expected = read_data("trailing_comma_optional_parens3") actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) @patch("black.dump_to_file", dump_to_stderr) def test_pep_572(self) -> None: source, expected = read_data("pep_572") actual = fs(source) self.assertFormatEqual(expected, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) + if sys.version_info >= (3, 8): + black.assert_equivalent(source, actual) + + @patch("black.dump_to_file", dump_to_stderr) + def test_pep_572_remove_parens(self) -> None: + source, expected = read_data("pep_572_remove_parens") + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_stable(source, actual, DEFAULT_MODE) if sys.version_info >= (3, 8): black.assert_equivalent(source, actual) @@ -347,10 +292,11 @@ class BlackTestCase(unittest.TestCase): self.assertFormatEqual(expected, actual) with patch("black.dump_to_file", dump_to_stderr): black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) def test_expression_diff(self) -> None: source, _ = read_data("expression.py") + config = THIS_DIR / "data" / "empty_pyproject.toml" expected, _ = read_data("expression.diff") tmp_file = Path(black.dump_to_file(source)) diff_header = re.compile( @@ -358,13 +304,14 @@ class BlackTestCase(unittest.TestCase): r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d" ) try: - result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)]) + result = BlackRunner().invoke( + black.main, ["--diff", str(tmp_file), f"--config={config}"] + ) self.assertEqual(result.exit_code, 0) finally: os.unlink(tmp_file) actual = result.output actual = diff_header.sub(DETERMINISTIC_HEADER, actual) - actual = actual.rstrip() + "\n" # the diff output has a trailing space if expected != actual: dump = black.dump_to_file(actual) msg = ( @@ -376,11 +323,12 @@ class BlackTestCase(unittest.TestCase): def test_expression_diff_with_color(self) -> None: source, _ = read_data("expression.py") + config = THIS_DIR / "data" / "empty_pyproject.toml" expected, _ = read_data("expression.diff") tmp_file = Path(black.dump_to_file(source)) try: result = BlackRunner().invoke( - black.main, ["--diff", "--color", str(tmp_file)] + black.main, ["--diff", "--color", str(tmp_file), f"--config={config}"] ) finally: os.unlink(tmp_file) @@ -393,20 +341,12 @@ class BlackTestCase(unittest.TestCase): self.assertIn("\033[31m", actual) self.assertIn("\033[0m", actual) - @patch("black.dump_to_file", dump_to_stderr) - def test_fstring(self) -> None: - source, expected = read_data("fstring") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - @patch("black.dump_to_file", dump_to_stderr) def test_pep_570(self) -> None: source, expected = read_data("pep_570") actual = fs(source) self.assertFormatEqual(expected, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) if sys.version_info >= (3, 8): black.assert_equivalent(source, actual) @@ -424,171 +364,35 @@ class BlackTestCase(unittest.TestCase): actual = fs(source) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - mode = black.FileMode(string_normalization=False) + black.assert_stable(source, actual, DEFAULT_MODE) + mode = replace(DEFAULT_MODE, string_normalization=False) not_normalized = fs(source, mode=mode) self.assertFormatEqual(source.replace("\\\n", ""), not_normalized) black.assert_equivalent(source, not_normalized) black.assert_stable(source, not_normalized, mode=mode) @patch("black.dump_to_file", dump_to_stderr) - def test_docstring(self) -> None: - source, expected = read_data("docstring") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - def test_long_strings(self) -> None: - """Tests for splitting long strings.""" - source, expected = read_data("long_strings") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_long_strings__edge_case(self) -> None: - """Edge-case tests for splitting long strings.""" - source, expected = read_data("long_strings__edge_case") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_long_strings__regression(self) -> None: - """Regression tests for splitting long strings.""" - source, expected = read_data("long_strings__regression") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_slices(self) -> None: - source, expected = read_data("slices") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments(self) -> None: - source, expected = read_data("comments") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments2(self) -> None: - source, expected = read_data("comments2") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments3(self) -> None: - source, expected = read_data("comments3") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments4(self) -> None: - source, expected = read_data("comments4") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments5(self) -> None: - source, expected = read_data("comments5") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments6(self) -> None: - source, expected = read_data("comments6") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comments7(self) -> None: - source, expected = read_data("comments7") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_comment_after_escaped_newline(self) -> None: - source, expected = read_data("comment_after_escaped_newline") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_cantfit(self) -> None: - source, expected = read_data("cantfit") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_import_spacing(self) -> None: - source, expected = read_data("import_spacing") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_composition(self) -> None: - source, expected = read_data("composition") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_empty_lines(self) -> None: - source, expected = read_data("empty_lines") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_remove_parens(self) -> None: - source, expected = read_data("remove_parens") - actual = fs(source) + def test_docstring_no_string_normalization(self) -> None: + """Like test_docstring but with string normalization off.""" + source, expected = read_data("docstring_no_string_normalization") + mode = replace(DEFAULT_MODE, string_normalization=False) + actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, mode) - @patch("black.dump_to_file", dump_to_stderr) - def test_string_prefixes(self) -> None: - source, expected = read_data("string_prefixes") - actual = fs(source) + def test_long_strings_flag_disabled(self) -> None: + """Tests for turning off the string processing logic.""" + source, expected = read_data("long_strings_flag_disabled") + mode = replace(DEFAULT_MODE, experimental_string_processing=False) + actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(expected, actual, mode) @patch("black.dump_to_file", dump_to_stderr) def test_numeric_literals(self) -> None: source, expected = read_data("numeric_literals") - mode = black.FileMode(target_versions=black.PY36_VERSIONS) + mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS) actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) @@ -597,47 +401,49 @@ class BlackTestCase(unittest.TestCase): @patch("black.dump_to_file", dump_to_stderr) def test_numeric_literals_ignoring_underscores(self) -> None: source, expected = read_data("numeric_literals_skip_underscores") - mode = black.FileMode(target_versions=black.PY36_VERSIONS) + mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS) actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) black.assert_stable(source, actual, mode) - @patch("black.dump_to_file", dump_to_stderr) - def test_numeric_literals_py2(self) -> None: - source, expected = read_data("numeric_literals_py2") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_python2(self) -> None: - source, expected = read_data("python2") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + def test_skip_magic_trailing_comma(self) -> None: + source, _ = read_data("expression.py") + expected, _ = read_data("expression_skip_magic_trailing_comma.diff") + tmp_file = Path(black.dump_to_file(source)) + diff_header = re.compile( + rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d " + r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d" + ) + try: + result = BlackRunner().invoke(black.main, ["-C", "--diff", str(tmp_file)]) + self.assertEqual(result.exit_code, 0) + finally: + os.unlink(tmp_file) + actual = result.output + actual = diff_header.sub(DETERMINISTIC_HEADER, actual) + actual = actual.rstrip() + "\n" # the diff output has a trailing space + if expected != actual: + dump = black.dump_to_file(actual) + msg = ( + "Expected diff isn't equal to the actual. If you made changes to" + " expression.py and this is an anticipated difference, overwrite" + f" tests/data/expression_skip_magic_trailing_comma.diff with {dump}" + ) + self.assertEqual(expected, actual, msg) @patch("black.dump_to_file", dump_to_stderr) def test_python2_print_function(self) -> None: source, expected = read_data("python2_print_function") - mode = black.FileMode(target_versions={TargetVersion.PY27}) + mode = replace(DEFAULT_MODE, target_versions={TargetVersion.PY27}) actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) black.assert_equivalent(source, actual) black.assert_stable(source, actual, mode) - @patch("black.dump_to_file", dump_to_stderr) - def test_python2_unicode_literals(self) -> None: - source, expected = read_data("python2_unicode_literals") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - @patch("black.dump_to_file", dump_to_stderr) def test_stub(self) -> None: - mode = black.FileMode(is_pyi=True) + mode = replace(DEFAULT_MODE, is_pyi=True) source, expected = read_data("stub.pyi") actual = fs(source, mode=mode) self.assertFormatEqual(expected, actual) @@ -652,7 +458,7 @@ class BlackTestCase(unittest.TestCase): major, minor = sys.version_info[:2] if major < 3 or (major <= 3 and minor < 7): black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) # ensure black can parse this when the target is 3.6 self.invokeBlack([str(source_path), "--target-version", "py36"]) # but not on 3.7, because async/await is no longer an identifier @@ -667,7 +473,7 @@ class BlackTestCase(unittest.TestCase): major, minor = sys.version_info[:2] if major > 3 or (major == 3 and minor >= 7): black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) # ensure black can parse this when the target is 3.7 self.invokeBlack([str(source_path), "--target-version", "py37"]) # but not on 3.6, because we use async as a reserved keyword @@ -681,79 +487,17 @@ class BlackTestCase(unittest.TestCase): major, minor = sys.version_info[:2] if major > 3 or (major == 3 and minor >= 8): black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_fmtonoff(self) -> None: - source, expected = read_data("fmtonoff") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_fmtonoff2(self) -> None: - source, expected = read_data("fmtonoff2") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_fmtonoff3(self) -> None: - source, expected = read_data("fmtonoff3") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_fmtonoff4(self) -> None: - source, expected = read_data("fmtonoff4") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_remove_empty_parentheses_after_class(self) -> None: - source, expected = read_data("class_blank_parentheses") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_new_line_between_class_and_code(self) -> None: - source, expected = read_data("class_methods_new_line") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + black.assert_stable(source, actual, DEFAULT_MODE) @patch("black.dump_to_file", dump_to_stderr) - def test_bracket_match(self) -> None: - source, expected = read_data("bracketmatch") + def test_python39(self) -> None: + source, expected = read_data("python39") actual = fs(source) self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_tuple_assign(self) -> None: - source, expected = read_data("tupleassign") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) - - @patch("black.dump_to_file", dump_to_stderr) - def test_beginning_backslash(self) -> None: - source, expected = read_data("beginning_backslash") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + major, minor = sys.version_info[:2] + if major > 3 or (major == 3 and minor >= 9): + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, DEFAULT_MODE) def test_tab_comment_indentation(self) -> None: contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n" @@ -1095,6 +839,39 @@ class BlackTestCase(unittest.TestCase): black.lib2to3_parse(py3_only, {TargetVersion.PY36}) black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36}) + def test_get_features_used_decorator(self) -> None: + # Test the feature detection of new decorator syntax + # since this makes some test cases of test_get_features_used() + # fails if it fails, this is tested first so that a useful case + # is identified + simples, relaxed = read_data("decorators") + # skip explanation comments at the top of the file + for simple_test in simples.split("##")[1:]: + node = black.lib2to3_parse(simple_test) + decorator = str(node.children[0].children[0]).strip() + self.assertNotIn( + Feature.RELAXED_DECORATORS, + black.get_features_used(node), + msg=( + f"decorator '{decorator}' follows python<=3.8 syntax" + "but is detected as 3.9+" + # f"The full node is\n{node!r}" + ), + ) + # skip the '# output' comment at the top of the output part + for relaxed_test in relaxed.split("##")[1:]: + node = black.lib2to3_parse(relaxed_test) + decorator = str(node.children[0].children[0]).strip() + self.assertIn( + Feature.RELAXED_DECORATORS, + black.get_features_used(node), + msg=( + f"decorator '{decorator}' uses python3.9+ syntax" + "but is detected as python<=3.8" + # f"The full node is\n{node!r}" + ), + ) + def test_get_features_used(self) -> None: node = black.lib2to3_parse("def f(*, arg): ...\n") self.assertEqual(black.get_features_used(node), set()) @@ -1182,7 +959,7 @@ class BlackTestCase(unittest.TestCase): def test_format_file_contents(self) -> None: empty = "" - mode = black.FileMode() + mode = DEFAULT_MODE with self.assertRaises(black.NothingChanged): black.format_file_contents(empty, mode=mode, fast=False) just_nl = "\n" @@ -1227,7 +1004,7 @@ class BlackTestCase(unittest.TestCase): self.assertEqual("".join(err_lines), "") def test_cache_broken_file(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir() as workspace: cache_file = black.get_cache_file(mode) with cache_file.open("w") as fobj: @@ -1238,10 +1015,10 @@ class BlackTestCase(unittest.TestCase): fobj.write("print('hello')") self.invokeBlack([str(src)]) cache = black.read_cache(mode) - self.assertIn(src, cache) + self.assertIn(str(src), cache) def test_cache_single_file_already_cached(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir() as workspace: src = (workspace / "test.py").resolve() with src.open("w") as fobj: @@ -1251,9 +1028,9 @@ class BlackTestCase(unittest.TestCase): with src.open("r") as fobj: self.assertEqual(fobj.read(), "print('hello')") - @event_loop(close=False) + @event_loop() def test_cache_multiple_files(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir() as workspace, patch( "black.ProcessPoolExecutor", new=ThreadPoolExecutor ): @@ -1270,42 +1047,67 @@ class BlackTestCase(unittest.TestCase): with two.open("r") as fobj: self.assertEqual(fobj.read(), 'print("hello")\n') cache = black.read_cache(mode) - self.assertIn(one, cache) - self.assertIn(two, cache) + self.assertIn(str(one), cache) + self.assertIn(str(two), cache) - @patch("black.ProcessPoolExecutor", autospec=True) - def test_works_in_mono_process_only_environment(self, mock_executor) -> None: - mock_executor.side_effect = OSError() - mode = black.FileMode() + def test_no_cache_when_writeback_diff(self) -> None: + mode = DEFAULT_MODE with cache_dir() as workspace: - one = (workspace / "one.py").resolve() - with one.open("w") as fobj: - fobj.write("print('hello')") - two = (workspace / "two.py").resolve() - with two.open("w") as fobj: + src = (workspace / "test.py").resolve() + with src.open("w") as fobj: fobj.write("print('hello')") - black.write_cache({}, [one], mode) - self.invokeBlack([str(workspace)]) - with one.open("r") as fobj: - self.assertEqual(fobj.read(), "print('hello')") - with two.open("r") as fobj: - self.assertEqual(fobj.read(), 'print("hello")\n') - cache = black.read_cache(mode) - self.assertIn(one, cache) - self.assertIn(two, cache) - - def test_no_cache_when_writeback_diff(self) -> None: - mode = black.FileMode() + with patch("black.read_cache") as read_cache, patch( + "black.write_cache" + ) as write_cache: + self.invokeBlack([str(src), "--diff"]) + cache_file = black.get_cache_file(mode) + self.assertFalse(cache_file.exists()) + write_cache.assert_not_called() + read_cache.assert_not_called() + + def test_no_cache_when_writeback_color_diff(self) -> None: + mode = DEFAULT_MODE with cache_dir() as workspace: src = (workspace / "test.py").resolve() with src.open("w") as fobj: fobj.write("print('hello')") - self.invokeBlack([str(src), "--diff"]) - cache_file = black.get_cache_file(mode) - self.assertFalse(cache_file.exists()) + with patch("black.read_cache") as read_cache, patch( + "black.write_cache" + ) as write_cache: + self.invokeBlack([str(src), "--diff", "--color"]) + cache_file = black.get_cache_file(mode) + self.assertFalse(cache_file.exists()) + write_cache.assert_not_called() + read_cache.assert_not_called() + + @event_loop() + def test_output_locking_when_writeback_diff(self) -> None: + with cache_dir() as workspace: + for tag in range(0, 4): + src = (workspace / f"test{tag}.py").resolve() + with src.open("w") as fobj: + fobj.write("print('hello')") + with patch("black.Manager", wraps=multiprocessing.Manager) as mgr: + self.invokeBlack(["--diff", str(workspace)], exit_code=0) + # this isn't quite doing what we want, but if it _isn't_ + # called then we cannot be using the lock it provides + mgr.assert_called() + + @event_loop() + def test_output_locking_when_writeback_color_diff(self) -> None: + with cache_dir() as workspace: + for tag in range(0, 4): + src = (workspace / f"test{tag}.py").resolve() + with src.open("w") as fobj: + fobj.write("print('hello')") + with patch("black.Manager", wraps=multiprocessing.Manager) as mgr: + self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0) + # this isn't quite doing what we want, but if it _isn't_ + # called then we cannot be using the lock it provides + mgr.assert_called() def test_no_cache_when_stdin(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir(): result = CliRunner().invoke( black.main, ["-"], input=BytesIO(b"print('hello')") @@ -1315,19 +1117,19 @@ class BlackTestCase(unittest.TestCase): self.assertFalse(cache_file.exists()) def test_read_cache_no_cachefile(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir(): self.assertEqual(black.read_cache(mode), {}) def test_write_cache_read_cache(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir() as workspace: src = (workspace / "test.py").resolve() src.touch() black.write_cache({}, [src], mode) cache = black.read_cache(mode) - self.assertIn(src, cache) - self.assertEqual(cache[src], black.get_cache_info(src)) + self.assertIn(str(src), cache) + self.assertEqual(cache[str(src)], black.get_cache_info(src)) def test_filter_cached(self) -> None: with TemporaryDirectory() as workspace: @@ -1338,7 +1140,10 @@ class BlackTestCase(unittest.TestCase): uncached.touch() cached.touch() cached_but_changed.touch() - cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)} + cache = { + str(cached): black.get_cache_info(cached), + str(cached_but_changed): (0.0, 0), + } todo, done = black.filter_cached( cache, {uncached, cached, cached_but_changed} ) @@ -1346,15 +1151,15 @@ class BlackTestCase(unittest.TestCase): self.assertEqual(done, {cached}) def test_write_cache_creates_directory_if_needed(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir(exists=False) as workspace: self.assertFalse(workspace.exists()) black.write_cache({}, [], mode) self.assertTrue(workspace.exists()) - @event_loop(close=False) + @event_loop() def test_failed_formatting_does_not_get_cached(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir() as workspace, patch( "black.ProcessPoolExecutor", new=ThreadPoolExecutor ): @@ -1366,16 +1171,27 @@ class BlackTestCase(unittest.TestCase): fobj.write('print("hello")\n') self.invokeBlack([str(workspace)], exit_code=123) cache = black.read_cache(mode) - self.assertNotIn(failing, cache) - self.assertIn(clean, cache) + self.assertNotIn(str(failing), cache) + self.assertIn(str(clean), cache) def test_write_cache_write_fail(self) -> None: - mode = black.FileMode() + mode = DEFAULT_MODE with cache_dir(), patch.object(Path, "open") as mock: mock.side_effect = OSError black.write_cache({}, [], mode) - @event_loop(close=False) + @event_loop() + @patch("black.ProcessPoolExecutor", MagicMock(side_effect=OSError)) + def test_works_in_mono_process_only_environment(self) -> None: + with cache_dir() as workspace: + for f in [ + (workspace / "one.py").resolve(), + (workspace / "two.py").resolve(), + ]: + f.write_text('print("hello")\n') + self.invokeBlack([str(workspace)]) + + @event_loop() def test_check_diff_use_together(self) -> None: with cache_dir(): # Files which will be reformatted. @@ -1402,27 +1218,19 @@ class BlackTestCase(unittest.TestCase): self.invokeBlack([str(workspace.resolve())]) def test_read_cache_line_lengths(self) -> None: - mode = black.FileMode() - short_mode = black.FileMode(line_length=1) + mode = DEFAULT_MODE + short_mode = replace(DEFAULT_MODE, line_length=1) with cache_dir() as workspace: path = (workspace / "file.py").resolve() path.touch() black.write_cache({}, [path], mode) one = black.read_cache(mode) - self.assertIn(path, one) + self.assertIn(str(path), one) two = black.read_cache(short_mode) - self.assertNotIn(path, two) - - def test_tricky_unicode_symbols(self) -> None: - source, expected = read_data("tricky_unicode_symbols") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + self.assertNotIn(str(path), two) def test_single_file_force_pyi(self) -> None: - reg_mode = black.FileMode() - pyi_mode = black.FileMode(is_pyi=True) + pyi_mode = replace(DEFAULT_MODE, is_pyi=True) contents, expected = read_data("force_pyi") with cache_dir() as workspace: path = (workspace / "file.py").resolve() @@ -1433,15 +1241,17 @@ class BlackTestCase(unittest.TestCase): actual = fh.read() # verify cache with --pyi is separate pyi_cache = black.read_cache(pyi_mode) - self.assertIn(path, pyi_cache) - normal_cache = black.read_cache(reg_mode) - self.assertNotIn(path, normal_cache) - self.assertEqual(actual, expected) + self.assertIn(str(path), pyi_cache) + normal_cache = black.read_cache(DEFAULT_MODE) + self.assertNotIn(str(path), normal_cache) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(contents, actual) + black.assert_stable(contents, actual, pyi_mode) - @event_loop(close=False) + @event_loop() def test_multi_file_force_pyi(self) -> None: - reg_mode = black.FileMode() - pyi_mode = black.FileMode(is_pyi=True) + reg_mode = DEFAULT_MODE + pyi_mode = replace(DEFAULT_MODE, is_pyi=True) contents, expected = read_data("force_pyi") with cache_dir() as workspace: paths = [ @@ -1460,8 +1270,8 @@ class BlackTestCase(unittest.TestCase): pyi_cache = black.read_cache(pyi_mode) normal_cache = black.read_cache(reg_mode) for path in paths: - self.assertIn(path, pyi_cache) - self.assertNotIn(path, normal_cache) + self.assertIn(str(path), pyi_cache) + self.assertNotIn(str(path), normal_cache) def test_pipe_force_pyi(self) -> None: source, expected = read_data("force_pyi") @@ -1473,8 +1283,8 @@ class BlackTestCase(unittest.TestCase): self.assertFormatEqual(actual, expected) def test_single_file_force_py36(self) -> None: - reg_mode = black.FileMode() - py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS) + reg_mode = DEFAULT_MODE + py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS) source, expected = read_data("force_py36") with cache_dir() as workspace: path = (workspace / "file.py").resolve() @@ -1485,15 +1295,15 @@ class BlackTestCase(unittest.TestCase): actual = fh.read() # verify cache with --target-version is separate py36_cache = black.read_cache(py36_mode) - self.assertIn(path, py36_cache) + self.assertIn(str(path), py36_cache) normal_cache = black.read_cache(reg_mode) - self.assertNotIn(path, normal_cache) + self.assertNotIn(str(path), normal_cache) self.assertEqual(actual, expected) - @event_loop(close=False) + @event_loop() def test_multi_file_force_py36(self) -> None: - reg_mode = black.FileMode() - py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS) + reg_mode = DEFAULT_MODE + py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS) source, expected = read_data("force_py36") with cache_dir() as workspace: paths = [ @@ -1512,15 +1322,8 @@ class BlackTestCase(unittest.TestCase): pyi_cache = black.read_cache(py36_mode) normal_cache = black.read_cache(reg_mode) for path in paths: - self.assertIn(path, pyi_cache) - self.assertNotIn(path, normal_cache) - - def test_collections(self) -> None: - source, expected = read_data("collections") - actual = fs(source) - self.assertFormatEqual(expected, actual) - black.assert_equivalent(source, actual) - black.assert_stable(source, actual, black.FileMode()) + self.assertIn(str(path), pyi_cache) + self.assertNotIn(str(path), normal_cache) def test_pipe_force_py36(self) -> None: source, expected = read_data("force_py36") @@ -1546,12 +1349,236 @@ class BlackTestCase(unittest.TestCase): ] this_abs = THIS_DIR.resolve() sources.extend( - black.gen_python_files_in_dir( - path, this_abs, include, exclude, report, gitignore + black.gen_python_files( + path.iterdir(), + this_abs, + include, + exclude, + None, + None, + report, + gitignore, ) ) self.assertEqual(sorted(expected), sorted(sources)) + @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + def test_exclude_for_issue_1572(self) -> None: + # Exclude shouldn't touch files that were explicitly given to Black through the + # CLI. Exclude is supposed to only apply to the recursive discovery of files. + # https://github.com/psf/black/issues/1572 + path = THIS_DIR / "data" / "include_exclude_tests" + include = "" + exclude = r"/exclude/|a\.py" + src = str(path / "b/exclude/a.py") + report = black.Report() + expected = [Path(path / "b/exclude/a.py")] + sources = list( + black.get_sources( + ctx=FakeContext(), + src=(src,), + quiet=True, + verbose=False, + include=re.compile(include), + exclude=re.compile(exclude), + extend_exclude=None, + force_exclude=None, + report=report, + stdin_filename=None, + ) + ) + self.assertEqual(sorted(expected), sorted(sources)) + + @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + def test_get_sources_with_stdin(self) -> None: + include = "" + exclude = r"/exclude/|a\.py" + src = "-" + report = black.Report() + expected = [Path("-")] + sources = list( + black.get_sources( + ctx=FakeContext(), + src=(src,), + quiet=True, + verbose=False, + include=re.compile(include), + exclude=re.compile(exclude), + extend_exclude=None, + force_exclude=None, + report=report, + stdin_filename=None, + ) + ) + self.assertEqual(sorted(expected), sorted(sources)) + + @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + def test_get_sources_with_stdin_filename(self) -> None: + include = "" + exclude = r"/exclude/|a\.py" + src = "-" + report = black.Report() + stdin_filename = str(THIS_DIR / "data/collections.py") + expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")] + sources = list( + black.get_sources( + ctx=FakeContext(), + src=(src,), + quiet=True, + verbose=False, + include=re.compile(include), + exclude=re.compile(exclude), + extend_exclude=None, + force_exclude=None, + report=report, + stdin_filename=stdin_filename, + ) + ) + self.assertEqual(sorted(expected), sorted(sources)) + + @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + def test_get_sources_with_stdin_filename_and_exclude(self) -> None: + # Exclude shouldn't exclude stdin_filename since it is mimicing the + # file being passed directly. This is the same as + # test_exclude_for_issue_1572 + path = THIS_DIR / "data" / "include_exclude_tests" + include = "" + exclude = r"/exclude/|a\.py" + src = "-" + report = black.Report() + stdin_filename = str(path / "b/exclude/a.py") + expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")] + sources = list( + black.get_sources( + ctx=FakeContext(), + src=(src,), + quiet=True, + verbose=False, + include=re.compile(include), + exclude=re.compile(exclude), + extend_exclude=None, + force_exclude=None, + report=report, + stdin_filename=stdin_filename, + ) + ) + self.assertEqual(sorted(expected), sorted(sources)) + + @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + def test_get_sources_with_stdin_filename_and_extend_exclude(self) -> None: + # Extend exclude shouldn't exclude stdin_filename since it is mimicking the + # file being passed directly. This is the same as + # test_exclude_for_issue_1572 + path = THIS_DIR / "data" / "include_exclude_tests" + include = "" + extend_exclude = r"/exclude/|a\.py" + src = "-" + report = black.Report() + stdin_filename = str(path / "b/exclude/a.py") + expected = [Path(f"__BLACK_STDIN_FILENAME__{stdin_filename}")] + sources = list( + black.get_sources( + ctx=FakeContext(), + src=(src,), + quiet=True, + verbose=False, + include=re.compile(include), + exclude=re.compile(""), + extend_exclude=re.compile(extend_exclude), + force_exclude=None, + report=report, + stdin_filename=stdin_filename, + ) + ) + self.assertEqual(sorted(expected), sorted(sources)) + + @patch("black.find_project_root", lambda *args: THIS_DIR.resolve()) + def test_get_sources_with_stdin_filename_and_force_exclude(self) -> None: + # Force exclude should exclude the file when passing it through + # stdin_filename + path = THIS_DIR / "data" / "include_exclude_tests" + include = "" + force_exclude = r"/exclude/|a\.py" + src = "-" + report = black.Report() + stdin_filename = str(path / "b/exclude/a.py") + sources = list( + black.get_sources( + ctx=FakeContext(), + src=(src,), + quiet=True, + verbose=False, + include=re.compile(include), + exclude=re.compile(""), + extend_exclude=None, + force_exclude=re.compile(force_exclude), + report=report, + stdin_filename=stdin_filename, + ) + ) + self.assertEqual([], sorted(sources)) + + def test_reformat_one_with_stdin(self) -> None: + with patch( + "black.format_stdin_to_stdout", + return_value=lambda *args, **kwargs: black.Changed.YES, + ) as fsts: + report = MagicMock() + path = Path("-") + black.reformat_one( + path, + fast=True, + write_back=black.WriteBack.YES, + mode=DEFAULT_MODE, + report=report, + ) + fsts.assert_called_once() + report.done.assert_called_with(path, black.Changed.YES) + + def test_reformat_one_with_stdin_filename(self) -> None: + with patch( + "black.format_stdin_to_stdout", + return_value=lambda *args, **kwargs: black.Changed.YES, + ) as fsts: + report = MagicMock() + p = "foo.py" + path = Path(f"__BLACK_STDIN_FILENAME__{p}") + expected = Path(p) + black.reformat_one( + path, + fast=True, + write_back=black.WriteBack.YES, + mode=DEFAULT_MODE, + report=report, + ) + fsts.assert_called_once() + # __BLACK_STDIN_FILENAME__ should have been striped + report.done.assert_called_with(expected, black.Changed.YES) + + def test_reformat_one_with_stdin_and_existing_path(self) -> None: + with patch( + "black.format_stdin_to_stdout", + return_value=lambda *args, **kwargs: black.Changed.YES, + ) as fsts: + report = MagicMock() + # Even with an existing file, since we are forcing stdin, black + # should output to stdout and not modify the file inplace + p = Path(str(THIS_DIR / "data/collections.py")) + # Make sure is_file actually returns True + self.assertTrue(p.is_file()) + path = Path(f"__BLACK_STDIN_FILENAME__{p}") + expected = Path(p) + black.reformat_one( + path, + fast=True, + write_back=black.WriteBack.YES, + mode=DEFAULT_MODE, + report=report, + ) + fsts.assert_called_once() + # __BLACK_STDIN_FILENAME__ should have been striped + report.done.assert_called_with(expected, black.Changed.YES) + def test_gitignore_exclude(self) -> None: path = THIS_DIR / "data" / "include_exclude_tests" include = re.compile(r"\.pyi?$") @@ -1567,8 +1594,15 @@ class BlackTestCase(unittest.TestCase): ] this_abs = THIS_DIR.resolve() sources.extend( - black.gen_python_files_in_dir( - path, this_abs, include, exclude, report, gitignore + black.gen_python_files( + path.iterdir(), + this_abs, + include, + exclude, + None, + None, + report, + gitignore, ) ) self.assertEqual(sorted(expected), sorted(sources)) @@ -1592,46 +1626,45 @@ class BlackTestCase(unittest.TestCase): ] this_abs = THIS_DIR.resolve() sources.extend( - black.gen_python_files_in_dir( - path, + black.gen_python_files( + path.iterdir(), this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), + None, + None, report, gitignore, ) ) self.assertEqual(sorted(expected), sorted(sources)) - def test_empty_exclude(self) -> None: + def test_extend_exclude(self) -> None: path = THIS_DIR / "data" / "include_exclude_tests" report = black.Report() gitignore = PathSpec.from_lines("gitwildmatch", []) - empty = re.compile(r"") sources: List[Path] = [] expected = [ - Path(path / "b/dont_exclude/a.py"), - Path(path / "b/dont_exclude/a.pyi"), Path(path / "b/exclude/a.py"), - Path(path / "b/exclude/a.pyi"), - Path(path / "b/.definitely_exclude/a.py"), - Path(path / "b/.definitely_exclude/a.pyi"), + Path(path / "b/dont_exclude/a.py"), ] this_abs = THIS_DIR.resolve() sources.extend( - black.gen_python_files_in_dir( - path, + black.gen_python_files( + path.iterdir(), this_abs, re.compile(black.DEFAULT_INCLUDES), - empty, + re.compile(r"\.pyi$"), + re.compile(r"\.definitely_exclude"), + None, report, gitignore, ) ) self.assertEqual(sorted(expected), sorted(sources)) - def test_invalid_include_exclude(self) -> None: - for option in ["--include", "--exclude"]: + def test_invalid_cli_regex(self) -> None: + for option in ["--include", "--exclude", "--extend-exclude", "--force-exclude"]: self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2) def test_preserves_line_endings(self) -> None: @@ -1665,7 +1698,7 @@ class BlackTestCase(unittest.TestCase): def test_symlink_out_of_root_directory(self) -> None: path = MagicMock() - root = THIS_DIR + root = THIS_DIR.resolve() child = MagicMock() include = re.compile(black.DEFAULT_INCLUDES) exclude = re.compile(black.DEFAULT_EXCLUDES) @@ -1679,8 +1712,15 @@ class BlackTestCase(unittest.TestCase): child.is_symlink.return_value = True try: list( - black.gen_python_files_in_dir( - path, root, include, exclude, report, gitignore + black.gen_python_files( + path.iterdir(), + root, + include, + exclude, + None, + None, + report, + gitignore, ) ) except ValueError as ve: @@ -1693,8 +1733,15 @@ class BlackTestCase(unittest.TestCase): child.is_symlink.return_value = False with self.assertRaises(ValueError): list( - black.gen_python_files_in_dir( - path, root, include, exclude, report, gitignore + black.gen_python_files( + path.iterdir(), + root, + include, + exclude, + None, + None, + report, + gitignore, ) ) path.iterdir.assert_called() @@ -1741,14 +1788,6 @@ class BlackTestCase(unittest.TestCase): ): ff(THIS_FILE) - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - def test_blackd_main(self) -> None: - with patch("blackd.web.run_app"): - result = CliRunner().invoke(blackd.main, []) - if result.exception is not None: - raise result.exception - self.assertEqual(result.exit_code, 0) - def test_invalid_config_return_code(self) -> None: tmp_file = Path(black.dump_to_file()) try: @@ -1759,173 +1798,151 @@ class BlackTestCase(unittest.TestCase): finally: tmp_file.unlink() + def test_parse_pyproject_toml(self) -> None: + test_toml_file = THIS_DIR / "test.toml" + config = black.parse_pyproject_toml(str(test_toml_file)) + self.assertEqual(config["verbose"], 1) + self.assertEqual(config["check"], "no") + self.assertEqual(config["diff"], "y") + self.assertEqual(config["color"], True) + self.assertEqual(config["line_length"], 79) + self.assertEqual(config["target_version"], ["py36", "py37", "py38"]) + self.assertEqual(config["exclude"], r"\.pyi?$") + self.assertEqual(config["include"], r"\.py?$") + + def test_read_pyproject_toml(self) -> None: + test_toml_file = THIS_DIR / "test.toml" + fake_ctx = FakeContext() + black.read_pyproject_toml(fake_ctx, FakeParameter(), str(test_toml_file)) + config = fake_ctx.default_map + self.assertEqual(config["verbose"], "1") + self.assertEqual(config["check"], "no") + self.assertEqual(config["diff"], "y") + self.assertEqual(config["color"], "True") + self.assertEqual(config["line_length"], "79") + self.assertEqual(config["target_version"], ["py36", "py37", "py38"]) + self.assertEqual(config["exclude"], r"\.pyi?$") + self.assertEqual(config["include"], r"\.py?$") + + def test_find_project_root(self) -> None: + with TemporaryDirectory() as workspace: + root = Path(workspace) + test_dir = root / "test" + test_dir.mkdir() -class BlackDTestCase(AioHTTPTestCase): - async def get_application(self) -> web.Application: - return blackd.make_app() - - # TODO: remove these decorators once the below is released - # https://github.com/aio-libs/aiohttp/pull/3727 - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_request_needs_formatting(self) -> None: - response = await self.client.post("/", data=b"print('hello world')") - self.assertEqual(response.status, 200) - self.assertEqual(response.charset, "utf8") - self.assertEqual(await response.read(), b'print("hello world")\n') - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_request_no_change(self) -> None: - response = await self.client.post("/", data=b'print("hello world")\n') - self.assertEqual(response.status, 204) - self.assertEqual(await response.read(), b"") - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_request_syntax_error(self) -> None: - response = await self.client.post("/", data=b"what even ( is") - self.assertEqual(response.status, 400) - content = await response.text() - self.assertTrue( - content.startswith("Cannot parse"), - msg=f"Expected error to start with 'Cannot parse', got {repr(content)}", - ) + src_dir = root / "src" + src_dir.mkdir() - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_unsupported_version(self) -> None: - response = await self.client.post( - "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"} - ) - self.assertEqual(response.status, 501) - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_supported_version(self) -> None: - response = await self.client.post( - "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"} - ) - self.assertEqual(response.status, 200) - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_invalid_python_variant(self) -> None: - async def check(header_value: str, expected_status: int = 400) -> None: - response = await self.client.post( - "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: header_value} + root_pyproject = root / "pyproject.toml" + root_pyproject.touch() + src_pyproject = src_dir / "pyproject.toml" + src_pyproject.touch() + src_python = src_dir / "foo.py" + src_python.touch() + + self.assertEqual( + black.find_project_root((src_dir, test_dir)), root.resolve() ) - self.assertEqual(response.status, expected_status) - - await check("lol") - await check("ruby3.5") - await check("pyi3.6") - await check("py1.5") - await check("2.8") - await check("py2.8") - await check("3.0") - await check("pypy3.0") - await check("jython3.4") - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_pyi(self) -> None: - source, expected = read_data("stub.pyi") - response = await self.client.post( - "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"} - ) - self.assertEqual(response.status, 200) - self.assertEqual(await response.text(), expected) + self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve()) + self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve()) - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_diff(self) -> None: - diff_header = re.compile( - rf"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d" - ) + @patch("black.find_user_pyproject_toml", black.find_user_pyproject_toml.__wrapped__) + def test_find_user_pyproject_toml_linux(self) -> None: + if system() == "Windows": + return - source, _ = read_data("blackd_diff.py") - expected, _ = read_data("blackd_diff.diff") + # Test if XDG_CONFIG_HOME is checked + with TemporaryDirectory() as workspace: + tmp_user_config = Path(workspace) / "black" + with patch.dict("os.environ", {"XDG_CONFIG_HOME": workspace}): + self.assertEqual( + black.find_user_pyproject_toml(), tmp_user_config.resolve() + ) - response = await self.client.post( - "/", data=source, headers={blackd.DIFF_HEADER: "true"} - ) - self.assertEqual(response.status, 200) + # Test fallback for XDG_CONFIG_HOME + with patch.dict("os.environ"): + os.environ.pop("XDG_CONFIG_HOME", None) + fallback_user_config = Path("~/.config").expanduser() / "black" + self.assertEqual( + black.find_user_pyproject_toml(), fallback_user_config.resolve() + ) + + def test_find_user_pyproject_toml_windows(self) -> None: + if system() != "Windows": + return + + user_config_path = Path.home() / ".black" + self.assertEqual(black.find_user_pyproject_toml(), user_config_path.resolve()) + + def test_bpo_33660_workaround(self) -> None: + if system() == "Windows": + return - actual = await response.text() + # https://bugs.python.org/issue33660 + + old_cwd = Path.cwd() + try: + root = Path("/") + os.chdir(str(root)) + path = Path("workspace") / "project" + report = black.Report(verbose=True) + normalized_path = black.normalize_path_maybe_ignore(path, root, report) + self.assertEqual(normalized_path, "workspace/project") + finally: + os.chdir(str(old_cwd)) + + def test_newline_comment_interaction(self) -> None: + source = "class A:\\\r\n# type: ignore\n pass\n" + output = black.format_str(source, mode=DEFAULT_MODE) + black.assert_stable(source, output, mode=DEFAULT_MODE) + + def test_bpo_2142_workaround(self) -> None: + + # https://bugs.python.org/issue2142 + + source, _ = read_data("missing_final_newline.py") + # read_data adds a trailing newline + source = source.rstrip() + expected, _ = read_data("missing_final_newline.diff") + tmp_file = Path(black.dump_to_file(source, ensure_final_newline=False)) + diff_header = re.compile( + rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d " + r"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d" + ) + try: + result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)]) + self.assertEqual(result.exit_code, 0) + finally: + os.unlink(tmp_file) + actual = result.output actual = diff_header.sub(DETERMINISTIC_HEADER, actual) self.assertEqual(actual, expected) - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_python_variant(self) -> None: - code = ( - "def f(\n" - " and_has_a_bunch_of,\n" - " very_long_arguments_too,\n" - " and_lots_of_them_as_well_lol,\n" - " **and_very_long_keyword_arguments\n" - "):\n" - " pass\n" - ) - async def check(header_value: str, expected_status: int) -> None: - response = await self.client.post( - "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value} - ) - self.assertEqual( - response.status, expected_status, msg=await response.text() - ) +with open(black.__file__, "r", encoding="utf-8") as _bf: + black_source_lines = _bf.readlines() - await check("3.6", 200) - await check("py3.6", 200) - await check("3.6,3.7", 200) - await check("3.6,py3.7", 200) - await check("py36,py37", 200) - await check("36", 200) - await check("3.6.4", 200) - - await check("2", 204) - await check("2.7", 204) - await check("py2.7", 204) - await check("3.4", 204) - await check("py3.4", 204) - await check("py34,py36", 204) - await check("34", 204) - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_line_length(self) -> None: - response = await self.client.post( - "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"} - ) - self.assertEqual(response.status, 200) - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_invalid_line_length(self) -> None: - response = await self.client.post( - "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "NaN"} - ) - self.assertEqual(response.status, 400) - - @skip_if_exception("ClientOSError") - @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed") - @unittest_run_loop - async def test_blackd_response_black_version_header(self) -> None: - response = await self.client.post("/") - self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER)) + +def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable: + """Show function calls `from black/__init__.py` as they happen. + + Register this with `sys.settrace()` in a test you're debugging. + """ + if event != "call": + return tracefunc + + stack = len(inspect.stack()) - 19 + stack *= 2 + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_sig_lineno = lineno - 1 + funcname = black_source_lines[func_sig_lineno].strip() + while funcname.startswith("@"): + func_sig_lineno += 1 + funcname = black_source_lines[func_sig_lineno].strip() + if "black/__init__.py" in filename: + print(f"{' ' * stack}{lineno}:{funcname}") + return tracefunc if __name__ == "__main__":