]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update the changelog
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 from concurrent.futures import Executor, ProcessPoolExecutor
3 from datetime import datetime
4 from functools import partial
5 import logging
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 from aiohttp import web
10 import aiohttp_cors
11 import black
12 import click
13
14 from _black_version import version as __version__
15
16 # This is used internally by tests to shut down the server prematurely
17 _stop_signal = asyncio.Event()
18
19 # Request headers
20 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
21 LINE_LENGTH_HEADER = "X-Line-Length"
22 PYTHON_VARIANT_HEADER = "X-Python-Variant"
23 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
24 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
25 DIFF_HEADER = "X-Diff"
26
27 BLACK_HEADERS = [
28     PROTOCOL_VERSION_HEADER,
29     LINE_LENGTH_HEADER,
30     PYTHON_VARIANT_HEADER,
31     SKIP_STRING_NORMALIZATION_HEADER,
32     FAST_OR_SAFE_HEADER,
33     DIFF_HEADER,
34 ]
35
36 # Response headers
37 BLACK_VERSION_HEADER = "X-Black-Version"
38
39
40 class InvalidVariantHeader(Exception):
41     pass
42
43
44 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
45 @click.option(
46     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
47 )
48 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
49 @click.version_option(version=black.__version__)
50 def main(bind_host: str, bind_port: int) -> None:
51     logging.basicConfig(level=logging.INFO)
52     app = make_app()
53     ver = black.__version__
54     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
55     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
56
57
58 def make_app() -> web.Application:
59     app = web.Application()
60     executor = ProcessPoolExecutor()
61
62     cors = aiohttp_cors.setup(app)
63     resource = cors.add(app.router.add_resource("/"))
64     cors.add(
65         resource.add_route("POST", partial(handle, executor=executor)),
66         {
67             "*": aiohttp_cors.ResourceOptions(
68                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
69             )
70         },
71     )
72
73     return app
74
75
76 async def handle(request: web.Request, executor: Executor) -> web.Response:
77     headers = {BLACK_VERSION_HEADER: __version__}
78     try:
79         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
80             return web.Response(
81                 status=501, text="This server only supports protocol version 1"
82             )
83         try:
84             line_length = int(
85                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
86             )
87         except ValueError:
88             return web.Response(status=400, text="Invalid line length header value")
89
90         if PYTHON_VARIANT_HEADER in request.headers:
91             value = request.headers[PYTHON_VARIANT_HEADER]
92             try:
93                 pyi, versions = parse_python_variant_header(value)
94             except InvalidVariantHeader as e:
95                 return web.Response(
96                     status=400,
97                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
98                 )
99         else:
100             pyi = False
101             versions = set()
102
103         skip_string_normalization = bool(
104             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
105         )
106         fast = False
107         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
108             fast = True
109         mode = black.FileMode(
110             target_versions=versions,
111             is_pyi=pyi,
112             line_length=line_length,
113             string_normalization=not skip_string_normalization,
114         )
115         req_bytes = await request.content.read()
116         charset = request.charset if request.charset is not None else "utf8"
117         req_str = req_bytes.decode(charset)
118         then = datetime.utcnow()
119
120         loop = asyncio.get_event_loop()
121         formatted_str = await loop.run_in_executor(
122             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
123         )
124
125         # Only output the diff in the HTTP response
126         only_diff = bool(request.headers.get(DIFF_HEADER, False))
127         if only_diff:
128             now = datetime.utcnow()
129             src_name = f"In\t{then} +0000"
130             dst_name = f"Out\t{now} +0000"
131             loop = asyncio.get_event_loop()
132             formatted_str = await loop.run_in_executor(
133                 executor,
134                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
135             )
136
137         return web.Response(
138             content_type=request.content_type,
139             charset=charset,
140             headers=headers,
141             text=formatted_str,
142         )
143     except black.NothingChanged:
144         return web.Response(status=204, headers=headers)
145     except black.InvalidInput as e:
146         return web.Response(status=400, headers=headers, text=str(e))
147     except Exception as e:
148         logging.exception("Exception during handling a request")
149         return web.Response(status=500, headers=headers, text=str(e))
150
151
152 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
153     if value == "pyi":
154         return True, set()
155     else:
156         versions = set()
157         for version in value.split(","):
158             if version.startswith("py"):
159                 version = version[len("py") :]
160             if "." in version:
161                 major_str, *rest = version.split(".")
162             else:
163                 major_str = version[0]
164                 rest = [version[1:]] if len(version) > 1 else []
165             try:
166                 major = int(major_str)
167                 if major not in (2, 3):
168                     raise InvalidVariantHeader("major version must be 2 or 3")
169                 if len(rest) > 0:
170                     minor = int(rest[0])
171                     if major == 2 and minor != 7:
172                         raise InvalidVariantHeader(
173                             "minor version must be 7 for Python 2"
174                         )
175                 else:
176                     # Default to lowest supported minor version.
177                     minor = 7 if major == 2 else 3
178                 version_str = f"PY{major}{minor}"
179                 if major == 3 and not hasattr(black.TargetVersion, version_str):
180                     raise InvalidVariantHeader(f"3.{minor} is not supported")
181                 versions.add(black.TargetVersion[version_str])
182             except (KeyError, ValueError):
183                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
184         return False, versions
185
186
187 def patched_main() -> None:
188     freeze_support()
189     black.patch_click()
190     main()
191
192
193 if __name__ == "__main__":
194     patched_main()