]> git.madduck.net Git - etc/vim.git/blobdiff - autoload/black.vim

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Return `NothingChanged` if non-Python cell magic is detected, to avoid tokenize error...
[etc/vim.git] / autoload / black.vim
index f0357b0712320f454da40fc7c9f606b1461bc735..66c5b9c2841a21475d68af3dc814f6730782ac13 100644 (file)
@@ -3,8 +3,13 @@ import collections
 import os
 import sys
 import vim
-from distutils.util import strtobool
 
+def strtobool(text):
+  if text.lower() in ['y', 'yes', 't', 'true' 'on', '1']:
+    return True
+  if text.lower() in ['n', 'no', 'f', 'false' 'off', '0']:
+    return False
+  raise ValueError(f"{text} is not convertable to boolean")
 
 class Flag(collections.namedtuple("FlagBase", "name, cast")):
   @property
@@ -22,8 +27,9 @@ class Flag(collections.namedtuple("FlagBase", "name, cast")):
 FLAGS = [
   Flag(name="line_length", cast=int),
   Flag(name="fast", cast=strtobool),
-  Flag(name="string_normalization", cast=strtobool),
+  Flag(name="skip_string_normalization", cast=strtobool),
   Flag(name="quiet", cast=strtobool),
+  Flag(name="skip_magic_trailing_comma", cast=strtobool),
 ]
 
 
@@ -98,13 +104,48 @@ if _initialize_black_env():
   import black
   import time
 
-def Black():
+def get_target_version(tv):
+  if isinstance(tv, black.TargetVersion):
+    return tv
+  ret = None
+  try:
+    ret = black.TargetVersion[tv.upper()]
+  except KeyError:
+    print(f"WARNING: Target version {tv!r} not recognized by Black, using default target")
+  return ret
+
+def Black(**kwargs):
+  """
+  kwargs allows you to override ``target_versions`` argument of
+  ``black.FileMode``.
+
+  ``target_version`` needs to be cleaned because ``black.FileMode``
+  expects the ``target_versions`` argument to be a set of TargetVersion enums.
+
+  Allow kwargs["target_version"] to be a string to allow
+  to type it more quickly.
+
+  Using also target_version instead of target_versions to remain
+  consistent to Black's documentation of the structure of pyproject.toml.
+  """
   start = time.time()
   configs = get_configs()
+
+  black_kwargs = {}
+  if "target_version" in kwargs:
+    target_version = kwargs["target_version"]
+
+    if not isinstance(target_version, (list, set)):
+      target_version = [target_version]
+    target_version = set(filter(lambda x: x, map(lambda tv: get_target_version(tv), target_version)))
+    black_kwargs["target_versions"] = target_version
+
   mode = black.FileMode(
     line_length=configs["line_length"],
-    string_normalization=configs["string_normalization"],
+    string_normalization=not configs["skip_string_normalization"],
     is_pyi=vim.current.buffer.name.endswith('.pyi'),
+    magic_trailing_comma=not configs["skip_magic_trailing_comma"],
+    **black_kwargs,
   )
   quiet = configs["quiet"]
 
@@ -139,14 +180,15 @@ def Black():
       print(f'Reformatted in {time.time() - start:.4f}s.')
 
 def get_configs():
-  path_pyproject_toml = black.find_pyproject_toml(vim.eval("fnamemodify(getcwd(), ':t')"))
+  filename = vim.eval("@%")
+  path_pyproject_toml = black.find_pyproject_toml((filename,))
   if path_pyproject_toml:
     toml_config = black.parse_pyproject_toml(path_pyproject_toml)
   else:
     toml_config = {}
 
   return {
-    flag.var_name: flag.cast(toml_config.get(flag.name, vim.eval(flag.vim_rc_name)))
+    flag.var_name: toml_config.get(flag.name, flag.cast(vim.eval(flag.vim_rc_name)))
     for flag in FLAGS
   }
 
@@ -159,8 +201,17 @@ def BlackVersion():
 
 EndPython3
 
-function black#Black()
-  :py3 Black()
+function black#Black(...)
+    let kwargs = {}
+    for arg in a:000
+        let arg_list = split(arg, '=')
+        let kwargs[arg_list[0]] = arg_list[1]
+    endfor
+python3 << EOF
+import vim
+kwargs = vim.eval("kwargs")
+EOF
+  :py3 Black(**kwargs)
 endfunction
 
 function black#BlackUpgrade()