All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from typing import TYPE_CHECKING, Any, Callable, TypeVar
4 from unittest.mock import patch
7 from click.testing import CliRunner
9 from tests.util import DETERMINISTIC_HEADER, read_data
11 LESS_THAN_311 = sys.version_info < (3, 11)
13 if LESS_THAN_311: # noqa: C901
15 from aiohttp import web
16 from aiohttp.test_utils import AioHTTPTestCase
19 except ImportError as e:
20 raise RuntimeError("Please install Black with the 'd' extra") from e
23 F = TypeVar("F", bound=Callable[..., Any])
25 unittest_run_loop: Callable[[F], F] = lambda x: x
28 from aiohttp.test_utils import unittest_run_loop
30 # unittest_run_loop is unnecessary and a no-op since aiohttp 3.8, and
31 # aiohttp 4 removed it. To maintain compatibility we can make our own
33 def unittest_run_loop(func, *args, **kwargs):
37 class BlackDTestCase(AioHTTPTestCase): # type: ignore[misc]
38 def test_blackd_main(self) -> None:
39 with patch("blackd.web.run_app"):
40 result = CliRunner().invoke(blackd.main, [])
41 if result.exception is not None:
42 raise result.exception
43 self.assertEqual(result.exit_code, 0)
45 async def get_application(self) -> web.Application:
46 return blackd.make_app()
49 async def test_blackd_request_needs_formatting(self) -> None:
50 response = await self.client.post("/", data=b"print('hello world')")
51 self.assertEqual(response.status, 200)
52 self.assertEqual(response.charset, "utf8")
53 self.assertEqual(await response.read(), b'print("hello world")\n')
56 async def test_blackd_request_no_change(self) -> None:
57 response = await self.client.post("/", data=b'print("hello world")\n')
58 self.assertEqual(response.status, 204)
59 self.assertEqual(await response.read(), b"")
62 async def test_blackd_request_syntax_error(self) -> None:
63 response = await self.client.post("/", data=b"what even ( is")
64 self.assertEqual(response.status, 400)
65 content = await response.text()
67 content.startswith("Cannot parse"),
68 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
72 async def test_blackd_unsupported_version(self) -> None:
73 response = await self.client.post(
74 "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
76 self.assertEqual(response.status, 501)
79 async def test_blackd_supported_version(self) -> None:
80 response = await self.client.post(
81 "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
83 self.assertEqual(response.status, 200)
86 async def test_blackd_invalid_python_variant(self) -> None:
87 async def check(header_value: str, expected_status: int = 400) -> None:
88 response = await self.client.post(
91 headers={blackd.PYTHON_VARIANT_HEADER: header_value},
93 self.assertEqual(response.status, expected_status)
96 await check("ruby3.5")
105 await check("pypy3.0")
106 await check("jython3.4")
109 async def test_blackd_pyi(self) -> None:
110 source, expected = read_data("miscellaneous", "stub.pyi")
111 response = await self.client.post(
112 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
114 self.assertEqual(response.status, 200)
115 self.assertEqual(await response.text(), expected)
118 async def test_blackd_diff(self) -> None:
119 diff_header = re.compile(
120 r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
123 source, _ = read_data("miscellaneous", "blackd_diff")
124 expected, _ = read_data("miscellaneous", "blackd_diff.diff")
126 response = await self.client.post(
127 "/", data=source, headers={blackd.DIFF_HEADER: "true"}
129 self.assertEqual(response.status, 200)
131 actual = await response.text()
132 actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
133 self.assertEqual(actual, expected)
136 async def test_blackd_python_variant(self) -> None:
139 " and_has_a_bunch_of,\n"
140 " very_long_arguments_too,\n"
141 " and_lots_of_them_as_well_lol,\n"
142 " **and_very_long_keyword_arguments\n"
147 async def check(header_value: str, expected_status: int) -> None:
148 response = await self.client.post(
149 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
152 response.status, expected_status, msg=await response.text()
155 await check("3.6", 200)
156 await check("py3.6", 200)
157 await check("3.6,3.7", 200)
158 await check("3.6,py3.7", 200)
159 await check("py36,py37", 200)
160 await check("36", 200)
161 await check("3.6.4", 200)
162 await check("3.4", 204)
163 await check("py3.4", 204)
164 await check("py34,py36", 204)
165 await check("34", 204)
168 async def test_blackd_line_length(self) -> None:
169 response = await self.client.post(
170 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
172 self.assertEqual(response.status, 200)
175 async def test_blackd_invalid_line_length(self) -> None:
176 response = await self.client.post(
178 data=b'print("hello")\n',
179 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
181 self.assertEqual(response.status, 400)
184 async def test_blackd_preview(self) -> None:
185 response = await self.client.post(
186 "/", data=b'print("hello")\n', headers={blackd.PREVIEW: "true"}
188 self.assertEqual(response.status, 204)
191 async def test_blackd_response_black_version_header(self) -> None:
192 response = await self.client.post("/")
193 self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
196 async def test_cors_preflight(self) -> None:
197 response = await self.client.options(
200 "Access-Control-Request-Method": "POST",
202 "Access-Control-Request-Headers": "Content-Type",
205 self.assertEqual(response.status, 200)
206 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
207 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Headers"))
208 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Methods"))
211 async def test_cors_headers_present(self) -> None:
212 response = await self.client.post("/", headers={"Origin": "*"})
213 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
214 self.assertIsNotNone(response.headers.get("Access-Control-Expose-Headers"))