X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/f311d82569b9595d85c08cc8fcf5250de525e7a0..df50fee7fd85018f8db462774512a83031f00322:/src/blackd/__init__.py diff --git a/src/blackd/__init__.py b/src/blackd/__init__.py index f77a5e8..4f2d87d 100644 --- a/src/blackd/__init__.py +++ b/src/blackd/__init__.py @@ -1,28 +1,27 @@ import asyncio import logging -import sys from concurrent.futures import Executor, ProcessPoolExecutor -from datetime import datetime +from datetime import datetime, timezone from functools import partial from multiprocessing import freeze_support from typing import Set, Tuple try: from aiohttp import web - import aiohttp_cors + + from .middlewares import cors except ImportError as ie: - print( + raise ImportError( f"aiohttp dependency is not installed: {ie}. " + "Please re-install black with the '[d]' extra install " - + "to obtain aiohttp_cors: `pip install black[d]`", - file=sys.stderr, - ) - sys.exit(-1) + + "to obtain aiohttp_cors: `pip install black[d]`" + ) from None -import black import click +import black from _black_version import version as __version__ +from black.concurrency import maybe_install_uvloop # This is used internally by tests to shut down the server prematurely _stop_signal = asyncio.Event() @@ -31,7 +30,10 @@ _stop_signal = asyncio.Event() PROTOCOL_VERSION_HEADER = "X-Protocol-Version" LINE_LENGTH_HEADER = "X-Line-Length" PYTHON_VARIANT_HEADER = "X-Python-Variant" +SKIP_SOURCE_FIRST_LINE = "X-Skip-Source-First-Line" SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" +SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma" +PREVIEW = "X-Preview" FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" DIFF_HEADER = "X-Diff" @@ -39,7 +41,10 @@ BLACK_HEADERS = [ PROTOCOL_VERSION_HEADER, LINE_LENGTH_HEADER, PYTHON_VARIANT_HEADER, + SKIP_SOURCE_FIRST_LINE, SKIP_STRING_NORMALIZATION_HEADER, + SKIP_MAGIC_TRAILING_COMMA, + PREVIEW, FAST_OR_SAFE_HEADER, DIFF_HEADER, ] @@ -54,9 +59,15 @@ class InvalidVariantHeader(Exception): @click.command(context_settings={"help_option_names": ["-h", "--help"]}) @click.option( - "--bind-host", type=str, help="Address to bind the server to.", default="localhost" + "--bind-host", + type=str, + help="Address to bind the server to.", + default="localhost", + show_default=True, +) +@click.option( + "--bind-port", type=int, help="Port to listen on", default=45484, show_default=True ) -@click.option("--bind-port", type=int, help="Port to listen on", default=45484) @click.version_option(version=black.__version__) def main(bind_host: str, bind_port: int) -> None: logging.basicConfig(level=logging.INFO) @@ -67,20 +78,11 @@ def main(bind_host: str, bind_port: int) -> None: def make_app() -> web.Application: - app = web.Application() - executor = ProcessPoolExecutor() - - cors = aiohttp_cors.setup(app) - resource = cors.add(app.router.add_resource("/")) - cors.add( - resource.add_route("POST", partial(handle, executor=executor)), - { - "*": aiohttp_cors.ResourceOptions( - allow_headers=(*BLACK_HEADERS, "Content-Type"), expose_headers="*" - ) - }, + app = web.Application( + middlewares=[cors(allow_headers=(*BLACK_HEADERS, "Content-Type"))] ) - + executor = ProcessPoolExecutor() + app.add_routes([web.post("/", partial(handle, executor=executor))]) return app @@ -114,6 +116,13 @@ async def handle(request: web.Request, executor: Executor) -> web.Response: skip_string_normalization = bool( request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) ) + skip_magic_trailing_comma = bool( + request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False) + ) + skip_source_first_line = bool( + request.headers.get(SKIP_SOURCE_FIRST_LINE, False) + ) + preview = bool(request.headers.get(PREVIEW, False)) fast = False if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": fast = True @@ -121,24 +130,44 @@ async def handle(request: web.Request, executor: Executor) -> web.Response: target_versions=versions, is_pyi=pyi, line_length=line_length, + skip_source_first_line=skip_source_first_line, string_normalization=not skip_string_normalization, + magic_trailing_comma=not skip_magic_trailing_comma, + preview=preview, ) req_bytes = await request.content.read() charset = request.charset if request.charset is not None else "utf8" req_str = req_bytes.decode(charset) - then = datetime.utcnow() + then = datetime.now(timezone.utc) + + header = "" + if skip_source_first_line: + first_newline_position: int = req_str.find("\n") + 1 + header = req_str[:first_newline_position] + req_str = req_str[first_newline_position:] loop = asyncio.get_event_loop() formatted_str = await loop.run_in_executor( executor, partial(black.format_file_contents, req_str, fast=fast, mode=mode) ) + # Preserve CRLF line endings + if req_str[req_str.find("\n") - 1] == "\r": + formatted_str = formatted_str.replace("\n", "\r\n") + # If, after swapping line endings, nothing changed, then say so + if formatted_str == req_str: + raise black.NothingChanged + + # Put the source first line back + req_str = header + req_str + formatted_str = header + formatted_str + # Only output the diff in the HTTP response only_diff = bool(request.headers.get(DIFF_HEADER, False)) if only_diff: - now = datetime.utcnow() - src_name = f"In\t{then} +0000" - dst_name = f"Out\t{now} +0000" + now = datetime.now(timezone.utc) + src_name = f"In\t{then}" + dst_name = f"Out\t{now}" loop = asyncio.get_event_loop() formatted_str = await loop.run_in_executor( executor, @@ -179,10 +208,8 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi raise InvalidVariantHeader("major version must be 2 or 3") if len(rest) > 0: minor = int(rest[0]) - if major == 2 and minor != 7: - raise InvalidVariantHeader( - "minor version must be 7 for Python 2" - ) + if major == 2: + raise InvalidVariantHeader("Python 2 is not supported") else: # Default to lowest supported minor version. minor = 7 if major == 2 else 3 @@ -191,13 +218,13 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi raise InvalidVariantHeader(f"3.{minor} is not supported") versions.add(black.TargetVersion[version_str]) except (KeyError, ValueError): - raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") + raise InvalidVariantHeader("expected e.g. '3.7', 'py3.5'") from None return False, versions def patched_main() -> None: + maybe_install_uvloop() freeze_support() - black.patch_click() main()