From: Zsolt Dollenstein Date: Mon, 2 Jul 2018 16:48:48 +0000 (+0100) Subject: Improve get_future_imports implementation. X-Git-Url: https://git.madduck.net/etc/vim.git/commitdiff_plain/dd8bde6d2fbfe8a7a11093e761a0cb5837efa96a?ds=sidebyside Improve get_future_imports implementation. Closes #389. --- diff --git a/black.py b/black.py index f49e6df..36a180d 100644 --- a/black.py +++ b/black.py @@ -20,6 +20,7 @@ from typing import ( Callable, Collection, Dict, + Generator, Generic, Iterable, Iterator, @@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf def get_future_imports(node: Node) -> Set[str]: """Return a set of __future__ imports in the file.""" - imports = set() + imports: Set[str] = set() + + def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]: + for child in children: + if isinstance(child, Leaf): + if child.type == token.NAME: + yield child.value + elif child.type == syms.import_as_name: + orig_name = child.children[0] + assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" + assert orig_name.type == token.NAME, "Invalid syntax parsing imports" + yield orig_name.value + elif child.type == syms.import_as_names: + yield from get_imports_from_children(child.children) + else: + assert False, "Invalid syntax parsing imports" + for child in node.children: if child.type != syms.simple_stmt: break @@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]: module_name = first_child.children[1] if not isinstance(module_name, Leaf) or module_name.value != "__future__": break - for import_from_child in first_child.children[3:]: - if isinstance(import_from_child, Leaf): - if import_from_child.type == token.NAME: - imports.add(import_from_child.value) - else: - assert import_from_child.type == syms.import_as_names - for leaf in import_from_child.children: - if isinstance(leaf, Leaf) and leaf.type == token.NAME: - imports.add(leaf.value) + imports |= set(get_imports_from_children(first_child.children[3:])) else: break return imports diff --git a/tests/data/python2_unicode_literals.py b/tests/data/python2_unicode_literals.py index ae27919..2fe7039 100644 --- a/tests/data/python2_unicode_literals.py +++ b/tests/data/python2_unicode_literals.py @@ -1,5 +1,7 @@ #!/usr/bin/env python2 -from __future__ import unicode_literals +from __future__ import unicode_literals as _unicode_literals +from __future__ import absolute_import +from __future__ import print_function as lol, with_function u'hello' U"hello" @@ -9,7 +11,9 @@ Ur"hello" #!/usr/bin/env python2 -from __future__ import unicode_literals +from __future__ import unicode_literals as _unicode_literals +from __future__ import absolute_import +from __future__ import print_function as lol, with_function "hello" "hello" diff --git a/tests/test_black.py b/tests/test_black.py index 8a37197..cc53aa6 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -735,6 +735,14 @@ class BlackTestCase(unittest.TestCase): self.assertEqual(set(), black.get_future_imports(node)) node = black.lib2to3_parse("from some.module import black\n") self.assertEqual(set(), black.get_future_imports(node)) + node = black.lib2to3_parse( + "from __future__ import unicode_literals as _unicode_literals" + ) + self.assertEqual({"unicode_literals"}, black.get_future_imports(node)) + node = black.lib2to3_parse( + "from __future__ import unicode_literals as _lol, print" + ) + self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node)) def test_debug_visitor(self) -> None: source, _ = read_data("debug_visitor.py")