"""Caching of formatted files with feature-based invalidation."""
-
+import hashlib
import os
import pickle
-from pathlib import Path
+import sys
import tempfile
-from typing import Dict, Iterable, Set, Tuple
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Dict, Iterable, NamedTuple, Set, Tuple
from platformdirs import user_cache_dir
-from black.mode import Mode
-
from _black_version import version as __version__
+from black.mode import Mode
+if sys.version_info >= (3, 11):
+ from typing import Self
+else:
+ from typing_extensions import Self
-# types
-Timestamp = float
-FileSize = int
-CacheInfo = Tuple[Timestamp, FileSize]
-Cache = Dict[str, CacheInfo]
+class FileData(NamedTuple):
+ st_mtime: float
+ st_size: int
+ hash: str
-CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+def get_cache_dir() -> Path:
+ """Get the cache directory used by black.
-def read_cache(mode: Mode) -> Cache:
- """Read the cache if it exists and is well formed.
+ Users can customize this directory on all systems using `BLACK_CACHE_DIR`
+ environment variable. By default, the cache directory is the user cache directory
+ under the black application.
- If it is not well formed, the call to write_cache later should resolve the issue.
+ This result is immediately set to a constant `black.cache.CACHE_DIR` as to avoid
+ repeated calls.
"""
- cache_file = get_cache_file(mode)
- if not cache_file.exists():
- return {}
+ # NOTE: Function mostly exists as a clean way to test getting the cache directory.
+ default_cache_dir = user_cache_dir("black", version=__version__)
+ cache_dir = Path(os.environ.get("BLACK_CACHE_DIR", default_cache_dir))
+ return cache_dir
- with cache_file.open("rb") as fobj:
- try:
- cache: Cache = pickle.load(fobj)
- except (pickle.UnpicklingError, ValueError, IndexError):
- return {}
- return cache
+CACHE_DIR = get_cache_dir()
def get_cache_file(mode: Mode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
-def get_cache_info(path: Path) -> CacheInfo:
- """Return the information used to check if a file is already formatted or not."""
- stat = path.stat()
- return stat.st_mtime, stat.st_size
-
-
-def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
- """Split an iterable of paths in `sources` into two sets.
-
- The first contains paths of files that modified on disk or are not in the
- cache. The other contains paths to non-modified files.
- """
- todo, done = set(), set()
- for src in sources:
- res_src = src.resolve()
- if cache.get(str(res_src)) != get_cache_info(res_src):
- todo.add(src)
- else:
- done.add(src)
- return todo, done
-
-
-def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
- """Update the cache file."""
- cache_file = get_cache_file(mode)
- try:
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
- new_cache = {
- **cache,
- **{str(src.resolve()): get_cache_info(src) for src in sources},
- }
- with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
- pickle.dump(new_cache, f, protocol=4)
- os.replace(f.name, cache_file)
- except OSError:
- pass
+@dataclass
+class Cache:
+ mode: Mode
+ cache_file: Path
+ file_data: Dict[str, FileData] = field(default_factory=dict)
+
+ @classmethod
+ def read(cls, mode: Mode) -> Self:
+ """Read the cache if it exists and is well formed.
+
+ If it is not well formed, the call to write later should
+ resolve the issue.
+ """
+ cache_file = get_cache_file(mode)
+ if not cache_file.exists():
+ return cls(mode, cache_file)
+
+ with cache_file.open("rb") as fobj:
+ try:
+ file_data: Dict[str, FileData] = pickle.load(fobj)
+ except (pickle.UnpicklingError, ValueError, IndexError):
+ return cls(mode, cache_file)
+
+ return cls(mode, cache_file, file_data)
+
+ @staticmethod
+ def hash_digest(path: Path) -> str:
+ """Return hash digest for path."""
+
+ data = path.read_bytes()
+ return hashlib.sha256(data).hexdigest()
+
+ @staticmethod
+ def get_file_data(path: Path) -> FileData:
+ """Return file data for path."""
+
+ stat = path.stat()
+ hash = Cache.hash_digest(path)
+ return FileData(stat.st_mtime, stat.st_size, hash)
+
+ def is_changed(self, source: Path) -> bool:
+ """Check if source has changed compared to cached version."""
+ res_src = source.resolve()
+ old = self.file_data.get(str(res_src))
+ if old is None:
+ return True
+
+ st = res_src.stat()
+ if st.st_size != old.st_size:
+ return True
+ if int(st.st_mtime) != int(old.st_mtime):
+ new_hash = Cache.hash_digest(res_src)
+ if new_hash != old.hash:
+ return True
+ return False
+
+ def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
+ """Split an iterable of paths in `sources` into two sets.
+
+ The first contains paths of files that modified on disk or are not in the
+ cache. The other contains paths to non-modified files.
+ """
+ changed: Set[Path] = set()
+ done: Set[Path] = set()
+ for src in sources:
+ if self.is_changed(src):
+ changed.add(src)
+ else:
+ done.add(src)
+ return changed, done
+
+ def write(self, sources: Iterable[Path]) -> None:
+ """Update the cache file data and write a new cache file."""
+ self.file_data.update(
+ **{str(src.resolve()): Cache.get_file_data(src) for src in sources}
+ )
+ try:
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
+ with tempfile.NamedTemporaryFile(
+ dir=str(self.cache_file.parent), delete=False
+ ) as f:
+ pickle.dump(self.file_data, f, protocol=4)
+ os.replace(f.name, self.cache_file)
+ except OSError:
+ pass