]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Ignore empty bracket pairs while splitting
[etc/vim.git] / black.py
index d48fddbe6526d6dc2753a71c482e53afe8c2f9ee..b343da7b6d7363e52bc4a29e8715b247886a6a02 100644 (file)
--- a/black.py
+++ b/black.py
@@ -7,6 +7,7 @@ import keyword
 import os
 from pathlib import Path
 import tokenize
+import sys
 from typing import (
     Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union
 )
@@ -20,7 +21,7 @@ from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
-__version__ = "18.3a1"
+__version__ = "18.3a2"
 DEFAULT_LINE_LENGTH = 88
 # types
 syms = pygram.python_symbols
@@ -55,6 +56,15 @@ class CannotSplit(Exception):
     help='How many character per line to allow.',
     show_default=True,
 )
+@click.option(
+    '--check',
+    is_flag=True,
+    help=(
+        "Don't write back the files, just return the status.  Return code 0 "
+        "means nothing changed.  Return code 1 means some files were "
+        "reformatted.  Return code 123 means there was an internal error."
+    ),
+)
 @click.option(
     '--fast/--safe',
     is_flag=True,
@@ -67,7 +77,9 @@ class CannotSplit(Exception):
     type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
 )
 @click.pass_context
-def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> None:
+def main(
+    ctx: click.Context, line_length: int, check: bool, fast: bool, src: List[str]
+) -> None:
     """The uncompromising code formatter."""
     sources: List[Path] = []
     for s in src:
@@ -85,7 +97,9 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No
         p = sources[0]
         report = Report()
         try:
-            changed = format_file_in_place(p, line_length=line_length, fast=fast)
+            changed = format_file_in_place(
+                p, line_length=line_length, fast=fast, write_back=not check
+            )
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
@@ -96,7 +110,9 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No
         return_code = 1
         try:
             return_code = loop.run_until_complete(
-                schedule_formatting(sources, line_length, fast, loop, executor)
+                schedule_formatting(
+                    sources, line_length, not check, fast, loop, executor
+                )
             )
         finally:
             loop.close()
@@ -106,13 +122,14 @@ def main(ctx: click.Context, line_length: int, fast: bool, src: List[str]) -> No
 async def schedule_formatting(
     sources: List[Path],
     line_length: int,
+    write_back: bool,
     fast: bool,
     loop: BaseEventLoop,
     executor: Executor,
 ) -> int:
     tasks = {
         src: loop.run_in_executor(
-            executor, format_file_in_place, src, line_length, fast
+            executor, format_file_in_place, src, line_length, fast, write_back
         )
         for src in sources
     }
@@ -135,15 +152,18 @@ async def schedule_formatting(
     return report.return_code
 
 
-def format_file_in_place(src: Path, line_length: int, fast: bool) -> bool:
+def format_file_in_place(
+    src: Path, line_length: int, fast: bool, write_back: bool = False
+) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
     try:
         contents, encoding = format_file(src, line_length=line_length, fast=fast)
     except NothingChanged:
         return False
 
-    with open(src, "w", encoding=encoding) as f:
-        f.write(contents)
+    if write_back:
+        with open(src, "w", encoding=encoding) as f:
+            f.write(contents)
     return True
 
 
@@ -173,6 +193,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
     comments: List[Line] = []
     lines = LineGenerator()
     elt = EmptyLineTracker()
+    py36 = is_python36(src_node)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -185,7 +206,7 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
             for comment in comments:
                 dst_contents += str(comment)
             comments = []
-            for line in split_line(current_line, line_length=line_length):
+            for line in split_line(current_line, line_length=line_length, py36=py36):
                 dst_contents += str(line)
         else:
             comments.append(current_line)
@@ -326,8 +347,8 @@ class BracketTracker:
         if leaf.type in CLOSING_BRACKETS:
             self.depth -= 1
             opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
-            leaf.opening_bracket = opening_bracket  # type: ignore
-        leaf.bracket_depth = self.depth  # type: ignore
+            leaf.opening_bracket = opening_bracket
+        leaf.bracket_depth = self.depth
         if self.depth == 0:
             delim = is_delimiter(leaf)
             if delim:
@@ -373,6 +394,8 @@ class Line:
     comments: Dict[LeafID, Leaf] = attrib(default=Factory(dict))
     bracket_tracker: BracketTracker = attrib(default=Factory(BracketTracker))
     inside_brackets: bool = attrib(default=False)
+    has_for: bool = attrib(default=False)
+    _for_loop_variable: bool = attrib(default=False, init=False)
 
     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
         has_value = leaf.value.strip()
@@ -384,8 +407,10 @@ class Line:
             # imports, for which we only preserve newlines.
             leaf.prefix += whitespace(leaf)
         if self.inside_brackets or not preformatted:
+            self.maybe_decrement_after_for_loop_variable(leaf)
             self.bracket_tracker.mark(leaf)
             self.maybe_remove_trailing_comma(leaf)
+            self.maybe_increment_for_loop_variable(leaf)
             if self.maybe_adapt_standalone_comment(leaf):
                 return
 
@@ -466,9 +491,9 @@ class Line:
         # For parens let's check if it's safe to remove the comma.  If the
         # trailing one is the only one, we might mistakenly change a tuple
         # into a different type by removing the comma.
-        depth = closing.bracket_depth + 1  # type: ignore
+        depth = closing.bracket_depth + 1
         commas = 0
-        opening = closing.opening_bracket  # type: ignore
+        opening = closing.opening_bracket
         for _opening_index, leaf in enumerate(self.leaves):
             if leaf is opening:
                 break
@@ -480,7 +505,7 @@ class Line:
             if leaf is closing:
                 break
 
-            bracket_depth = leaf.bracket_depth  # type: ignore
+            bracket_depth = leaf.bracket_depth
             if bracket_depth == depth and leaf.type == token.COMMA:
                 commas += 1
         if commas > 1:
@@ -489,6 +514,29 @@ class Line:
 
         return False
 
+    def maybe_increment_for_loop_variable(self, leaf: Leaf) -> bool:
+        """In a for loop, or comprehension, the variables are often unpacks.
+
+        To avoid splitting on the comma in this situation, we will increase
+        the depth of tokens between `for` and `in`.
+        """
+        if leaf.type == token.NAME and leaf.value == 'for':
+            self.has_for = True
+            self.bracket_tracker.depth += 1
+            self._for_loop_variable = True
+            return True
+
+        return False
+
+    def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
+        # See `maybe_increment_for_loop_variable` above for explanation.
+        if self._for_loop_variable and leaf.type == token.NAME and leaf.value == 'in':
+            self.bracket_tracker.depth -= 1
+            self._for_loop_variable = False
+            return True
+
+        return False
+
     def maybe_adapt_standalone_comment(self, comment: Leaf) -> bool:
         """Hack a standalone comment to act as a trailing comment for line splitting.
 
@@ -756,9 +804,10 @@ BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.R
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
+ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
 
 
-def whitespace(leaf: Leaf) -> str:
+def whitespace(leaf: Leaf) -> str:  # noqa C901
     """Return whitespace prefix if needed for the given `leaf`."""
     NO = ''
     SPACE = ' '
@@ -766,24 +815,12 @@ def whitespace(leaf: Leaf) -> str:
     t = leaf.type
     p = leaf.parent
     v = leaf.value
-    if t == token.COLON:
-        return NO
-
-    if t == token.COMMA:
-        return NO
-
-    if t == token.RPAR:
+    if t in ALWAYS_NO_SPACE:
         return NO
 
     if t == token.COMMENT:
         return DOUBLESPACE
 
-    if t == STANDALONE_COMMENT:
-        return NO
-
-    if t in CLOSING_BRACKETS:
-        return NO
-
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
     prev = leaf.prev_sibling
     if not prev:
@@ -815,7 +852,7 @@ def whitespace(leaf: Leaf) -> str:
             if prevp.parent and prevp.parent.type == syms.subscript:
                 return NO
 
-        elif prevp.parent and prevp.parent.type == syms.factor:
+        elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
             return NO
 
     elif prev.type in OPENING_BRACKETS:
@@ -912,7 +949,14 @@ def whitespace(leaf: Leaf) -> str:
 
     elif p.type == syms.subscript:
         # indexing
-        if not prev or prev.type == token.COLON:
+        if not prev:
+            assert p.parent is not None, "subscripts are always parented"
+            if p.parent.type == syms.subscriptlist:
+                return SPACE
+
+            return NO
+
+        elif prev.type == token.COLON:
             return NO
 
     elif p.type == syms.atom:
@@ -937,7 +981,7 @@ def whitespace(leaf: Leaf) -> str:
         if prev.type == token.DOUBLESTAR:
             return NO
 
-    elif p.type == syms.factor or p.type == syms.star_expr:
+    elif p.type in {syms.factor, syms.star_expr}:
         # unary ops
         if not prev:
             prevp = preceding_leaf(p)
@@ -1062,13 +1106,18 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
         yield Leaf(STANDALONE_COMMENT, line)
 
 
-def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Line]:
+def split_line(
+    line: Line, line_length: int, inner: bool = False, py36: bool = False
+) -> Iterator[Line]:
     """Splits a `line` into potentially many lines.
 
     They should fit in the allotted `line_length` but might not be able to.
     `inner` signifies that there were a pair of brackets somewhere around the
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
+
+    If `py36` is True, splitting may generate syntax that is only compatible
+    with Python 3.6 and later.
     """
     line_str = str(line).strip('\n')
     if len(line_str) <= line_length and '\n' not in line_str:
@@ -1091,11 +1140,13 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line):
+            for l in split_func(line, py36=py36):
                 if str(l).strip('\n') == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
-                result.extend(split_line(l, line_length=line_length, inner=True))
+                result.extend(
+                    split_line(l, line_length=line_length, inner=True, py36=py36)
+                )
         except CannotSplit as cs:
             continue
 
@@ -1107,7 +1158,7 @@ def split_line(line: Line, line_length: int, inner: bool = False) -> Iterator[Li
         yield line
 
 
-def left_hand_split(line: Line) -> Iterator[Line]:
+def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
@@ -1125,9 +1176,9 @@ def left_hand_split(line: Line) -> Iterator[Line]:
         if (
             current_leaves is body_leaves and
             leaf.type in CLOSING_BRACKETS and
-            leaf.opening_bracket is matching_bracket  # type: ignore
+            leaf.opening_bracket is matching_bracket
         ):
-            current_leaves = tail_leaves
+            current_leaves = tail_leaves if body_leaves else head_leaves
         current_leaves.append(leaf)
         if current_leaves is head_leaves:
             if leaf.type in OPENING_BRACKETS:
@@ -1145,24 +1196,13 @@ def left_hand_split(line: Line) -> Iterator[Line]:
             comment_after = line.comments.get(id(leaf))
             if comment_after:
                 result.append(comment_after, preformatted=True)
-    # Check if the split succeeded.
-    tail_len = len(str(tail))
-    if not body:
-        if tail_len == 0:
-            raise CannotSplit("Splitting brackets produced the same line")
-
-        elif tail_len < 3:
-            raise CannotSplit(
-                f"Splitting brackets on an empty body to save "
-                f"{tail_len} characters is not worth it"
-            )
-
+    split_succeeded_or_raise(head, body, tail)
     for result in (head, body, tail):
         if result:
             yield result
 
 
-def right_hand_split(line: Line) -> Iterator[Line]:
+def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair."""
     head = Line(depth=line.depth)
     body = Line(depth=line.depth + 1, inside_brackets=True)
@@ -1175,11 +1215,11 @@ def right_hand_split(line: Line) -> Iterator[Line]:
     for leaf in reversed(line.leaves):
         if current_leaves is body_leaves:
             if leaf is opening_bracket:
-                current_leaves = head_leaves
+                current_leaves = head_leaves if body_leaves else tail_leaves
         current_leaves.append(leaf)
         if current_leaves is tail_leaves:
             if leaf.type in CLOSING_BRACKETS:
-                opening_bracket = leaf.opening_bracket  # type: ignore
+                opening_bracket = leaf.opening_bracket
                 current_leaves = body_leaves
     tail_leaves.reverse()
     body_leaves.reverse()
@@ -1196,8 +1236,14 @@ def right_hand_split(line: Line) -> Iterator[Line]:
             comment_after = line.comments.get(id(leaf))
             if comment_after:
                 result.append(comment_after, preformatted=True)
-    # Check if the split succeeded.
-    tail_len = len(str(tail).strip('\n'))
+    split_succeeded_or_raise(head, body, tail)
+    for result in (head, body, tail):
+        if result:
+            yield result
+
+
+def split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None:
+    tail_len = len(str(tail).strip())
     if not body:
         if tail_len == 0:
             raise CannotSplit("Splitting brackets produced the same line")
@@ -1208,15 +1254,13 @@ def right_hand_split(line: Line) -> Iterator[Line]:
                 f"{tail_len} characters is not worth it"
             )
 
-    for result in (head, body, tail):
-        if result:
-            yield result
-
 
-def delimiter_split(line: Line) -> Iterator[Line]:
+def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
     This kind of split doesn't increase indentation.
+    If `py36` is True, the split will add trailing commas also in function
+    signatures that contain * and **.
     """
     try:
         last_leaf = line.leaves[-1]
@@ -1230,11 +1274,20 @@ def delimiter_split(line: Line) -> Iterator[Line]:
         raise CannotSplit("No delimiters found")
 
     current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+    lowest_depth = sys.maxsize
+    trailing_comma_safe = True
     for leaf in line.leaves:
         current_line.append(leaf, preformatted=True)
         comment_after = line.comments.get(id(leaf))
         if comment_after:
             current_line.append(comment_after, preformatted=True)
+        lowest_depth = min(lowest_depth, leaf.bracket_depth)
+        if (
+            leaf.bracket_depth == lowest_depth and
+            leaf.type == token.STAR or
+            leaf.type == token.DOUBLESTAR
+        ):
+            trailing_comma_safe = trailing_comma_safe and py36
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             normalize_prefix(current_line.leaves[0])
@@ -1244,7 +1297,8 @@ def delimiter_split(line: Line) -> Iterator[Line]:
     if current_line:
         if (
             delimiter_priority == COMMA_PRIORITY and
-            current_line.leaves[-1].type != token.COMMA
+            current_line.leaves[-1].type != token.COMMA and
+            trailing_comma_safe
         ):
             current_line.append(Leaf(token.COMMA, ','))
         normalize_prefix(current_line.leaves[0])
@@ -1279,6 +1333,31 @@ def normalize_prefix(leaf: Leaf) -> None:
     leaf.prefix = ''
 
 
+def is_python36(node: Node) -> bool:
+    """Returns True if the current file is using Python 3.6+ features.
+
+    Currently looking for:
+    - f-strings; and
+    - trailing commas after * or ** in function signatures.
+    """
+    for n in node.pre_order():
+        if n.type == token.STRING:
+            value_head = n.value[:2]  # type: ignore
+            if value_head in {'f"', 'F"', "f'", "F'", 'rf', 'fr', 'RF', 'FR'}:
+                return True
+
+        elif (
+            n.type == syms.typedargslist and
+            n.children and
+            n.children[-1].type == token.COMMA
+        ):
+            for ch in n.children:
+                if ch.type == token.STAR or ch.type == token.DOUBLESTAR:
+                    return True
+
+    return False
+
+
 PYTHON_EXTENSIONS = {'.py'}
 BLACKLISTED_DIRECTORIES = {
     'build', 'buck-out', 'dist', '_build', '.git', '.hg', '.mypy_cache', '.tox', '.venv'
@@ -1321,7 +1400,15 @@ class Report:
     @property
     def return_code(self) -> int:
         """Which return code should the app use considering the current state."""
-        return 1 if self.failure_count else 0
+        # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
+        # 126 we have special returncodes reserved by the shell.
+        if self.failure_count:
+            return 123
+
+        elif self.change_count:
+            return 1
+
+        return 0
 
     def __str__(self) -> str:
         """A color report of the current state.