All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
2 from typing import TYPE_CHECKING, Any, Callable, TypeVar
3 from unittest.mock import patch
6 from click.testing import CliRunner
8 from tests.util import DETERMINISTIC_HEADER, read_data
11 from aiohttp import web
12 from aiohttp.test_utils import AioHTTPTestCase
15 except ImportError as e:
16 raise RuntimeError("Please install Black with the 'd' extra") from e
19 F = TypeVar("F", bound=Callable[..., Any])
21 unittest_run_loop: Callable[[F], F] = lambda x: x
24 from aiohttp.test_utils import unittest_run_loop
26 # unittest_run_loop is unnecessary and a no-op since aiohttp 3.8, and
27 # aiohttp 4 removed it. To maintain compatibility we can make our own
29 def unittest_run_loop(func, *args, **kwargs):
34 class BlackDTestCase(AioHTTPTestCase): # type: ignore[misc]
35 def test_blackd_main(self) -> None:
36 with patch("blackd.web.run_app"):
37 result = CliRunner().invoke(blackd.main, [])
38 if result.exception is not None:
39 raise result.exception
40 self.assertEqual(result.exit_code, 0)
42 async def get_application(self) -> web.Application:
43 return blackd.make_app()
46 async def test_blackd_request_needs_formatting(self) -> None:
47 response = await self.client.post("/", data=b"print('hello world')")
48 self.assertEqual(response.status, 200)
49 self.assertEqual(response.charset, "utf8")
50 self.assertEqual(await response.read(), b'print("hello world")\n')
53 async def test_blackd_request_no_change(self) -> None:
54 response = await self.client.post("/", data=b'print("hello world")\n')
55 self.assertEqual(response.status, 204)
56 self.assertEqual(await response.read(), b"")
59 async def test_blackd_request_syntax_error(self) -> None:
60 response = await self.client.post("/", data=b"what even ( is")
61 self.assertEqual(response.status, 400)
62 content = await response.text()
64 content.startswith("Cannot parse"),
65 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
69 async def test_blackd_unsupported_version(self) -> None:
70 response = await self.client.post(
71 "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "2"}
73 self.assertEqual(response.status, 501)
76 async def test_blackd_supported_version(self) -> None:
77 response = await self.client.post(
78 "/", data=b"what", headers={blackd.PROTOCOL_VERSION_HEADER: "1"}
80 self.assertEqual(response.status, 200)
83 async def test_blackd_invalid_python_variant(self) -> None:
84 async def check(header_value: str, expected_status: int = 400) -> None:
85 response = await self.client.post(
88 headers={blackd.PYTHON_VARIANT_HEADER: header_value},
90 self.assertEqual(response.status, expected_status)
93 await check("ruby3.5")
102 await check("pypy3.0")
103 await check("jython3.4")
106 async def test_blackd_pyi(self) -> None:
107 source, expected = read_data("miscellaneous", "stub.pyi")
108 response = await self.client.post(
109 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
111 self.assertEqual(response.status, 200)
112 self.assertEqual(await response.text(), expected)
115 async def test_blackd_diff(self) -> None:
116 diff_header = re.compile(
117 r"(In|Out)\t\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d\+\d\d:\d\d"
120 source, _ = read_data("miscellaneous", "blackd_diff")
121 expected, _ = read_data("miscellaneous", "blackd_diff.diff")
123 response = await self.client.post(
124 "/", data=source, headers={blackd.DIFF_HEADER: "true"}
126 self.assertEqual(response.status, 200)
128 actual = await response.text()
129 actual = diff_header.sub(DETERMINISTIC_HEADER, actual)
130 self.assertEqual(actual, expected)
133 async def test_blackd_python_variant(self) -> None:
136 " and_has_a_bunch_of,\n"
137 " very_long_arguments_too,\n"
138 " and_lots_of_them_as_well_lol,\n"
139 " **and_very_long_keyword_arguments\n"
144 async def check(header_value: str, expected_status: int) -> None:
145 response = await self.client.post(
146 "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
149 response.status, expected_status, msg=await response.text()
152 await check("3.6", 200)
153 await check("py3.6", 200)
154 await check("3.6,3.7", 200)
155 await check("3.6,py3.7", 200)
156 await check("py36,py37", 200)
157 await check("36", 200)
158 await check("3.6.4", 200)
159 await check("3.4", 204)
160 await check("py3.4", 204)
161 await check("py34,py36", 204)
162 await check("34", 204)
165 async def test_blackd_line_length(self) -> None:
166 response = await self.client.post(
167 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
169 self.assertEqual(response.status, 200)
172 async def test_blackd_invalid_line_length(self) -> None:
173 response = await self.client.post(
175 data=b'print("hello")\n',
176 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
178 self.assertEqual(response.status, 400)
181 async def test_blackd_skip_first_source_line(self) -> None:
182 invalid_first_line = b"Header will be skipped\r\ni = [1,2,3]\nj = [1,2,3]\n"
183 expected_result = b"Header will be skipped\r\ni = [1, 2, 3]\nj = [1, 2, 3]\n"
184 response = await self.client.post("/", data=invalid_first_line)
185 self.assertEqual(response.status, 400)
186 response = await self.client.post(
188 data=invalid_first_line,
189 headers={blackd.SKIP_SOURCE_FIRST_LINE: "true"},
191 self.assertEqual(response.status, 200)
192 self.assertEqual(await response.read(), expected_result)
195 async def test_blackd_preview(self) -> None:
196 response = await self.client.post(
197 "/", data=b'print("hello")\n', headers={blackd.PREVIEW: "true"}
199 self.assertEqual(response.status, 204)
202 async def test_blackd_response_black_version_header(self) -> None:
203 response = await self.client.post("/")
204 self.assertIsNotNone(response.headers.get(blackd.BLACK_VERSION_HEADER))
207 async def test_cors_preflight(self) -> None:
208 response = await self.client.options(
211 "Access-Control-Request-Method": "POST",
213 "Access-Control-Request-Headers": "Content-Type",
216 self.assertEqual(response.status, 200)
217 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
218 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Headers"))
219 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Methods"))
222 async def test_cors_headers_present(self) -> None:
223 response = await self.client.post("/", headers={"Origin": "*"})
224 self.assertIsNotNone(response.headers.get("Access-Control-Allow-Origin"))
225 self.assertIsNotNone(response.headers.get("Access-Control-Expose-Headers"))
228 async def test_preserves_line_endings(self) -> None:
229 for data in (b"c\r\nc\r\n", b"l\nl\n"):
230 # test preserved newlines when reformatted
231 response = await self.client.post("/", data=data + b" ")
232 self.assertEqual(await response.text(), data.decode())
233 # test 204 when no change
234 response = await self.client.post("/", data=data)
235 self.assertEqual(response.status, 204)
238 async def test_normalizes_line_endings(self) -> None:
239 for data, expected in ((b"c\r\nc\n", "c\r\nc\r\n"), (b"l\nl\r\n", "l\nl\n")):
240 response = await self.client.post("/", data=data)
241 self.assertEqual(await response.text(), expected)
242 self.assertEqual(response.status, 200)