]> git.madduck.net Git - etc/vim.git/blob - src/black/cache.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Move coverage configurations to `pyproject.toml` (#3858)
[etc/vim.git] / src / black / cache.py
1 """Caching of formatted files with feature-based invalidation."""
2 import hashlib
3 import os
4 import pickle
5 import sys
6 import tempfile
7 from dataclasses import dataclass, field
8 from pathlib import Path
9 from typing import Dict, Iterable, NamedTuple, Set, Tuple
10
11 from platformdirs import user_cache_dir
12
13 from _black_version import version as __version__
14 from black.mode import Mode
15
16 if sys.version_info >= (3, 11):
17     from typing import Self
18 else:
19     from typing_extensions import Self
20
21
22 class FileData(NamedTuple):
23     st_mtime: float
24     st_size: int
25     hash: str
26
27
28 def get_cache_dir() -> Path:
29     """Get the cache directory used by black.
30
31     Users can customize this directory on all systems using `BLACK_CACHE_DIR`
32     environment variable. By default, the cache directory is the user cache directory
33     under the black application.
34
35     This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid
36     repeated calls.
37     """
38     # NOTE: Function mostly exists as a clean way to test getting the cache directory.
39     default_cache_dir = user_cache_dir("black", version=__version__)
40     cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir))
41     return cache_dir
42
43
44 CACHE_DIR = get_cache_dir()
45
46
47 def get_cache_file(mode: Mode) -> Path:
48     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
49
50
51 @dataclass
52 class Cache:
53     mode: Mode
54     cache_file: Path
55     file_data: Dict[str, FileData] = field(default_factory=dict)
56
57     @classmethod
58     def read(cls, mode: Mode) -> Self:
59         """Read the cache if it exists and is well formed.
60
61         If it is not well formed, the call to write later should
62         resolve the issue.
63         """
64         cache_file = get_cache_file(mode)
65         if not cache_file.exists():
66             return cls(mode, cache_file)
67
68         with cache_file.open("rb") as fobj:
69             try:
70                 file_data: Dict[str, FileData] = pickle.load(fobj)
71             except (pickle.UnpicklingError, ValueError, IndexError):
72                 return cls(mode, cache_file)
73
74         return cls(mode, cache_file, file_data)
75
76     @staticmethod
77     def hash_digest(path: Path) -> str:
78         """Return hash digest for path."""
79
80         data = path.read_bytes()
81         return hashlib.sha256(data).hexdigest()
82
83     @staticmethod
84     def get_file_data(path: Path) -> FileData:
85         """Return file data for path."""
86
87         stat = path.stat()
88         hash = Cache.hash_digest(path)
89         return FileData(stat.st_mtime, stat.st_size, hash)
90
91     def is_changed(self, source: Path) -> bool:
92         """Check if source has changed compared to cached version."""
93         res_src = source.resolve()
94         old = self.file_data.get(str(res_src))
95         if old is None:
96             return True
97
98         st = res_src.stat()
99         if st.st_size != old.st_size:
100             return True
101         if int(st.st_mtime) != int(old.st_mtime):
102             new_hash = Cache.hash_digest(res_src)
103             if new_hash != old.hash:
104                 return True
105         return False
106
107     def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
108         """Split an iterable of paths in `sources` into two sets.
109
110         The first contains paths of files that modified on disk or are not in the
111         cache. The other contains paths to non-modified files.
112         """
113         changed: Set[Path] = set()
114         done: Set[Path] = set()
115         for src in sources:
116             if self.is_changed(src):
117                 changed.add(src)
118             else:
119                 done.add(src)
120         return changed, done
121
122     def write(self, sources: Iterable[Path]) -> None:
123         """Update the cache file data and write a new cache file."""
124         self.file_data.update(
125             **{str(src.resolve()): Cache.get_file_data(src) for src in sources}
126         )
127         try:
128             CACHE_DIR.mkdir(parents=True, exist_ok=True)
129             with tempfile.NamedTemporaryFile(
130                 dir=str(self.cache_file.parent), delete=False
131             ) as f:
132                 pickle.dump(self.file_data, f, protocol=4)
133             os.replace(f.name, self.cache_file)
134         except OSError:
135             pass