Callable,
Collection,
Dict,
+ Generator,
Generic,
Iterable,
Iterator,
def get_future_imports(node: Node) -> Set[str]:
"""Return a set of __future__ imports in the file."""
- imports = set()
+ imports: Set[str] = set()
+
+ def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
+ for child in children:
+ if isinstance(child, Leaf):
+ if child.type == token.NAME:
+ yield child.value
+ elif child.type == syms.import_as_name:
+ orig_name = child.children[0]
+ assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
+ assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
+ yield orig_name.value
+ elif child.type == syms.import_as_names:
+ yield from get_imports_from_children(child.children)
+ else:
+ assert False, "Invalid syntax parsing imports"
+
for child in node.children:
if child.type != syms.simple_stmt:
break
module_name = first_child.children[1]
if not isinstance(module_name, Leaf) or module_name.value != "__future__":
break
- for import_from_child in first_child.children[3:]:
- if isinstance(import_from_child, Leaf):
- if import_from_child.type == token.NAME:
- imports.add(import_from_child.value)
- else:
- assert import_from_child.type == syms.import_as_names
- for leaf in import_from_child.children:
- if isinstance(leaf, Leaf) and leaf.type == token.NAME:
- imports.add(leaf.value)
+ imports |= set(get_imports_from_children(first_child.children[3:]))
else:
break
return imports
#!/usr/bin/env python2
-from __future__ import unicode_literals
+from __future__ import unicode_literals as _unicode_literals
+from __future__ import absolute_import
+from __future__ import print_function as lol, with_function
u'hello'
U"hello"
#!/usr/bin/env python2
-from __future__ import unicode_literals
+from __future__ import unicode_literals as _unicode_literals
+from __future__ import absolute_import
+from __future__ import print_function as lol, with_function
"hello"
"hello"
self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse("from some.module import black\n")
self.assertEqual(set(), black.get_future_imports(node))
+ node = black.lib2to3_parse(
+ "from __future__ import unicode_literals as _unicode_literals"
+ )
+ self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
+ node = black.lib2to3_parse(
+ "from __future__ import unicode_literals as _lol, print"
+ )
+ self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py")