]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Improve get_future_imports implementation.
[etc/vim.git] / black.py
index f1b1c3003e78f09b5a76a8d949cdc4e9965be777..36a180da702a3a003276dc3a8e1f25b96b4abd01 100644 (file)
--- a/black.py
+++ b/black.py
@@ -20,6 +20,7 @@ from typing import (
     Callable,
     Collection,
     Dict,
     Callable,
     Collection,
     Dict,
+    Generator,
     Generic,
     Iterable,
     Iterator,
     Generic,
     Iterable,
     Iterator,
@@ -46,7 +47,7 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.6b2"
+__version__ = "18.6b4"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -2608,7 +2609,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                 )
                 return True
 
                 )
                 return True
 
-            previous_consumed += comment.consumed
+            previous_consumed = comment.consumed
 
     return False
 
 
     return False
 
@@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
 
 def get_future_imports(node: Node) -> Set[str]:
     """Return a set of __future__ imports in the file."""
 
 def get_future_imports(node: Node) -> Set[str]:
     """Return a set of __future__ imports in the file."""
-    imports = set()
+    imports: Set[str] = set()
+
+    def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+        for child in children:
+            if isinstance(child, Leaf):
+                if child.type == token.NAME:
+                    yield child.value
+            elif child.type == syms.import_as_name:
+                orig_name = child.children[0]
+                assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+                assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+                yield orig_name.value
+            elif child.type == syms.import_as_names:
+                yield from get_imports_from_children(child.children)
+            else:
+                assert False, "Invalid syntax parsing imports"
+
     for child in node.children:
         if child.type != syms.simple_stmt:
             break
     for child in node.children:
         if child.type != syms.simple_stmt:
             break
@@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]:
             module_name = first_child.children[1]
             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
                 break
             module_name = first_child.children[1]
             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
                 break
-            for import_from_child in first_child.children[3:]:
-                if isinstance(import_from_child, Leaf):
-                    if import_from_child.type == token.NAME:
-                        imports.add(import_from_child.value)
-                else:
-                    assert import_from_child.type == syms.import_as_names
-                    for leaf in import_from_child.children:
-                        if isinstance(leaf, Leaf) and leaf.type == token.NAME:
-                            imports.add(leaf.value)
+            imports |= set(get_imports_from_children(first_child.children[3:]))
         else:
             break
     return imports
         else:
             break
     return imports