X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/2082a325fdd14f0aabd88f7f12a20f9fb085c538..40fae18134916b8499bd992d8bef4ae23bcd2986:/src/blackd/__init__.py diff --git a/src/blackd/__init__.py b/src/blackd/__init__.py index d79bfe7..3e2a7e7 100644 --- a/src/blackd/__init__.py +++ b/src/blackd/__init__.py @@ -1,14 +1,26 @@ import asyncio +import logging +import sys from concurrent.futures import Executor, ProcessPoolExecutor from datetime import datetime from functools import partial -import logging from multiprocessing import freeze_support from typing import Set, Tuple -from aiohttp import web -import aiohttp_cors +try: + from aiohttp import web + import aiohttp_cors +except ImportError as ie: + print( + f"aiohttp dependency is not installed: {ie}. " + + "Please re-install black with the '[d]' extra install " + + "to obtain aiohttp_cors: `pip install black[d]`", + file=sys.stderr, + ) + sys.exit(-1) + import black +from black.concurrency import maybe_install_uvloop import click from _black_version import version as __version__ @@ -21,6 +33,7 @@ PROTOCOL_VERSION_HEADER = "X-Protocol-Version" LINE_LENGTH_HEADER = "X-Line-Length" PYTHON_VARIANT_HEADER = "X-Python-Variant" SKIP_STRING_NORMALIZATION_HEADER = "X-Skip-String-Normalization" +SKIP_MAGIC_TRAILING_COMMA = "X-Skip-Magic-Trailing-Comma" FAST_OR_SAFE_HEADER = "X-Fast-Or-Safe" DIFF_HEADER = "X-Diff" @@ -29,6 +42,7 @@ BLACK_HEADERS = [ LINE_LENGTH_HEADER, PYTHON_VARIANT_HEADER, SKIP_STRING_NORMALIZATION_HEADER, + SKIP_MAGIC_TRAILING_COMMA, FAST_OR_SAFE_HEADER, DIFF_HEADER, ] @@ -103,6 +117,9 @@ async def handle(request: web.Request, executor: Executor) -> web.Response: skip_string_normalization = bool( request.headers.get(SKIP_STRING_NORMALIZATION_HEADER, False) ) + skip_magic_trailing_comma = bool( + request.headers.get(SKIP_MAGIC_TRAILING_COMMA, False) + ) fast = False if request.headers.get(FAST_OR_SAFE_HEADER, "safe") == "fast": fast = True @@ -111,6 +128,7 @@ async def handle(request: web.Request, executor: Executor) -> web.Response: is_pyi=pyi, line_length=line_length, string_normalization=not skip_string_normalization, + magic_trailing_comma=not skip_magic_trailing_comma, ) req_bytes = await request.content.read() charset = request.charset if request.charset is not None else "utf8" @@ -185,6 +203,7 @@ def parse_python_variant_header(value: str) -> Tuple[bool, Set[black.TargetVersi def patched_main() -> None: + maybe_install_uvloop() freeze_support() black.patch_click() main()