X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/394edc388755eff01cc2ea155918bc4379ce933d..4eb822f20cafcc2a62ac88f4f26298d732423bf1:/plugin/black.vim?ds=inline

diff --git a/plugin/black.vim b/plugin/black.vim
index 8e05c2a..8106ea1 100644
--- a/plugin/black.vim
+++ b/plugin/black.vim
@@ -14,7 +14,12 @@
 "    - restore cursor/window position after formatting
 
 if v:version < 700 || !has('python3')
-    echo "This script requires vim7.0+ with Python 3.6 support."
+    func! __BLACK_MISSING()
+        echo "The black.vim plugin requires vim7.0+ with Python 3.6 support."
+    endfunc
+    command! Black :call __BLACK_MISSING()
+    command! BlackUpgrade :call __BLACK_MISSING()
+    command! BlackVersion :call __BLACK_MISSING()
     finish
 endif
 
@@ -41,10 +46,34 @@ if !exists("g:black_skip_string_normalization")
 endif
 
 python3 << endpython3
+import collections
 import os
 import sys
 import vim
 
+
+class Flag(collections.namedtuple("FlagBase", "name, cast")):
+  @property
+  def var_name(self):
+    return self.name.replace("-", "_")
+
+  @property
+  def vim_rc_name(self):
+    name = self.var_name
+    if name == "line_length":
+      name = name.replace("_", "")
+    if name == "string_normalization":
+      name = "skip_" + name
+    return "g:black_" + name
+
+
+FLAGS = [
+  Flag(name="line_length", cast=int),
+  Flag(name="fast", cast=bool),
+  Flag(name="string_normalization", cast=bool),
+]
+
+
 def _get_python_binary(exec_prefix):
   try:
     default = vim.eval("g:pymode_python").strip()
@@ -108,28 +137,54 @@ if _initialize_black_env():
 
 def Black():
   start = time.time()
-  fast = bool(int(vim.eval("g:black_fast")))
+  configs = get_configs()
   mode = black.FileMode(
-    line_length=int(vim.eval("g:black_linelength")),
-    string_normalization=not bool(int(vim.eval("g:black_skip_string_normalization"))),
+    line_length=configs["line_length"],
+    string_normalization=configs["string_normalization"],
     is_pyi=vim.current.buffer.name.endswith('.pyi'),
   )
+
   buffer_str = '\n'.join(vim.current.buffer) + '\n'
   try:
-    new_buffer_str = black.format_file_contents(buffer_str, fast=fast, mode=mode)
+    new_buffer_str = black.format_file_contents(
+      buffer_str,
+      fast=configs["fast"],
+      mode=mode,
+    )
   except black.NothingChanged:
     print(f'Already well formatted, good job. (took {time.time() - start:.4f}s)')
   except Exception as exc:
     print(exc)
   else:
-    cursor = vim.current.window.cursor
+    current_buffer = vim.current.window.buffer
+    cursors = []
+    for i, tabpage in enumerate(vim.tabpages):
+      if tabpage.valid:
+        for j, window in enumerate(tabpage.windows):
+          if window.valid and window.buffer == current_buffer:
+            cursors.append((i, j, window.cursor))
     vim.current.buffer[:] = new_buffer_str.split('\n')[:-1]
-    try:
-      vim.current.window.cursor = cursor
-    except vim.error:
-      vim.current.window.cursor = (len(vim.current.buffer), 0)
+    for i, j, cursor in cursors:
+      window = vim.tabpages[i].windows[j]
+      try:
+        window.cursor = cursor
+      except vim.error:
+        window.cursor = (len(window.buffer), 0)
     print(f'Reformatted in {time.time() - start:.4f}s.')
 
+def get_configs():
+  path_pyproject_toml = black.find_pyproject_toml(vim.eval("fnamemodify(getcwd(), ':t')"))
+  if path_pyproject_toml:
+    toml_config = black.parse_pyproject_toml(path_pyproject_toml)
+  else:
+    toml_config = {}
+
+  return {
+    flag.var_name: toml_config.get(flag.name, flag.cast(vim.eval(flag.vim_rc_name)))
+    for flag in FLAGS
+  }
+
+
 def BlackUpgrade():
   _initialize_black_env(upgrade=True)