]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Lazily import parallelized format modules
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
34 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
35 PREVIEW = "X-Preview"
36 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
37 DIFF_HEADER = "X-Diff"
38
39 BLACK_HEADERS = [
40     PROTOCOL_VERSION_HEADER,
41     LINE_LENGTH_HEADER,
42     PYTHON_VARIANT_HEADER,
43     SKIP_STRING_NORMALIZATION_HEADER,
44     SKIP_MAGIC_TRAILING_COMMA,
45     PREVIEW,
46     FAST_OR_SAFE_HEADER,
47     DIFF_HEADER,
48 ]
49
50 # Response headers
51 BLACK_VERSION_HEADER = "X-Black-Version"
52
53
54 class InvalidVariantHeader(Exception):
55     pass
56
57
58 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
59 @click.option(
60     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
61 )
62 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
63 @click.version_option(version=black.__version__)
64 def main(bind_host: str, bind_port: int) -> None:
65     logging.basicConfig(level=logging.INFO)
66     app = make_app()
67     ver = black.__version__
68     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
69     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
70
71
72 def make_app() -> web.Application:
73     app = web.Application(
74         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
75     )
76     executor = ProcessPoolExecutor()
77     app.add_routes([web.post("/", partial(handle, executor=executor))])
78     return app
79
80
81 async def handle(request: web.Request, executor: Executor) -> web.Response:
82     headers = {BLACK_VERSION_HEADER: __version__}
83     try:
84         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
85             return web.Response(
86                 status=501, text="This server only supports protocol version 1"
87             )
88         try:
89             line_length = int(
90                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
91             )
92         except ValueError:
93             return web.Response(status=400, text="Invalid line length header value")
94
95         if PYTHON_VARIANT_HEADER in request.headers:
96             value = request.headers[PYTHON_VARIANT_HEADER]
97             try:
98                 pyi, versions = parse_python_variant_header(value)
99             except InvalidVariantHeader as e:
100                 return web.Response(
101                     status=400,
102                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
103                 )
104         else:
105             pyi = False
106             versions = set()
107
108         skip_string_normalization = bool(
109             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
110         )
111         skip_magic_trailing_comma = bool(
112             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
113         )
114         preview = bool(request.headers.get(PREVIEW, False))
115         fast = False
116         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
117             fast = True
118         mode = black.FileMode(
119             target_versions=versions,
120             is_pyi=pyi,
121             line_length=line_length,
122             string_normalization=not skip_string_normalization,
123             magic_trailing_comma=not skip_magic_trailing_comma,
124             preview=preview,
125         )
126         req_bytes = await request.content.read()
127         charset = request.charset if request.charset is not None else "utf8"
128         req_str = req_bytes.decode(charset)
129         then = datetime.utcnow()
130
131         loop = asyncio.get_event_loop()
132         formatted_str = await loop.run_in_executor(
133             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
134         )
135
136         # Only output the diff in the HTTP response
137         only_diff = bool(request.headers.get(DIFF_HEADER, False))
138         if only_diff:
139             now = datetime.utcnow()
140             src_name = f"In\t{then} +0000"
141             dst_name = f"Out\t{now} +0000"
142             loop = asyncio.get_event_loop()
143             formatted_str = await loop.run_in_executor(
144                 executor,
145                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
146             )
147
148         return web.Response(
149             content_type=request.content_type,
150             charset=charset,
151             headers=headers,
152             text=formatted_str,
153         )
154     except black.NothingChanged:
155         return web.Response(status=204, headers=headers)
156     except black.InvalidInput as e:
157         return web.Response(status=400, headers=headers, text=str(e))
158     except Exception as e:
159         logging.exception("Exception during handling a request")
160         return web.Response(status=500, headers=headers, text=str(e))
161
162
163 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
164     if value == "pyi":
165         return True, set()
166     else:
167         versions = set()
168         for version in value.split(","):
169             if version.startswith("py"):
170                 version = version[len("py") :]
171             if "." in version:
172                 major_str, *rest = version.split(".")
173             else:
174                 major_str = version[0]
175                 rest = [version[1:]] if len(version) > 1 else []
176             try:
177                 major = int(major_str)
178                 if major not in (2, 3):
179                     raise InvalidVariantHeader("major version must be 2 or 3")
180                 if len(rest) > 0:
181                     minor = int(rest[0])
182                     if major == 2:
183                         raise InvalidVariantHeader("Python 2 is not supported")
184                 else:
185                     # Default to lowest supported minor version.
186                     minor = 7 if major == 2 else 3
187                 version_str = f"PY{major}{minor}"
188                 if major == 3 and not hasattr(black.TargetVersion, version_str):
189                     raise InvalidVariantHeader(f"3.{minor} is not supported")
190                 versions.add(black.TargetVersion[version_str])
191             except (KeyError, ValueError):
192                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
193         return False, versions
194
195
196 def patched_main() -> None:
197     maybe_install_uvloop()
198     freeze_support()
199     black.patch_click()
200     main()
201
202
203 if __name__ == "__main__":
204     patched_main()