]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Migrate mypy config to pyproject.toml (#3936)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime, timezone
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
34 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
35 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
36 PREVIEW = "X-Preview"
37 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
38 DIFF_HEADER = "X-Diff"
39
40 BLACK_HEADERS = [
41     PROTOCOL_VERSION_HEADER,
42     LINE_LENGTH_HEADER,
43     PYTHON_VARIANT_HEADER,
44     SKIP_SOURCE_FIRST_LINE,
45     SKIP_STRING_NORMALIZATION_HEADER,
46     SKIP_MAGIC_TRAILING_COMMA,
47     PREVIEW,
48     FAST_OR_SAFE_HEADER,
49     DIFF_HEADER,
50 ]
51
52 # Response headers
53 BLACK_VERSION_HEADER = "X-Black-Version"
54
55
56 class InvalidVariantHeader(Exception):
57     pass
58
59
60 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
61 @click.option(
62     "--bind-host",
63     type=str,
64     help="Address to bind the server to.",
65     default="localhost",
66     show_default=True,
67 )
68 @click.option(
69     "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True
70 )
71 @click.version_option(version=black.__version__)
72 def main(bind_host: str, bind_port: int) -> None:
73     logging.basicConfig(level=logging.INFO)
74     app = make_app()
75     ver = black.__version__
76     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
77     # TODO: aiohttp had an incorrect annotation for `print` argument,
78     #  It'll be fixed once aiohttp releases that code
79     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)  # type: ignore[arg-type]
80
81
82 def make_app() -> web.Application:
83     app = web.Application(
84         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
85     )
86     executor = ProcessPoolExecutor()
87     app.add_routes([web.post("/", partial(handle, executor=executor))])
88     return app
89
90
91 async def handle(request: web.Request, executor: Executor) -> web.Response:
92     headers = {BLACK_VERSION_HEADER: __version__}
93     try:
94         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
95             return web.Response(
96                 status=501, text="This server only supports protocol version 1"
97             )
98         try:
99             line_length = int(
100                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
101             )
102         except ValueError:
103             return web.Response(status=400, text="Invalid line length header value")
104
105         if PYTHON_VARIANT_HEADER in request.headers:
106             value = request.headers[PYTHON_VARIANT_HEADER]
107             try:
108                 pyi, versions = parse_python_variant_header(value)
109             except InvalidVariantHeader as e:
110                 return web.Response(
111                     status=400,
112                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
113                 )
114         else:
115             pyi = False
116             versions = set()
117
118         skip_string_normalization = bool(
119             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
120         )
121         skip_magic_trailing_comma = bool(
122             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
123         )
124         skip_source_first_line = bool(
125             request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
126         )
127         preview = bool(request.headers.get(PREVIEW, False))
128         fast = False
129         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
130             fast = True
131         mode = black.FileMode(
132             target_versions=versions,
133             is_pyi=pyi,
134             line_length=line_length,
135             skip_source_first_line=skip_source_first_line,
136             string_normalization=not skip_string_normalization,
137             magic_trailing_comma=not skip_magic_trailing_comma,
138             preview=preview,
139         )
140         req_bytes = await request.content.read()
141         charset = request.charset if request.charset is not None else "utf8"
142         req_str = req_bytes.decode(charset)
143         then = datetime.now(timezone.utc)
144
145         header = ""
146         if skip_source_first_line:
147             first_newline_position: int = req_str.find("\n") + 1
148             header = req_str[:first_newline_position]
149             req_str = req_str[first_newline_position:]
150
151         loop = asyncio.get_event_loop()
152         formatted_str = await loop.run_in_executor(
153             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
154         )
155
156         # Preserve CRLF line endings
157         nl = req_str.find("\n")
158         if nl > 0 and req_str[nl - 1] == "\r":
159             formatted_str = formatted_str.replace("\n", "\r\n")
160             # If, after swapping line endings, nothing changed, then say so
161             if formatted_str == req_str:
162                 raise black.NothingChanged
163
164         # Put the source first line back
165         req_str = header + req_str
166         formatted_str = header + formatted_str
167
168         # Only output the diff in the HTTP response
169         only_diff = bool(request.headers.get(DIFF_HEADER, False))
170         if only_diff:
171             now = datetime.now(timezone.utc)
172             src_name = f"In\t{then}"
173             dst_name = f"Out\t{now}"
174             loop = asyncio.get_event_loop()
175             formatted_str = await loop.run_in_executor(
176                 executor,
177                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
178             )
179
180         return web.Response(
181             content_type=request.content_type,
182             charset=charset,
183             headers=headers,
184             text=formatted_str,
185         )
186     except black.NothingChanged:
187         return web.Response(status=204, headers=headers)
188     except black.InvalidInput as e:
189         return web.Response(status=400, headers=headers, text=str(e))
190     except Exception as e:
191         logging.exception("Exception during handling a request")
192         return web.Response(status=500, headers=headers, text=str(e))
193
194
195 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
196     if value == "pyi":
197         return True, set()
198     else:
199         versions = set()
200         for version in value.split(","):
201             if version.startswith("py"):
202                 version = version[len("py") :]
203             if "." in version:
204                 major_str, *rest = version.split(".")
205             else:
206                 major_str = version[0]
207                 rest = [version[1:]] if len(version) > 1 else []
208             try:
209                 major = int(major_str)
210                 if major not in (2, 3):
211                     raise InvalidVariantHeader("major version must be 2 or 3")
212                 if len(rest) > 0:
213                     minor = int(rest[0])
214                     if major == 2:
215                         raise InvalidVariantHeader("Python 2 is not supported")
216                 else:
217                     # Default to lowest supported minor version.
218                     minor = 7 if major == 2 else 3
219                 version_str = f"PY{major}{minor}"
220                 if major == 3 and not hasattr(black.TargetVersion, version_str):
221                     raise InvalidVariantHeader(f"3.{minor} is not supported")
222                 versions.add(black.TargetVersion[version_str])
223             except (KeyError, ValueError):
224                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
225         return False, versions
226
227
228 def patched_main() -> None:
229     maybe_install_uvloop()
230     freeze_support()
231     main()
232
233
234 if __name__ == "__main__":
235     patched_main()