X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/3500e1cda5bef73ddc7eaf79be6c67c918738936..bf7a16254ec96b084a6caf3d435ec18f0f245cc7:/tests/util.py

diff --git a/tests/util.py b/tests/util.py
index 84e98bb..967d576 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -2,6 +2,7 @@ import os
 import sys
 import unittest
 from contextlib import contextmanager
+from dataclasses import replace
 from functools import partial
 from pathlib import Path
 from typing import Any, Iterator, List, Optional, Tuple
@@ -9,7 +10,10 @@ from typing import Any, Iterator, List, Optional, Tuple
 import black
 from black.debug import DebugVisitor
 from black.mode import TargetVersion
-from black.output import err, out
+from black.output import diff, err, out
+
+PYTHON_SUFFIX = ".py"
+ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
 
 THIS_DIR = Path(__file__).parent
 DATA_DIR = THIS_DIR / "data"
@@ -47,9 +51,16 @@ def _assert_format_equal(expected: str, actual: str) -> None:
         except Exception as ve:
             err(str(ve))
 
+    if actual != expected:
+        out(diff(expected, actual, "expected", "actual"))
+
     assert actual == expected
 
 
+class FormatFailure(Exception):
+    """Used to wrap failures when assert_format() runs in an extra mode."""
+
+
 def assert_format(
     source: str,
     expected: str,
@@ -64,12 +75,57 @@ def assert_format(
     safety guards so they don't just crash with a SyntaxError. Please note this is
     separate from TargetVerson Mode configuration.
     """
+    _assert_format_inner(
+        source, expected, mode, fast=fast, minimum_version=minimum_version
+    )
+
+    # For both preview and non-preview tests, ensure that Black doesn't crash on
+    # this code, but don't pass "expected" because the precise output may differ.
+    try:
+        _assert_format_inner(
+            source,
+            None,
+            replace(mode, preview=not mode.preview),
+            fast=fast,
+            minimum_version=minimum_version,
+        )
+    except Exception as e:
+        text = "non-preview" if mode.preview else "preview"
+        raise FormatFailure(
+            f"Black crashed formatting this case in {text} mode."
+        ) from e
+    # Similarly, setting line length to 1 is a good way to catch
+    # stability bugs. But only in non-preview mode because preview mode
+    # currently has a lot of line length 1 bugs.
+    try:
+        _assert_format_inner(
+            source,
+            None,
+            replace(mode, preview=False, line_length=1),
+            fast=fast,
+            minimum_version=minimum_version,
+        )
+    except Exception as e:
+        raise FormatFailure(
+            "Black crashed formatting this case with line-length set to 1."
+        ) from e
+
+
+def _assert_format_inner(
+    source: str,
+    expected: Optional[str] = None,
+    mode: black.Mode = DEFAULT_MODE,
+    *,
+    fast: bool = False,
+    minimum_version: Optional[Tuple[int, int]] = None,
+) -> None:
     actual = black.format_str(source, mode=mode)
-    _assert_format_equal(expected, actual)
+    if expected is not None:
+        _assert_format_equal(expected, actual)
     # It's not useful to run safety checks if we're expecting no changes anyway. The
     # assertion right above will raise if reality does actually make changes. This just
     # avoids wasted CPU cycles.
-    if not fast and source != expected:
+    if not fast and source != actual:
         # Unfortunately the AST equivalence check relies on the built-in ast module
         # being able to parse the code being formatted. This doesn't always work out
         # when checking modern code on older versions.
@@ -87,12 +143,30 @@ class BlackBaseTestCase(unittest.TestCase):
         _assert_format_equal(expected, actual)
 
 
-def read_data(name: str, data: bool = True) -> Tuple[str, str]:
+def get_base_dir(data: bool) -> Path:
+    return DATA_DIR if data else PROJECT_ROOT
+
+
+def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
+    cases_dir = get_base_dir(data) / subdir_name
+    assert cases_dir.is_dir()
+    return [case_path.stem for case_path in cases_dir.iterdir()]
+
+
+def get_case_path(
+    subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
+) -> Path:
+    """Get case path from name"""
+    case_path = get_base_dir(data) / subdir_name / name
+    if not name.endswith(ALLOWED_SUFFIXES):
+        case_path = case_path.with_suffix(suffix)
+    assert case_path.is_file(), f"{case_path} is not a file."
+    return case_path
+
+
+def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
-    if not name.endswith((".py", ".pyi", ".out", ".diff")):
-        name += ".py"
-    base_dir = DATA_DIR if data else PROJECT_ROOT
-    return read_data_from_file(base_dir / name)
+    return read_data_from_file(get_case_path(subdir_name, name, data))
 
 
 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
@@ -114,6 +188,18 @@ def read_data_from_file(file_name: Path) -> Tuple[str, str]:
     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
 
 
+def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
+    return read_jupyter_notebook_from_file(
+        get_case_path(subdir_name, name, data, suffix=".ipynb")
+    )
+
+
+def read_jupyter_notebook_from_file(file_name: Path) -> str:
+    with open(file_name, mode="rb") as fd:
+        content_bytes = fd.read()
+    return content_bytes.decode()
+
+
 @contextmanager
 def change_directory(path: Path) -> Iterator[None]:
     """Context manager to temporarily chdir to a different directory."""