From 21c8dc85f5fa6ca70b028027a03588e12f532636 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 9 May 2018 15:26:41 -0700 Subject: [PATCH 1/1] Remove u prefix if unicode_literals is present (#199) --- black.py | 58 +++++++++++++++++++++++++++++-- tests/python2_unicode_literals.py | 16 +++++++++ tests/string_prefixes.py | 14 ++++++++ tests/test_black.py | 37 ++++++++++++++++++++ 4 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 tests/python2_unicode_literals.py create mode 100644 tests/string_prefixes.py diff --git a/black.py b/black.py index b2cf543..913fe8d 100644 --- a/black.py +++ b/black.py @@ -409,9 +409,10 @@ def format_str(src_contents: str, line_length: int) -> FileContent: """ src_node = lib2to3_parse(src_contents) dst_contents = "" - lines = LineGenerator() - elt = EmptyLineTracker() + future_imports = get_future_imports(src_node) py36 = is_python36(src_node) + lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports) + elt = EmptyLineTracker() empty_line = Line() after = 0 for current_line in lines.visit(src_node): @@ -1171,6 +1172,7 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ current_line: Line = Factory(Line) + remove_u_prefix: bool = False def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]: """Generate a line. @@ -1238,6 +1240,7 @@ class LineGenerator(Visitor[Line]): else: normalize_prefix(node, inside_brackets=any_open_brackets) if node.type == token.STRING: + normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type not in WHITESPACE: self.current_line.append(node) @@ -2161,6 +2164,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: leaf.prefix = "" +def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: + """Make all string prefixes lowercase. + + If remove_u_prefix is given, also removes any u prefix from the string. + + Note: Mutates its argument. + """ + match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) + assert match is not None, f"failed to match string {leaf.value!r}" + orig_prefix = match.group(1) + new_prefix = orig_prefix.lower() + if remove_u_prefix: + new_prefix = new_prefix.replace("u", "") + leaf.value = f"{new_prefix}{match.group(2)}" + + def normalize_string_quotes(leaf: Leaf) -> None: """Prefer double quotes but only if it doesn't cause more escaping. @@ -2423,6 +2442,41 @@ def is_python36(node: Node) -> bool: return False +def get_future_imports(node: Node) -> Set[str]: + """Return a set of __future__ imports in the file.""" + imports = set() + for child in node.children: + if child.type != syms.simple_stmt: + break + first_child = child.children[0] + if isinstance(first_child, Leaf): + # Continue looking if we see a docstring; otherwise stop. + if ( + len(child.children) == 2 + and first_child.type == token.STRING + and child.children[1].type == token.NEWLINE + ): + continue + else: + break + elif first_child.type == syms.import_from: + module_name = first_child.children[1] + if not isinstance(module_name, Leaf) or module_name.value != "__future__": + break + for import_from_child in first_child.children[3:]: + if isinstance(import_from_child, Leaf): + if import_from_child.type == token.NAME: + imports.add(import_from_child.value) + else: + assert import_from_child.type == syms.import_as_names + for leaf in import_from_child.children: + if isinstance(leaf, Leaf) and leaf.type == token.NAME: + imports.add(leaf.value) + else: + break + return imports + + PYTHON_EXTENSIONS = {".py"} BLACKLISTED_DIRECTORIES = { "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv" diff --git a/tests/python2_unicode_literals.py b/tests/python2_unicode_literals.py new file mode 100644 index 0000000..ae27919 --- /dev/null +++ b/tests/python2_unicode_literals.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python2 +from __future__ import unicode_literals + +u'hello' +U"hello" +Ur"hello" + +# output + + +#!/usr/bin/env python2 +from __future__ import unicode_literals + +"hello" +"hello" +r"hello" diff --git a/tests/string_prefixes.py b/tests/string_prefixes.py new file mode 100644 index 0000000..fbad5e0 --- /dev/null +++ b/tests/string_prefixes.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3.6 + +name = R"Łukasz" +F"hello {name}" +B"hello" + +# output + + +#!/usr/bin/env python3.6 + +name = r"Łukasz" +f"hello {name}" +b"hello" diff --git a/tests/test_black.py b/tests/test_black.py index 5b84c3c..cf20945 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -317,6 +317,14 @@ class BlackTestCase(unittest.TestCase): black.assert_equivalent(source, actual) black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) + def test_string_prefixes(self) -> None: + source, expected = read_data("string_prefixes") + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) def test_python2(self) -> None: source, expected = read_data("python2") @@ -325,6 +333,13 @@ class BlackTestCase(unittest.TestCase): # black.assert_equivalent(source, actual) black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) + def test_python2_unicode_literals(self) -> None: + source, expected = read_data("python2_unicode_literals") + actual = fs(source) + self.assertFormatEqual(expected, actual) + black.assert_stable(source, actual, line_length=ll) + @patch("black.dump_to_file", dump_to_stderr) def test_fmtonoff(self) -> None: source, expected = read_data("fmtonoff") @@ -444,6 +459,28 @@ class BlackTestCase(unittest.TestCase): node = black.lib2to3_parse(expected) self.assertFalse(black.is_python36(node)) + def test_get_future_imports(self) -> None: + node = black.lib2to3_parse("\n") + self.assertEqual(set(), black.get_future_imports(node)) + node = black.lib2to3_parse("from __future__ import black\n") + self.assertEqual({"black"}, black.get_future_imports(node)) + node = black.lib2to3_parse("from __future__ import multiple, imports\n") + self.assertEqual({"multiple", "imports"}, black.get_future_imports(node)) + node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n") + self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node)) + node = black.lib2to3_parse( + "from __future__ import multiple\nfrom __future__ import imports\n" + ) + self.assertEqual({"multiple", "imports"}, black.get_future_imports(node)) + node = black.lib2to3_parse("# comment\nfrom __future__ import black\n") + self.assertEqual({"black"}, black.get_future_imports(node)) + node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n') + self.assertEqual({"black"}, black.get_future_imports(node)) + node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n") + self.assertEqual(set(), black.get_future_imports(node)) + node = black.lib2to3_parse("from some.module import black\n") + self.assertEqual(set(), black.get_future_imports(node)) + def test_debug_visitor(self) -> None: source, _ = read_data("debug_visitor.py") expected, _ = read_data("debug_visitor.out") -- 2.39.5