]> git.madduck.net Git - etc/vim.git/blob - src/blackd/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix regular expression that black uses to identify f-expressions (#2287)
[etc/vim.git] / src / blackd / __init__.py
1 import asyncio
2 import logging
3 import sys
4 from concurrent.futures import Executor, ProcessPoolExecutor
5 from datetime import datetime
6 from functools import partial
7 from multiprocessing import freeze_support
8 from typing import Set, Tuple
9
10 try:
11     from aiohttp import web
12     import aiohttp_cors
13 except ImportError as ie:
14     print(
15         f"aiohttp dependency is not installed: {ie}. "
16         + "Please re-install black with the '[d]' extra install "
17         + "to obtain aiohttp_cors: `pip install black[d]`",
18         file=sys.stderr,
19     )
20     sys.exit(-1)
21
22 import black
23 import click
24
25 # If our environment has uvloop installed lets use it
26 try:
27     import uvloop
28
29     uvloop.install()
30 except ImportError:
31     pass
32
33 from _black_version import version as __version__
34
35 # This is used internally by tests to shut down the server prematurely
36 _stop_signal = asyncio.Event()
37
38 # Request headers
39 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
40 LINE_LENGTH_HEADER = "X-Line-Length"
41 PYTHON_VARIANT_HEADER = "X-Python-Variant"
42 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
43 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
44 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
45 DIFF_HEADER = "X-Diff"
46
47 BLACK_HEADERS = [
48     PROTOCOL_VERSION_HEADER,
49     LINE_LENGTH_HEADER,
50     PYTHON_VARIANT_HEADER,
51     SKIP_STRING_NORMALIZATION_HEADER,
52     SKIP_MAGIC_TRAILING_COMMA,
53     FAST_OR_SAFE_HEADER,
54     DIFF_HEADER,
55 ]
56
57 # Response headers
58 BLACK_VERSION_HEADER = "X-Black-Version"
59
60
61 class InvalidVariantHeader(Exception):
62     pass
63
64
65 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
66 @click.option(
67     "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
68 )
69 @click.option("--bind-port", type=int, help="Port to listen on", default=45484)
70 @click.version_option(version=black.__version__)
71 def main(bind_host: str, bind_port: int) -> None:
72     logging.basicConfig(level=logging.INFO)
73     app = make_app()
74     ver = black.__version__
75     black.out(f"blackd version {ver} listening on {bind_host} port {bind_port}")
76     web.run_app(app, host=bind_host, port=bind_port, handle_signals=True, print=None)
77
78
79 def make_app() -> web.Application:
80     app = web.Application()
81     executor = ProcessPoolExecutor()
82
83     cors = aiohttp_cors.setup(app)
84     resource = cors.add(app.router.add_resource("/"))
85     cors.add(
86         resource.add_route("POST", partial(handle, executor=executor)),
87         {
88             "*": aiohttp_cors.ResourceOptions(
89                 allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
90             )
91         },
92     )
93
94     return app
95
96
97 async def handle(request: web.Request, executor: Executor) -> web.Response:
98     headers = {BLACK_VERSION_HEADER: __version__}
99     try:
100         if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
101             return web.Response(
102                 status=501, text="This server only supports protocol version 1"
103             )
104         try:
105             line_length = int(
106                 request.headers.get(LINE_LENGTH_HEADER, black.DEFAULT_LINE_LENGTH)
107             )
108         except ValueError:
109             return web.Response(status=400, text="Invalid line length header value")
110
111         if PYTHON_VARIANT_HEADER in request.headers:
112             value = request.headers[PYTHON_VARIANT_HEADER]
113             try:
114                 pyi, versions = parse_python_variant_header(value)
115             except InvalidVariantHeader as e:
116                 return web.Response(
117                     status=400,
118                     text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
119                 )
120         else:
121             pyi = False
122             versions = set()
123
124         skip_string_normalization = bool(
125             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
126         )
127         skip_magic_trailing_comma = bool(
128             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
129         )
130         fast = False
131         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
132             fast = True
133         mode = black.FileMode(
134             target_versions=versions,
135             is_pyi=pyi,
136             line_length=line_length,
137             string_normalization=not skip_string_normalization,
138             magic_trailing_comma=not skip_magic_trailing_comma,
139         )
140         req_bytes = await request.content.read()
141         charset = request.charset if request.charset is not None else "utf8"
142         req_str = req_bytes.decode(charset)
143         then = datetime.utcnow()
144
145         loop = asyncio.get_event_loop()
146         formatted_str = await loop.run_in_executor(
147             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
148         )
149
150         # Only output the diff in the HTTP response
151         only_diff = bool(request.headers.get(DIFF_HEADER, False))
152         if only_diff:
153             now = datetime.utcnow()
154             src_name = f"In\t{then} +0000"
155             dst_name = f"Out\t{now} +0000"
156             loop = asyncio.get_event_loop()
157             formatted_str = await loop.run_in_executor(
158                 executor,
159                 partial(black.diff, req_str, formatted_str, src_name, dst_name),
160             )
161
162         return web.Response(
163             content_type=request.content_type,
164             charset=charset,
165             headers=headers,
166             text=formatted_str,
167         )
168     except black.NothingChanged:
169         return web.Response(status=204, headers=headers)
170     except black.InvalidInput as e:
171         return web.Response(status=400, headers=headers, text=str(e))
172     except Exception as e:
173         logging.exception("Exception during handling a request")
174         return web.Response(status=500, headers=headers, text=str(e))
175
176
177 def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
178     if value == "pyi":
179         return True, set()
180     else:
181         versions = set()
182         for version in value.split(","):
183             if version.startswith("py"):
184                 version = version[len("py") :]
185             if "." in version:
186                 major_str, *rest = version.split(".")
187             else:
188                 major_str = version[0]
189                 rest = [version[1:]] if len(version) > 1 else []
190             try:
191                 major = int(major_str)
192                 if major not in (2, 3):
193                     raise InvalidVariantHeader("major version must be 2 or 3")
194                 if len(rest) > 0:
195                     minor = int(rest[0])
196                     if major == 2 and minor != 7:
197                         raise InvalidVariantHeader(
198                             "minor version must be 7 for Python 2"
199                         )
200                 else:
201                     # Default to lowest supported minor version.
202                     minor = 7 if major == 2 else 3
203                 version_str = f"PY{major}{minor}"
204                 if major == 3 and not hasattr(black.TargetVersion, version_str):
205                     raise InvalidVariantHeader(f"3.{minor} is not supported")
206                 versions.add(black.TargetVersion[version_str])
207             except (KeyError, ValueError):
208                 raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
209         return False, versions
210
211
212 def patched_main() -> None:
213     freeze_support()
214     black.patch_click()
215     main()
216
217
218 if __name__ == "__main__":
219     patched_main()