import asyncio
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
-from functools import partial
+from functools import partial, wraps
from io import BytesIO, TextIOWrapper
import os
from pathlib import Path
import re
import sys
from tempfile import TemporaryDirectory
-from typing import Any, BinaryIO, Generator, List, Tuple, Iterator
+from typing import (
+ Any,
+ BinaryIO,
+ Callable,
+ Coroutine,
+ Generator,
+ List,
+ Tuple,
+ Iterator,
+ TypeVar,
+)
import unittest
from unittest.mock import patch, MagicMock
import black
+try:
+ import blackd
+ from aiohttp.test_utils import TestClient, TestServer
+except ImportError:
+ has_blackd_deps = False
+else:
+ has_blackd_deps = True
+
ll = 88
ff = partial(black.format_file_in_place, line_length=ll, fast=True)
THIS_FILE = Path(__file__)
THIS_DIR = THIS_FILE.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
+T = TypeVar("T")
+R = TypeVar("R")
def dump_to_stderr(*output: str) -> str:
loop.close()
+def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
+ @event_loop(close=True)
+ @wraps(f)
+ def wrapper(*args: Any, **kwargs: Any) -> None:
+ asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
+
+ return wrapper
+
+
class BlackRunner(CliRunner):
"""Modify CliRunner so that stderr is not merged with stdout.
sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
yield output
finally:
- self.stdout_bytes = sys.stdout.buffer.getvalue()
- self.stderr_bytes = sys.stderr.buffer.getvalue()
+ self.stdout_bytes = sys.stdout.buffer.getvalue() # type: ignore
+ self.stderr_bytes = sys.stderr.buffer.getvalue() # type: ignore
sys.stderr = hold_stderr
actual = black.format_file_contents(different, line_length=ll, fast=False)
self.assertEqual(expected, actual)
invalid = "return if you can"
- with self.assertRaises(ValueError) as e:
+ with self.assertRaises(black.InvalidInput) as e:
black.format_file_contents(invalid, line_length=ll, fast=False)
self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
result = runner.invoke(
black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
)
+ self.assertEqual(result.exit_code, 0)
output = runner.stdout_bytes
self.assertIn(nl.encode("utf8"), output)
if nl == "\n":
except RuntimeError as re:
self.fail(f"`patch_click()` failed, exception still raised: {re}")
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_request_needs_formatting(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post("/", data=b"print('hello world')")
+ self.assertEqual(response.status, 200)
+ self.assertEqual(response.charset, "utf8")
+ self.assertEqual(await response.read(), b'print("hello world")\n')
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_request_no_change(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post("/", data=b'print("hello world")\n')
+ self.assertEqual(response.status, 204)
+ self.assertEqual(await response.read(), b"")
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_request_syntax_error(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post("/", data=b"what even ( is")
+ self.assertEqual(response.status, 400)
+ content = await response.text()
+ self.assertTrue(
+ content.startswith("Cannot parse"),
+ msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
+ )
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_unsupported_version(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post(
+ "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
+ )
+ self.assertEqual(response.status, 501)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_supported_version(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post(
+ "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
+ )
+ self.assertEqual(response.status, 200)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_invalid_python_variant(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post(
+ "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"}
+ )
+ self.assertEqual(response.status, 400)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_pyi(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ source, expected = read_data("stub.pyi")
+ response = await client.post(
+ "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
+ )
+ self.assertEqual(response.status, 200)
+ self.assertEqual(await response.text(), expected)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_py36(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post(
+ "/",
+ data=(
+ "def f(\n"
+ " and_has_a_bunch_of,\n"
+ " very_long_arguments_too,\n"
+ " and_lots_of_them_as_well_lol,\n"
+ " **and_very_long_keyword_arguments\n"
+ "):\n"
+ " pass\n"
+ ),
+ headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
+ )
+ self.assertEqual(response.status, 200)
+ response = await client.post(
+ "/",
+ data=(
+ "def f(\n"
+ " and_has_a_bunch_of,\n"
+ " very_long_arguments_too,\n"
+ " and_lots_of_them_as_well_lol,\n"
+ " **and_very_long_keyword_arguments\n"
+ "):\n"
+ " pass\n"
+ ),
+ headers={blackd.PYTHON_VARIANT_HEADER: "3.5"},
+ )
+ self.assertEqual(response.status, 204)
+ response = await client.post(
+ "/",
+ data=(
+ "def f(\n"
+ " and_has_a_bunch_of,\n"
+ " very_long_arguments_too,\n"
+ " and_lots_of_them_as_well_lol,\n"
+ " **and_very_long_keyword_arguments\n"
+ "):\n"
+ " pass\n"
+ ),
+ headers={blackd.PYTHON_VARIANT_HEADER: "2"},
+ )
+ self.assertEqual(response.status, 204)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_fast(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post("/", data=b"ur'hello'")
+ self.assertEqual(response.status, 500)
+ self.assertIn("failed to parse source file", await response.text())
+ response = await client.post(
+ "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
+ )
+ self.assertEqual(response.status, 200)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_line_length(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post(
+ "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
+ )
+ self.assertEqual(response.status, 200)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ @async_test
+ async def test_blackd_invalid_line_length(self) -> None:
+ app = blackd.make_app()
+ async with TestClient(TestServer(app)) as client:
+ response = await client.post(
+ "/",
+ data=b'print("hello")\n',
+ headers={blackd.LINE_LENGTH_HEADER: "NaN"},
+ )
+ self.assertEqual(response.status, 400)
+
+ @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
+ def test_blackd_main(self) -> None:
+ with patch("blackd.web.run_app"):
+ result = CliRunner().invoke(blackd.main, [])
+ if result.exception is not None:
+ raise result.exception
+ self.assertEqual(result.exit_code, 0)
+
if __name__ == "__main__":
unittest.main(module="test_black")