]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix crash on type comment with trailing space (#3773)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 from concurrent.futures import Executor, ProcessPoolExecutor
4 from datetime import datetime, timezone
5 from functools import partial
6 from multiprocessing import freeze_support
7 from typing import Set, Tuple
8
9 try:
10     from aiohttp import web
11
12     from .middlewares import cors
13 except ImportError as ie:
14     raise ImportError(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`"
18     ) from None
19
20 import click
21
22 import black
23 from _black_version import version as __version__
24 from black.concurrency import maybe_install_uvloop
25
26 # This is used internally by tests to shut down the server prematurely
27 _stop_signal = asyncio.Event()
28
29 # Request headers
30 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
31 LINE_LENGTH_HEADER = "X-Line-Length"
32 PYTHON_VARIANT_HEADER = "X-Python-Variant"
33 SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
34 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
35 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
36 PREVIEW = "X-Preview"
37 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
38 DIFF_HEADER = "X-Diff"
39
40 BLACK_HEADERS = [
41     PROTOCOL_VERSION_HEADER,
42     LINE_LENGTH_HEADER,
43     PYTHON_VARIANT_HEADER,
44     SKIP_SOURCE_FIRST_LINE,
45     SKIP_STRING_NORMALIZATION_HEADER,
46     SKIP_MAGIC_TRAILING_COMMA,
47     PREVIEW,
48     FAST_OR_SAFE_HEADER,
49     DIFF_HEADER,
50 ]
51
52 # Response headers
53 BLACK_VERSION_HEADER = "X-Black-Version"
54
55
56 class InvalidVariantHeader(Exception):
57     pass
58
59
60 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
61 @click.option(
62     "--bind-host",
63     type=str,
64     help="Address to bind the server to.",
65     default="localhost",
66     show_default=True,
67 )
68 @click.option(
69     "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True
70 )
71 @click.version_option(version=black.__version__)
72 def main(bind_host: str, bind_port: int) -> None:
73     logging.basicConfig(level=logging.INFO)
74     app = make_app()
75     ver = black.__version__
76     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
77     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
78
79
80 def make_app() -> web.Application:
81     app = web.Application(
82         middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
83     )
84     executor = ProcessPoolExecutor()
85     app.add_routes([web.post("/", partial(handle, executor=executor))])
86     return app
87
88
89 async def handle(request: web.Request, executor: Executor) -> web.Response:
90     headers = {BLACK_VERSION_HEADER: __version__}
91     try:
92         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
93             return web.Response(
94                 status=501, text="This server only supports protocol version 1"
95             )
96         try:
97             line_length = int(
98                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
99             )
100         except ValueError:
101             return web.Response(status=400, text="Invalid line length header value")
102
103         if PYTHON_VARIANT_HEADER in request.headers:
104             value = request.headers[PYTHON_VARIANT_HEADER]
105             try:
106                 pyi, versions = parse_python_variant_header(value)
107             except InvalidVariantHeader as e:
108                 return web.Response(
109                     status=400,
110                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
111                 )
112         else:
113             pyi = False
114             versions = set()
115
116         skip_string_normalization = bool(
117             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
118         )
119         skip_magic_trailing_comma = bool(
120             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
121         )
122         skip_source_first_line = bool(
123             request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
124         )
125         preview = bool(request.headers.get(PREVIEW, False))
126         fast = False
127         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
128             fast = True
129         mode = black.FileMode(
130             target_versions=versions,
131             is_pyi=pyi,
132             line_length=line_length,
133             skip_source_first_line=skip_source_first_line,
134             string_normalization=not skip_string_normalization,
135             magic_trailing_comma=not skip_magic_trailing_comma,
136             preview=preview,
137         )
138         req_bytes = await request.content.read()
139         charset = request.charset if request.charset is not None else "utf8"
140         req_str = req_bytes.decode(charset)
141         then = datetime.now(timezone.utc)
142
143         header = ""
144         if skip_source_first_line:
145             first_newline_position: int = req_str.find("\n") + 1
146             header = req_str[:first_newline_position]
147             req_str = req_str[first_newline_position:]
148
149         loop = asyncio.get_event_loop()
150         formatted_str = await loop.run_in_executor(
151             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
152         )
153
154         # Preserve CRLF line endings
155         if req_str[req_str.find("\n") - 1] == "\r":
156             formatted_str = formatted_str.replace("\n", "\r\n")
157             # If, after swapping line endings, nothing changed, then say so
158             if formatted_str == req_str:
159                 raise black.NothingChanged
160
161         # Put the source first line back
162         req_str = header + req_str
163         formatted_str = header + formatted_str
164
165         # Only output the diff in the HTTP response
166         only_diff = bool(request.headers.get(DIFF_HEADER, False))
167         if only_diff:
168             now = datetime.now(timezone.utc)
169             src_name = f"In\t{then}"
170             dst_name = f"Out\t{now}"
171             loop = asyncio.get_event_loop()
172             formatted_str = await loop.run_in_executor(
173                 executor,
174                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
175             )
176
177         return web.Response(
178             content_type=request.content_type,
179             charset=charset,
180             headers=headers,
181             text=formatted_str,
182         )
183     except black.NothingChanged:
184         return web.Response(status=204, headers=headers)
185     except black.InvalidInput as e:
186         return web.Response(status=400, headers=headers, text=str(e))
187     except Exception as e:
188         logging.exception("Exception during handling a request")
189         return web.Response(status=500, headers=headers, text=str(e))
190
191
192 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
193     if value == "pyi":
194         return True, set()
195     else:
196         versions = set()
197         for version in value.split(","):
198             if version.startswith("py"):
199                 version = version[len("py") :]
200             if "." in version:
201                 major_str, *rest = version.split(".")
202             else:
203                 major_str = version[0]
204                 rest = [version[1:]] if len(version) > 1 else []
205             try:
206                 major = int(major_str)
207                 if major not in (2, 3):
208                     raise InvalidVariantHeader("major version must be 2 or 3")
209                 if len(rest) > 0:
210                     minor = int(rest[0])
211                     if major == 2:
212                         raise InvalidVariantHeader("Python 2 is not supported")
213                 else:
214                     # Default to lowest supported minor version.
215                     minor = 7 if major == 2 else 3
216                 version_str = f"PY{major}{minor}"
217                 if major == 3 and not hasattr(black.TargetVersion, version_str):
218                     raise InvalidVariantHeader(f"3.{minor} is not supported")
219                 versions.add(black.TargetVersion[version_str])
220             except (KeyError, ValueError):
221                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
222         return False, versions
223
224
225 def patched_main() -> None:
226     maybe_install_uvloop()
227     freeze_support()
228     main()
229
230
231 if __name__ == "__main__":
232     patched_main()