]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

blib2to3: support unparenthesized wulruses in more places (#2447)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11     import aiohttp_cors
12 except ImportError as ie:
13     raise ImportError(
14         f"aiohttp dependency is not installed: {ie}. "
15         + "Please re-install black with the '[d]' extra install "
16         + "to obtain aiohttp_cors: `pip install black[d]`"
17     ) from None
18
19 import black
20 from black.concurrency import maybe_install_uvloop
21 import click
22
23 from _black_version import version as __version__
24
25 # This is used internally by tests to shut down the server prematurely
26 _stop_signal = asyncio.Event()
27
28 # Request headers
29 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
30 LINE_LENGTH_HEADER = "X-Line-Length"
31 PYTHON_VARIANT_HEADER = "X-Python-Variant"
32 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
33 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
34 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
35 DIFF_HEADER = "X-Diff"
36
37 BLACK_HEADERS = [
38     PROTOCOL_VERSION_HEADER,
39     LINE_LENGTH_HEADER,
40     PYTHON_VARIANT_HEADER,
41     SKIP_STRING_NORMALIZATION_HEADER,
42     SKIP_MAGIC_TRAILING_COMMA,
43     FAST_OR_SAFE_HEADER,
44     DIFF_HEADER,
45 ]
46
47 # Response headers
48 BLACK_VERSION_HEADER = "X-Black-Version"
49
50
51 class InvalidVariantHeader(Exception):
52     pass
53
54
55 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
56 @click.option(
57     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
58 )
59 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
60 @click.version_option(version=black.__version__)
61 def main(bind_host: str, bind_port: int) -> None:
62     logging.basicConfig(level=logging.INFO)
63     app = make_app()
64     ver = black.__version__
65     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
66     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
67
68
69 def make_app() -> web.Application:
70     app = web.Application()
71     executor = ProcessPoolExecutor()
72
73     cors = aiohttp_cors.setup(app)
74     resource = cors.add(app.router.add_resource("/"))
75     cors.add(
76         resource.add_route("POST", partial(handle, executor=executor)),
77         {
78             "*": aiohttp_cors.ResourceOptions(
79                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
80             )
81         },
82     )
83
84     return app
85
86
87 async def handle(request: web.Request, executor: Executor) -> web.Response:
88     headers = {BLACK_VERSION_HEADER: __version__}
89     try:
90         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
91             return web.Response(
92                 status=501, text="This server only supports protocol version 1"
93             )
94         try:
95             line_length = int(
96                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
97             )
98         except ValueError:
99             return web.Response(status=400, text="Invalid line length header value")
100
101         if PYTHON_VARIANT_HEADER in request.headers:
102             value = request.headers[PYTHON_VARIANT_HEADER]
103             try:
104                 pyi, versions = parse_python_variant_header(value)
105             except InvalidVariantHeader as e:
106                 return web.Response(
107                     status=400,
108                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
109                 )
110         else:
111             pyi = False
112             versions = set()
113
114         skip_string_normalization = bool(
115             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
116         )
117         skip_magic_trailing_comma = bool(
118             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
119         )
120         fast = False
121         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
122             fast = True
123         mode = black.FileMode(
124             target_versions=versions,
125             is_pyi=pyi,
126             line_length=line_length,
127             string_normalization=not skip_string_normalization,
128             magic_trailing_comma=not skip_magic_trailing_comma,
129         )
130         req_bytes = await request.content.read()
131         charset = request.charset if request.charset is not None else "utf8"
132         req_str = req_bytes.decode(charset)
133         then = datetime.utcnow()
134
135         loop = asyncio.get_event_loop()
136         formatted_str = await loop.run_in_executor(
137             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
138         )
139
140         # Only output the diff in the HTTP response
141         only_diff = bool(request.headers.get(DIFF_HEADER, False))
142         if only_diff:
143             now = datetime.utcnow()
144             src_name = f"In\t{then} +0000"
145             dst_name = f"Out\t{now} +0000"
146             loop = asyncio.get_event_loop()
147             formatted_str = await loop.run_in_executor(
148                 executor,
149                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
150             )
151
152         return web.Response(
153             content_type=request.content_type,
154             charset=charset,
155             headers=headers,
156             text=formatted_str,
157         )
158     except black.NothingChanged:
159         return web.Response(status=204, headers=headers)
160     except black.InvalidInput as e:
161         return web.Response(status=400, headers=headers, text=str(e))
162     except Exception as e:
163         logging.exception("Exception during handling a request")
164         return web.Response(status=500, headers=headers, text=str(e))
165
166
167 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
168     if value == "pyi":
169         return True, set()
170     else:
171         versions = set()
172         for version in value.split(","):
173             if version.startswith("py"):
174                 version = version[len("py") :]
175             if "." in version:
176                 major_str, *rest = version.split(".")
177             else:
178                 major_str = version[0]
179                 rest = [version[1:]] if len(version) > 1 else []
180             try:
181                 major = int(major_str)
182                 if major not in (2, 3):
183                     raise InvalidVariantHeader("major version must be 2 or 3")
184                 if len(rest) > 0:
185                     minor = int(rest[0])
186                     if major == 2 and minor != 7:
187                         raise InvalidVariantHeader(
188                             "minor version must be 7 for Python 2"
189                         )
190                 else:
191                     # Default to lowest supported minor version.
192                     minor = 7 if major == 2 else 3
193                 version_str = f"PY{major}{minor}"
194                 if major == 3 and not hasattr(black.TargetVersion, version_str):
195                     raise InvalidVariantHeader(f"3.{minor} is not supported")
196                 versions.add(black.TargetVersion[version_str])
197             except (KeyError, ValueError):
198                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
199         return False, versions
200
201
202 def patched_main() -> None:
203     maybe_install_uvloop()
204     freeze_support()
205     black.patch_click()
206     main()
207
208
209 if __name__ == "__main__":
210     patched_main()