X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/5c2dd96a69a935cf45acbdf2ffabbd39b27d38fa..fa1163545fd8779633d27a45d81e0dfa6ebd61fa:/blackd.py?ds=inline

diff --git a/blackd.py b/blackd.py
index e1006a1..d79bfe7 100644
--- a/blackd.py
+++ b/blackd.py
@@ -1,21 +1,44 @@
 import asyncio
 from concurrent.futures import Executor, ProcessPoolExecutor
+from datetime import datetime
 from functools import partial
 import logging
+from multiprocessing import freeze_support
+from typing import Set, Tuple
 
 from aiohttp import web
+import aiohttp_cors
 import black
 import click
 
+from _black_version import version as __version__
+
 # This is used internally by tests to shut down the server prematurely
 _stop_signal = asyncio.Event()
 
-VERSION_HEADER = "X-Protocol-Version"
+# Request headers
+PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
 LINE_LENGTH_HEADER = "X-Line-Length"
 PYTHON_VARIANT_HEADER = "X-Python-Variant"
 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
-SKIP_NUMERIC_UNDERSCORE_NORMALIZATION_HEADER = "X-Skip-Numeric-Underscore-Normalization"
 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
+DIFF_HEADER = "X-Diff"
+
+BLACK_HEADERS = [
+    PROTOCOL_VERSION_HEADER,
+    LINE_LENGTH_HEADER,
+    PYTHON_VARIANT_HEADER,
+    SKIP_STRING_NORMALIZATION_HEADER,
+    FAST_OR_SAFE_HEADER,
+    DIFF_HEADER,
+]
+
+# Response headers
+BLACK_VERSION_HEADER = "X-Black-Version"
+
+
+class InvalidVariantHeader(Exception):
+    pass
 
 
 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
@@ -35,13 +58,25 @@ def main(bind_host: str, bind_port: int) -> None:
 def make_app() -> web.Application:
     app = web.Application()
     executor = ProcessPoolExecutor()
-    app.add_routes([web.post("/", partial(handle, executor=executor))])
+
+    cors = aiohttp_cors.setup(app)
+    resource = cors.add(app.router.add_resource("/"))
+    cors.add(
+        resource.add_route("POST", partial(handle, executor=executor)),
+        {
+            "*": aiohttp_cors.ResourceOptions(
+                allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
+            )
+        },
+    )
+
     return app
 
 
 async def handle(request: web.Request, executor: Executor) -> web.Response:
+    headers = {BLACK_VERSION_HEADER: __version__}
     try:
-        if request.headers.get(VERSION_HEADER, "1") != "1":
+        if request.headers.get(PROTOCOL_VERSION_HEADER, "1") != "1":
             return web.Response(
                 status=501, text="This server only supports protocol version 1"
             )
@@ -51,64 +86,106 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
             )
         except ValueError:
             return web.Response(status=400, text="Invalid line length header value")
-        py36 = False
-        pyi = False
+
         if PYTHON_VARIANT_HEADER in request.headers:
             value = request.headers[PYTHON_VARIANT_HEADER]
-            if value == "pyi":
-                pyi = True
-            else:
-                try:
-                    major, *rest = value.split(".")
-                    if int(major) == 3 and len(rest) > 0:
-                        if int(rest[0]) >= 6:
-                            py36 = True
-                except ValueError:
-                    return web.Response(
-                        status=400, text=f"Invalid value for {PYTHON_VARIANT_HEADER}"
-                    )
+            try:
+                pyi, versions = parse_python_variant_header(value)
+            except InvalidVariantHeader as e:
+                return web.Response(
+                    status=400,
+                    text=f"Invalid value for {PYTHON_VARIANT_HEADER}: {e.args[0]}",
+                )
+        else:
+            pyi = False
+            versions = set()
+
         skip_string_normalization = bool(
             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
         )
-        skip_numeric_underscore_normalization = bool(
-            request.headers.get(SKIP_NUMERIC_UNDERSCORE_NORMALIZATION_HEADER, False)
-        )
         fast = False
         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
             fast = True
-        mode = black.FileMode.from_configuration(
-            py36=py36,
-            pyi=pyi,
-            skip_string_normalization=skip_string_normalization,
-            skip_numeric_underscore_normalization=skip_numeric_underscore_normalization,
+        mode = black.FileMode(
+            target_versions=versions,
+            is_pyi=pyi,
+            line_length=line_length,
+            string_normalization=not skip_string_normalization,
         )
         req_bytes = await request.content.read()
         charset = request.charset if request.charset is not None else "utf8"
         req_str = req_bytes.decode(charset)
+        then = datetime.utcnow()
+
         loop = asyncio.get_event_loop()
         formatted_str = await loop.run_in_executor(
-            executor,
-            partial(
-                black.format_file_contents,
-                req_str,
-                line_length=line_length,
-                fast=fast,
-                mode=mode,
-            ),
+            executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
         )
+
+        # Only output the diff in the HTTP response
+        only_diff = bool(request.headers.get(DIFF_HEADER, False))
+        if only_diff:
+            now = datetime.utcnow()
+            src_name = f"In\t{then} +0000"
+            dst_name = f"Out\t{now} +0000"
+            loop = asyncio.get_event_loop()
+            formatted_str = await loop.run_in_executor(
+                executor,
+                partial(black.diff, req_str, formatted_str, src_name, dst_name),
+            )
+
         return web.Response(
-            content_type=request.content_type, charset=charset, text=formatted_str
+            content_type=request.content_type,
+            charset=charset,
+            headers=headers,
+            text=formatted_str,
         )
     except black.NothingChanged:
-        return web.Response(status=204)
+        return web.Response(status=204, headers=headers)
     except black.InvalidInput as e:
-        return web.Response(status=400, text=str(e))
+        return web.Response(status=400, headers=headers, text=str(e))
     except Exception as e:
         logging.exception("Exception during handling a request")
-        return web.Response(status=500, text=str(e))
+        return web.Response(status=500, headers=headers, text=str(e))
+
+
+def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersion]]:
+    if value == "pyi":
+        return True, set()
+    else:
+        versions = set()
+        for version in value.split(","):
+            if version.startswith("py"):
+                version = version[len("py") :]
+            if "." in version:
+                major_str, *rest = version.split(".")
+            else:
+                major_str = version[0]
+                rest = [version[1:]] if len(version) > 1 else []
+            try:
+                major = int(major_str)
+                if major not in (2, 3):
+                    raise InvalidVariantHeader("major version must be 2 or 3")
+                if len(rest) > 0:
+                    minor = int(rest[0])
+                    if major == 2 and minor != 7:
+                        raise InvalidVariantHeader(
+                            "minor version must be 7 for Python 2"
+                        )
+                else:
+                    # Default to lowest supported minor version.
+                    minor = 7 if major == 2 else 3
+                version_str = f"PY{major}{minor}"
+                if major == 3 and not hasattr(black.TargetVersion, version_str):
+                    raise InvalidVariantHeader(f"3.{minor} is not supported")
+                versions.add(black.TargetVersion[version_str])
+            except (KeyError, ValueError):
+                raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
+        return False, versions
 
 
 def patched_main() -> None:
+    freeze_support()
     black.patch_click()
     main()