]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix invalid code on stars in long from-imports being wrapped in parentheses
[etc/vim.git] / black.py
index e8af3f05b944f5b9f50179ce6ae1186f3886f65a..afc37d99fe3ac547383c1fa3c59cfba763aefc51 100644 (file)
--- a/black.py
+++ b/black.py
@@ -30,6 +30,7 @@ from typing import (
     Type,
     TypeVar,
     Union,
     Type,
     TypeVar,
     Union,
+    cast,
 )
 
 from appdirs import user_cache_dir
 )
 
 from appdirs import user_cache_dir
@@ -43,11 +44,12 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.4a6"
+__version__ = "18.5b0"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_LINE_LENGTH = 88
+CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+
 
 # types
 
 # types
-syms = pygram.python_symbols
 FileContent = str
 Encoding = str
 Depth = int
 FileContent = str
 Encoding = str
 Depth = int
@@ -64,6 +66,9 @@ Cache = Dict[Path, CacheInfo]
 out = partial(click.secho, bold=True, err=True)
 err = partial(click.secho, fg="red", err=True)
 
 out = partial(click.secho, bold=True, err=True)
 err = partial(click.secho, fg="red", err=True)
 
+pygram.initialize(CACHE_DIR)
+syms = pygram.python_symbols
+
 
 class NothingChanged(UserWarning):
     """Raised by :func:`format_file` when reformatted code is the same as source."""
 
 class NothingChanged(UserWarning):
     """Raised by :func:`format_file` when reformatted code is the same as source."""
@@ -1370,32 +1375,6 @@ class LineGenerator(Visitor[Line]):
             yield from self.line()
             yield from self.visit(child)
 
             yield from self.line()
             yield from self.visit(child)
 
-    def visit_import_from(self, node: Node) -> Iterator[Line]:
-        """Visit import_from and maybe put invisible parentheses.
-
-        This is separate from `visit_stmt` because import statements don't
-        support arbitrary atoms and thus handling of parentheses is custom.
-        """
-        check_lpar = False
-        for index, child in enumerate(node.children):
-            if check_lpar:
-                if child.type == token.LPAR:
-                    # make parentheses invisible
-                    child.value = ""  # type: ignore
-                    node.children[-1].value = ""  # type: ignore
-                else:
-                    # insert invisible parentheses
-                    node.insert_child(index, Leaf(token.LPAR, ""))
-                    node.append_child(Leaf(token.RPAR, ""))
-                break
-
-            check_lpar = (
-                child.type == token.NAME and child.value == "import"  # type: ignore
-            )
-
-        for child in node.children:
-            yield from self.visit(child)
-
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
         """Remove a semicolon and put the other statement on a separate line."""
         yield from self.line()
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
         """Remove a semicolon and put the other statement on a separate line."""
         yield from self.line()
@@ -1442,6 +1421,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
+        self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1815,7 +1795,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
     return 0
 
 
-def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
+def generate_comments(leaf: LN) -> Iterator[Leaf]:
     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
@@ -1922,12 +1902,14 @@ def split_line(
 
         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
 
         def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(right_hand_split(line, py36, omit=omit))
+                lines = list(right_hand_split(line, line_length, py36, omit=omit))
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
 
             # All splits failed, best effort split with no omits.
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
 
             # All splits failed, best effort split with no omits.
+            # This mostly happens to multiline strings that are by definition
+            # reported as not fitting a single line.
             yield from right_hand_split(line, py36)
 
         if line.inside_brackets:
             yield from right_hand_split(line, py36)
 
         if line.inside_brackets:
@@ -2001,7 +1983,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 def right_hand_split(
 
 
 def right_hand_split(
-    line: Line, py36: bool = False, omit: Collection[LeafID] = ()
+    line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
@@ -2063,27 +2045,9 @@ def right_hand_split(
         and not line.is_import
     ):
         omit = {id(closing_bracket), *omit}
         and not line.is_import
     ):
         omit = {id(closing_bracket), *omit}
-        delimiter_count = body.bracket_tracker.delimiter_count_with_priority()
-        first = body.leaves[0]
-        last = body.leaves[-1]
-        if (
-            delimiter_count == 0
-            or delimiter_count == 1
-            and (
-                first.type in OPENING_BRACKETS
-                or last.type == token.RPAR
-                or last.type == token.RBRACE
-                or (
-                    # don't use indexing for omitting optional parentheses;
-                    # it looks weird
-                    last.type == token.RSQB
-                    and last.parent
-                    and last.parent.type != syms.trailer
-                )
-            )
-        ):
+        if can_omit_invisible_parens(body, line_length):
             try:
             try:
-                yield from right_hand_split(line, py36=py36, omit=omit)
+                yield from right_hand_split(line, line_length, py36=py36, omit=omit)
                 return
             except CannotSplit:
                 pass
                 return
             except CannotSplit:
                 pass
@@ -2348,8 +2312,13 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
+    try:
+        list(generate_comments(node))
+    except FormatOff:
+        return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
+
     check_lpar = False
     check_lpar = False
-    for child in list(node.children):
+    for index, child in enumerate(list(node.children)):
         if check_lpar:
             if child.type == syms.atom:
                 maybe_make_parens_invisible_in_atom(child)
         if check_lpar:
             if child.type == syms.atom:
                 maybe_make_parens_invisible_in_atom(child)
@@ -2357,8 +2326,21 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
                 rpar = Leaf(token.RPAR, ")")
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
                 rpar = Leaf(token.RPAR, ")")
-                index = child.remove() or 0
+                child.remove()
                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+            elif node.type == syms.import_from:
+                # "import from" nodes store parentheses directly as part of
+                # the statement
+                if child.type == token.LPAR:
+                    # make parentheses invisible
+                    child.value = ""  # type: ignore
+                    node.children[-1].value = ""  # type: ignore
+                elif child.type != token.STAR:
+                    # insert invisible parentheses
+                    node.insert_child(index, Leaf(token.LPAR, ""))
+                    node.append_child(Leaf(token.RPAR, ""))
+                break
+
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 # wrap child in invisible parentheses
                 lpar = Leaf(token.LPAR, "")
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 # wrap child in invisible parentheses
                 lpar = Leaf(token.LPAR, "")
@@ -2547,13 +2529,21 @@ def ensure_visible(leaf: Leaf) -> None:
 
 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
 
 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
-    return bool(
+    if not (
         opening_bracket.parent
         and opening_bracket.parent.type in {syms.atom, syms.import_from}
         and opening_bracket.value in "[{("
         opening_bracket.parent
         and opening_bracket.parent.type in {syms.atom, syms.import_from}
         and opening_bracket.value in "[{("
-        and line.bracket_tracker.delimiters
-        and line.bracket_tracker.max_delimiter_priority() == COMMA_PRIORITY
-    )
+    ):
+        return False
+
+    try:
+        last_leaf = line.leaves[-1]
+        exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set()
+        max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
+    except (IndexError, ValueError):
+        return False
+
+    return max_priority == COMMA_PRIORITY
 
 
 def is_python36(node: Node) -> bool:
 
 
 def is_python36(node: Node) -> bool:
@@ -2604,21 +2594,13 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     closing_bracket = None
     optional_brackets: Set[LeafID] = set()
     inner_brackets: Set[LeafID] = set()
     closing_bracket = None
     optional_brackets: Set[LeafID] = set()
     inner_brackets: Set[LeafID] = set()
-    for index, leaf in enumerate_reversed(line.leaves):
-        length += len(leaf.prefix) + len(leaf.value)
+    for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
+        length += leaf_length
         if length > line_length:
             break
 
         if length > line_length:
             break
 
-        comment: Optional[Leaf]
-        for comment in line.comments_after(leaf, index):
-            if "\n" in comment.prefix:
-                break  # Oops, standalone comment!
-
-            length += len(comment.value)
-        else:
-            comment = None
-        if comment is not None:
-            break  # There was a standalone comment, we can't continue.
+        if leaf.type == STANDALONE_COMMENT:
+            break
 
         optional_brackets.discard(id(leaf))
         if opening_bracket:
 
         optional_brackets.discard(id(leaf))
         if opening_bracket:
@@ -2940,6 +2922,32 @@ def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
         index -= 1
 
 
         index -= 1
 
 
+def enumerate_with_length(
+    line: Line, reversed: bool = False
+) -> Iterator[Tuple[Index, Leaf, int]]:
+    """Return an enumeration of leaves with their length.
+
+    Stops prematurely on multiline strings and standalone comments.
+    """
+    op = cast(
+        Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]],
+        enumerate_reversed if reversed else enumerate,
+    )
+    for index, leaf in op(line.leaves):
+        length = len(leaf.prefix) + len(leaf.value)
+        if "\n" in leaf.value:
+            return  # Multiline strings, we can't continue.
+
+        comment: Optional[Leaf]
+        for comment in line.comments_after(leaf, index):
+            if "\n" in comment.prefix:
+                return  # Oops, standalone comment!
+
+            length += len(comment.value)
+
+        yield index, leaf, length
+
+
 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
     """Return True if `line` is no longer than `line_length`.
 
 def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
     """Return True if `line` is no longer than `line_length`.
 
@@ -2954,7 +2962,93 @@ def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") ->
     )
 
 
     )
 
 
-CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
+    """Does `line` have a shape safe to reformat without optional parens around it?
+
+    Returns True for only a subset of potentially nice looking formattings but
+    the point is to not return false positives that end up producing lines that
+    are too long.
+    """
+    bt = line.bracket_tracker
+    if not bt.delimiters:
+        # Without delimiters the optional parentheses are useless.
+        return True
+
+    max_priority = bt.max_delimiter_priority()
+    if bt.delimiter_count_with_priority(max_priority) > 1:
+        # With more than one delimiter of a kind the optional parentheses read better.
+        return False
+
+    if max_priority == DOT_PRIORITY:
+        # A single stranded method call doesn't require optional parentheses.
+        return True
+
+    assert len(line.leaves) >= 2, "Stranded delimiter"
+
+    first = line.leaves[0]
+    second = line.leaves[1]
+    penultimate = line.leaves[-2]
+    last = line.leaves[-1]
+
+    # With a single delimiter, omit if the expression starts or ends with
+    # a bracket.
+    if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS:
+        remainder = False
+        length = 4 * line.depth
+        for _index, leaf, leaf_length in enumerate_with_length(line):
+            if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first:
+                remainder = True
+            if remainder:
+                length += leaf_length
+                if length > line_length:
+                    break
+
+                if leaf.type in OPENING_BRACKETS:
+                    # There are brackets we can further split on.
+                    remainder = False
+
+        else:
+            # checked the entire string and line length wasn't exceeded
+            if len(line.leaves) == _index + 1:
+                return True
+
+        # Note: we are not returning False here because a line might have *both*
+        # a leading opening bracket and a trailing closing bracket.  If the
+        # opening bracket doesn't match our rule, maybe the closing will.
+
+    if (
+        last.type == token.RPAR
+        or last.type == token.RBRACE
+        or (
+            # don't use indexing for omitting optional parentheses;
+            # it looks weird
+            last.type == token.RSQB
+            and last.parent
+            and last.parent.type != syms.trailer
+        )
+    ):
+        if penultimate.type in OPENING_BRACKETS:
+            # Empty brackets don't help.
+            return False
+
+        if is_multiline_string(first):
+            # Additional wrapping of a multiline string in this situation is
+            # unnecessary.
+            return True
+
+        length = 4 * line.depth
+        seen_other_brackets = False
+        for _index, leaf, leaf_length in enumerate_with_length(line):
+            length += leaf_length
+            if leaf is last.opening_bracket:
+                if seen_other_brackets or length <= line_length:
+                    return True
+
+            elif leaf.type in OPENING_BRACKETS:
+                # There are brackets we can further split on.
+                seen_other_brackets = True
+
+    return False
 
 
 def get_cache_file(line_length: int) -> Path:
 
 
 def get_cache_file(line_length: int) -> Path: