X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/591bedc2be0cec92c5f253fd473864c876233114..a9f50cd0b58259a11a1c851bde7b4f11321e5b3b:/tests/test_black.py

diff --git a/tests/test_black.py b/tests/test_black.py
index ec7a883..759bda5 100644
--- a/tests/test_black.py
+++ b/tests/test_black.py
@@ -17,6 +17,7 @@ ff = partial(black.format_file_in_place, line_length=ll, fast=True)
 fs = partial(black.format_str, line_length=ll)
 THIS_FILE = Path(__file__)
 THIS_DIR = THIS_FILE.parent
+EMPTY_LINE = '# EMPTY LINE WITH WHITESPACE' + ' (this comment will be removed)'
 
 
 def dump_to_stderr(*output: str) -> str:
@@ -25,7 +26,7 @@ def dump_to_stderr(*output: str) -> str:
 
 def read_data(name: str) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
-    if not name.endswith('.py'):
+    if not name.endswith(('.py', '.out')):
         name += '.py'
     _input: List[str] = []
     _output: List[str] = []
@@ -33,6 +34,7 @@ def read_data(name: str) -> Tuple[str, str]:
         lines = test.readlines()
     result = _input
     for line in lines:
+        line = line.replace(EMPTY_LINE, '')
         if line.rstrip() == '# output':
             result = _output
             continue
@@ -180,6 +182,22 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_python2(self) -> None:
+        source, expected = read_data('python2')
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        # black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_fmtonoff(self) -> None:
+        source, expected = read_data('fmtonoff')
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     def test_report(self) -> None:
         report = black.Report()
         out_lines = []
@@ -205,7 +223,10 @@ class BlackTestCase(unittest.TestCase):
             self.assertEqual(
                 unstyle(str(report)), '1 file reformatted, 1 file left unchanged.'
             )
+            self.assertEqual(report.return_code, 0)
+            report.check = True
             self.assertEqual(report.return_code, 1)
+            report.check = False
             report.failed(Path('e1'), 'boom')
             self.assertEqual(len(out_lines), 2)
             self.assertEqual(len(err_lines), 1)
@@ -246,6 +267,12 @@ class BlackTestCase(unittest.TestCase):
                 '2 files failed to reformat.',
             )
             self.assertEqual(report.return_code, 123)
+            report.check = True
+            self.assertEqual(
+                unstyle(str(report)),
+                '2 files would be reformatted, 2 files would be left unchanged, '
+                '2 files would fail to reformat.',
+            )
 
     def test_is_python36(self) -> None:
         node = black.lib2to3_parse("def f(*, arg): ...\n")
@@ -265,6 +292,30 @@ class BlackTestCase(unittest.TestCase):
         node = black.lib2to3_parse(expected)
         self.assertFalse(black.is_python36(node))
 
+    def test_debug_visitor(self) -> None:
+        source, _ = read_data('debug_visitor.py')
+        expected, _ = read_data('debug_visitor.out')
+        out_lines = []
+        err_lines = []
+
+        def out(msg: str, **kwargs: Any) -> None:
+            out_lines.append(msg)
+
+        def err(msg: str, **kwargs: Any) -> None:
+            err_lines.append(msg)
+
+        with patch("black.out", out), patch("black.err", err):
+            black.DebugVisitor.show(source)
+        actual = '\n'.join(out_lines) + '\n'
+        log_name = ''
+        if expected != actual:
+            log_name = black.dump_to_file(*out_lines)
+        self.assertEqual(
+            expected,
+            actual,
+            f"AST print out is different. Actual version dumped to {log_name}",
+        )
+
 
 if __name__ == '__main__':
     unittest.main()