]> git.madduck.net Git - etc/vim.git/blob - blackd.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix type: ignore line breaking when there is a destructuring assignment (#1065)
[etc/vim.git] / blackd.py
1 import asyncio
2 from concurrent.futures import Executor, ProcessPoolExecutor
3 from functools import partial
4 import logging
5 from multiprocessing import freeze_support
6 from typing import Set, Tuple
7
8 from aiohttp import web
9 import aiohttp_cors
10 import black
11 import click
12
13 from _version import version as __version__
14
15 # This is used internally by tests to shut down the server prematurely
16 _stop_signal = asyncio.Event()
17
18 # Request headers
19 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
20 LINE_LENGTH_HEADER = "X-Line-Length"
21 PYTHON_VARIANT_HEADER = "X-Python-Variant"
22 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
23 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
24
25 BLACK_HEADERS = [
26     PROTOCOL_VERSION_HEADER,
27     LINE_LENGTH_HEADER,
28     PYTHON_VARIANT_HEADER,
29     SKIP_STRING_NORMALIZATION_HEADER,
30     FAST_OR_SAFE_HEADER,
31 ]
32
33 # Response headers
34 BLACK_VERSION_HEADER = "X-Black-Version"
35
36
37 class InvalidVariantHeader(Exception):
38     pass
39
40
41 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
42 @click.option(
43     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
44 )
45 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
46 @click.version_option(version=black.__version__)
47 def main(bind_host: str, bind_port: int) -> None:
48     logging.basicConfig(level=logging.INFO)
49     app = make_app()
50     ver = black.__version__
51     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
52     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
53
54
55 def make_app() -> web.Application:
56     app = web.Application()
57     executor = ProcessPoolExecutor()
58
59     cors = aiohttp_cors.setup(app)
60     resource = cors.add(app.router.add_resource("/"))
61     cors.add(
62         resource.add_route("POST", partial(handle, executor=executor)),
63         {
64             "*": aiohttp_cors.ResourceOptions(
65                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
66             )
67         },
68     )
69
70     return app
71
72
73 async def handle(request: web.Request, executor: Executor) -> web.Response:
74     headers = {BLACK_VERSION_HEADER: __version__}
75     try:
76         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
77             return web.Response(
78                 status=501, text="This server only supports protocol version 1"
79             )
80         try:
81             line_length = int(
82                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
83             )
84         except ValueError:
85             return web.Response(status=400, text="Invalid line length header value")
86
87         if PYTHON_VARIANT_HEADER in request.headers:
88             value = request.headers[PYTHON_VARIANT_HEADER]
89             try:
90                 pyi, versions = parse_python_variant_header(value)
91             except InvalidVariantHeader as e:
92                 return web.Response(
93                     status=400,
94                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
95                 )
96         else:
97             pyi = False
98             versions = set()
99
100         skip_string_normalization = bool(
101             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
102         )
103         fast = False
104         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
105             fast = True
106         mode = black.FileMode(
107             target_versions=versions,
108             is_pyi=pyi,
109             line_length=line_length,
110             string_normalization=not skip_string_normalization,
111         )
112         req_bytes = await request.content.read()
113         charset = request.charset if request.charset is not None else "utf8"
114         req_str = req_bytes.decode(charset)
115         loop = asyncio.get_event_loop()
116         formatted_str = await loop.run_in_executor(
117             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
118         )
119         return web.Response(
120             content_type=request.content_type,
121             charset=charset,
122             headers=headers,
123             text=formatted_str,
124         )
125     except black.NothingChanged:
126         return web.Response(status=204, headers=headers)
127     except black.InvalidInput as e:
128         return web.Response(status=400, headers=headers, text=str(e))
129     except Exception as e:
130         logging.exception("Exception during handling a request")
131         return web.Response(status=500, headers=headers, text=str(e))
132
133
134 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
135     if value == "pyi":
136         return True, set()
137     else:
138         versions = set()
139         for version in value.split(","):
140             if version.startswith("py"):
141                 version = version[len("py") :]
142             if "." in version:
143                 major_str, *rest = version.split(".")
144             else:
145                 major_str = version[0]
146                 rest = [version[1:]] if len(version) > 1 else []
147             try:
148                 major = int(major_str)
149                 if major not in (2, 3):
150                     raise InvalidVariantHeader("major version must be 2 or 3")
151                 if len(rest) > 0:
152                     minor = int(rest[0])
153                     if major == 2 and minor != 7:
154                         raise InvalidVariantHeader(
155                             "minor version must be 7 for Python 2"
156                         )
157                 else:
158                     # Default to lowest supported minor version.
159                     minor = 7 if major == 2 else 3
160                 version_str = f"PY{major}{minor}"
161                 if major == 3 and not hasattr(black.TargetVersion, version_str):
162                     raise InvalidVariantHeader(f"3.{minor} is not supported")
163                 versions.add(black.TargetVersion[version_str])
164             except (KeyError, ValueError):
165                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
166         return False, versions
167
168
169 def patched_main() -> None:
170     freeze_support()
171     black.patch_click()
172     main()
173
174
175 if __name__ == "__main__":
176     patched_main()