]> git.madduck.net Git - etc/vim.git/blob - blackd.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add `--target-version` option to allow users to choose targeted Python versions ...
[etc/vim.git] / blackd.py
1 import asyncio
2 from concurrent.futures import Executor, ProcessPoolExecutor
3 from functools import partial
4 import logging
5 from multiprocessing import freeze_support
6 from typing import Set, Tuple
7
8 from aiohttp import web
9 import aiohttp_cors
10 import black
11 import click
12
13 # This is used internally by tests to shut down the server prematurely
14 _stop_signal = asyncio.Event()
15
16 VERSION_HEADER = "X-Protocol-Version"
17 LINE_LENGTH_HEADER = "X-Line-Length"
18 PYTHON_VARIANT_HEADER = "X-Python-Variant"
19 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
20 SKIP_NUMERIC_UNDERSCORE_NORMALIZATION_HEADER = "X-Skip-Numeric-Underscore-Normalization"
21 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
22
23 BLACK_HEADERS = [
24     VERSION_HEADER,
25     LINE_LENGTH_HEADER,
26     PYTHON_VARIANT_HEADER,
27     SKIP_STRING_NORMALIZATION_HEADER,
28     SKIP_NUMERIC_UNDERSCORE_NORMALIZATION_HEADER,
29     FAST_OR_SAFE_HEADER,
30 ]
31
32
33 class InvalidVariantHeader(Exception):
34     pass
35
36
37 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
38 @click.option(
39     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
40 )
41 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
42 @click.version_option(version=black.__version__)
43 def main(bind_host: str, bind_port: int) -> None:
44     logging.basicConfig(level=logging.INFO)
45     app = make_app()
46     ver = black.__version__
47     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
48     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
49
50
51 def make_app() -> web.Application:
52     app = web.Application()
53     executor = ProcessPoolExecutor()
54
55     cors = aiohttp_cors.setup(app)
56     resource = cors.add(app.router.add_resource("/"))
57     cors.add(
58         resource.add_route("POST", partial(handle, executor=executor)),
59         {
60             "*": aiohttp_cors.ResourceOptions(
61                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
62             )
63         },
64     )
65
66     return app
67
68
69 async def handle(request: web.Request, executor: Executor) -> web.Response:
70     try:
71         if request.headers.get(VERSION_HEADER, "1") != "1":
72             return web.Response(
73                 status=501, text="This server only supports protocol version 1"
74             )
75         try:
76             line_length = int(
77                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
78             )
79         except ValueError:
80             return web.Response(status=400, text="Invalid line length header value")
81
82         if PYTHON_VARIANT_HEADER in request.headers:
83             value = request.headers[PYTHON_VARIANT_HEADER]
84             try:
85                 pyi, versions = parse_python_variant_header(value)
86             except InvalidVariantHeader as e:
87                 return web.Response(
88                     status=400,
89                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
90                 )
91         else:
92             pyi = False
93             versions = set()
94
95         skip_string_normalization = bool(
96             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
97         )
98         skip_numeric_underscore_normalization = bool(
99             request.headers.get(SKIP_NUMERIC_UNDERSCORE_NORMALIZATION_HEADER, False)
100         )
101         fast = False
102         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
103             fast = True
104         mode = black.FileMode(
105             target_versions=versions,
106             is_pyi=pyi,
107             line_length=line_length,
108             string_normalization=not skip_string_normalization,
109             numeric_underscore_normalization=not skip_numeric_underscore_normalization,
110         )
111         req_bytes = await request.content.read()
112         charset = request.charset if request.charset is not None else "utf8"
113         req_str = req_bytes.decode(charset)
114         loop = asyncio.get_event_loop()
115         formatted_str = await loop.run_in_executor(
116             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
117         )
118         return web.Response(
119             content_type=request.content_type, charset=charset, text=formatted_str
120         )
121     except black.NothingChanged:
122         return web.Response(status=204)
123     except black.InvalidInput as e:
124         return web.Response(status=400, text=str(e))
125     except Exception as e:
126         logging.exception("Exception during handling a request")
127         return web.Response(status=500, text=str(e))
128
129
130 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
131     if value == "pyi":
132         return True, set()
133     else:
134         versions = set()
135         for version in value.split(","):
136             tag = "cpy"
137             if version.startswith("cpy"):
138                 version = version[len("cpy") :]
139             elif version.startswith("pypy"):
140                 tag = "pypy"
141                 version = version[len("pypy") :]
142             major_str, *rest = version.split(".")
143             try:
144                 major = int(major_str)
145                 if major not in (2, 3):
146                     raise InvalidVariantHeader("major version must be 2 or 3")
147                 if len(rest) > 0:
148                     minor = int(rest[0])
149                     if major == 2 and minor != 7:
150                         raise InvalidVariantHeader(
151                             "minor version must be 7 for Python 2"
152                         )
153                 else:
154                     # Default to lowest supported minor version.
155                     minor = 7 if major == 2 else 3
156                 version_str = f"{tag.upper()}{major}{minor}"
157                 # If PyPY is the same as CPython in some version, use
158                 # the corresponding CPython version.
159                 if tag == "pypy" and not hasattr(black.TargetVersion, version_str):
160                     version_str = f"CPY{major}{minor}"
161                 if major == 3 and not hasattr(black.TargetVersion, version_str):
162                     raise InvalidVariantHeader(f"3.{minor} is not supported")
163                 versions.add(black.TargetVersion[version_str])
164             except (KeyError, ValueError):
165                 raise InvalidVariantHeader("expected e.g. '3.7', 'pypy3.5'")
166         return False, versions
167
168
169 def patched_main() -> None:
170     freeze_support()
171     black.patch_click()
172     main()
173
174
175 if __name__ == "__main__":
176     patched_main()