X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/a82f1867875c906bedfe3ef675473b795d8b0440..8d6d92aa5b5248b5ff70ebf7977f8af5cbcb10b9:/blackd.py?ds=sidebyside diff --git a/blackd.py b/blackd.py index f2bbc8a..d79bfe7 100644 --- a/blackd.py +++ b/blackd.py @@ -1,20 +1,44 @@ import asyncio from concurrent.futures import Executor, ProcessPoolExecutor +from datetime import datetime from functools import partial import logging +from multiprocessing import freeze_support +from typing import Set, Tuple from aiohttp import web +import aiohttp_cors import black import click +from _black_version import version as __version__ + # This is used internally by tests to shut down the server prematurely _stop_signal = asyncio.Event() -VERSION_HEADER = "X-Protocol-Version" +# Request headers +PROTOCOL_VERSION_HEADER = "X-Protocol-Version" LINE_LENGTH_HEADER = "X-Line-Length" PYTHON_VARIANT_HEADER = "X-Python-Variant" SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" +DIFF_HEADER = "X-Diff" + +BLACK_HEADERS = [ + PROTOCOL_VERSION_HEADER, + LINE_LENGTH_HEADER, + PYTHON_VARIANT_HEADER, + SKIP_STRING_NORMALIZATION_HEADER, + FAST_OR_SAFE_HEADER, + DIFF_HEADER, +] + +# Response headers +BLACK_VERSION_HEADER = "X-Black-Version" + + +class InvalidVariantHeader(Exception): + pass @click.command(context_settings={"help_option_names": ["-h", "--help"]}) @@ -34,13 +58,25 @@ def main(bind_host: str, bind_port: int) -> None: def make_app() -> web.Application: app = web.Application() executor = ProcessPoolExecutor() - app.add_routes([web.post("/", partial(handle, executor=executor))]) + + cors = aiohttp_cors.setup(app) + resource = cors.add(app.router.add_resource("/")) + cors.add( + resource.add_route("POST", partial(handle, executor=executor)), + { + "*": aiohttp_cors.ResourceOptions( + allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*" + ) + }, + ) + return app async def handle(request: web.Request, executor: Executor) -> web.Response: + headers = {BLACK_VERSION_HEADER: __version__} try: - if request.headers.get(VERSION_HEADER, "1") != "1": + if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1": return web.Response( status=501, text="This server only supports protocol version 1" ) @@ -50,57 +86,109 @@ async def handle(request: web.Request, executor: Executor) -> web.Response: ) except ValueError: return web.Response(status=400, text="Invalid line length header value") - py36 = False - pyi = False + if PYTHON_VARIANT_HEADER in request.headers: value = request.headers[PYTHON_VARIANT_HEADER] - if value == "pyi": - pyi = True - else: - try: - major, *rest = value.split(".") - if int(major) == 3 and len(rest) > 0: - if int(rest[0]) >= 6: - py36 = True - except ValueError: - return web.Response( - status=400, text=f"Invalid value for {PYTHON_VARIANT_HEADER}" - ) + try: + pyi, versions = parse_python_variant_header(value) + except InvalidVariantHeader as e: + return web.Response( + status=400, + text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}", + ) + else: + pyi = False + versions = set() + skip_string_normalization = bool( request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) ) fast = False if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": fast = True - mode = black.FileMode.from_configuration( - py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization + mode = black.FileMode( + target_versions=versions, + is_pyi=pyi, + line_length=line_length, + string_normalization=not skip_string_normalization, ) req_bytes = await request.content.read() charset = request.charset if request.charset is not None else "utf8" req_str = req_bytes.decode(charset) + then = datetime.utcnow() + loop = asyncio.get_event_loop() formatted_str = await loop.run_in_executor( - executor, - partial( - black.format_file_contents, - req_str, - line_length=line_length, - fast=fast, - mode=mode, - ), + executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode) ) + + # Only output the diff in the HTTP response + only_diff = bool(request.headers.get(DIFF_HEADER, False)) + if only_diff: + now = datetime.utcnow() + src_name = f"In\t{then} +0000" + dst_name = f"Out\t{now} +0000" + loop = asyncio.get_event_loop() + formatted_str = await loop.run_in_executor( + executor, + partial(black.diff, req_str, formatted_str, src_name, dst_name), + ) + return web.Response( - content_type=request.content_type, charset=charset, text=formatted_str + content_type=request.content_type, + charset=charset, + headers=headers, + text=formatted_str, ) except black.NothingChanged: - return web.Response(status=204) + return web.Response(status=204, headers=headers) except black.InvalidInput as e: - return web.Response(status=400, text=str(e)) + return web.Response(status=400, headers=headers, text=str(e)) except Exception as e: logging.exception("Exception during handling a request") - return web.Response(status=500, text=str(e)) + return web.Response(status=500, headers=headers, text=str(e)) -if __name__ == "__main__": +def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]: + if value == "pyi": + return True, set() + else: + versions = set() + for version in value.split(","): + if version.startswith("py"): + version = version[len("py") :] + if "." in version: + major_str, *rest = version.split(".") + else: + major_str = version[0] + rest = [version[1:]] if len(version) > 1 else [] + try: + major = int(major_str) + if major not in (2, 3): + raise InvalidVariantHeader("major version must be 2 or 3") + if len(rest) > 0: + minor = int(rest[0]) + if major == 2 and minor != 7: + raise InvalidVariantHeader( + "minor version must be 7 for Python 2" + ) + else: + # Default to lowest supported minor version. + minor = 7 if major == 2 else 3 + version_str = f"PY{major}{minor}" + if major == 3 and not hasattr(black.TargetVersion, version_str): + raise InvalidVariantHeader(f"3.{minor} is not supported") + versions.add(black.TargetVersion[version_str]) + except (KeyError, ValueError): + raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") + return False, versions + + +def patched_main() -> None: + freeze_support() black.patch_click() main() + + +if __name__ == "__main__": + patched_main()