X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/4bcae4cf839aeba828bcdc4764369ae790f81d0d..cac18293d5a6bd6b34a953f9cb5413f9826e505f:/plugin/black.vim?ds=sidebyside

diff --git a/plugin/black.vim b/plugin/black.vim
index 9def61a..c5f0313 100644
--- a/plugin/black.vim
+++ b/plugin/black.vim
@@ -14,7 +14,12 @@
 "    - restore cursor/window position after formatting
 
 if v:version < 700 || !has('python3')
-    echo "This script requires vim7.0+ with Python 3.6 support."
+    func! __BLACK_MISSING()
+        echo "The black.vim plugin requires vim7.0+ with Python 3.6 support."
+    endfunc
+    command! Black :call __BLACK_MISSING()
+    command! BlackUpgrade :call __BLACK_MISSING()
+    command! BlackVersion :call __BLACK_MISSING()
     finish
 endif
 
@@ -36,14 +41,45 @@ endif
 if !exists("g:black_linelength")
   let g:black_linelength = 88
 endif
-if !exists("g:black_skip_string_normalization")
-  let g:black_skip_string_normalization = 0
+if !exists("g:black_string_normalization")
+  if exists("g:black_skip_string_normalization")
+    let g:black_string_normalization = !g:black_skip_string_normalization
+  else
+    let g:black_string_normalization = 1
+  endif
+endif
+if !exists("g:black_quiet")
+  let g:black_quiet = 0
 endif
 
-python3 << endpython3
+python3 << EndPython3
+import collections
 import os
 import sys
 import vim
+from distutils.util import strtobool
+
+
+class Flag(collections.namedtuple("FlagBase", "name, cast")):
+  @property
+  def var_name(self):
+    return self.name.replace("-", "_")
+
+  @property
+  def vim_rc_name(self):
+    name = self.var_name
+    if name == "line_length":
+      name = name.replace("_", "")
+    return "g:black_" + name
+
+
+FLAGS = [
+  Flag(name="line_length", cast=int),
+  Flag(name="fast", cast=strtobool),
+  Flag(name="string_normalization", cast=strtobool),
+  Flag(name="quiet", cast=strtobool),
+]
+
 
 def _get_python_binary(exec_prefix):
   try:
@@ -81,13 +117,23 @@ def _initialize_black_env(upgrade=False):
   if not virtualenv_path.is_dir():
     print('Please wait, one time setup for Black.')
     _executable = sys.executable
+    _base_executable = getattr(sys, "_base_executable", _executable)
     try:
-      sys.executable = str(_get_python_binary(Path(sys.exec_prefix)))
+      executable = str(_get_python_binary(Path(sys.exec_prefix)))
+      sys.executable = executable
+      sys._base_executable = executable
       print(f'Creating a virtualenv in {virtualenv_path}...')
       print('(this path can be customized in .vimrc by setting g:black_virtualenv)')
       venv.create(virtualenv_path, with_pip=True)
+    except Exception:
+      print('Encountered exception while creating virtualenv (see traceback below).')
+      print(f'Removing {virtualenv_path}...')
+      import shutil
+      shutil.rmtree(virtualenv_path)
+      raise
     finally:
       sys.executable = _executable
+      sys._base_executable = _base_executable
     first_install = True
   if first_install:
     print('Installing Black with pip...')
@@ -99,7 +145,7 @@ def _initialize_black_env(upgrade=False):
   if first_install:
     print('Pro-tip: to upgrade Black in the future, use the :BlackUpgrade command and restart Vim.\n')
   if virtualenv_site_packages not in sys.path:
-    sys.path.append(virtualenv_site_packages)
+    sys.path.insert(0, virtualenv_site_packages)
   return True
 
 if _initialize_black_env():
@@ -108,51 +154,56 @@ if _initialize_black_env():
 
 def Black():
   start = time.time()
-  fast = bool(int(vim.eval("g:black_fast")))
+  configs = get_configs()
   mode = black.FileMode(
-    line_length=int(vim.eval("g:black_linelength")),
-    string_normalization=not bool(int(vim.eval("g:black_skip_string_normalization"))),
+    line_length=configs["line_length"],
+    string_normalization=configs["string_normalization"],
     is_pyi=vim.current.buffer.name.endswith('.pyi'),
   )
-  (cursor_line, cursor_column) = vim.current.window.cursor
-  cb = vim.current.buffer[:]
-  cb_bc = cb[0:cursor_line]
-  # Format all code before the cursor.
-  # Detect unclosed blocks, close them with pass.
-  last_line = cb_bc[-1]
-  if last_line.rstrip().endswith(":"):
-      cb_bc[-1] = last_line + " pass"
-  # Determine old:new cursor location mapping
-  buffer_str_before = '\n'.join(cb_bc)+'\n'
-  try:
-    new_buffer_str_before = black.format_file_contents(buffer_str_before, fast=fast, mode=mode)
-    new_cb = new_buffer_str_before.split('\n')[:-1]
-    new_cursor_line = len(new_cb)
-    new_cursor = (new_cursor_line, cursor_column)
-  except black.NothingChanged:
-    new_cursor_line = cursor_line
-    new_cursor = (new_cursor_line, cursor_column)
-  except Exception as exc:
-    print(exc)
-  # Now we know where the cursor should be
-  # when we format the entire buffer. Do it:
-  buffer_str = '\n'.join(cb) + '\n'
+  quiet = configs["quiet"]
+
+  buffer_str = '\n'.join(vim.current.buffer) + '\n'
   try:
-    new_buffer_str = black.format_file_contents(buffer_str, fast=fast, mode=mode)
+    new_buffer_str = black.format_file_contents(
+      buffer_str,
+      fast=configs["fast"],
+      mode=mode,
+    )
   except black.NothingChanged:
-    print(f'Already well formatted, good job. (took {time.time() - start:.4f}s)')
+    if not quiet:
+      print(f'Already well formatted, good job. (took {time.time() - start:.4f}s)')
   except Exception as exc:
     print(exc)
   else:
-    # Replace the buffer
-    new_cb = new_buffer_str.split('\n')[:-1]
-    vim.current.buffer[:] = new_cb
-    # Restore the cursor to its rightful place
-    try:
-      vim.current.window.cursor = new_cursor
-    except vim.error:
-      vim.current.window.cursor = (len(vim.current.buffer), 0)
-    print(f'Reformatted in {time.time() - start:.4f}s.')
+    current_buffer = vim.current.window.buffer
+    cursors = []
+    for i, tabpage in enumerate(vim.tabpages):
+      if tabpage.valid:
+        for j, window in enumerate(tabpage.windows):
+          if window.valid and window.buffer == current_buffer:
+            cursors.append((i, j, window.cursor))
+    vim.current.buffer[:] = new_buffer_str.split('\n')[:-1]
+    for i, j, cursor in cursors:
+      window = vim.tabpages[i].windows[j]
+      try:
+        window.cursor = cursor
+      except vim.error:
+        window.cursor = (len(window.buffer), 0)
+    if not quiet:
+      print(f'Reformatted in {time.time() - start:.4f}s.')
+
+def get_configs():
+  path_pyproject_toml = black.find_pyproject_toml(vim.eval("fnamemodify(getcwd(), ':t')"))
+  if path_pyproject_toml:
+    toml_config = black.parse_pyproject_toml(path_pyproject_toml)
+  else:
+    toml_config = {}
+
+  return {
+    flag.var_name: flag.cast(toml_config.get(flag.name, vim.eval(flag.vim_rc_name)))
+    for flag in FLAGS
+  }
+
 
 def BlackUpgrade():
   _initialize_black_env(upgrade=True)
@@ -160,7 +211,7 @@ def BlackUpgrade():
 def BlackVersion():
   print(f'Black, version {black.__version__} on Python {sys.version}.')
 
-endpython3
+EndPython3
 
 command! Black :py3 Black()
 command! BlackUpgrade :py3 BlackUpgrade()