]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use --no-implicit-optional for type checking (#3220)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
34 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
35 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
36 DIFF_HEADER = "X-Diff"
37
38 BLACK_HEADERS = [
39     PROTOCOL_VERSION_HEADER,
40     LINE_LENGTH_HEADER,
41     PYTHON_VARIANT_HEADER,
42     SKIP_STRING_NORMALIZATION_HEADER,
43     SKIP_MAGIC_TRAILING_COMMA,
44     FAST_OR_SAFE_HEADER,
45     DIFF_HEADER,
46 ]
47
48 # Response headers
49 BLACK_VERSION_HEADER = "X-Black-Version"
50
51
52 class InvalidVariantHeader(Exception):
53     pass
54
55
56 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
57 @click.option(
58     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
59 )
60 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
61 @click.version_option(version=black.__version__)
62 def main(bind_host: str, bind_port: int) -> None:
63     logging.basicConfig(level=logging.INFO)
64     app = make_app()
65     ver = black.__version__
66     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
67     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
68
69
70 def make_app() -> web.Application:
71     app = web.Application(
72         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
73     )
74     executor = ProcessPoolExecutor()
75     app.add_routes([web.post("/", partial(handle, executor=executor))])
76     return app
77
78
79 async def handle(request: web.Request, executor: Executor) -> web.Response:
80     headers = {BLACK_VERSION_HEADER: __version__}
81     try:
82         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
83             return web.Response(
84                 status=501, text="This server only supports protocol version 1"
85             )
86         try:
87             line_length = int(
88                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
89             )
90         except ValueError:
91             return web.Response(status=400, text="Invalid line length header value")
92
93         if PYTHON_VARIANT_HEADER in request.headers:
94             value = request.headers[PYTHON_VARIANT_HEADER]
95             try:
96                 pyi, versions = parse_python_variant_header(value)
97             except InvalidVariantHeader as e:
98                 return web.Response(
99                     status=400,
100                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
101                 )
102         else:
103             pyi = False
104             versions = set()
105
106         skip_string_normalization = bool(
107             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
108         )
109         skip_magic_trailing_comma = bool(
110             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
111         )
112         fast = False
113         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
114             fast = True
115         mode = black.FileMode(
116             target_versions=versions,
117             is_pyi=pyi,
118             line_length=line_length,
119             string_normalization=not skip_string_normalization,
120             magic_trailing_comma=not skip_magic_trailing_comma,
121         )
122         req_bytes = await request.content.read()
123         charset = request.charset if request.charset is not None else "utf8"
124         req_str = req_bytes.decode(charset)
125         then = datetime.utcnow()
126
127         loop = asyncio.get_event_loop()
128         formatted_str = await loop.run_in_executor(
129             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
130         )
131
132         # Only output the diff in the HTTP response
133         only_diff = bool(request.headers.get(DIFF_HEADER, False))
134         if only_diff:
135             now = datetime.utcnow()
136             src_name = f"In\t{then} +0000"
137             dst_name = f"Out\t{now} +0000"
138             loop = asyncio.get_event_loop()
139             formatted_str = await loop.run_in_executor(
140                 executor,
141                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
142             )
143
144         return web.Response(
145             content_type=request.content_type,
146             charset=charset,
147             headers=headers,
148             text=formatted_str,
149         )
150     except black.NothingChanged:
151         return web.Response(status=204, headers=headers)
152     except black.InvalidInput as e:
153         return web.Response(status=400, headers=headers, text=str(e))
154     except Exception as e:
155         logging.exception("Exception during handling a request")
156         return web.Response(status=500, headers=headers, text=str(e))
157
158
159 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
160     if value == "pyi":
161         return True, set()
162     else:
163         versions = set()
164         for version in value.split(","):
165             if version.startswith("py"):
166                 version = version[len("py") :]
167             if "." in version:
168                 major_str, *rest = version.split(".")
169             else:
170                 major_str = version[0]
171                 rest = [version[1:]] if len(version) > 1 else []
172             try:
173                 major = int(major_str)
174                 if major not in (2, 3):
175                     raise InvalidVariantHeader("major version must be 2 or 3")
176                 if len(rest) > 0:
177                     minor = int(rest[0])
178                     if major == 2:
179                         raise InvalidVariantHeader("Python 2 is not supported")
180                 else:
181                     # Default to lowest supported minor version.
182                     minor = 7 if major == 2 else 3
183                 version_str = f"PY{major}{minor}"
184                 if major == 3 and not hasattr(black.TargetVersion, version_str):
185                     raise InvalidVariantHeader(f"3.{minor} is not supported")
186                 versions.add(black.TargetVersion[version_str])
187             except (KeyError, ValueError):
188                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
189         return False, versions
190
191
192 def patched_main() -> None:
193     maybe_install_uvloop()
194     freeze_support()
195     black.patch_click()
196     main()
197
198
199 if __name__ == "__main__":
200     patched_main()