]> git.madduck.net Git - etc/vim.git/blob - src/black/concurrency.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update stable branch after publishing to PyPI (#3223)
[etc/vim.git] / src / black / concurrency.py
1 """
2 Formatting many files at once via multiprocessing. Contains entrypoint and utilities.
3
4 NOTE: this module is only imported if we need to format several files at once.
5 """
6
7 import asyncio
8 import logging
9 import os
10 import signal
11 import sys
12 from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor
13 from multiprocessing import Manager
14 from pathlib import Path
15 from typing import Any, Iterable, Optional, Set
16
17 from mypy_extensions import mypyc_attr
18
19 from black import WriteBack, format_file_in_place
20 from black.cache import Cache, filter_cached, read_cache, write_cache
21 from black.mode import Mode
22 from black.output import err
23 from black.report import Changed, Report
24
25
26 def maybe_install_uvloop() -> None:
27     """If our environment has uvloop installed we use it.
28
29     This is called only from command-line entry points to avoid
30     interfering with the parent process if Black is used as a library.
31     """
32     try:
33         import uvloop
34
35         uvloop.install()
36     except ImportError:
37         pass
38
39
40 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
41     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
42     err("Aborted!")
43     for task in tasks:
44         task.cancel()
45
46
47 def shutdown(loop: asyncio.AbstractEventLoop) -> None:
48     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
49     try:
50         if sys.version_info[:2] >= (3, 7):
51             all_tasks = asyncio.all_tasks
52         else:
53             all_tasks = asyncio.Task.all_tasks
54         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
55         to_cancel = [task for task in all_tasks(loop) if not task.done()]
56         if not to_cancel:
57             return
58
59         for task in to_cancel:
60             task.cancel()
61         if sys.version_info >= (3, 7):
62             loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
63         else:
64             loop.run_until_complete(
65                 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True)
66             )
67     finally:
68         # `concurrent.futures.Future` objects cannot be cancelled once they
69         # are already running. There might be some when the `shutdown()` happened.
70         # Silence their logger's spew about the event loop being closed.
71         cf_logger = logging.getLogger("concurrent.futures")
72         cf_logger.setLevel(logging.CRITICAL)
73         loop.close()
74
75
76 # diff-shades depends on being to monkeypatch this function to operate. I know it's
77 # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
78 @mypyc_attr(patchable=True)
79 def reformat_many(
80     sources: Set[Path],
81     fast: bool,
82     write_back: WriteBack,
83     mode: Mode,
84     report: Report,
85     workers: Optional[int],
86 ) -> None:
87     """Reformat multiple files using a ProcessPoolExecutor."""
88     maybe_install_uvloop()
89
90     executor: Executor
91     if workers is None:
92         workers = os.cpu_count() or 1
93     if sys.platform == "win32":
94         # Work around https://bugs.python.org/issue26903
95         workers = min(workers, 60)
96     try:
97         executor = ProcessPoolExecutor(max_workers=workers)
98     except (ImportError, NotImplementedError, OSError):
99         # we arrive here if the underlying system does not support multi-processing
100         # like in AWS Lambda or Termux, in which case we gracefully fallback to
101         # a ThreadPoolExecutor with just a single worker (more workers would not do us
102         # any good due to the Global Interpreter Lock)
103         executor = ThreadPoolExecutor(max_workers=1)
104
105     loop = asyncio.new_event_loop()
106     asyncio.set_event_loop(loop)
107     try:
108         loop.run_until_complete(
109             schedule_formatting(
110                 sources=sources,
111                 fast=fast,
112                 write_back=write_back,
113                 mode=mode,
114                 report=report,
115                 loop=loop,
116                 executor=executor,
117             )
118         )
119     finally:
120         try:
121             shutdown(loop)
122         finally:
123             asyncio.set_event_loop(None)
124         if executor is not None:
125             executor.shutdown()
126
127
128 async def schedule_formatting(
129     sources: Set[Path],
130     fast: bool,
131     write_back: WriteBack,
132     mode: Mode,
133     report: "Report",
134     loop: asyncio.AbstractEventLoop,
135     executor: "Executor",
136 ) -> None:
137     """Run formatting of `sources` in parallel using the provided `executor`.
138
139     (Use ProcessPoolExecutors for actual parallelism.)
140
141     `write_back`, `fast`, and `mode` options are passed to
142     :func:`format_file_in_place`.
143     """
144     cache: Cache = {}
145     if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
146         cache = read_cache(mode)
147         sources, cached = filter_cached(cache, sources)
148         for src in sorted(cached):
149             report.done(src, Changed.CACHED)
150     if not sources:
151         return
152
153     cancelled = []
154     sources_to_cache = []
155     lock = None
156     if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
157         # For diff output, we need locks to ensure we don't interleave output
158         # from different processes.
159         manager = Manager()
160         lock = manager.Lock()
161     tasks = {
162         asyncio.ensure_future(
163             loop.run_in_executor(
164                 executor, format_file_in_place, src, fast, mode, write_back, lock
165             )
166         ): src
167         for src in sorted(sources)
168     }
169     pending = tasks.keys()
170     try:
171         loop.add_signal_handler(signal.SIGINT, cancel, pending)
172         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
173     except NotImplementedError:
174         # There are no good alternatives for these on Windows.
175         pass
176     while pending:
177         done, _ = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
178         for task in done:
179             src = tasks.pop(task)
180             if task.cancelled():
181                 cancelled.append(task)
182             elif task.exception():
183                 report.failed(src, str(task.exception()))
184             else:
185                 changed = Changed.YES if task.result() else Changed.NO
186                 # If the file was written back or was successfully checked as
187                 # well-formatted, store this information in the cache.
188                 if write_back is WriteBack.YES or (
189                     write_back is WriteBack.CHECK and changed is Changed.NO
190                 ):
191                     sources_to_cache.append(src)
192                 report.done(src, changed)
193     if cancelled:
194         if sys.version_info >= (3, 7):
195             await asyncio.gather(*cancelled, return_exceptions=True)
196         else:
197             await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
198     if sources_to_cache:
199         write_cache(cache, sources_to_cache, mode)