]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

f77a5e8e7be49c0a5b881e22fe4401b5a36f221e
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 import sys
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from datetime import datetime
6 from functools import partial
7 from multiprocessing import freeze_support
8 from typing import Set, Tuple
9
10 try:
11     from aiohttp import web
12     import aiohttp_cors
13 except ImportError as ie:
14     print(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`",
18         file=sys.stderr,
19     )
20     sys.exit(-1)
21
22 import black
23 import click
24
25 from _black_version import version as __version__
26
27 # This is used internally by tests to shut down the server prematurely
28 _stop_signal = asyncio.Event()
29
30 # Request headers
31 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
32 LINE_LENGTH_HEADER = "X-Line-Length"
33 PYTHON_VARIANT_HEADER = "X-Python-Variant"
34 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
35 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
36 DIFF_HEADER = "X-Diff"
37
38 BLACK_HEADERS = [
39     PROTOCOL_VERSION_HEADER,
40     LINE_LENGTH_HEADER,
41     PYTHON_VARIANT_HEADER,
42     SKIP_STRING_NORMALIZATION_HEADER,
43     FAST_OR_SAFE_HEADER,
44     DIFF_HEADER,
45 ]
46
47 # Response headers
48 BLACK_VERSION_HEADER = "X-Black-Version"
49
50
51 class InvalidVariantHeader(Exception):
52     pass
53
54
55 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
56 @click.option(
57     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
58 )
59 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
60 @click.version_option(version=black.__version__)
61 def main(bind_host: str, bind_port: int) -> None:
62     logging.basicConfig(level=logging.INFO)
63     app = make_app()
64     ver = black.__version__
65     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
66     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
67
68
69 def make_app() -> web.Application:
70     app = web.Application()
71     executor = ProcessPoolExecutor()
72
73     cors = aiohttp_cors.setup(app)
74     resource = cors.add(app.router.add_resource("/"))
75     cors.add(
76         resource.add_route("POST", partial(handle, executor=executor)),
77         {
78             "*": aiohttp_cors.ResourceOptions(
79                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
80             )
81         },
82     )
83
84     return app
85
86
87 async def handle(request: web.Request, executor: Executor) -> web.Response:
88     headers = {BLACK_VERSION_HEADER: __version__}
89     try:
90         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
91             return web.Response(
92                 status=501, text="This server only supports protocol version 1"
93             )
94         try:
95             line_length = int(
96                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
97             )
98         except ValueError:
99             return web.Response(status=400, text="Invalid line length header value")
100
101         if PYTHON_VARIANT_HEADER in request.headers:
102             value = request.headers[PYTHON_VARIANT_HEADER]
103             try:
104                 pyi, versions = parse_python_variant_header(value)
105             except InvalidVariantHeader as e:
106                 return web.Response(
107                     status=400,
108                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
109                 )
110         else:
111             pyi = False
112             versions = set()
113
114         skip_string_normalization = bool(
115             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
116         )
117         fast = False
118         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
119             fast = True
120         mode = black.FileMode(
121             target_versions=versions,
122             is_pyi=pyi,
123             line_length=line_length,
124             string_normalization=not skip_string_normalization,
125         )
126         req_bytes = await request.content.read()
127         charset = request.charset if request.charset is not None else "utf8"
128         req_str = req_bytes.decode(charset)
129         then = datetime.utcnow()
130
131         loop = asyncio.get_event_loop()
132         formatted_str = await loop.run_in_executor(
133             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
134         )
135
136         # Only output the diff in the HTTP response
137         only_diff = bool(request.headers.get(DIFF_HEADER, False))
138         if only_diff:
139             now = datetime.utcnow()
140             src_name = f"In\t{then} +0000"
141             dst_name = f"Out\t{now} +0000"
142             loop = asyncio.get_event_loop()
143             formatted_str = await loop.run_in_executor(
144                 executor,
145                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
146             )
147
148         return web.Response(
149             content_type=request.content_type,
150             charset=charset,
151             headers=headers,
152             text=formatted_str,
153         )
154     except black.NothingChanged:
155         return web.Response(status=204, headers=headers)
156     except black.InvalidInput as e:
157         return web.Response(status=400, headers=headers, text=str(e))
158     except Exception as e:
159         logging.exception("Exception during handling a request")
160         return web.Response(status=500, headers=headers, text=str(e))
161
162
163 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
164     if value == "pyi":
165         return True, set()
166     else:
167         versions = set()
168         for version in value.split(","):
169             if version.startswith("py"):
170                 version = version[len("py") :]
171             if "." in version:
172                 major_str, *rest = version.split(".")
173             else:
174                 major_str = version[0]
175                 rest = [version[1:]] if len(version) > 1 else []
176             try:
177                 major = int(major_str)
178                 if major not in (2, 3):
179                     raise InvalidVariantHeader("major version must be 2 or 3")
180                 if len(rest) > 0:
181                     minor = int(rest[0])
182                     if major == 2 and minor != 7:
183                         raise InvalidVariantHeader(
184                             "minor version must be 7 for Python 2"
185                         )
186                 else:
187                     # Default to lowest supported minor version.
188                     minor = 7 if major == 2 else 3
189                 version_str = f"PY{major}{minor}"
190                 if major == 3 and not hasattr(black.TargetVersion, version_str):
191                     raise InvalidVariantHeader(f"3.{minor} is not supported")
192                 versions.add(black.TargetVersion[version_str])
193             except (KeyError, ValueError):
194                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
195         return False, versions
196
197
198 def patched_main() -> None:
199     freeze_support()
200     black.patch_click()
201     main()
202
203
204 if __name__ == "__main__":
205     patched_main()