]> git.madduck.net Git - etc/vim.git/blobdiff - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Run trailing comma tests with TargetVersion.PY38
[etc/vim.git] / tests / test_black.py
index 686232a7f9c910677ec2703d06026018c2fce6c8..16002c0b728c9c82bad30726ceb9820ff1a2ee8d 100644 (file)
@@ -5,13 +5,25 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import contextmanager
 from dataclasses import replace
 from functools import partial
 from contextlib import contextmanager
 from dataclasses import replace
 from functools import partial
+import inspect
 from io import BytesIO, TextIOWrapper
 import os
 from pathlib import Path
 import regex as re
 import sys
 from tempfile import TemporaryDirectory
 from io import BytesIO, TextIOWrapper
 import os
 from pathlib import Path
 import regex as re
 import sys
 from tempfile import TemporaryDirectory
-from typing import Any, BinaryIO, Dict, Generator, List, Tuple, Iterator, TypeVar
+import types
+from typing import (
+    Any,
+    BinaryIO,
+    Callable,
+    Dict,
+    Generator,
+    List,
+    Tuple,
+    Iterator,
+    TypeVar,
+)
 import unittest
 from unittest.mock import patch, MagicMock
 
 import unittest
 from unittest.mock import patch, MagicMock
 
@@ -153,6 +165,7 @@ class BlackRunner(CliRunner):
 
 class BlackTestCase(unittest.TestCase):
     maxDiff = None
 
 class BlackTestCase(unittest.TestCase):
     maxDiff = None
+    _diffThreshold = 2 ** 20
 
     def assertFormatEqual(self, expected: str, actual: str) -> None:
         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
 
     def assertFormatEqual(self, expected: str, actual: str) -> None:
         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
@@ -171,7 +184,7 @@ class BlackTestCase(unittest.TestCase):
                 list(bdv.visit(exp_node))
             except Exception as ve:
                 black.err(str(ve))
                 list(bdv.visit(exp_node))
             except Exception as ve:
                 black.err(str(ve))
-        self.assertEqual(expected, actual)
+        self.assertMultiLineEqual(expected, actual)
 
     def invokeBlack(
         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
 
     def invokeBlack(
         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
@@ -332,6 +345,16 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_function_trailing_comma_wip(self) -> None:
+        source, expected = read_data("function_trailing_comma_wip")
+        # sys.settrace(tracefunc)
+        actual = fs(source)
+        # sys.settrace(None)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, black.FileMode())
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_function_trailing_comma(self) -> None:
         source, expected = read_data("function_trailing_comma")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_function_trailing_comma(self) -> None:
         source, expected = read_data("function_trailing_comma")
@@ -570,7 +593,8 @@ class BlackTestCase(unittest.TestCase):
     @patch("black.dump_to_file", dump_to_stderr)
     def test_comments7(self) -> None:
         source, expected = read_data("comments7")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_comments7(self) -> None:
         source, expected = read_data("comments7")
-        actual = fs(source)
+        mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
+        actual = fs(source, mode=mode)
         self.assertFormatEqual(expected, actual)
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
         self.assertFormatEqual(expected, actual)
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
@@ -607,6 +631,15 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, DEFAULT_MODE)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_composition_no_trailing_comma(self) -> None:
+        source, expected = read_data("composition_no_trailing_comma")
+        mode = replace(DEFAULT_MODE, target_versions={black.TargetVersion.PY38})
+        actual = fs(source, mode=mode)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, DEFAULT_MODE)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_empty_lines(self) -> None:
         source, expected = read_data("empty_lines")
     @patch("black.dump_to_file", dump_to_stderr)
     def test_empty_lines(self) -> None:
         source, expected = read_data("empty_lines")
@@ -2039,5 +2072,30 @@ class BlackDTestCase(AioHTTPTestCase):
         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
 
 
         self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
 
 
+with open(black.__file__, "r", encoding="utf-8") as _bf:
+    black_source_lines = _bf.readlines()
+
+
+def tracefunc(frame: types.FrameType, event: str, arg: Any) -> Callable:
+    """Show function calls `from black/__init__.py` as they happen.
+
+    Register this with `sys.settrace()` in a test you're debugging.
+    """
+    if event != "call":
+        return tracefunc
+
+    stack = len(inspect.stack()) - 19
+    filename = frame.f_code.co_filename
+    lineno = frame.f_lineno
+    func_sig_lineno = lineno - 1
+    funcname = black_source_lines[func_sig_lineno].strip()
+    while funcname.startswith("@"):
+        func_sig_lineno += 1
+        funcname = black_source_lines[func_sig_lineno].strip()
+    if "black/__init__.py" in filename:
+        print(f"{' ' * stack}{lineno}:{funcname}")
+    return tracefunc
+
+
 if __name__ == "__main__":
     unittest.main(module="test_black")
 if __name__ == "__main__":
     unittest.main(module="test_black")