]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Use nullcontext in case when lock is None. Shutdown pool after code formatting. ...
[etc/vim.git] / black.py
index 8318674a722f41fa70e4dfe95febe839dd93c87e..952fb0c659ff2a9a4969efc5506e37f504a31f43 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
-from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from concurrent.futures import Executor, ProcessPoolExecutor
+from contextlib import contextmanager
 from datetime import datetime
 from enum import Enum
 from functools import lru_cache, partial, wraps
 from datetime import datetime
 from enum import Enum
 from functools import lru_cache, partial, wraps
@@ -524,6 +524,7 @@ def reformat_many(
         )
     finally:
         shutdown(loop)
         )
     finally:
         shutdown(loop)
+        executor.shutdown()
 
 
 async def schedule_formatting(
 
 
 async def schedule_formatting(
@@ -532,14 +533,14 @@ async def schedule_formatting(
     write_back: WriteBack,
     mode: FileMode,
     report: "Report",
     write_back: WriteBack,
     mode: FileMode,
     report: "Report",
-    loop: BaseEventLoop,
+    loop: asyncio.AbstractEventLoop,
     executor: Executor,
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
     executor: Executor,
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
-    `line_length`, `write_back`, `fast`, and `pyi` options are passed to
+    `write_back`, `fast`, and `mode` options are passed to
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
@@ -629,9 +630,8 @@ def format_file_in_place(
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
-        if lock:
-            lock.acquire()
-        try:
+
+        with lock or nullcontext():
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
@@ -640,9 +640,7 @@ def format_file_in_place(
             )
             f.write(diff_contents)
             f.detach()
             )
             f.write(diff_contents)
             f.detach()
-        finally:
-            if lock:
-                lock.release()
+
     return True
 
 
     return True
 
 
@@ -1282,10 +1280,13 @@ class Line:
         try:
             last_leaf = self.leaves[-1]
             ignored_ids.add(id(last_leaf))
         try:
             last_leaf = self.leaves[-1]
             ignored_ids.add(id(last_leaf))
-            if last_leaf.type == token.COMMA:
-                # When trailing commas are inserted by Black for consistency, comments
-                # after the previous last element are not moved (they don't have to,
-                # rendering will still be correct).  So we ignore trailing commas.
+            if last_leaf.type == token.COMMA or (
+                last_leaf.type == token.RPAR and not last_leaf.value
+            ):
+                # When trailing commas or optional parens are inserted by Black for
+                # consistency, comments after the previous last element are not moved
+                # (they don't have to, rendering will still be correct).  So we ignore
+                # trailing commas and invisible.
                 last_leaf = self.leaves[-2]
                 ignored_ids.add(id(last_leaf))
         except IndexError:
                 last_leaf = self.leaves[-2]
                 ignored_ids.add(id(last_leaf))
         except IndexError:
@@ -1382,7 +1383,23 @@ class Line:
             comment.prefix = ""
             return False
 
             comment.prefix = ""
             return False
 
-        self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
+        last_leaf = self.leaves[-1]
+        if (
+            last_leaf.type == token.RPAR
+            and not last_leaf.value
+            and last_leaf.parent
+            and len(list(last_leaf.parent.leaves())) <= 3
+            and not is_type_comment(comment)
+        ):
+            # Comments on an optional parens wrapping a single leaf should belong to
+            # the wrapped node except if it's a type comment. Pinning the comment like
+            # this avoids unstable formatting caused by comment migration.
+            if len(self.leaves) < 2:
+                comment.type = STANDALONE_COMMENT
+                comment.prefix = ""
+                return False
+            last_leaf = self.leaves[-2]
+        self.comments.setdefault(id(last_leaf), []).append(comment)
         return True
 
     def comments_after(self, leaf: Leaf) -> List[Leaf]:
         return True
 
     def comments_after(self, leaf: Leaf) -> List[Leaf]:
@@ -1625,6 +1642,19 @@ class LineGenerator(Visitor[Line]):
             node.children[2].value = ""
         yield from super().visit_default(node)
 
             node.children[2].value = ""
         yield from super().visit_default(node)
 
+    def visit_factor(self, node: Node) -> Iterator[Line]:
+        """Force parentheses between a unary op and a binary power:
+
+        -2 ** 8 -> -(2 ** 8)
+        """
+        child = node.children[1]
+        if child.type == syms.power and len(child.children) == 3:
+            lpar = Leaf(token.LPAR, "(")
+            rpar = Leaf(token.RPAR, ")")
+            index = child.remove() or 0
+            node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+        yield from self.visit_default(node)
+
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
@@ -3121,7 +3151,7 @@ def ensure_visible(leaf: Leaf) -> None:
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
-    :func:`normalize_invible_parens` and :func:`visit_import_from`).
+    :func:`normalize_invisible_parens` and :func:`visit_import_from`).
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
@@ -3562,6 +3592,13 @@ def dump_to_file(*output: str) -> str:
     return f.name
 
 
     return f.name
 
 
+@contextmanager
+def nullcontext() -> Iterator[None]:
+    """Return context manager that does nothing.
+    Similar to `nullcontext` from python 3.7"""
+    yield
+
+
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
@@ -3580,7 +3617,7 @@ def cancel(tasks: Iterable[asyncio.Task]) -> None:
         task.cancel()
 
 
         task.cancel()
 
 
-def shutdown(loop: BaseEventLoop) -> None:
+def shutdown(loop: asyncio.AbstractEventLoop) -> None:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
         if sys.version_info[:2] >= (3, 7):
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
         if sys.version_info[:2] >= (3, 7):