]> git.madduck.net Git - etc/vim.git/blob - src/black/concurrency.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Lazily import parallelized format modules
[etc/vim.git] / src / black / concurrency.py
1 """
2 Formatting many files at once via multiprocessing. Contains entrypoint and utilities.
3
4 NOTE: this module is only imported if we need to format several files at once.
5 """
6
7 import asyncio
8 import logging
9 import signal
10 import sys
11 from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
12 from multiprocessing import Manager
13 from pathlib import Path
14 from typing import Any, Iterable, Optional, Set
15
16 from mypy_extensions import mypyc_attr
17
18 from black import DEFAULT_WORKERS, WriteBack, format_file_in_place
19 from black.cache import Cache, filter_cached, read_cache, write_cache
20 from black.mode import Mode
21 from black.output import err
22 from black.report import Changed, Report
23
24
25 def maybe_install_uvloop() -> None:
26     """If our environment has uvloop installed we use it.
27
28     This is called only from command-line entry points to avoid
29     interfering with the parent process if Black is used as a library.
30     """
31     try:
32         import uvloop
33
34         uvloop.install()
35     except ImportError:
36         pass
37
38
39 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
40     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
41     err("Aborted!")
42     for task in tasks:
43         task.cancel()
44
45
46 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
47     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
48     try:
49         if sys.version_info[:2] >= (3, 7):
50             all_tasks = asyncio.all_tasks
51         else:
52             all_tasks = asyncio.Task.all_tasks
53         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
54         to_cancel = [task for task in all_tasks(loop) if not task.done()]
55         if not to_cancel:
56             return
57
58         for task in to_cancel:
59             task.cancel()
60         if sys.version_info >= (3, 7):
61             loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
62         else:
63             loop.run_until_complete(
64                 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
65             )
66     finally:
67         # `concurrent.futures.Future` objects cannot be cancelled once they
68         # are already running. There might be some when the `shutdown()` happened.
69         # Silence their logger's spew about the event loop being closed.
70         cf_logger = logging.getLogger("concurrent.futures")
71         cf_logger.setLevel(logging.CRITICAL)
72         loop.close()
73
74
75 # diff-shades depends on being to monkeypatch this function to operate. I know it's
76 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
77 @mypyc_attr(patchable=True)
78 def reformat_many(
79     sources: Set[Path],
80     fast: bool,
81     write_back: WriteBack,
82     mode: Mode,
83     report: Report,
84     workers: Optional[int],
85 ) -> None:
86     """Reformat multiple files using a ProcessPoolExecutor."""
87     maybe_install_uvloop()
88
89     executor: Executor
90     worker_count = workers if workers is not None else DEFAULT_WORKERS
91     if sys.platform == "win32":
92         # Work around https://bugs.python.org/issue26903
93         assert worker_count is not None
94         worker_count = min(worker_count, 60)
95     try:
96         executor = ProcessPoolExecutor(max_workers=worker_count)
97     except (ImportError, NotImplementedError, OSError):
98         # we arrive here if the underlying system does not support multi-processing
99         # like in AWS Lambda or Termux, in which case we gracefully fallback to
100         # a ThreadPoolExecutor with just a single worker (more workers would not do us
101         # any good due to the Global Interpreter Lock)
102         executor = ThreadPoolExecutor(max_workers=1)
103
104     loop = asyncio.new_event_loop()
105     asyncio.set_event_loop(loop)
106     try:
107         loop.run_until_complete(
108             schedule_formatting(
109                 sources=sources,
110                 fast=fast,
111                 write_back=write_back,
112                 mode=mode,
113                 report=report,
114                 loop=loop,
115                 executor=executor,
116             )
117         )
118     finally:
119         try:
120             shutdown(loop)
121         finally:
122             asyncio.set_event_loop(None)
123         if executor is not None:
124             executor.shutdown()
125
126
127 async def schedule_formatting(
128     sources: Set[Path],
129     fast: bool,
130     write_back: WriteBack,
131     mode: Mode,
132     report: "Report",
133     loop: asyncio.AbstractEventLoop,
134     executor: "Executor",
135 ) -> None:
136     """Run formatting of `sources` in parallel using the provided `executor`.
137
138     (Use ProcessPoolExecutors for actual parallelism.)
139
140     `write_back`, `fast`, and `mode` options are passed to
141     :func:`format_file_in_place`.
142     """
143     cache: Cache = {}
144     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
145         cache = read_cache(mode)
146         sources, cached = filter_cached(cache, sources)
147         for src in sorted(cached):
148             report.done(src, Changed.CACHED)
149     if not sources:
150         return
151
152     cancelled = []
153     sources_to_cache = []
154     lock = None
155     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
156         # For diff output, we need locks to ensure we don't interleave output
157         # from different processes.
158         manager = Manager()
159         lock = manager.Lock()
160     tasks = {
161         asyncio.ensure_future(
162             loop.run_in_executor(
163                 executor, format_file_in_place, src, fast, mode, write_back, lock
164             )
165         ): src
166         for src in sorted(sources)
167     }
168     pending = tasks.keys()
169     try:
170         loop.add_signal_handler(signal.SIGINT, cancel, pending)
171         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
172     except NotImplementedError:
173         # There are no good alternatives for these on Windows.
174         pass
175     while pending:
176         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
177         for task in done:
178             src = tasks.pop(task)
179             if task.cancelled():
180                 cancelled.append(task)
181             elif task.exception():
182                 report.failed(src, str(task.exception()))
183             else:
184                 changed = Changed.YES if task.result() else Changed.NO
185                 # If the file was written back or was successfully checked as
186                 # well-formatted, store this information in the cache.
187                 if write_back is WriteBack.YES or (
188                     write_back is WriteBack.CHECK and changed is Changed.NO
189                 ):
190                     sources_to_cache.append(src)
191                 report.done(src, changed)
192     if cancelled:
193         if sys.version_info >= (3, 7):
194             await asyncio.gather(*cancelled, return_exceptions=True)
195         else:
196             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
197     if sources_to_cache:
198         write_cache(cache, sources_to_cache, mode)