]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Improve get_future_imports implementation.
authorZsolt Dollenstein <zsol.zsol@gmail.com>
Mon, 2 Jul 2018 16:48:48 +0000 (17:48 +0100)
committerZsolt Dollenstein <zsol.zsol@gmail.com>
Mon, 2 Jul 2018 16:49:47 +0000 (17:49 +0100)
Closes #389.

black.py
tests/data/python2_unicode_literals.py
tests/test_black.py

index f49e6df423e438acb9896c3d60e5c4d81758e608..36a180da702a3a003276dc3a8e1f25b96b4abd01 100644 (file)
--- a/black.py
+++ b/black.py
@@ -20,6 +20,7 @@ from typing import (
     Callable,
     Collection,
     Dict,
+    Generator,
     Generic,
     Iterable,
     Iterator,
@@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
 
 def get_future_imports(node: Node) -> Set[str]:
     """Return a set of __future__ imports in the file."""
-    imports = set()
+    imports: Set[str] = set()
+
+    def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+        for child in children:
+            if isinstance(child, Leaf):
+                if child.type == token.NAME:
+                    yield child.value
+            elif child.type == syms.import_as_name:
+                orig_name = child.children[0]
+                assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+                assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+                yield orig_name.value
+            elif child.type == syms.import_as_names:
+                yield from get_imports_from_children(child.children)
+            else:
+                assert False, "Invalid syntax parsing imports"
+
     for child in node.children:
         if child.type != syms.simple_stmt:
             break
@@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]:
             module_name = first_child.children[1]
             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
                 break
-            for import_from_child in first_child.children[3:]:
-                if isinstance(import_from_child, Leaf):
-                    if import_from_child.type == token.NAME:
-                        imports.add(import_from_child.value)
-                else:
-                    assert import_from_child.type == syms.import_as_names
-                    for leaf in import_from_child.children:
-                        if isinstance(leaf, Leaf) and leaf.type == token.NAME:
-                            imports.add(leaf.value)
+            imports |= set(get_imports_from_children(first_child.children[3:]))
         else:
             break
     return imports
index ae27919afad8d1d7a03d488957725b2fec29a9ff..2fe70392af6f6174828efb6acc1f94779b313270 100644 (file)
@@ -1,5 +1,7 @@
 #!/usr/bin/env python2
-from __future__ import unicode_literals
+from __future__ import unicode_literals as _unicode_literals
+from __future__ import absolute_import
+from __future__ import print_function as lol, with_function
 
 u'hello'
 U"hello"
@@ -9,7 +11,9 @@ Ur"hello"
 
 
 #!/usr/bin/env python2
-from __future__ import unicode_literals
+from __future__ import unicode_literals as _unicode_literals
+from __future__ import absolute_import
+from __future__ import print_function as lol, with_function
 
 "hello"
 "hello"
index 8a371973b28c1a4e1192ec41846778e8f793ac2e..cc53aa61221e346fc144b2106a2a3077aca9f193 100644 (file)
@@ -735,6 +735,14 @@ class BlackTestCase(unittest.TestCase):
         self.assertEqual(set(), black.get_future_imports(node))
         node = black.lib2to3_parse("from some.module import black\n")
         self.assertEqual(set(), black.get_future_imports(node))
+        node = black.lib2to3_parse(
+            "from __future__ import unicode_literals as _unicode_literals"
+        )
+        self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
+        node = black.lib2to3_parse(
+            "from __future__ import unicode_literals as _lol, print"
+        )
+        self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
 
     def test_debug_visitor(self) -> None:
         source, _ = read_data("debug_visitor.py")