]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Check for broken symlinks before checking file data (#202)
[etc/vim.git] / black.py
index eb99a41b04217f3cdc0cc551be97014eee7fe62f..a1131693cb9b2b9ab17b1cd53e283c5169b7b3e1 100644 (file)
--- a/black.py
+++ b/black.py
@@ -409,9 +409,10 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
     """
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     """
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
-    lines = LineGenerator()
-    elt = EmptyLineTracker()
+    future_imports = get_future_imports(src_node)
     py36 = is_python36(src_node)
     py36 = is_python36(src_node)
+    lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
+    elt = EmptyLineTracker()
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -599,6 +600,22 @@ TEST_DESCENDANTS = {
     syms.term,
     syms.power,
 }
     syms.term,
     syms.power,
 }
+ASSIGNMENTS = {
+    "=",
+    "+=",
+    "-=",
+    "*=",
+    "@=",
+    "/=",
+    "%=",
+    "&=",
+    "|=",
+    "^=",
+    "<<=",
+    ">>=",
+    "**=",
+    "//=",
+}
 COMPREHENSION_PRIORITY = 20
 COMMA_PRIORITY = 18
 TERNARY_PRIORITY = 16
 COMPREHENSION_PRIORITY = 20
 COMMA_PRIORITY = 18
 TERNARY_PRIORITY = 16
@@ -631,8 +648,8 @@ class BracketTracker:
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
-    _for_loop_variable: bool = False
-    _lambda_arguments: bool = False
+    _for_loop_variable: int = 0
+    _lambda_arguments: int = 0
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -694,7 +711,7 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "for":
             self.depth += 1
         """
         if leaf.type == token.NAME and leaf.value == "for":
             self.depth += 1
-            self._for_loop_variable = True
+            self._for_loop_variable += 1
             return True
 
         return False
             return True
 
         return False
@@ -703,7 +720,7 @@ class BracketTracker:
         """See `maybe_increment_for_loop_variable` above for explanation."""
         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
             self.depth -= 1
         """See `maybe_increment_for_loop_variable` above for explanation."""
         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
             self.depth -= 1
-            self._for_loop_variable = False
+            self._for_loop_variable -= 1
             return True
 
         return False
             return True
 
         return False
@@ -716,7 +733,7 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "lambda":
             self.depth += 1
         """
         if leaf.type == token.NAME and leaf.value == "lambda":
             self.depth += 1
-            self._lambda_arguments = True
+            self._lambda_arguments += 1
             return True
 
         return False
             return True
 
         return False
@@ -725,7 +742,7 @@ class BracketTracker:
         """See `maybe_increment_lambda_arguments` above for explanation."""
         if self._lambda_arguments and leaf.type == token.COLON:
             self.depth -= 1
         """See `maybe_increment_lambda_arguments` above for explanation."""
         if self._lambda_arguments and leaf.type == token.COLON:
             self.depth -= 1
-            self._lambda_arguments = False
+            self._lambda_arguments -= 1
             return True
 
         return False
             return True
 
         return False
@@ -994,8 +1011,9 @@ class Line:
             and subscript_start.type == syms.subscriptlist
         ):
             subscript_start = child_towards(subscript_start, leaf)
             and subscript_start.type == syms.subscriptlist
         ):
             subscript_start = child_towards(subscript_start, leaf)
-        return subscript_start is not None and any(
-            n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
+        return (
+            subscript_start is not None
+            and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
         )
 
     def __str__(self) -> str:
         )
 
     def __str__(self) -> str:
@@ -1154,6 +1172,7 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
+    remove_u_prefix: bool = False
 
     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
         """Generate a line.
 
     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
         """Generate a line.
@@ -1221,6 +1240,7 @@ class LineGenerator(Visitor[Line]):
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
                 if node.type == token.STRING:
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
                 if node.type == token.STRING:
+                    normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
                     self.current_line.append(node)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
                     self.current_line.append(node)
@@ -1252,7 +1272,7 @@ class LineGenerator(Visitor[Line]):
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
-        `def`, `with`, `class`, and `assert`.
+        `def`, `with`, `class`, `assert` and assignments.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
@@ -1358,7 +1378,9 @@ class LineGenerator(Visitor[Line]):
         v = self.visit_stmt
         Ø: Set[str] = set()
         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
         v = self.visit_stmt
         Ø: Set[str] = set()
         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
-        self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
+        self.visit_if_stmt = partial(
+            v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
+        )
         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
         self.visit_try_stmt = partial(
         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
         self.visit_try_stmt = partial(
@@ -1368,6 +1390,8 @@ class LineGenerator(Visitor[Line]):
         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
+        self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
+        self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -2140,6 +2164,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
     leaf.prefix = ""
 
 
     leaf.prefix = ""
 
 
+def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
+    """Make all string prefixes lowercase.
+
+    If remove_u_prefix is given, also removes any u prefix from the string.
+
+    Note: Mutates its argument.
+    """
+    match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
+    assert match is not None, f"failed to match string {leaf.value!r}"
+    orig_prefix = match.group(1)
+    new_prefix = orig_prefix.lower()
+    if remove_u_prefix:
+        new_prefix = new_prefix.replace("u", "")
+    leaf.value = f"{new_prefix}{match.group(2)}"
+
+
 def normalize_string_quotes(leaf: Leaf) -> None:
     """Prefer double quotes but only if it doesn't cause more escaping.
 
 def normalize_string_quotes(leaf: Leaf) -> None:
     """Prefer double quotes but only if it doesn't cause more escaping.
 
@@ -2237,6 +2277,7 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
+        or is_yield(node)
         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
     ):
         return False
         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
     ):
         return False
@@ -2287,12 +2328,33 @@ def is_one_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def is_yield(node: LN) -> bool:
+    """Return True if `node` holds a `yield` or `yield from` expression."""
+    if node.type == syms.yield_expr:
+        return True
+
+    if node.type == token.NAME and node.value == "yield":  # type: ignore
+        return True
+
+    if node.type != syms.atom:
+        return False
+
+    if len(node.children) != 3:
+        return False
+
+    lpar, expr, rpar = node.children
+    if lpar.type == token.LPAR and rpar.type == token.RPAR:
+        return is_yield(expr)
+
+    return False
+
+
 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
     """Return True if `leaf` is a star or double star in a vararg or kwarg.
 
     If `within` includes VARARGS_PARENTS, this applies to function signatures.
 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
     """Return True if `leaf` is a star or double star in a vararg or kwarg.
 
     If `within` includes VARARGS_PARENTS, this applies to function signatures.
-    If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
-    hand-side extended iterable unpacking (PEP 3132) and additional unpacking
+    If `within` includes UNPACKING_PARENTS, it applies to right hand-side
+    extended iterable unpacking (PEP 3132) and additional unpacking
     generalizations (PEP 448).
     """
     if leaf.type not in STARS or not leaf.parent:
     generalizations (PEP 448).
     """
     if leaf.type not in STARS or not leaf.parent:
@@ -2380,6 +2442,41 @@ def is_python36(node: Node) -> bool:
     return False
 
 
     return False
 
 
+def get_future_imports(node: Node) -> Set[str]:
+    """Return a set of __future__ imports in the file."""
+    imports = set()
+    for child in node.children:
+        if child.type != syms.simple_stmt:
+            break
+        first_child = child.children[0]
+        if isinstance(first_child, Leaf):
+            # Continue looking if we see a docstring; otherwise stop.
+            if (
+                len(child.children) == 2
+                and first_child.type == token.STRING
+                and child.children[1].type == token.NEWLINE
+            ):
+                continue
+            else:
+                break
+        elif first_child.type == syms.import_from:
+            module_name = first_child.children[1]
+            if not isinstance(module_name, Leaf) or module_name.value != "__future__":
+                break
+            for import_from_child in first_child.children[3:]:
+                if isinstance(import_from_child, Leaf):
+                    if import_from_child.type == token.NAME:
+                        imports.add(import_from_child.value)
+                else:
+                    assert import_from_child.type == syms.import_as_names
+                    for leaf in import_from_child.children:
+                        if isinstance(leaf, Leaf) and leaf.type == token.NAME:
+                            imports.add(leaf.value)
+        else:
+            break
+    return imports
+
+
 PYTHON_EXTENSIONS = {".py"}
 BLACKLISTED_DIRECTORIES = {
     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
 PYTHON_EXTENSIONS = {".py"}
 BLACKLISTED_DIRECTORIES = {
     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"
@@ -2397,7 +2494,7 @@ def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
 
             yield from gen_python_files_in_dir(child)
 
 
             yield from gen_python_files_in_dir(child)
 
-        elif child.suffix in PYTHON_EXTENSIONS:
+        elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
             yield child
 
 
             yield child