X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/5446a92f0161e398de765bf9532d8c76c5652333..f6c139c5215ce04fd3e73a900f1372942d58eca0:/autoload/black.vim?ds=inline

diff --git a/autoload/black.vim b/autoload/black.vim
index f0357b0..6c381b4 100644
--- a/autoload/black.vim
+++ b/autoload/black.vim
@@ -3,8 +3,13 @@ import collections
 import os
 import sys
 import vim
-from distutils.util import strtobool
 
+def strtobool(text):
+  if text.lower() in ['y', 'yes', 't', 'true', 'on', '1']:
+    return True
+  if text.lower() in ['n', 'no', 'f', 'false', 'off', '0']:
+    return False
+  raise ValueError(f"{text} is not convertable to boolean")
 
 class Flag(collections.namedtuple("FlagBase", "name, cast")):
   @property
@@ -22,8 +27,9 @@ class Flag(collections.namedtuple("FlagBase", "name, cast")):
 FLAGS = [
   Flag(name="line_length", cast=int),
   Flag(name="fast", cast=strtobool),
-  Flag(name="string_normalization", cast=strtobool),
+  Flag(name="skip_string_normalization", cast=strtobool),
   Flag(name="quiet", cast=strtobool),
+  Flag(name="skip_magic_trailing_comma", cast=strtobool),
 ]
 
 
@@ -98,13 +104,48 @@ if _initialize_black_env():
   import black
   import time
 
-def Black():
+def get_target_version(tv):
+  if isinstance(tv, black.TargetVersion):
+    return tv
+  ret = None
+  try:
+    ret = black.TargetVersion[tv.upper()]
+  except KeyError:
+    print(f"WARNING: Target version {tv!r} not recognized by Black, using default target")
+  return ret
+
+def Black(**kwargs):
+  """
+  kwargs allows you to override ``target_versions`` argument of
+  ``black.FileMode``.
+
+  ``target_version`` needs to be cleaned because ``black.FileMode``
+  expects the ``target_versions`` argument to be a set of TargetVersion enums.
+
+  Allow kwargs["target_version"] to be a string to allow
+  to type it more quickly.
+
+  Using also target_version instead of target_versions to remain
+  consistent to Black's documentation of the structure of pyproject.toml.
+  """
   start = time.time()
   configs = get_configs()
+
+  black_kwargs = {}
+  if "target_version" in kwargs:
+    target_version = kwargs["target_version"]
+
+    if not isinstance(target_version, (list, set)):
+      target_version = [target_version]
+    target_version = set(filter(lambda x: x, map(lambda tv: get_target_version(tv), target_version)))
+    black_kwargs["target_versions"] = target_version
+
   mode = black.FileMode(
     line_length=configs["line_length"],
-    string_normalization=configs["string_normalization"],
+    string_normalization=not configs["skip_string_normalization"],
     is_pyi=vim.current.buffer.name.endswith('.pyi'),
+    magic_trailing_comma=not configs["skip_magic_trailing_comma"],
+    **black_kwargs,
   )
   quiet = configs["quiet"]
 
@@ -139,14 +180,15 @@ def Black():
       print(f'Reformatted in {time.time() - start:.4f}s.')
 
 def get_configs():
-  path_pyproject_toml = black.find_pyproject_toml(vim.eval("fnamemodify(getcwd(), ':t')"))
+  filename = vim.eval("@%")
+  path_pyproject_toml = black.find_pyproject_toml((filename,))
   if path_pyproject_toml:
     toml_config = black.parse_pyproject_toml(path_pyproject_toml)
   else:
     toml_config = {}
 
   return {
-    flag.var_name: flag.cast(toml_config.get(flag.name, vim.eval(flag.vim_rc_name)))
+    flag.var_name: toml_config.get(flag.name, flag.cast(vim.eval(flag.vim_rc_name)))
     for flag in FLAGS
   }
 
@@ -159,8 +201,17 @@ def BlackVersion():
 
 EndPython3
 
-function black#Black()
-  :py3 Black()
+function black#Black(...)
+    let kwargs = {}
+    for arg in a:000
+        let arg_list = split(arg, '=')
+        let kwargs[arg_list[0]] = arg_list[1]
+    endfor
+python3 << EOF
+import vim
+kwargs = vim.eval("kwargs")
+EOF
+  :py3 Black(**kwargs)
 endfunction
 
 function black#BlackUpgrade()