#!/usr/bin/env python3
+import multiprocessing
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import replace
from functools import partial
+import inspect
from io import BytesIO, TextIOWrapper
import os
from pathlib import Path
+from platform import system
import regex as re
import sys
from tempfile import TemporaryDirectory
-from typing import Any, BinaryIO, Dict, Generator, List, Tuple, Iterator, TypeVar
+import types
+from typing import (
+ Any,
+ BinaryIO,
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Tuple,
+ Iterator,
+ TypeVar,
+)
import unittest
from unittest.mock import patch, MagicMock
PROJECT_ROOT = THIS_DIR.parent
DETERMINISTIC_HEADER = "[Deterministic header]"
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
-PY36_ARGS = [
- f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
-]
+PY36_VERSIONS = {
+ TargetVersion.PY36,
+ TargetVersion.PY37,
+ TargetVersion.PY38,
+ TargetVersion.PY39,
+}
+PY36_ARGS = [f"--target-version={version.name.lower()}" for version in PY36_VERSIONS]
T = TypeVar("T")
R = TypeVar("R")
class BlackTestCase(unittest.TestCase):
maxDiff = None
+ _diffThreshold = 2 ** 20
def assertFormatEqual(self, expected: str, actual: str) -> None:
if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
list(bdv.visit(exp_node))
except Exception as ve:
black.err(str(ve))
- self.assertEqual(expected, actual)
+ self.assertMultiLineEqual(expected, actual)
def invokeBlack(
self, args: List[str], exit_code: int = 0, ignore_config: bool = True
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, DEFAULT_MODE)
+ @patch("black.dump_to_file", dump_to_stderr)
+ def _test_wip(self) -> None:
+ source, expected = read_data("wip")
+ sys.settrace(tracefunc)
+ mode = replace(
+ DEFAULT_MODE,
+ experimental_string_processing=False,
+ target_versions={black.TargetVersion.PY38},
+ )
+ actual = fs(source, mode=mode)
+ sys.settrace(None)
+ self.assertFormatEqual(expected, actual)
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, black.FileMode())
+
@patch("black.dump_to_file", dump_to_stderr)
def test_function_trailing_comma(self) -> None:
source, expected = read_data("function_trailing_comma")
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, DEFAULT_MODE)
+ @unittest.expectedFailure
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_trailing_comma_optional_parens_stability1(self) -> None:
+ source, _expected = read_data("trailing_comma_optional_parens1")
+ actual = fs(source)
+ black.assert_stable(source, actual, DEFAULT_MODE)
+
+ @unittest.expectedFailure
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_trailing_comma_optional_parens_stability2(self) -> None:
+ source, _expected = read_data("trailing_comma_optional_parens2")
+ actual = fs(source)
+ black.assert_stable(source, actual, DEFAULT_MODE)
+
+ @unittest.expectedFailure
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_trailing_comma_optional_parens_stability3(self) -> None:
+ source, _expected = read_data("trailing_comma_optional_parens3")
+ actual = fs(source)
+ black.assert_stable(source, actual, DEFAULT_MODE)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_expression(self) -> None:
source, expected = read_data("expression")
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, DEFAULT_MODE)
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_docstring_no_string_normalization(self) -> None:
+ """Like test_docstring but with string normalization off."""
+ source, expected = read_data("docstring_no_string_normalization")
+ mode = replace(DEFAULT_MODE, string_normalization=False)
+ actual = fs(source, mode=mode)
+ self.assertFormatEqual(expected, actual)
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, mode)
+
def test_long_strings(self) -> None:
"""Tests for splitting long strings."""
source, expected = read_data("long_strings")
@patch("black.dump_to_file", dump_to_stderr)
def test_comments7(self) -> None:
source, expected = read_data("comments7")
- actual = fs(source)
+ mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
+ actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, DEFAULT_MODE)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, DEFAULT_MODE)
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_composition_no_trailing_comma(self) -> None:
+ source, expected = read_data("composition_no_trailing_comma")
+ mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
+ actual = fs(source, mode=mode)
+ self.assertFormatEqual(expected, actual)
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, DEFAULT_MODE)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_empty_lines(self) -> None:
source, expected = read_data("empty_lines")
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals(self) -> None:
source, expected = read_data("numeric_literals")
- mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+ mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
@patch("black.dump_to_file", dump_to_stderr)
def test_numeric_literals_ignoring_underscores(self) -> None:
source, expected = read_data("numeric_literals_skip_underscores")
- mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+ mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
actual = fs(source, mode=mode)
self.assertFormatEqual(expected, actual)
black.assert_equivalent(source, actual)
black.assert_equivalent(source, actual)
black.assert_stable(source, actual, DEFAULT_MODE)
+ @patch("black.dump_to_file", dump_to_stderr)
+ def test_python39(self) -> None:
+ source, expected = read_data("python39")
+ actual = fs(source)
+ self.assertFormatEqual(expected, actual)
+ major, minor = sys.version_info[:2]
+ if major > 3 or (major == 3 and minor >= 9):
+ black.assert_equivalent(source, actual)
+ black.assert_stable(source, actual, DEFAULT_MODE)
+
@patch("black.dump_to_file", dump_to_stderr)
def test_fmtonoff(self) -> None:
source, expected = read_data("fmtonoff")
black.lib2to3_parse(py3_only, {TargetVersion.PY36})
black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
+ def test_get_features_used_decorator(self) -> None:
+ # Test the feature detection of new decorator syntax
+ # since this makes some test cases of test_get_features_used()
+ # fails if it fails, this is tested first so that a useful case
+ # is identified
+ simples, relaxed = read_data("decorators")
+ # skip explanation comments at the top of the file
+ for simple_test in simples.split("##")[1:]:
+ node = black.lib2to3_parse(simple_test)
+ decorator = str(node.children[0].children[0]).strip()
+ self.assertNotIn(
+ Feature.RELAXED_DECORATORS,
+ black.get_features_used(node),
+ msg=(
+ f"decorator '{decorator}' follows python<=3.8 syntax"
+ "but is detected as 3.9+"
+ # f"The full node is\n{node!r}"
+ ),
+ )
+ # skip the '# output' comment at the top of the output part
+ for relaxed_test in relaxed.split("##")[1:]:
+ node = black.lib2to3_parse(relaxed_test)
+ decorator = str(node.children[0].children[0]).strip()
+ self.assertIn(
+ Feature.RELAXED_DECORATORS,
+ black.get_features_used(node),
+ msg=(
+ f"decorator '{decorator}' uses python3.9+ syntax"
+ "but is detected as python<=3.8"
+ # f"The full node is\n{node!r}"
+ ),
+ )
+
def test_get_features_used(self) -> None:
node = black.lib2to3_parse("def f(*, arg): ...\n")
self.assertEqual(black.get_features_used(node), set())
src = (workspace / "test.py").resolve()
with src.open("w") as fobj:
fobj.write("print('hello')")
- self.invokeBlack([str(src), "--diff"])
- cache_file = black.get_cache_file(mode)
- self.assertFalse(cache_file.exists())
+ with patch("black.read_cache") as read_cache, patch(
+ "black.write_cache"
+ ) as write_cache:
+ self.invokeBlack([str(src), "--diff"])
+ cache_file = black.get_cache_file(mode)
+ self.assertFalse(cache_file.exists())
+ write_cache.assert_not_called()
+ read_cache.assert_not_called()
+
+ def test_no_cache_when_writeback_color_diff(self) -> None:
+ mode = DEFAULT_MODE
+ with cache_dir() as workspace:
+ src = (workspace / "test.py").resolve()
+ with src.open("w") as fobj:
+ fobj.write("print('hello')")
+ with patch("black.read_cache") as read_cache, patch(
+ "black.write_cache"
+ ) as write_cache:
+ self.invokeBlack([str(src), "--diff", "--color"])
+ cache_file = black.get_cache_file(mode)
+ self.assertFalse(cache_file.exists())
+ write_cache.assert_not_called()
+ read_cache.assert_not_called()
+
+ @event_loop()
+ def test_output_locking_when_writeback_diff(self) -> None:
+ with cache_dir() as workspace:
+ for tag in range(0, 4):
+ src = (workspace / f"test{tag}.py").resolve()
+ with src.open("w") as fobj:
+ fobj.write("print('hello')")
+ with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
+ self.invokeBlack(["--diff", str(workspace)], exit_code=0)
+ # this isn't quite doing what we want, but if it _isn't_
+ # called then we cannot be using the lock it provides
+ mgr.assert_called()
+
+ @event_loop()
+ def test_output_locking_when_writeback_color_diff(self) -> None:
+ with cache_dir() as workspace:
+ for tag in range(0, 4):
+ src = (workspace / f"test{tag}.py").resolve()
+ with src.open("w") as fobj:
+ fobj.write("print('hello')")
+ with patch("black.Manager", wraps=multiprocessing.Manager) as mgr:
+ self.invokeBlack(["--diff", "--color", str(workspace)], exit_code=0)
+ # this isn't quite doing what we want, but if it _isn't_
+ # called then we cannot be using the lock it provides
+ mgr.assert_called()
def test_no_cache_when_stdin(self) -> None:
mode = DEFAULT_MODE
black.assert_stable(source, actual, DEFAULT_MODE)
def test_single_file_force_pyi(self) -> None:
- reg_mode = DEFAULT_MODE
pyi_mode = replace(DEFAULT_MODE, is_pyi=True)
contents, expected = read_data("force_pyi")
with cache_dir() as workspace:
# verify cache with --pyi is separate
pyi_cache = black.read_cache(pyi_mode)
self.assertIn(path, pyi_cache)
- normal_cache = black.read_cache(reg_mode)
+ normal_cache = black.read_cache(DEFAULT_MODE)
self.assertNotIn(path, normal_cache)
- self.assertEqual(actual, expected)
+ self.assertFormatEqual(expected, actual)
+ black.assert_equivalent(contents, actual)
+ black.assert_stable(contents, actual, pyi_mode)
@event_loop()
def test_multi_file_force_pyi(self) -> None:
def test_single_file_force_py36(self) -> None:
reg_mode = DEFAULT_MODE
- py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+ py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
source, expected = read_data("force_py36")
with cache_dir() as workspace:
path = (workspace / "file.py").resolve()
@event_loop()
def test_multi_file_force_py36(self) -> None:
reg_mode = DEFAULT_MODE
- py36_mode = replace(DEFAULT_MODE, target_versions=black.PY36_VERSIONS)
+ py36_mode = replace(DEFAULT_MODE, target_versions=PY36_VERSIONS)
source, expected = read_data("force_py36")
with cache_dir() as workspace:
paths = [
self.assertEqual(black.find_project_root((src_dir,)), src_dir.resolve())
self.assertEqual(black.find_project_root((src_python,)), src_dir.resolve())
+ def test_bpo_33660_workaround(self) -> None:
+ if system() == "Windows":
+ return
+
+ # https://bugs.python.org/issue33660
+
+ old_cwd = Path.cwd()
+ try:
+ root = Path("/")
+ os.chdir(str(root))
+ path = Path("workspace") / "project"
+ report = black.Report(verbose=True)
+ normalized_path = black.normalize_path_maybe_ignore(path, root, report)
+ self.assertEqual(normalized_path, "workspace/project")
+ finally:
+ os.chdir(str(old_cwd))
+
class BlackDTestCase(AioHTTPTestCase):
async def get_application(self) -> web.Application:
self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
+with open(black.__file__, "r", encoding="utf-8") as _bf:
+ black_source_lines = _bf.readlines()
+
+
+def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
+ """Show function calls `from black/__init__.py` as they happen.
+
+ Register this with `sys.settrace()` in a test you're debugging.
+ """
+ if event != "call":
+ return tracefunc
+
+ stack = len(inspect.stack()) - 19
+ stack *= 2
+ filename = frame.f_code.co_filename
+ lineno = frame.f_lineno
+ func_sig_lineno = lineno - 1
+ funcname = black_source_lines[func_sig_lineno].strip()
+ while funcname.startswith("@"):
+ func_sig_lineno += 1
+ funcname = black_source_lines[func_sig_lineno].strip()
+ if "black/__init__.py" in filename:
+ print(f"{' ' * stack}{lineno}:{funcname}")
+ return tracefunc
+
+
if __name__ == "__main__":
unittest.main(module="test_black")