]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

fix typing issue around lru_cache arguments (#2098)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 import sys
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from datetime import datetime
6 from functools import partial
7 from multiprocessing import freeze_support
8 from typing import Set, Tuple
9
10 try:
11     from aiohttp import web
12     import aiohttp_cors
13 except ImportError as ie:
14     print(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`",
18         file=sys.stderr,
19     )
20     sys.exit(-1)
21
22 import black
23 import click
24
25 from _black_version import version as __version__
26
27 # This is used internally by tests to shut down the server prematurely
28 _stop_signal = asyncio.Event()
29
30 # Request headers
31 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
32 LINE_LENGTH_HEADER = "X-Line-Length"
33 PYTHON_VARIANT_HEADER = "X-Python-Variant"
34 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
35 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
36 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
37 DIFF_HEADER = "X-Diff"
38
39 BLACK_HEADERS = [
40     PROTOCOL_VERSION_HEADER,
41     LINE_LENGTH_HEADER,
42     PYTHON_VARIANT_HEADER,
43     SKIP_STRING_NORMALIZATION_HEADER,
44     SKIP_MAGIC_TRAILING_COMMA,
45     FAST_OR_SAFE_HEADER,
46     DIFF_HEADER,
47 ]
48
49 # Response headers
50 BLACK_VERSION_HEADER = "X-Black-Version"
51
52
53 class InvalidVariantHeader(Exception):
54     pass
55
56
57 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
58 @click.option(
59     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
60 )
61 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
62 @click.version_option(version=black.__version__)
63 def main(bind_host: str, bind_port: int) -> None:
64     logging.basicConfig(level=logging.INFO)
65     app = make_app()
66     ver = black.__version__
67     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
68     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
69
70
71 def make_app() -> web.Application:
72     app = web.Application()
73     executor = ProcessPoolExecutor()
74
75     cors = aiohttp_cors.setup(app)
76     resource = cors.add(app.router.add_resource("/"))
77     cors.add(
78         resource.add_route("POST", partial(handle, executor=executor)),
79         {
80             "*": aiohttp_cors.ResourceOptions(
81                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
82             )
83         },
84     )
85
86     return app
87
88
89 async def handle(request: web.Request, executor: Executor) -> web.Response:
90     headers = {BLACK_VERSION_HEADER: __version__}
91     try:
92         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
93             return web.Response(
94                 status=501, text="This server only supports protocol version 1"
95             )
96         try:
97             line_length = int(
98                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
99             )
100         except ValueError:
101             return web.Response(status=400, text="Invalid line length header value")
102
103         if PYTHON_VARIANT_HEADER in request.headers:
104             value = request.headers[PYTHON_VARIANT_HEADER]
105             try:
106                 pyi, versions = parse_python_variant_header(value)
107             except InvalidVariantHeader as e:
108                 return web.Response(
109                     status=400,
110                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
111                 )
112         else:
113             pyi = False
114             versions = set()
115
116         skip_string_normalization = bool(
117             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
118         )
119         skip_magic_trailing_comma = bool(
120             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
121         )
122         fast = False
123         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
124             fast = True
125         mode = black.FileMode(
126             target_versions=versions,
127             is_pyi=pyi,
128             line_length=line_length,
129             string_normalization=not skip_string_normalization,
130             magic_trailing_comma=not skip_magic_trailing_comma,
131         )
132         req_bytes = await request.content.read()
133         charset = request.charset if request.charset is not None else "utf8"
134         req_str = req_bytes.decode(charset)
135         then = datetime.utcnow()
136
137         loop = asyncio.get_event_loop()
138         formatted_str = await loop.run_in_executor(
139             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
140         )
141
142         # Only output the diff in the HTTP response
143         only_diff = bool(request.headers.get(DIFF_HEADER, False))
144         if only_diff:
145             now = datetime.utcnow()
146             src_name = f"In\t{then} +0000"
147             dst_name = f"Out\t{now} +0000"
148             loop = asyncio.get_event_loop()
149             formatted_str = await loop.run_in_executor(
150                 executor,
151                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
152             )
153
154         return web.Response(
155             content_type=request.content_type,
156             charset=charset,
157             headers=headers,
158             text=formatted_str,
159         )
160     except black.NothingChanged:
161         return web.Response(status=204, headers=headers)
162     except black.InvalidInput as e:
163         return web.Response(status=400, headers=headers, text=str(e))
164     except Exception as e:
165         logging.exception("Exception during handling a request")
166         return web.Response(status=500, headers=headers, text=str(e))
167
168
169 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
170     if value == "pyi":
171         return True, set()
172     else:
173         versions = set()
174         for version in value.split(","):
175             if version.startswith("py"):
176                 version = version[len("py") :]
177             if "." in version:
178                 major_str, *rest = version.split(".")
179             else:
180                 major_str = version[0]
181                 rest = [version[1:]] if len(version) > 1 else []
182             try:
183                 major = int(major_str)
184                 if major not in (2, 3):
185                     raise InvalidVariantHeader("major version must be 2 or 3")
186                 if len(rest) > 0:
187                     minor = int(rest[0])
188                     if major == 2 and minor != 7:
189                         raise InvalidVariantHeader(
190                             "minor version must be 7 for Python 2"
191                         )
192                 else:
193                     # Default to lowest supported minor version.
194                     minor = 7 if major == 2 else 3
195                 version_str = f"PY{major}{minor}"
196                 if major == 3 and not hasattr(black.TargetVersion, version_str):
197                     raise InvalidVariantHeader(f"3.{minor} is not supported")
198                 versions.add(black.TargetVersion[version_str])
199             except (KeyError, ValueError):
200                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
201         return False, versions
202
203
204 def patched_main() -> None:
205     freeze_support()
206     black.patch_click()
207     main()
208
209
210 if __name__ == "__main__":
211     patched_main()