X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/ce14fa8b497bae2b50ec48b3bd7022573a59cdb1..8fef74cf527d7fa5f2da78fafc61152c8766d0ad:/black.py diff --git a/black.py b/black.py index 3897eba..69d24c5 100644 --- a/black.py +++ b/black.py @@ -213,6 +213,23 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] ) -> Optional[str]: @@ -223,16 +240,12 @@ def read_pyproject_toml( """ assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -243,9 +256,7 @@ def read_pyproject_toml( if ctx.default_map is None: ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + ctx.default_map.update(config) # type: ignore # bad types in .pyi return value