]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Remove whitespaces of whitespace-only files (#3348)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
34 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
35 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
36 PREVIEW = "X-Preview"
37 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
38 DIFF_HEADER = "X-Diff"
39
40 BLACK_HEADERS = [
41     PROTOCOL_VERSION_HEADER,
42     LINE_LENGTH_HEADER,
43     PYTHON_VARIANT_HEADER,
44     SKIP_SOURCE_FIRST_LINE,
45     SKIP_STRING_NORMALIZATION_HEADER,
46     SKIP_MAGIC_TRAILING_COMMA,
47     PREVIEW,
48     FAST_OR_SAFE_HEADER,
49     DIFF_HEADER,
50 ]
51
52 # Response headers
53 BLACK_VERSION_HEADER = "X-Black-Version"
54
55
56 class InvalidVariantHeader(Exception):
57     pass
58
59
60 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
61 @click.option(
62     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
63 )
64 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
65 @click.version_option(version=black.__version__)
66 def main(bind_host: str, bind_port: int) -> None:
67     logging.basicConfig(level=logging.INFO)
68     app = make_app()
69     ver = black.__version__
70     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
71     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
72
73
74 def make_app() -> web.Application:
75     app = web.Application(
76         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
77     )
78     executor = ProcessPoolExecutor()
79     app.add_routes([web.post("/", partial(handle, executor=executor))])
80     return app
81
82
83 async def handle(request: web.Request, executor: Executor) -> web.Response:
84     headers = {BLACK_VERSION_HEADER: __version__}
85     try:
86         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
87             return web.Response(
88                 status=501, text="This server only supports protocol version 1"
89             )
90         try:
91             line_length = int(
92                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
93             )
94         except ValueError:
95             return web.Response(status=400, text="Invalid line length header value")
96
97         if PYTHON_VARIANT_HEADER in request.headers:
98             value = request.headers[PYTHON_VARIANT_HEADER]
99             try:
100                 pyi, versions = parse_python_variant_header(value)
101             except InvalidVariantHeader as e:
102                 return web.Response(
103                     status=400,
104                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
105                 )
106         else:
107             pyi = False
108             versions = set()
109
110         skip_string_normalization = bool(
111             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
112         )
113         skip_magic_trailing_comma = bool(
114             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
115         )
116         skip_source_first_line = bool(
117             request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
118         )
119         preview = bool(request.headers.get(PREVIEW, False))
120         fast = False
121         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
122             fast = True
123         mode = black.FileMode(
124             target_versions=versions,
125             is_pyi=pyi,
126             line_length=line_length,
127             skip_source_first_line=skip_source_first_line,
128             string_normalization=not skip_string_normalization,
129             magic_trailing_comma=not skip_magic_trailing_comma,
130             preview=preview,
131         )
132         req_bytes = await request.content.read()
133         charset = request.charset if request.charset is not None else "utf8"
134         req_str = req_bytes.decode(charset)
135         then = datetime.utcnow()
136
137         header = ""
138         if skip_source_first_line:
139             first_newline_position: int = req_str.find("\n") + 1
140             header = req_str[:first_newline_position]
141             req_str = req_str[first_newline_position:]
142
143         loop = asyncio.get_event_loop()
144         formatted_str = await loop.run_in_executor(
145             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
146         )
147
148         # Preserve CRLF line endings
149         if req_str[req_str.find("\n") - 1] == "\r":
150             formatted_str = formatted_str.replace("\n", "\r\n")
151             # If, after swapping line endings, nothing changed, then say so
152             if formatted_str == req_str:
153                 raise black.NothingChanged
154
155         # Put the source first line back
156         req_str = header + req_str
157         formatted_str = header + formatted_str
158
159         # Only output the diff in the HTTP response
160         only_diff = bool(request.headers.get(DIFF_HEADER, False))
161         if only_diff:
162             now = datetime.utcnow()
163             src_name = f"In\t{then} +0000"
164             dst_name = f"Out\t{now} +0000"
165             loop = asyncio.get_event_loop()
166             formatted_str = await loop.run_in_executor(
167                 executor,
168                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
169             )
170
171         return web.Response(
172             content_type=request.content_type,
173             charset=charset,
174             headers=headers,
175             text=formatted_str,
176         )
177     except black.NothingChanged:
178         return web.Response(status=204, headers=headers)
179     except black.InvalidInput as e:
180         return web.Response(status=400, headers=headers, text=str(e))
181     except Exception as e:
182         logging.exception("Exception during handling a request")
183         return web.Response(status=500, headers=headers, text=str(e))
184
185
186 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
187     if value == "pyi":
188         return True, set()
189     else:
190         versions = set()
191         for version in value.split(","):
192             if version.startswith("py"):
193                 version = version[len("py") :]
194             if "." in version:
195                 major_str, *rest = version.split(".")
196             else:
197                 major_str = version[0]
198                 rest = [version[1:]] if len(version) > 1 else []
199             try:
200                 major = int(major_str)
201                 if major not in (2, 3):
202                     raise InvalidVariantHeader("major version must be 2 or 3")
203                 if len(rest) > 0:
204                     minor = int(rest[0])
205                     if major == 2:
206                         raise InvalidVariantHeader("Python 2 is not supported")
207                 else:
208                     # Default to lowest supported minor version.
209                     minor = 7 if major == 2 else 3
210                 version_str = f"PY{major}{minor}"
211                 if major == 3 and not hasattr(black.TargetVersion, version_str):
212                     raise InvalidVariantHeader(f"3.{minor} is not supported")
213                 versions.add(black.TargetVersion[version_str])
214             except (KeyError, ValueError):
215                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
216         return False, versions
217
218
219 def patched_main() -> None:
220     maybe_install_uvloop()
221     freeze_support()
222     black.patch_click()
223     main()
224
225
226 if __name__ == "__main__":
227     patched_main()