From 8fef74cf527d7fa5f2da78fafc61152c8766d0ad Mon Sep 17 00:00:00 2001 From: Tal Amuyal Date: Tue, 3 Mar 2020 13:23:28 +0200 Subject: [PATCH] Teach the Vim plugin to respect pyproject.toml (issue 414) (#1273) Creates two separate functions: 1) abspath_pyproject_toml: find the absolute path to pyproject.toml 2) parse_pyproject_toml: finds black-specific toml config Co-authored-by: Samuel Roeca --- black.py | 31 ++++++++++++++++++++---------- plugin/black.vim | 50 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/black.py b/black.py index 3897eba..69d24c5 100644 --- a/black.py +++ b/black.py @@ -213,6 +213,23 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] ) -> Optional[str]: @@ -223,16 +240,12 @@ def read_pyproject_toml( """ assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -243,9 +256,7 @@ def read_pyproject_toml( if ctx.default_map is None: ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + ctx.default_map.update(config) # type: ignore # bad types in .pyi return value diff --git a/plugin/black.vim b/plugin/black.vim index b174e59..a4047d4 100644 --- a/plugin/black.vim +++ b/plugin/black.vim @@ -41,10 +41,34 @@ if !exists("g:black_skip_string_normalization") endif python3 << endpython3 +import collections import os import sys import vim + +class Flag(collections.namedtuple("FlagBase", "name, cast")): + @property + def var_name(self): + return self.name.replace("-", "_") + + @property + def vim_rc_name(self): + name = self.var_name + if name == "line_length": + name = name.replace("_", "") + if name == "string_normalization": + name = "skip_" + name + return "g:black_" + name + + +FLAGS = [ + Flag(name="line_length", cast=int), + Flag(name="fast", cast=bool), + Flag(name="string_normalization", cast=bool), +] + + def _get_python_binary(exec_prefix): try: default = vim.eval("g:pymode_python").strip() @@ -108,15 +132,20 @@ if _initialize_black_env(): def Black(): start = time.time() - fast = bool(int(vim.eval("g:black_fast"))) + configs = get_configs() mode = black.FileMode( - line_length=int(vim.eval("g:black_linelength")), - string_normalization=not bool(int(vim.eval("g:black_skip_string_normalization"))), + line_length=configs["line_length"], + string_normalization=configs["string_normalization"], is_pyi=vim.current.buffer.name.endswith('.pyi'), ) + buffer_str = '\n'.join(vim.current.buffer) + '\n' try: - new_buffer_str = black.format_file_contents(buffer_str, fast=fast, mode=mode) + new_buffer_str = black.format_file_contents( + buffer_str, + fast=configs["fast"], + mode=mode, + ) except black.NothingChanged: print(f'Already well formatted, good job. (took {time.time() - start:.4f}s)') except Exception as exc: @@ -138,6 +167,19 @@ def Black(): window.cursor = (len(window.buffer), 0) print(f'Reformatted in {time.time() - start:.4f}s.') +def get_configs(): + path_pyproject_toml = black.find_pyproject_toml(vim.eval("fnamemodify(getcwd(), ':t')")) + if path_pyproject_toml: + toml_config = black.parse_pyproject_toml(path_pyproject_toml) + else: + toml_config = {} + + return { + flag.var_name: toml_config.get(flag.name, flag.cast(vim.eval(flag.vim_rc_name))) + for flag in FLAGS + } + + def BlackUpgrade(): _initialize_black_env(upgrade=True) -- 2.39.5