]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

acks += miggaiowski
[etc/vim.git] / black.py
index 43b9bd17f6f2dcd3b52648908e54239620ddd4b3..7823ae0afe2e809bdb2f5ca208ec9cc0565b4ced 100644 (file)
--- a/black.py
+++ b/black.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Optional,
     Pattern,
+    Sequence,
     Set,
     Tuple,
     Type,
@@ -41,6 +42,7 @@ from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
+
 __version__ = "18.4a6"
 DEFAULT_LINE_LENGTH = 88
 
@@ -1828,11 +1830,7 @@ def split_line(
         return
 
     line_str = str(line).strip("\n")
-    if (
-        len(line_str) <= line_length
-        and "\n" not in line_str  # multiline strings
-        and not line.contains_standalone_comments()
-    ):
+    if is_line_short_enough(line, line_length=line_length, line_str=line_str):
         yield line
         return
 
@@ -1841,10 +1839,22 @@ def split_line(
         split_funcs = [left_hand_split]
     elif line.is_import:
         split_funcs = [explode_split]
-    elif line.inside_brackets:
-        split_funcs = [delimiter_split, standalone_comment_split, right_hand_split]
     else:
-        split_funcs = [right_hand_split]
+
+        def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+            for omit in generate_trailers_to_omit(line, line_length):
+                lines = list(right_hand_split(line, py36, omit=omit))
+                if is_line_short_enough(lines[0], line_length=line_length):
+                    yield from lines
+                    return
+
+            # All splits failed, best effort split with no omits.
+            yield from right_hand_split(line, py36)
+
+        if line.inside_brackets:
+            split_funcs = [delimiter_split, standalone_comment_split, rhs]
+        else:
+            split_funcs = [rhs]
     for split_func in split_funcs:
         # We are accumulating lines in `result` because we might want to abort
         # mission and return the original line in the end, or attempt a different
@@ -1917,6 +1927,8 @@ def right_hand_split(
     """Split line into many lines, starting with the last matching bracket pair.
 
     If the split was by optional parentheses, attempt splitting without them, too.
+    `omit` is a collection of closing bracket IDs that shouldn't be considered for
+    this split.
     """
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
@@ -2446,6 +2458,67 @@ def is_python36(node: Node) -> bool:
     return False
 
 
+def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
+    """Generate sets of closing bracket IDs that should be omitted in a RHS.
+
+    Brackets can be omitted if the entire trailer up to and including
+    a preceding closing bracket fits in one line.
+
+    Yielded sets are cumulative (contain results of previous yields, too).  First
+    set is empty.
+    """
+
+    omit: Set[LeafID] = set()
+    yield omit
+
+    length = 4 * line.depth
+    opening_bracket = None
+    closing_bracket = None
+    optional_brackets: Set[LeafID] = set()
+    inner_brackets: Set[LeafID] = set()
+    for index, leaf in enumerate_reversed(line.leaves):
+        length += len(leaf.prefix) + len(leaf.value)
+        if length > line_length:
+            break
+
+        comment: Optional[Leaf]
+        for comment in line.comments_after(leaf, index):
+            if "\n" in comment.prefix:
+                break  # Oops, standalone comment!
+
+            length += len(comment.value)
+        else:
+            comment = None
+        if comment is not None:
+            break  # There was a standalone comment, we can't continue.
+
+        optional_brackets.discard(id(leaf))
+        if opening_bracket:
+            if leaf is opening_bracket:
+                opening_bracket = None
+            elif leaf.type in CLOSING_BRACKETS:
+                inner_brackets.add(id(leaf))
+        elif leaf.type in CLOSING_BRACKETS:
+            if not leaf.value:
+                optional_brackets.add(id(opening_bracket))
+                continue
+
+            if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
+                # Empty brackets would fail a split so treat them as "inner"
+                # brackets (e.g. only add them to the `omit` set if another
+                # pair of brackets was good enough.
+                inner_brackets.add(id(leaf))
+                continue
+
+            opening_bracket = leaf.opening_bracket
+            if closing_bracket:
+                omit.add(id(closing_bracket))
+                omit.update(inner_brackets)
+                inner_brackets.clear()
+                yield omit
+            closing_bracket = leaf
+
+
 def get_future_imports(node: Node) -> Set[str]:
     """Return a set of __future__ imports in the file."""
     imports = set()
@@ -2723,6 +2796,28 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
     return regex.sub(replacement, regex.sub(replacement, original))
 
 
+def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
+    """Like `reversed(enumerate(sequence))` if that were possible."""
+    index = len(sequence) - 1
+    for element in reversed(sequence):
+        yield (index, element)
+        index -= 1
+
+
+def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> bool:
+    """Return True if `line` is no longer than `line_length`.
+
+    Uses the provided `line_str` rendering, if any, otherwise computes a new one.
+    """
+    if not line_str:
+        line_str = str(line).strip("\n")
+    return (
+        len(line_str) <= line_length
+        and "\n" not in line_str  # multiline strings
+        and not line.contains_standalone_comments()
+    )
+
+
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))