]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

don't uvloop.install on import (#2303)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 import sys
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from datetime import datetime
6 from functools import partial
7 from multiprocessing import freeze_support
8 from typing import Set, Tuple
9
10 try:
11     from aiohttp import web
12     import aiohttp_cors
13 except ImportError as ie:
14     print(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`",
18         file=sys.stderr,
19     )
20     sys.exit(-1)
21
22 import black
23 from black.concurrency import maybe_install_uvloop
24 import click
25
26 from _black_version import version as __version__
27
28 # This is used internally by tests to shut down the server prematurely
29 _stop_signal = asyncio.Event()
30
31 # Request headers
32 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
33 LINE_LENGTH_HEADER = "X-Line-Length"
34 PYTHON_VARIANT_HEADER = "X-Python-Variant"
35 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
36 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
37 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
38 DIFF_HEADER = "X-Diff"
39
40 BLACK_HEADERS = [
41     PROTOCOL_VERSION_HEADER,
42     LINE_LENGTH_HEADER,
43     PYTHON_VARIANT_HEADER,
44     SKIP_STRING_NORMALIZATION_HEADER,
45     SKIP_MAGIC_TRAILING_COMMA,
46     FAST_OR_SAFE_HEADER,
47     DIFF_HEADER,
48 ]
49
50 # Response headers
51 BLACK_VERSION_HEADER = "X-Black-Version"
52
53
54 class InvalidVariantHeader(Exception):
55     pass
56
57
58 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
59 @click.option(
60     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
61 )
62 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
63 @click.version_option(version=black.__version__)
64 def main(bind_host: str, bind_port: int) -> None:
65     logging.basicConfig(level=logging.INFO)
66     app = make_app()
67     ver = black.__version__
68     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
69     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
70
71
72 def make_app() -> web.Application:
73     app = web.Application()
74     executor = ProcessPoolExecutor()
75
76     cors = aiohttp_cors.setup(app)
77     resource = cors.add(app.router.add_resource("/"))
78     cors.add(
79         resource.add_route("POST", partial(handle, executor=executor)),
80         {
81             "*": aiohttp_cors.ResourceOptions(
82                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
83             )
84         },
85     )
86
87     return app
88
89
90 async def handle(request: web.Request, executor: Executor) -> web.Response:
91     headers = {BLACK_VERSION_HEADER: __version__}
92     try:
93         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
94             return web.Response(
95                 status=501, text="This server only supports protocol version 1"
96             )
97         try:
98             line_length = int(
99                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
100             )
101         except ValueError:
102             return web.Response(status=400, text="Invalid line length header value")
103
104         if PYTHON_VARIANT_HEADER in request.headers:
105             value = request.headers[PYTHON_VARIANT_HEADER]
106             try:
107                 pyi, versions = parse_python_variant_header(value)
108             except InvalidVariantHeader as e:
109                 return web.Response(
110                     status=400,
111                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
112                 )
113         else:
114             pyi = False
115             versions = set()
116
117         skip_string_normalization = bool(
118             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
119         )
120         skip_magic_trailing_comma = bool(
121             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
122         )
123         fast = False
124         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
125             fast = True
126         mode = black.FileMode(
127             target_versions=versions,
128             is_pyi=pyi,
129             line_length=line_length,
130             string_normalization=not skip_string_normalization,
131             magic_trailing_comma=not skip_magic_trailing_comma,
132         )
133         req_bytes = await request.content.read()
134         charset = request.charset if request.charset is not None else "utf8"
135         req_str = req_bytes.decode(charset)
136         then = datetime.utcnow()
137
138         loop = asyncio.get_event_loop()
139         formatted_str = await loop.run_in_executor(
140             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
141         )
142
143         # Only output the diff in the HTTP response
144         only_diff = bool(request.headers.get(DIFF_HEADER, False))
145         if only_diff:
146             now = datetime.utcnow()
147             src_name = f"In\t{then} +0000"
148             dst_name = f"Out\t{now} +0000"
149             loop = asyncio.get_event_loop()
150             formatted_str = await loop.run_in_executor(
151                 executor,
152                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
153             )
154
155         return web.Response(
156             content_type=request.content_type,
157             charset=charset,
158             headers=headers,
159             text=formatted_str,
160         )
161     except black.NothingChanged:
162         return web.Response(status=204, headers=headers)
163     except black.InvalidInput as e:
164         return web.Response(status=400, headers=headers, text=str(e))
165     except Exception as e:
166         logging.exception("Exception during handling a request")
167         return web.Response(status=500, headers=headers, text=str(e))
168
169
170 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
171     if value == "pyi":
172         return True, set()
173     else:
174         versions = set()
175         for version in value.split(","):
176             if version.startswith("py"):
177                 version = version[len("py") :]
178             if "." in version:
179                 major_str, *rest = version.split(".")
180             else:
181                 major_str = version[0]
182                 rest = [version[1:]] if len(version) > 1 else []
183             try:
184                 major = int(major_str)
185                 if major not in (2, 3):
186                     raise InvalidVariantHeader("major version must be 2 or 3")
187                 if len(rest) > 0:
188                     minor = int(rest[0])
189                     if major == 2 and minor != 7:
190                         raise InvalidVariantHeader(
191                             "minor version must be 7 for Python 2"
192                         )
193                 else:
194                     # Default to lowest supported minor version.
195                     minor = 7 if major == 2 else 3
196                 version_str = f"PY{major}{minor}"
197                 if major == 3 and not hasattr(black.TargetVersion, version_str):
198                     raise InvalidVariantHeader(f"3.{minor} is not supported")
199                 versions.add(black.TargetVersion[version_str])
200             except (KeyError, ValueError):
201                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
202         return False, versions
203
204
205 def patched_main() -> None:
206     maybe_install_uvloop()
207     freeze_support()
208     black.patch_click()
209     main()
210
211
212 if __name__ == "__main__":
213     patched_main()