]> git.madduck.net Git - etc/vim.git/blob - src/black/cache.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Pickle raw tuples in FileData cache (#3877)
[etc/vim.git] / src / black / cache.py
1 """Caching of formatted files with feature-based invalidation."""
2 import hashlib
3 import os
4 import pickle
5 import sys
6 import tempfile
7 from dataclasses import dataclass, field
8 from pathlib import Path
9 from typing import Dict, Iterable, NamedTuple, Set, Tuple
10
11 from platformdirs import user_cache_dir
12
13 from _black_version import version as __version__
14 from black.mode import Mode
15
16 if sys.version_info >= (3, 11):
17     from typing import Self
18 else:
19     from typing_extensions import Self
20
21
22 class FileData(NamedTuple):
23     st_mtime: float
24     st_size: int
25     hash: str
26
27
28 def get_cache_dir() -> Path:
29     """Get the cache directory used by black.
30
31     Users can customize this directory on all systems using `BLACK_CACHE_DIR`
32     environment variable. By default, the cache directory is the user cache directory
33     under the black application.
34
35     This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid
36     repeated calls.
37     """
38     # NOTE: Function mostly exists as a clean way to test getting the cache directory.
39     default_cache_dir = user_cache_dir("black", version=__version__)
40     cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir))
41     return cache_dir
42
43
44 CACHE_DIR = get_cache_dir()
45
46
47 def get_cache_file(mode: Mode) -> Path:
48     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
49
50
51 @dataclass
52 class Cache:
53     mode: Mode
54     cache_file: Path
55     file_data: Dict[str, FileData] = field(default_factory=dict)
56
57     @classmethod
58     def read(cls, mode: Mode) -> Self:
59         """Read the cache if it exists and is well formed.
60
61         If it is not well formed, the call to write later should
62         resolve the issue.
63         """
64         cache_file = get_cache_file(mode)
65         if not cache_file.exists():
66             return cls(mode, cache_file)
67
68         with cache_file.open("rb") as fobj:
69             try:
70                 data: Dict[str, Tuple[float, int, str]] = pickle.load(fobj)
71                 file_data = {k: FileData(*v) for k, v in data.items()}
72             except (pickle.UnpicklingError, ValueError, IndexError):
73                 return cls(mode, cache_file)
74
75         return cls(mode, cache_file, file_data)
76
77     @staticmethod
78     def hash_digest(path: Path) -> str:
79         """Return hash digest for path."""
80
81         data = path.read_bytes()
82         return hashlib.sha256(data).hexdigest()
83
84     @staticmethod
85     def get_file_data(path: Path) -> FileData:
86         """Return file data for path."""
87
88         stat = path.stat()
89         hash = Cache.hash_digest(path)
90         return FileData(stat.st_mtime, stat.st_size, hash)
91
92     def is_changed(self, source: Path) -> bool:
93         """Check if source has changed compared to cached version."""
94         res_src = source.resolve()
95         old = self.file_data.get(str(res_src))
96         if old is None:
97             return True
98
99         st = res_src.stat()
100         if st.st_size != old.st_size:
101             return True
102         if int(st.st_mtime) != int(old.st_mtime):
103             new_hash = Cache.hash_digest(res_src)
104             if new_hash != old.hash:
105                 return True
106         return False
107
108     def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
109         """Split an iterable of paths in `sources` into two sets.
110
111         The first contains paths of files that modified on disk or are not in the
112         cache. The other contains paths to non-modified files.
113         """
114         changed: Set[Path] = set()
115         done: Set[Path] = set()
116         for src in sources:
117             if self.is_changed(src):
118                 changed.add(src)
119             else:
120                 done.add(src)
121         return changed, done
122
123     def write(self, sources: Iterable[Path]) -> None:
124         """Update the cache file data and write a new cache file."""
125         self.file_data.update(
126             **{str(src.resolve()): Cache.get_file_data(src) for src in sources}
127         )
128         try:
129             CACHE_DIR.mkdir(parents=True, exist_ok=True)
130             with tempfile.NamedTemporaryFile(
131                 dir=str(self.cache_file.parent), delete=False
132             ) as f:
133                 # We store raw tuples in the cache because pickling NamedTuples
134                 # doesn't work with mypyc on Python 3.8, and because it's faster.
135                 data: Dict[str, Tuple[float, int, str]] = {
136                     k: (*v,) for k, v in self.file_data.items()
137                 }
138                 pickle.dump(data, f, protocol=4)
139             os.replace(f.name, self.cache_file)
140         except OSError:
141             pass