]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

re-implement simple CORS middleware for blackd (#2500)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11     from .middlewares import cors
12 except ImportError as ie:
13     raise ImportError(
14         f"aiohttp dependency is not installed: {ie}. "
15         + "Please re-install black with the '[d]' extra install "
16         + "to obtain aiohttp_cors: `pip install black[d]`"
17     ) from None
18
19 import black
20 from black.concurrency import maybe_install_uvloop
21 import click
22
23 from _black_version import version as __version__
24
25 # This is used internally by tests to shut down the server prematurely
26 _stop_signal = asyncio.Event()
27
28 # Request headers
29 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
30 LINE_LENGTH_HEADER = "X-Line-Length"
31 PYTHON_VARIANT_HEADER = "X-Python-Variant"
32 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
33 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
34 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
35 DIFF_HEADER = "X-Diff"
36
37 BLACK_HEADERS = [
38     PROTOCOL_VERSION_HEADER,
39     LINE_LENGTH_HEADER,
40     PYTHON_VARIANT_HEADER,
41     SKIP_STRING_NORMALIZATION_HEADER,
42     SKIP_MAGIC_TRAILING_COMMA,
43     FAST_OR_SAFE_HEADER,
44     DIFF_HEADER,
45 ]
46
47 # Response headers
48 BLACK_VERSION_HEADER = "X-Black-Version"
49
50
51 class InvalidVariantHeader(Exception):
52     pass
53
54
55 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
56 @click.option(
57     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
58 )
59 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
60 @click.version_option(version=black.__version__)
61 def main(bind_host: str, bind_port: int) -> None:
62     logging.basicConfig(level=logging.INFO)
63     app = make_app()
64     ver = black.__version__
65     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
66     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
67
68
69 def make_app() -> web.Application:
70     app = web.Application(
71         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
72     )
73     executor = ProcessPoolExecutor()
74     app.add_routes([web.post("/", partial(handle, executor=executor))])
75     return app
76
77
78 async def handle(request: web.Request, executor: Executor) -> web.Response:
79     headers = {BLACK_VERSION_HEADER: __version__}
80     try:
81         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
82             return web.Response(
83                 status=501, text="This server only supports protocol version 1"
84             )
85         try:
86             line_length = int(
87                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
88             )
89         except ValueError:
90             return web.Response(status=400, text="Invalid line length header value")
91
92         if PYTHON_VARIANT_HEADER in request.headers:
93             value = request.headers[PYTHON_VARIANT_HEADER]
94             try:
95                 pyi, versions = parse_python_variant_header(value)
96             except InvalidVariantHeader as e:
97                 return web.Response(
98                     status=400,
99                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
100                 )
101         else:
102             pyi = False
103             versions = set()
104
105         skip_string_normalization = bool(
106             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
107         )
108         skip_magic_trailing_comma = bool(
109             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
110         )
111         fast = False
112         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
113             fast = True
114         mode = black.FileMode(
115             target_versions=versions,
116             is_pyi=pyi,
117             line_length=line_length,
118             string_normalization=not skip_string_normalization,
119             magic_trailing_comma=not skip_magic_trailing_comma,
120         )
121         req_bytes = await request.content.read()
122         charset = request.charset if request.charset is not None else "utf8"
123         req_str = req_bytes.decode(charset)
124         then = datetime.utcnow()
125
126         loop = asyncio.get_event_loop()
127         formatted_str = await loop.run_in_executor(
128             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
129         )
130
131         # Only output the diff in the HTTP response
132         only_diff = bool(request.headers.get(DIFF_HEADER, False))
133         if only_diff:
134             now = datetime.utcnow()
135             src_name = f"In\t{then} +0000"
136             dst_name = f"Out\t{now} +0000"
137             loop = asyncio.get_event_loop()
138             formatted_str = await loop.run_in_executor(
139                 executor,
140                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
141             )
142
143         return web.Response(
144             content_type=request.content_type,
145             charset=charset,
146             headers=headers,
147             text=formatted_str,
148         )
149     except black.NothingChanged:
150         return web.Response(status=204, headers=headers)
151     except black.InvalidInput as e:
152         return web.Response(status=400, headers=headers, text=str(e))
153     except Exception as e:
154         logging.exception("Exception during handling a request")
155         return web.Response(status=500, headers=headers, text=str(e))
156
157
158 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
159     if value == "pyi":
160         return True, set()
161     else:
162         versions = set()
163         for version in value.split(","):
164             if version.startswith("py"):
165                 version = version[len("py") :]
166             if "." in version:
167                 major_str, *rest = version.split(".")
168             else:
169                 major_str = version[0]
170                 rest = [version[1:]] if len(version) > 1 else []
171             try:
172                 major = int(major_str)
173                 if major not in (2, 3):
174                     raise InvalidVariantHeader("major version must be 2 or 3")
175                 if len(rest) > 0:
176                     minor = int(rest[0])
177                     if major == 2 and minor != 7:
178                         raise InvalidVariantHeader(
179                             "minor version must be 7 for Python 2"
180                         )
181                 else:
182                     # Default to lowest supported minor version.
183                     minor = 7 if major == 2 else 3
184                 version_str = f"PY{major}{minor}"
185                 if major == 3 and not hasattr(black.TargetVersion, version_str):
186                     raise InvalidVariantHeader(f"3.{minor} is not supported")
187                 versions.add(black.TargetVersion[version_str])
188             except (KeyError, ValueError):
189                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
190         return False, versions
191
192
193 def patched_main() -> None:
194     maybe_install_uvloop()
195     freeze_support()
196     black.patch_click()
197     main()
198
199
200 if __name__ == "__main__":
201     patched_main()