]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Adding Jupyter Notebook magic command (#200)
[etc/vim.git] / black.py
index c10eb39f10666ddd6622c47ee16a0dcca4f26d17..913fe8dfefaf496e57373e94d96b272caebf486d 100644 (file)
--- a/black.py
+++ b/black.py
@@ -41,7 +41,7 @@ from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
-__version__ = "18.4a5"
+__version__ = "18.4a6"
 DEFAULT_LINE_LENGTH = 88
 
 # types
@@ -409,9 +409,10 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
     """
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
-    lines = LineGenerator()
-    elt = EmptyLineTracker()
+    future_imports = get_future_imports(src_node)
     py36 = is_python36(src_node)
+    lines = LineGenerator(remove_u_prefix=py36 or "unicode_literals" in future_imports)
+    elt = EmptyLineTracker()
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -553,19 +554,20 @@ COMPARATORS = {
     token.GREATEREQUAL,
 }
 MATH_OPERATORS = {
+    token.VBAR,
+    token.CIRCUMFLEX,
+    token.AMPER,
+    token.LEFTSHIFT,
+    token.RIGHTSHIFT,
     token.PLUS,
     token.MINUS,
     token.STAR,
     token.SLASH,
-    token.VBAR,
-    token.AMPER,
+    token.DOUBLESLASH,
     token.PERCENT,
-    token.CIRCUMFLEX,
+    token.AT,
     token.TILDE,
-    token.LEFTSHIFT,
-    token.RIGHTSHIFT,
     token.DOUBLESTAR,
-    token.DOUBLESLASH,
 }
 STARS = {token.STAR, token.DOUBLESTAR}
 VARARGS_PARENTS = {
@@ -598,13 +600,44 @@ TEST_DESCENDANTS = {
     syms.term,
     syms.power,
 }
+ASSIGNMENTS = {
+    "=",
+    "+=",
+    "-=",
+    "*=",
+    "@=",
+    "/=",
+    "%=",
+    "&=",
+    "|=",
+    "^=",
+    "<<=",
+    ">>=",
+    "**=",
+    "//=",
+}
 COMPREHENSION_PRIORITY = 20
-COMMA_PRIORITY = 10
-TERNARY_PRIORITY = 7
-LOGIC_PRIORITY = 5
-STRING_PRIORITY = 4
-COMPARATOR_PRIORITY = 3
-MATH_PRIORITY = 1
+COMMA_PRIORITY = 18
+TERNARY_PRIORITY = 16
+LOGIC_PRIORITY = 14
+STRING_PRIORITY = 12
+COMPARATOR_PRIORITY = 10
+MATH_PRIORITIES = {
+    token.VBAR: 8,
+    token.CIRCUMFLEX: 7,
+    token.AMPER: 6,
+    token.LEFTSHIFT: 5,
+    token.RIGHTSHIFT: 5,
+    token.PLUS: 4,
+    token.MINUS: 4,
+    token.STAR: 3,
+    token.SLASH: 3,
+    token.DOUBLESLASH: 3,
+    token.PERCENT: 3,
+    token.AT: 3,
+    token.TILDE: 2,
+    token.DOUBLESTAR: 1,
+}
 
 
 @dataclass
@@ -615,8 +648,8 @@ class BracketTracker:
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
-    _for_loop_variable: bool = False
-    _lambda_arguments: bool = False
+    _for_loop_variable: int = 0
+    _lambda_arguments: int = 0
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -678,7 +711,7 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "for":
             self.depth += 1
-            self._for_loop_variable = True
+            self._for_loop_variable += 1
             return True
 
         return False
@@ -687,7 +720,7 @@ class BracketTracker:
         """See `maybe_increment_for_loop_variable` above for explanation."""
         if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
             self.depth -= 1
-            self._for_loop_variable = False
+            self._for_loop_variable -= 1
             return True
 
         return False
@@ -700,7 +733,7 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "lambda":
             self.depth += 1
-            self._lambda_arguments = True
+            self._lambda_arguments += 1
             return True
 
         return False
@@ -709,7 +742,7 @@ class BracketTracker:
         """See `maybe_increment_lambda_arguments` above for explanation."""
         if self._lambda_arguments and leaf.type == token.COLON:
             self.depth -= 1
-            self._lambda_arguments = False
+            self._lambda_arguments -= 1
             return True
 
         return False
@@ -978,8 +1011,9 @@ class Line:
             and subscript_start.type == syms.subscriptlist
         ):
             subscript_start = child_towards(subscript_start, leaf)
-        return subscript_start is not None and any(
-            n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
+        return (
+            subscript_start is not None
+            and any(n.type in TEST_DESCENDANTS for n in subscript_start.pre_order())
         )
 
     def __str__(self) -> str:
@@ -1138,6 +1172,7 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
+    remove_u_prefix: bool = False
 
     def line(self, indent: int = 0, type: Type[Line] = Line) -> Iterator[Line]:
         """Generate a line.
@@ -1205,6 +1240,7 @@ class LineGenerator(Visitor[Line]):
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
                 if node.type == token.STRING:
+                    normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
                     self.current_line.append(node)
@@ -1236,14 +1272,13 @@ class LineGenerator(Visitor[Line]):
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
-        `def`, `with`, `class`, and `assert`.
+        `def`, `with`, `class`, `assert` and assignments.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
 
-        `parens` holds pairs of nodes where invisible parentheses should be put.
-        Keys hold nodes after which opening parentheses should be put, values
-        hold nodes before which closing parentheses should be put.
+        `parens` holds a set of string leaf values immeditely after which
+        invisible parens should be put.
         """
         normalize_invisible_parens(node, parens_after=parens)
         for child in node.children:
@@ -1343,7 +1378,9 @@ class LineGenerator(Visitor[Line]):
         v = self.visit_stmt
         Ø: Set[str] = set()
         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
-        self.visit_if_stmt = partial(v, keywords={"if", "else", "elif"}, parens={"if"})
+        self.visit_if_stmt = partial(
+            v, keywords={"if", "else", "elif"}, parens={"if", "elif"}
+        )
         self.visit_while_stmt = partial(v, keywords={"while", "else"}, parens={"while"})
         self.visit_for_stmt = partial(v, keywords={"for", "else"}, parens={"for", "in"})
         self.visit_try_stmt = partial(
@@ -1353,6 +1390,8 @@ class LineGenerator(Visitor[Line]):
         self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø)
         self.visit_funcdef = partial(v, keywords={"def"}, parens=Ø)
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
+        self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
+        self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1649,7 +1688,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
         and leaf.parent
         and leaf.parent.type not in {syms.factor, syms.star_expr}
     ):
-        return MATH_PRIORITY
+        return MATH_PRIORITIES[leaf.type]
 
     if leaf.type in COMPARATORS:
         return COMPARATOR_PRIORITY
@@ -1830,7 +1869,8 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
-    Prefer RHS otherwise.
+    Prefer RHS otherwise.  This is why this function is not symmetrical with
+    :func:`right_hand_split` which also handles optional parentheses.
     """
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
@@ -1870,7 +1910,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 def right_hand_split(
     line: Line, py36: bool = False, omit: Collection[LeafID] = ()
 ) -> Iterator[Line]:
-    """Split line into many lines, starting with the last matching bracket pair."""
+    """Split line into many lines, starting with the last matching bracket pair.
+
+    If the split was by optional parentheses, attempt splitting without them, too.
+    """
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
     tail = Line(depth=line.depth)
@@ -1909,20 +1952,25 @@ def right_hand_split(
     bracket_split_succeeded_or_raise(head, body, tail)
     assert opening_bracket and closing_bracket
     if (
+        # the opening bracket is an optional paren
         opening_bracket.type == token.LPAR
         and not opening_bracket.value
+        # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
+        # there are no delimiters or standalone comments in the body
+        and not body.bracket_tracker.delimiters
+        and not line.contains_standalone_comments(0)
+        # and it's not an import (optional parens are the only thing we can split
+        # on in this case; attempting a split without them is a waste of time)
+        and not line.is_import
     ):
-        # These parens were optional. If there aren't any delimiters or standalone
-        # comments in the body, they were unnecessary and another split without
-        # them should be attempted.
-        if not (
-            body.bracket_tracker.delimiters or line.contains_standalone_comments(0)
-        ):
-            omit = {id(closing_bracket), *omit}
+        omit = {id(closing_bracket), *omit}
+        try:
             yield from right_hand_split(line, py36=py36, omit=omit)
             return
+        except CannotSplit:
+            pass
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
@@ -2116,6 +2164,22 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
     leaf.prefix = ""
 
 
+def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
+    """Make all string prefixes lowercase.
+
+    If remove_u_prefix is given, also removes any u prefix from the string.
+
+    Note: Mutates its argument.
+    """
+    match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
+    assert match is not None, f"failed to match string {leaf.value!r}"
+    orig_prefix = match.group(1)
+    new_prefix = orig_prefix.lower()
+    if remove_u_prefix:
+        new_prefix = new_prefix.replace("u", "")
+    leaf.value = f"{new_prefix}{match.group(2)}"
+
+
 def normalize_string_quotes(leaf: Leaf) -> None:
     """Prefer double quotes but only if it doesn't cause more escaping.
 
@@ -2180,6 +2244,9 @@ def normalize_string_quotes(leaf: Leaf) -> None:
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.
 
+    `parens_after` is a set of string leaf values immeditely after which parens
+    should be put.
+
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
@@ -2210,6 +2277,7 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
+        or is_yield(node)
         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
     ):
         return False
@@ -2260,12 +2328,33 @@ def is_one_tuple(node: LN) -> bool:
     )
 
 
+def is_yield(node: LN) -> bool:
+    """Return True if `node` holds a `yield` or `yield from` expression."""
+    if node.type == syms.yield_expr:
+        return True
+
+    if node.type == token.NAME and node.value == "yield":  # type: ignore
+        return True
+
+    if node.type != syms.atom:
+        return False
+
+    if len(node.children) != 3:
+        return False
+
+    lpar, expr, rpar = node.children
+    if lpar.type == token.LPAR and rpar.type == token.RPAR:
+        return is_yield(expr)
+
+    return False
+
+
 def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
     """Return True if `leaf` is a star or double star in a vararg or kwarg.
 
     If `within` includes VARARGS_PARENTS, this applies to function signatures.
-    If `within` includes COLLECTION_LIBERALS_PARENTS, it applies to right
-    hand-side extended iterable unpacking (PEP 3132) and additional unpacking
+    If `within` includes UNPACKING_PARENTS, it applies to right hand-side
+    extended iterable unpacking (PEP 3132) and additional unpacking
     generalizations (PEP 448).
     """
     if leaf.type not in STARS or not leaf.parent:
@@ -2328,7 +2417,7 @@ def is_python36(node: Node) -> bool:
 
     Currently looking for:
     - f-strings; and
-    - trailing commas after * or ** in function signatures.
+    - trailing commas after * or ** in function signatures and calls.
     """
     for n in node.pre_order():
         if n.type == token.STRING:
@@ -2337,7 +2426,7 @@ def is_python36(node: Node) -> bool:
                 return True
 
         elif (
-            n.type == syms.typedargslist
+            n.type in {syms.typedargslist, syms.arglist}
             and n.children
             and n.children[-1].type == token.COMMA
         ):
@@ -2345,9 +2434,49 @@ def is_python36(node: Node) -> bool:
                 if ch.type in STARS:
                     return True
 
+                if ch.type == syms.argument:
+                    for argch in ch.children:
+                        if argch.type in STARS:
+                            return True
+
     return False
 
 
+def get_future_imports(node: Node) -> Set[str]:
+    """Return a set of __future__ imports in the file."""
+    imports = set()
+    for child in node.children:
+        if child.type != syms.simple_stmt:
+            break
+        first_child = child.children[0]
+        if isinstance(first_child, Leaf):
+            # Continue looking if we see a docstring; otherwise stop.
+            if (
+                len(child.children) == 2
+                and first_child.type == token.STRING
+                and child.children[1].type == token.NEWLINE
+            ):
+                continue
+            else:
+                break
+        elif first_child.type == syms.import_from:
+            module_name = first_child.children[1]
+            if not isinstance(module_name, Leaf) or module_name.value != "__future__":
+                break
+            for import_from_child in first_child.children[3:]:
+                if isinstance(import_from_child, Leaf):
+                    if import_from_child.type == token.NAME:
+                        imports.add(import_from_child.value)
+                else:
+                    assert import_from_child.type == syms.import_as_names
+                    for leaf in import_from_child.children:
+                        if isinstance(leaf, Leaf) and leaf.type == token.NAME:
+                            imports.add(leaf.value)
+        else:
+            break
+    return imports
+
+
 PYTHON_EXTENSIONS = {".py"}
 BLACKLISTED_DIRECTORIES = {
     "build", "buck-out", "dist", "_build", ".git", ".hg", ".mypy_cache", ".tox", ".venv"