]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

6bbc7c520866421feefef37befca3e3dd224a11a
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
34 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
35 PREVIEW = "X-Preview"
36 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
37 DIFF_HEADER = "X-Diff"
38
39 BLACK_HEADERS = [
40     PROTOCOL_VERSION_HEADER,
41     LINE_LENGTH_HEADER,
42     PYTHON_VARIANT_HEADER,
43     SKIP_STRING_NORMALIZATION_HEADER,
44     SKIP_MAGIC_TRAILING_COMMA,
45     PREVIEW,
46     FAST_OR_SAFE_HEADER,
47     DIFF_HEADER,
48 ]
49
50 # Response headers
51 BLACK_VERSION_HEADER = "X-Black-Version"
52
53
54 class InvalidVariantHeader(Exception):
55     pass
56
57
58 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
59 @click.option(
60     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
61 )
62 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
63 @click.version_option(version=black.__version__)
64 def main(bind_host: str, bind_port: int) -> None:
65     logging.basicConfig(level=logging.INFO)
66     app = make_app()
67     ver = black.__version__
68     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
69     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
70
71
72 def make_app() -> web.Application:
73     app = web.Application(
74         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
75     )
76     executor = ProcessPoolExecutor()
77     app.add_routes([web.post("/", partial(handle, executor=executor))])
78     return app
79
80
81 async def handle(request: web.Request, executor: Executor) -> web.Response:
82     headers = {BLACK_VERSION_HEADER: __version__}
83     try:
84         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
85             return web.Response(
86                 status=501, text="This server only supports protocol version 1"
87             )
88         try:
89             line_length = int(
90                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
91             )
92         except ValueError:
93             return web.Response(status=400, text="Invalid line length header value")
94
95         if PYTHON_VARIANT_HEADER in request.headers:
96             value = request.headers[PYTHON_VARIANT_HEADER]
97             try:
98                 pyi, versions = parse_python_variant_header(value)
99             except InvalidVariantHeader as e:
100                 return web.Response(
101                     status=400,
102                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
103                 )
104         else:
105             pyi = False
106             versions = set()
107
108         skip_string_normalization = bool(
109             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
110         )
111         skip_magic_trailing_comma = bool(
112             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
113         )
114         preview = bool(request.headers.get(PREVIEW, False))
115         fast = False
116         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
117             fast = True
118         mode = black.FileMode(
119             target_versions=versions,
120             is_pyi=pyi,
121             line_length=line_length,
122             string_normalization=not skip_string_normalization,
123             magic_trailing_comma=not skip_magic_trailing_comma,
124             preview=preview,
125         )
126         req_bytes = await request.content.read()
127         charset = request.charset if request.charset is not None else "utf8"
128         req_str = req_bytes.decode(charset)
129         then = datetime.utcnow()
130
131         loop = asyncio.get_event_loop()
132         formatted_str = await loop.run_in_executor(
133             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
134         )
135
136         # Preserve CRLF line endings
137         if req_str[req_str.find("\n") - 1] == "\r":
138             formatted_str = formatted_str.replace("\n", "\r\n")
139             # If, after swapping line endings, nothing changed, then say so
140             if formatted_str == req_str:
141                 raise black.NothingChanged
142
143         # Only output the diff in the HTTP response
144         only_diff = bool(request.headers.get(DIFF_HEADER, False))
145         if only_diff:
146             now = datetime.utcnow()
147             src_name = f"In\t{then} +0000"
148             dst_name = f"Out\t{now} +0000"
149             loop = asyncio.get_event_loop()
150             formatted_str = await loop.run_in_executor(
151                 executor,
152                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
153             )
154
155         return web.Response(
156             content_type=request.content_type,
157             charset=charset,
158             headers=headers,
159             text=formatted_str,
160         )
161     except black.NothingChanged:
162         return web.Response(status=204, headers=headers)
163     except black.InvalidInput as e:
164         return web.Response(status=400, headers=headers, text=str(e))
165     except Exception as e:
166         logging.exception("Exception during handling a request")
167         return web.Response(status=500, headers=headers, text=str(e))
168
169
170 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
171     if value == "pyi":
172         return True, set()
173     else:
174         versions = set()
175         for version in value.split(","):
176             if version.startswith("py"):
177                 version = version[len("py") :]
178             if "." in version:
179                 major_str, *rest = version.split(".")
180             else:
181                 major_str = version[0]
182                 rest = [version[1:]] if len(version) > 1 else []
183             try:
184                 major = int(major_str)
185                 if major not in (2, 3):
186                     raise InvalidVariantHeader("major version must be 2 or 3")
187                 if len(rest) > 0:
188                     minor = int(rest[0])
189                     if major == 2:
190                         raise InvalidVariantHeader("Python 2 is not supported")
191                 else:
192                     # Default to lowest supported minor version.
193                     minor = 7 if major == 2 else 3
194                 version_str = f"PY{major}{minor}"
195                 if major == 3 and not hasattr(black.TargetVersion, version_str):
196                     raise InvalidVariantHeader(f"3.{minor} is not supported")
197                 versions.add(black.TargetVersion[version_str])
198             except (KeyError, ValueError):
199                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
200         return False, versions
201
202
203 def patched_main() -> None:
204     maybe_install_uvloop()
205     freeze_support()
206     black.patch_click()
207     main()
208
209
210 if __name__ == "__main__":
211     patched_main()