]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Delay worker count determination
[etc/vim.git] / src / black / __init__.py
index b8a9d03189617036ea4b49413ce1b99fd9e65480..86a0b63744264c85dc23f8702768d333c047f221 100644 (file)
@@ -1,10 +1,7 @@
-import asyncio
 import io
 import json
-import os
 import platform
 import re
-import signal
 import sys
 import tokenize
 import traceback
@@ -13,10 +10,8 @@ from dataclasses import replace
 from datetime import datetime
 from enum import Enum
 from json.decoder import JSONDecodeError
-from multiprocessing import Manager, freeze_support
 from pathlib import Path
 from typing import (
-    TYPE_CHECKING,
     Any,
     Dict,
     Generator,
@@ -38,9 +33,8 @@ from mypy_extensions import mypyc_attr
 from pathspec.patterns.gitwildmatch import GitWildMatchPatternError
 
 from _black_version import version as __version__
-from black.cache import Cache, filter_cached, get_cache_info, read_cache, write_cache
+from black.cache import Cache, get_cache_info, read_cache, write_cache
 from black.comments import normalize_fmt_off
-from black.concurrency import cancel, maybe_install_uvloop, shutdown
 from black.const import (
     DEFAULT_EXCLUDES,
     DEFAULT_INCLUDES,
@@ -91,9 +85,6 @@ from black.trans import iter_fexpr_spans
 from blib2to3.pgen2 import token
 from blib2to3.pytree import Leaf, Node
 
-if TYPE_CHECKING:
-    from concurrent.futures import Executor
-
 COMPILED = Path(__file__).suffix in (".pyd", ".so")
 
 # types
@@ -125,8 +116,6 @@ class WriteBack(Enum):
 # Legacy name, left for integrations.
 FileMode = Mode
 
-DEFAULT_WORKERS = os.cpu_count()
-
 
 def read_pyproject_toml(
     ctx: click.Context, param: click.Parameter, value: Optional[str]
@@ -375,9 +364,8 @@ def validate_regex(
     "-W",
     "--workers",
     type=click.IntRange(min=1),
-    default=DEFAULT_WORKERS,
-    show_default=True,
-    help="Number of parallel workers",
+    default=None,
+    help="Number of parallel workers [default: number of CPUs in the system]",
 )
 @click.option(
     "-q",
@@ -452,7 +440,7 @@ def main(  # noqa: C901
     extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     stdin_filename: Optional[str],
-    workers: int,
+    workers: Optional[int],
     src: Tuple[str, ...],
     config: Optional[str],
 ) -> None:
@@ -469,7 +457,9 @@ def main(  # noqa: C901
         out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.")
         ctx.exit(1)
 
-    root, method = find_project_root(src) if code is None else (None, None)
+    root, method = (
+        find_project_root(src, stdin_filename) if code is None else (None, None)
+    )
     ctx.obj["root"] = root
 
     if verbose:
@@ -480,7 +470,9 @@ def main(  # noqa: C901
             )
 
             normalized = [
-                (normalize_path_maybe_ignore(Path(source), root), source)
+                (source, source)
+                if source == "-"
+                else (normalize_path_maybe_ignore(Path(source), root), source)
                 for source in src
             ]
             srcs_string = ", ".join(
@@ -588,6 +580,8 @@ def main(  # noqa: C901
                 report=report,
             )
         else:
+            from black.concurrency import reformat_many
+
             reformat_many(
                 sources=sources,
                 fast=fast,
@@ -621,12 +615,7 @@ def get_sources(
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
     sources: Set[Path] = set()
-
-    if exclude is None:
-        exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
-        gitignore = get_gitignore(ctx.obj["root"])
-    else:
-        gitignore = None
+    root = ctx.obj["root"]
 
     for s in src:
         if s == "-" and stdin_filename:
@@ -661,6 +650,11 @@ def get_sources(
 
             sources.add(p)
         elif p.is_dir():
+            if exclude is None:
+                exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES)
+                gitignore = get_gitignore(root)
+            else:
+                gitignore = None
             sources.update(
                 gen_python_files(
                     p.iterdir(),
@@ -772,132 +766,6 @@ def reformat_one(
         report.failed(src, str(exc))
 
 
-# diff-shades depends on being to monkeypatch this function to operate. I know it's
-# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
-@mypyc_attr(patchable=True)
-def reformat_many(
-    sources: Set[Path],
-    fast: bool,
-    write_back: WriteBack,
-    mode: Mode,
-    report: "Report",
-    workers: Optional[int],
-) -> None:
-    """Reformat multiple files using a ProcessPoolExecutor."""
-    from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
-
-    executor: Executor
-    worker_count = workers if workers is not None else DEFAULT_WORKERS
-    if sys.platform == "win32":
-        # Work around https://bugs.python.org/issue26903
-        assert worker_count is not None
-        worker_count = min(worker_count, 60)
-    try:
-        executor = ProcessPoolExecutor(max_workers=worker_count)
-    except (ImportError, NotImplementedError, OSError):
-        # we arrive here if the underlying system does not support multi-processing
-        # like in AWS Lambda or Termux, in which case we gracefully fallback to
-        # a ThreadPoolExecutor with just a single worker (more workers would not do us
-        # any good due to the Global Interpreter Lock)
-        executor = ThreadPoolExecutor(max_workers=1)
-
-    loop = asyncio.new_event_loop()
-    asyncio.set_event_loop(loop)
-    try:
-        loop.run_until_complete(
-            schedule_formatting(
-                sources=sources,
-                fast=fast,
-                write_back=write_back,
-                mode=mode,
-                report=report,
-                loop=loop,
-                executor=executor,
-            )
-        )
-    finally:
-        try:
-            shutdown(loop)
-        finally:
-            asyncio.set_event_loop(None)
-        if executor is not None:
-            executor.shutdown()
-
-
-async def schedule_formatting(
-    sources: Set[Path],
-    fast: bool,
-    write_back: WriteBack,
-    mode: Mode,
-    report: "Report",
-    loop: asyncio.AbstractEventLoop,
-    executor: "Executor",
-) -> None:
-    """Run formatting of `sources` in parallel using the provided `executor`.
-
-    (Use ProcessPoolExecutors for actual parallelism.)
-
-    `write_back`, `fast`, and `mode` options are passed to
-    :func:`format_file_in_place`.
-    """
-    cache: Cache = {}
-    if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        cache = read_cache(mode)
-        sources, cached = filter_cached(cache, sources)
-        for src in sorted(cached):
-            report.done(src, Changed.CACHED)
-    if not sources:
-        return
-
-    cancelled = []
-    sources_to_cache = []
-    lock = None
-    if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
-        # For diff output, we need locks to ensure we don't interleave output
-        # from different processes.
-        manager = Manager()
-        lock = manager.Lock()
-    tasks = {
-        asyncio.ensure_future(
-            loop.run_in_executor(
-                executor, format_file_in_place, src, fast, mode, write_back, lock
-            )
-        ): src
-        for src in sorted(sources)
-    }
-    pending = tasks.keys()
-    try:
-        loop.add_signal_handler(signal.SIGINT, cancel, pending)
-        loop.add_signal_handler(signal.SIGTERM, cancel, pending)
-    except NotImplementedError:
-        # There are no good alternatives for these on Windows.
-        pass
-    while pending:
-        done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
-        for task in done:
-            src = tasks.pop(task)
-            if task.cancelled():
-                cancelled.append(task)
-            elif task.exception():
-                report.failed(src, str(task.exception()))
-            else:
-                changed = Changed.YES if task.result() else Changed.NO
-                # If the file was written back or was successfully checked as
-                # well-formatted, store this information in the cache.
-                if write_back is WriteBack.YES or (
-                    write_back is WriteBack.CHECK and changed is Changed.NO
-                ):
-                    sources_to_cache.append(src)
-                report.done(src, changed)
-    if cancelled:
-        if sys.version_info >= (3, 7):
-            await asyncio.gather(*cancelled, return_exceptions=True)
-        else:
-            await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
-    if sources_to_cache:
-        write_cache(cache, sources_to_cache, mode)
-
-
 def format_file_in_place(
     src: Path,
     fast: bool,
@@ -1502,8 +1370,11 @@ def patch_click() -> None:
 
 
 def patched_main() -> None:
-    maybe_install_uvloop()
-    freeze_support()
+    if sys.platform == "win32" and getattr(sys, "frozen", False):
+        from multiprocessing import freeze_support
+
+        freeze_support()
+
     patch_click()
     main()