X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/a5381ba7648f7308145c78c248e29118e18dc530..e73662ca7cfd6d4760e11a6ab489a1ec585d1cd4:/src/blackd/__init__.py

diff --git a/src/blackd/__init__.py b/src/blackd/__init__.py
index cc96640..6b0f3d3 100644
--- a/src/blackd/__init__.py
+++ b/src/blackd/__init__.py
@@ -1,13 +1,14 @@
 import asyncio
 import logging
 from concurrent.futures import Executor, ProcessPoolExecutor
-from datetime import datetime
+from datetime import datetime, timezone
 from functools import partial
 from multiprocessing import freeze_support
 from typing import Set, Tuple
 
 try:
     from aiohttp import web
+
     from .middlewares import cors
 except ImportError as ie:
     raise ImportError(
@@ -16,11 +17,11 @@ except ImportError as ie:
         + "to obtain aiohttp_cors: `pip install black[d]`"
     ) from None
 
-import black
-from black.concurrency import maybe_install_uvloop
 import click
 
+import black
 from _black_version import version as __version__
+from black.concurrency import maybe_install_uvloop
 
 # This is used internally by tests to shut down the server prematurely
 _stop_signal = asyncio.Event()
@@ -29,8 +30,10 @@ _stop_signal = asyncio.Event()
 PROTOCOL_VERSION_HEADER = "X-Protocol-Version"
 LINE_LENGTH_HEADER = "X-Line-Length"
 PYTHON_VARIANT_HEADER = "X-Python-Variant"
+SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line"
 SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization"
 SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma"
+PREVIEW = "X-Preview"
 FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe"
 DIFF_HEADER = "X-Diff"
 
@@ -38,8 +41,10 @@ BLACK_HEADERS = [
     PROTOCOL_VERSION_HEADER,
     LINE_LENGTH_HEADER,
     PYTHON_VARIANT_HEADER,
+    SKIP_SOURCE_FIRST_LINE,
     SKIP_STRING_NORMALIZATION_HEADER,
     SKIP_MAGIC_TRAILING_COMMA,
+    PREVIEW,
     FAST_OR_SAFE_HEADER,
     DIFF_HEADER,
 ]
@@ -54,9 +59,15 @@ class InvalidVariantHeader(Exception):
 
 @click.command(context_settings={"help_option_names": ["-h", "--help"]})
 @click.option(
-    "--bind-host", type=str, help="Address to bind the server to.", default="localhost"
+    "--bind-host",
+    type=str,
+    help="Address to bind the server to.",
+    default="localhost",
+    show_default=True,
+)
+@click.option(
+    "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True
 )
-@click.option("--bind-port", type=int, help="Port to listen on", default=45484)
 @click.version_option(version=black.__version__)
 def main(bind_host: str, bind_port: int) -> None:
     logging.basicConfig(level=logging.INFO)
@@ -108,6 +119,10 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
         skip_magic_trailing_comma = bool(
             request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False)
         )
+        skip_source_first_line = bool(
+            request.headers.get(SKIP_SOURCE_FIRST_LINE, False)
+        )
+        preview = bool(request.headers.get(PREVIEW, False))
         fast = False
         if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast":
             fast = True
@@ -115,25 +130,45 @@ async def handle(request: web.Request, executor: Executor) -> web.Response:
             target_versions=versions,
             is_pyi=pyi,
             line_length=line_length,
+            skip_source_first_line=skip_source_first_line,
             string_normalization=not skip_string_normalization,
             magic_trailing_comma=not skip_magic_trailing_comma,
+            preview=preview,
         )
         req_bytes = await request.content.read()
         charset = request.charset if request.charset is not None else "utf8"
         req_str = req_bytes.decode(charset)
-        then = datetime.utcnow()
+        then = datetime.now(timezone.utc)
+
+        header = ""
+        if skip_source_first_line:
+            first_newline_position: int = req_str.find("\n") + 1
+            header = req_str[:first_newline_position]
+            req_str = req_str[first_newline_position:]
 
         loop = asyncio.get_event_loop()
         formatted_str = await loop.run_in_executor(
             executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode)
         )
 
+        # Preserve CRLF line endings
+        nl = req_str.find("\n")
+        if nl > 0 and req_str[nl - 1] == "\r":
+            formatted_str = formatted_str.replace("\n", "\r\n")
+            # If, after swapping line endings, nothing changed, then say so
+            if formatted_str == req_str:
+                raise black.NothingChanged
+
+        # Put the source first line back
+        req_str = header + req_str
+        formatted_str = header + formatted_str
+
         # Only output the diff in the HTTP response
         only_diff = bool(request.headers.get(DIFF_HEADER, False))
         if only_diff:
-            now = datetime.utcnow()
-            src_name = f"In\t{then} +0000"
-            dst_name = f"Out\t{now} +0000"
+            now = datetime.now(timezone.utc)
+            src_name = f"In\t{then}"
+            dst_name = f"Out\t{now}"
             loop = asyncio.get_event_loop()
             formatted_str = await loop.run_in_executor(
                 executor,
@@ -174,10 +209,8 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi
                     raise InvalidVariantHeader("major version must be 2 or 3")
                 if len(rest) > 0:
                     minor = int(rest[0])
-                    if major == 2 and minor != 7:
-                        raise InvalidVariantHeader(
-                            "minor version must be 7 for Python 2"
-                        )
+                    if major == 2:
+                        raise InvalidVariantHeader("Python 2 is not supported")
                 else:
                     # Default to lowest supported minor version.
                     minor = 7 if major == 2 else 3
@@ -193,7 +226,6 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi
 def patched_main() -> None:
     maybe_install_uvloop()
     freeze_support()
-    black.patch_click()
     main()