X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/f311d82569b9595d85c08cc8fcf5250de525e7a0..c4bd2e31ceeac84d68592986fe70920f3d3d0443:/src/blackd/__init__.py

diff --git a/src/blackd/__init__.py b/src/blackd/__init__.py
index f77a5e8..ba4750b 100644
--- a/src/blackd/__init__.py
+++ b/src/blackd/__init__.py
@@ -1,6 +1,5 @@
 import asyncio
 import logging
-import sys
 from concurrent.futures import Executor, ProcessPoolExecutor
 from datetime import datetime
 from functools import partial
@@ -9,20 +8,20 @@ from typing import Set, Tuple
 
 try:
     from aiohttp import web
-    import aiohttp_cors
+
+    from .middlewares import cors
 except ImportError as ie:
-    print(
+    raise ImportError(
         f"aiohttp dependency is not installed: {ie}. "
         + "Please re-install black with the '[d]' extra install "
-        + "to obtain aiohttp_cors: `pip install black[d]`",
-        file=sys.stderr,
-    )
-    sys.exit(-1)
+        + "to obtain aiohttp_cors: `pip install black[d]`"
+    ) from None
 
-import black
 import click
 
+import black
 from _black_version import version as __version__
+from black.concurrency import maybe_install_uvloop
 
 # This is used internally by tests to shut down the server prematurely
 _stop_signal = asyncio.Event()
@@ -31,7 +30,10 @@ _stop_signal = asyncio.Event()
 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
 LINE_LENGTH_HEADER = "X-Line-Length"
 PYTHON_VARIANT_HEADER = "X-Python-Variant"
+SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
+SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
+PREVIEW = "X-Preview"
 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
 DIFF_HEADER = "X-Diff"
 
@@ -39,7 +41,10 @@ BLACK_HEADERS = [
     PROTOCOL_VERSION_HEADER,
     LINE_LENGTH_HEADER,
     PYTHON_VARIANT_HEADER,
+    SKIP_SOURCE_FIRST_LINE,
     SKIP_STRING_NORMALIZATION_HEADER,
+    SKIP_MAGIC_TRAILING_COMMA,
+    PREVIEW,
     FAST_OR_SAFE_HEADER,
     DIFF_HEADER,
 ]
@@ -67,20 +72,11 @@ def main(bind_host: str, bind_port: int) -> None:
 
 
 def make_app() -> web.Application:
-    app = web.Application()
-    executor = ProcessPoolExecutor()
-
-    cors = aiohttp_cors.setup(app)
-    resource = cors.add(app.router.add_resource("/"))
-    cors.add(
-        resource.add_route("POST", partial(handle, executor=executor)),
-        {
-            "*": aiohttp_cors.ResourceOptions(
-                allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*"
-            )
-        },
+    app = web.Application(
+        middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))]
     )
-
+    executor = ProcessPoolExecutor()
+    app.add_routes([web.post("/", partial(handle, executor=executor))])
     return app
 
 
@@ -114,6 +110,13 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
         skip_string_normalization = bool(
             request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False)
         )
+        skip_magic_trailing_comma = bool(
+            request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
+        )
+        skip_source_first_line = bool(
+            request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
+        )
+        preview = bool(request.headers.get(PREVIEW, False))
         fast = False
         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
             fast = True
@@ -121,18 +124,38 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
             target_versions=versions,
             is_pyi=pyi,
             line_length=line_length,
+            skip_source_first_line=skip_source_first_line,
             string_normalization=not skip_string_normalization,
+            magic_trailing_comma=not skip_magic_trailing_comma,
+            preview=preview,
         )
         req_bytes = await request.content.read()
         charset = request.charset if request.charset is not None else "utf8"
         req_str = req_bytes.decode(charset)
         then = datetime.utcnow()
 
+        header = ""
+        if skip_source_first_line:
+            first_newline_position: int = req_str.find("\n") + 1
+            header = req_str[:first_newline_position]
+            req_str = req_str[first_newline_position:]
+
         loop = asyncio.get_event_loop()
         formatted_str = await loop.run_in_executor(
             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
         )
 
+        # Preserve CRLF line endings
+        if req_str[req_str.find("\n") - 1] == "\r":
+            formatted_str = formatted_str.replace("\n", "\r\n")
+            # If, after swapping line endings, nothing changed, then say so
+            if formatted_str == req_str:
+                raise black.NothingChanged
+
+        # Put the source first line back
+        req_str = header + req_str
+        formatted_str = header + formatted_str
+
         # Only output the diff in the HTTP response
         only_diff = bool(request.headers.get(DIFF_HEADER, False))
         if only_diff:
@@ -179,10 +202,8 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi
                     raise InvalidVariantHeader("major version must be 2 or 3")
                 if len(rest) > 0:
                     minor = int(rest[0])
-                    if major == 2 and minor != 7:
-                        raise InvalidVariantHeader(
-                            "minor version must be 7 for Python 2"
-                        )
+                    if major == 2:
+                        raise InvalidVariantHeader("Python 2 is not supported")
                 else:
                     # Default to lowest supported minor version.
                     minor = 7 if major == 2 else 3
@@ -191,11 +212,12 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi
                     raise InvalidVariantHeader(f"3.{minor} is not supported")
                 versions.add(black.TargetVersion[version_str])
             except (KeyError, ValueError):
-                raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'")
+                raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None
         return False, versions
 
 
 def patched_main() -> None:
+    maybe_install_uvloop()
     freeze_support()
     black.patch_click()
     main()