]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Bump actions/checkout from 3 to 4 (#3883)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime, timezone
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
34 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
35 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
36 PREVIEW = "X-Preview"
37 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
38 DIFF_HEADER = "X-Diff"
39
40 BLACK_HEADERS = [
41     PROTOCOL_VERSION_HEADER,
42     LINE_LENGTH_HEADER,
43     PYTHON_VARIANT_HEADER,
44     SKIP_SOURCE_FIRST_LINE,
45     SKIP_STRING_NORMALIZATION_HEADER,
46     SKIP_MAGIC_TRAILING_COMMA,
47     PREVIEW,
48     FAST_OR_SAFE_HEADER,
49     DIFF_HEADER,
50 ]
51
52 # Response headers
53 BLACK_VERSION_HEADER = "X-Black-Version"
54
55
56 class InvalidVariantHeader(Exception):
57     pass
58
59
60 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
61 @click.option(
62     "--bind-host",
63     type=str,
64     help="Address to bind the server to.",
65     default="localhost",
66     show_default=True,
67 )
68 @click.option(
69     "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True
70 )
71 @click.version_option(version=black.__version__)
72 def main(bind_host: str, bind_port: int) -> None:
73     logging.basicConfig(level=logging.INFO)
74     app = make_app()
75     ver = black.__version__
76     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
77     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
78
79
80 def make_app() -> web.Application:
81     app = web.Application(
82         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
83     )
84     executor = ProcessPoolExecutor()
85     app.add_routes([web.post("/", partial(handle, executor=executor))])
86     return app
87
88
89 async def handle(request: web.Request, executor: Executor) -> web.Response:
90     headers = {BLACK_VERSION_HEADER: __version__}
91     try:
92         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
93             return web.Response(
94                 status=501, text="This server only supports protocol version 1"
95             )
96         try:
97             line_length = int(
98                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
99             )
100         except ValueError:
101             return web.Response(status=400, text="Invalid line length header value")
102
103         if PYTHON_VARIANT_HEADER in request.headers:
104             value = request.headers[PYTHON_VARIANT_HEADER]
105             try:
106                 pyi, versions = parse_python_variant_header(value)
107             except InvalidVariantHeader as e:
108                 return web.Response(
109                     status=400,
110                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
111                 )
112         else:
113             pyi = False
114             versions = set()
115
116         skip_string_normalization = bool(
117             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
118         )
119         skip_magic_trailing_comma = bool(
120             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
121         )
122         skip_source_first_line = bool(
123             request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
124         )
125         preview = bool(request.headers.get(PREVIEW, False))
126         fast = False
127         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
128             fast = True
129         mode = black.FileMode(
130             target_versions=versions,
131             is_pyi=pyi,
132             line_length=line_length,
133             skip_source_first_line=skip_source_first_line,
134             string_normalization=not skip_string_normalization,
135             magic_trailing_comma=not skip_magic_trailing_comma,
136             preview=preview,
137         )
138         req_bytes = await request.content.read()
139         charset = request.charset if request.charset is not None else "utf8"
140         req_str = req_bytes.decode(charset)
141         then = datetime.now(timezone.utc)
142
143         header = ""
144         if skip_source_first_line:
145             first_newline_position: int = req_str.find("\n") + 1
146             header = req_str[:first_newline_position]
147             req_str = req_str[first_newline_position:]
148
149         loop = asyncio.get_event_loop()
150         formatted_str = await loop.run_in_executor(
151             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
152         )
153
154         # Preserve CRLF line endings
155         nl = req_str.find("\n")
156         if nl > 0 and req_str[nl - 1] == "\r":
157             formatted_str = formatted_str.replace("\n", "\r\n")
158             # If, after swapping line endings, nothing changed, then say so
159             if formatted_str == req_str:
160                 raise black.NothingChanged
161
162         # Put the source first line back
163         req_str = header + req_str
164         formatted_str = header + formatted_str
165
166         # Only output the diff in the HTTP response
167         only_diff = bool(request.headers.get(DIFF_HEADER, False))
168         if only_diff:
169             now = datetime.now(timezone.utc)
170             src_name = f"In\t{then}"
171             dst_name = f"Out\t{now}"
172             loop = asyncio.get_event_loop()
173             formatted_str = await loop.run_in_executor(
174                 executor,
175                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
176             )
177
178         return web.Response(
179             content_type=request.content_type,
180             charset=charset,
181             headers=headers,
182             text=formatted_str,
183         )
184     except black.NothingChanged:
185         return web.Response(status=204, headers=headers)
186     except black.InvalidInput as e:
187         return web.Response(status=400, headers=headers, text=str(e))
188     except Exception as e:
189         logging.exception("Exception during handling a request")
190         return web.Response(status=500, headers=headers, text=str(e))
191
192
193 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
194     if value == "pyi":
195         return True, set()
196     else:
197         versions = set()
198         for version in value.split(","):
199             if version.startswith("py"):
200                 version = version[len("py") :]
201             if "." in version:
202                 major_str, *rest = version.split(".")
203             else:
204                 major_str = version[0]
205                 rest = [version[1:]] if len(version) > 1 else []
206             try:
207                 major = int(major_str)
208                 if major not in (2, 3):
209                     raise InvalidVariantHeader("major version must be 2 or 3")
210                 if len(rest) > 0:
211                     minor = int(rest[0])
212                     if major == 2:
213                         raise InvalidVariantHeader("Python 2 is not supported")
214                 else:
215                     # Default to lowest supported minor version.
216                     minor = 7 if major == 2 else 3
217                 version_str = f"PY{major}{minor}"
218                 if major == 3 and not hasattr(black.TargetVersion, version_str):
219                     raise InvalidVariantHeader(f"3.{minor} is not supported")
220                 versions.add(black.TargetVersion[version_str])
221             except (KeyError, ValueError):
222                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
223         return False, versions
224
225
226 def patched_main() -> None:
227     maybe_install_uvloop()
228     freeze_support()
229     main()
230
231
232 if __name__ == "__main__":
233     patched_main()