X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/cb5aadad74c0a1c9c514a633c632c99b668c70ed..fef8c71cb708b5b94047e69baeca3264d695ac66:/tests/test_black.py?ds=sidebyside

diff --git a/tests/test_black.py b/tests/test_black.py
index f71f9b3..82e3f5a 100644
--- a/tests/test_black.py
+++ b/tests/test_black.py
@@ -1,14 +1,19 @@
 #!/usr/bin/env python3
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager
 from functools import partial
 from io import StringIO
 import os
 from pathlib import Path
 import sys
-from typing import Any, List, Tuple
+from tempfile import TemporaryDirectory
+from typing import Any, List, Tuple, Iterator
 import unittest
 from unittest.mock import patch
 
 from click import unstyle
+from click.testing import CliRunner
 
 import black
 
@@ -26,7 +31,7 @@ def dump_to_stderr(*output: str) -> str:
 
 def read_data(name: str) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
-    if not name.endswith((".py", ".out", ".diff")):
+    if not name.endswith((".py", ".pyi", ".out", ".diff")):
         name += ".py"
     _input: List[str] = []
     _output: List[str] = []
@@ -46,6 +51,31 @@ def read_data(name: str) -> Tuple[str, str]:
     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
 
 
+@contextmanager
+def cache_dir(exists: bool = True) -> Iterator[Path]:
+    with TemporaryDirectory() as workspace:
+        cache_dir = Path(workspace)
+        if not exists:
+            cache_dir = cache_dir / "new"
+        with patch("black.CACHE_DIR", cache_dir):
+            yield cache_dir
+
+
+@contextmanager
+def event_loop(close: bool) -> Iterator[None]:
+    policy = asyncio.get_event_loop_policy()
+    old_loop = policy.get_event_loop()
+    loop = policy.new_event_loop()
+    asyncio.set_event_loop(loop)
+    try:
+        yield
+
+    finally:
+        policy.set_event_loop(old_loop)
+        if close:
+            loop.close()
+
+
 class BlackTestCase(unittest.TestCase):
     maxDiff = None
 
@@ -137,6 +167,14 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_function2(self) -> None:
+        source, expected = read_data("function2")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_expression(self) -> None:
         source, expected = read_data("expression")
@@ -150,7 +188,7 @@ class BlackTestCase(unittest.TestCase):
         tmp_file = Path(black.dump_to_file(source))
         try:
             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
-            with open(tmp_file) as f:
+            with open(tmp_file, encoding="utf8") as f:
                 actual = f.read()
         finally:
             os.unlink(tmp_file)
@@ -169,7 +207,7 @@ class BlackTestCase(unittest.TestCase):
             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
             sys.stdout.seek(0)
             actual = sys.stdout.read()
-            actual = actual.replace(tmp_file.name, "<stdin>")
+            actual = actual.replace(str(tmp_file), "<stdin>")
         finally:
             sys.stdout = hold_stdout
             os.unlink(tmp_file)
@@ -179,7 +217,7 @@ class BlackTestCase(unittest.TestCase):
             msg = (
                 f"Expected diff isn't equal to the actual. If you made changes "
                 f"to expression.py and this is an anticipated difference, "
-                f"overwrite tests/expression.diff with {dump}."
+                f"overwrite tests/expression.diff with {dump}"
             )
             self.assertEqual(expected, actual, msg)
 
@@ -199,6 +237,14 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_slices(self) -> None:
+        source, expected = read_data("slices")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_comments(self) -> None:
         source, expected = read_data("comments")
@@ -231,6 +277,14 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_comments5(self) -> None:
+        source, expected = read_data("comments5")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_cantfit(self) -> None:
         source, expected = read_data("cantfit")
@@ -263,6 +317,14 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_string_prefixes(self) -> None:
+        source, expected = read_data("string_prefixes")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_python2(self) -> None:
         source, expected = read_data("python2")
@@ -271,6 +333,20 @@ class BlackTestCase(unittest.TestCase):
         # black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_python2_unicode_literals(self) -> None:
+        source, expected = read_data("python2_unicode_literals")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_stub(self) -> None:
+        source, expected = read_data("stub.pyi")
+        actual = fs(source, is_pyi=True)
+        self.assertFormatEqual(expected, actual)
+        black.assert_stable(source, actual, line_length=ll, is_pyi=True)
+
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fmtonoff(self) -> None:
         source, expected = read_data("fmtonoff")
@@ -279,6 +355,14 @@ class BlackTestCase(unittest.TestCase):
         black.assert_equivalent(source, actual)
         black.assert_stable(source, actual, line_length=ll)
 
+    @patch("black.dump_to_file", dump_to_stderr)
+    def test_remove_empty_parentheses_after_class(self) -> None:
+        source, expected = read_data("class_blank_parentheses")
+        actual = fs(source)
+        self.assertFormatEqual(expected, actual)
+        black.assert_equivalent(source, actual)
+        black.assert_stable(source, actual, line_length=ll)
+
     def test_report(self) -> None:
         report = black.Report()
         out_lines = []
@@ -291,67 +375,76 @@ class BlackTestCase(unittest.TestCase):
             err_lines.append(msg)
 
         with patch("black.out", out), patch("black.err", err):
-            report.done(Path("f1"), changed=False)
+            report.done(Path("f1"), black.Changed.NO)
             self.assertEqual(len(out_lines), 1)
             self.assertEqual(len(err_lines), 0)
             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
             self.assertEqual(report.return_code, 0)
-            report.done(Path("f2"), changed=True)
+            report.done(Path("f2"), black.Changed.YES)
             self.assertEqual(len(out_lines), 2)
             self.assertEqual(len(err_lines), 0)
             self.assertEqual(out_lines[-1], "reformatted f2")
             self.assertEqual(
                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
             )
+            report.done(Path("f3"), black.Changed.CACHED)
+            self.assertEqual(len(out_lines), 3)
+            self.assertEqual(len(err_lines), 0)
+            self.assertEqual(
+                out_lines[-1], "f3 wasn't modified on disk since last run."
+            )
+            self.assertEqual(
+                unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
+            )
             self.assertEqual(report.return_code, 0)
             report.check = True
             self.assertEqual(report.return_code, 1)
             report.check = False
             report.failed(Path("e1"), "boom")
-            self.assertEqual(len(out_lines), 2)
+            self.assertEqual(len(out_lines), 3)
             self.assertEqual(len(err_lines), 1)
             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "1 file reformatted, 1 file left unchanged, "
+                "1 file reformatted, 2 files left unchanged, "
                 "1 file failed to reformat.",
             )
             self.assertEqual(report.return_code, 123)
-            report.done(Path("f3"), changed=True)
-            self.assertEqual(len(out_lines), 3)
+            report.done(Path("f3"), black.Changed.YES)
+            self.assertEqual(len(out_lines), 4)
             self.assertEqual(len(err_lines), 1)
             self.assertEqual(out_lines[-1], "reformatted f3")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 1 file left unchanged, "
+                "2 files reformatted, 2 files left unchanged, "
                 "1 file failed to reformat.",
             )
             self.assertEqual(report.return_code, 123)
             report.failed(Path("e2"), "boom")
-            self.assertEqual(len(out_lines), 3)
+            self.assertEqual(len(out_lines), 4)
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 1 file left unchanged, "
+                "2 files reformatted, 2 files left unchanged, "
                 "2 files failed to reformat.",
             )
             self.assertEqual(report.return_code, 123)
-            report.done(Path("f4"), changed=False)
-            self.assertEqual(len(out_lines), 4)
+            report.done(Path("f4"), black.Changed.NO)
+            self.assertEqual(len(out_lines), 5)
             self.assertEqual(len(err_lines), 2)
             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files reformatted, 2 files left unchanged, "
+                "2 files reformatted, 3 files left unchanged, "
                 "2 files failed to reformat.",
             )
             self.assertEqual(report.return_code, 123)
             report.check = True
             self.assertEqual(
                 unstyle(str(report)),
-                "2 files would be reformatted, 2 files would be left unchanged, "
+                "2 files would be reformatted, 3 files would be left unchanged, "
                 "2 files would fail to reformat.",
             )
 
@@ -373,6 +466,28 @@ class BlackTestCase(unittest.TestCase):
         node = black.lib2to3_parse(expected)
         self.assertFalse(black.is_python36(node))
 
+    def test_get_future_imports(self) -> None:
+        node = black.lib2to3_parse("\n")
+        self.assertEqual(set(), black.get_future_imports(node))
+        node = black.lib2to3_parse("from __future__ import black\n")
+        self.assertEqual({"black"}, black.get_future_imports(node))
+        node = black.lib2to3_parse("from __future__ import multiple, imports\n")
+        self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
+        node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
+        self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
+        node = black.lib2to3_parse(
+            "from __future__ import multiple\nfrom __future__ import imports\n"
+        )
+        self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
+        node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
+        self.assertEqual({"black"}, black.get_future_imports(node))
+        node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
+        self.assertEqual({"black"}, black.get_future_imports(node))
+        node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
+        self.assertEqual(set(), black.get_future_imports(node))
+        node = black.lib2to3_parse("from some.module import black\n")
+        self.assertEqual(set(), black.get_future_imports(node))
+
     def test_debug_visitor(self) -> None:
         source, _ = read_data("debug_visitor.py")
         expected, _ = read_data("debug_visitor.out")
@@ -442,6 +557,168 @@ class BlackTestCase(unittest.TestCase):
         self.assertTrue("Actual tree:" in out_str)
         self.assertEqual("".join(err_lines), "")
 
+    def test_cache_broken_file(self) -> None:
+        with cache_dir() as workspace:
+            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
+            with cache_file.open("w") as fobj:
+                fobj.write("this is not a pickle")
+            self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
+            src = (workspace / "test.py").resolve()
+            with src.open("w") as fobj:
+                fobj.write("print('hello')")
+            result = CliRunner().invoke(black.main, [str(src)])
+            self.assertEqual(result.exit_code, 0)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            self.assertIn(src, cache)
+
+    def test_cache_single_file_already_cached(self) -> None:
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            with src.open("w") as fobj:
+                fobj.write("print('hello')")
+            black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
+            result = CliRunner().invoke(black.main, [str(src)])
+            self.assertEqual(result.exit_code, 0)
+            with src.open("r") as fobj:
+                self.assertEqual(fobj.read(), "print('hello')")
+
+    @event_loop(close=False)
+    def test_cache_multiple_files(self) -> None:
+        with cache_dir() as workspace, patch(
+            "black.ProcessPoolExecutor", new=ThreadPoolExecutor
+        ):
+            one = (workspace / "one.py").resolve()
+            with one.open("w") as fobj:
+                fobj.write("print('hello')")
+            two = (workspace / "two.py").resolve()
+            with two.open("w") as fobj:
+                fobj.write("print('hello')")
+            black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
+            result = CliRunner().invoke(black.main, [str(workspace)])
+            self.assertEqual(result.exit_code, 0)
+            with one.open("r") as fobj:
+                self.assertEqual(fobj.read(), "print('hello')")
+            with two.open("r") as fobj:
+                self.assertEqual(fobj.read(), 'print("hello")\n')
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            self.assertIn(one, cache)
+            self.assertIn(two, cache)
+
+    def test_no_cache_when_writeback_diff(self) -> None:
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            with src.open("w") as fobj:
+                fobj.write("print('hello')")
+            result = CliRunner().invoke(black.main, [str(src), "--diff"])
+            self.assertEqual(result.exit_code, 0)
+            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
+            self.assertFalse(cache_file.exists())
+
+    def test_no_cache_when_stdin(self) -> None:
+        with cache_dir():
+            result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
+            self.assertEqual(result.exit_code, 0)
+            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
+            self.assertFalse(cache_file.exists())
+
+    def test_read_cache_no_cachefile(self) -> None:
+        with cache_dir():
+            self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
+
+    def test_write_cache_read_cache(self) -> None:
+        with cache_dir() as workspace:
+            src = (workspace / "test.py").resolve()
+            src.touch()
+            black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            self.assertIn(src, cache)
+            self.assertEqual(cache[src], black.get_cache_info(src))
+
+    def test_filter_cached(self) -> None:
+        with TemporaryDirectory() as workspace:
+            path = Path(workspace)
+            uncached = (path / "uncached").resolve()
+            cached = (path / "cached").resolve()
+            cached_but_changed = (path / "changed").resolve()
+            uncached.touch()
+            cached.touch()
+            cached_but_changed.touch()
+            cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
+            todo, done = black.filter_cached(
+                cache, [uncached, cached, cached_but_changed]
+            )
+            self.assertEqual(todo, [uncached, cached_but_changed])
+            self.assertEqual(done, [cached])
+
+    def test_write_cache_creates_directory_if_needed(self) -> None:
+        with cache_dir(exists=False) as workspace:
+            self.assertFalse(workspace.exists())
+            black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
+            self.assertTrue(workspace.exists())
+
+    @event_loop(close=False)
+    def test_failed_formatting_does_not_get_cached(self) -> None:
+        with cache_dir() as workspace, patch(
+            "black.ProcessPoolExecutor", new=ThreadPoolExecutor
+        ):
+            failing = (workspace / "failing.py").resolve()
+            with failing.open("w") as fobj:
+                fobj.write("not actually python")
+            clean = (workspace / "clean.py").resolve()
+            with clean.open("w") as fobj:
+                fobj.write('print("hello")\n')
+            result = CliRunner().invoke(black.main, [str(workspace)])
+            self.assertEqual(result.exit_code, 123)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            self.assertNotIn(failing, cache)
+            self.assertIn(clean, cache)
+
+    def test_write_cache_write_fail(self) -> None:
+        with cache_dir(), patch.object(Path, "open") as mock:
+            mock.side_effect = OSError
+            black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
+
+    def test_check_diff_use_together(self) -> None:
+        with cache_dir():
+            # Files which will be reformatted.
+            src1 = (THIS_DIR / "string_quotes.py").resolve()
+            result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
+            self.assertEqual(result.exit_code, 1)
+
+            # Files which will not be reformatted.
+            src2 = (THIS_DIR / "composition.py").resolve()
+            result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
+            self.assertEqual(result.exit_code, 0)
+
+            # Multi file command.
+            result = CliRunner().invoke(
+                black.main, [str(src1), str(src2), "--diff", "--check"]
+            )
+            self.assertEqual(result.exit_code, 1)
+
+    def test_no_files(self) -> None:
+        with cache_dir():
+            # Without an argument, black exits with error code 0.
+            result = CliRunner().invoke(black.main, [])
+            self.assertEqual(result.exit_code, 0)
+
+    def test_broken_symlink(self) -> None:
+        with cache_dir() as workspace:
+            symlink = workspace / "broken_link.py"
+            symlink.symlink_to("nonexistent.py")
+            result = CliRunner().invoke(black.main, [str(workspace.resolve())])
+            self.assertEqual(result.exit_code, 0)
+
+    def test_read_cache_line_lengths(self) -> None:
+        with cache_dir() as workspace:
+            path = (workspace / "file.py").resolve()
+            path.touch()
+            black.write_cache({}, [path], 1)
+            one = black.read_cache(1)
+            self.assertIn(path, one)
+            two = black.read_cache(2)
+            self.assertNotIn(path, two)
+
 
 if __name__ == "__main__":
     unittest.main()