]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Don't omit whitespace when the factor is not a math operator
[etc/vim.git] / black.py
index 89155f60103a32e2f89cd86ac1e6c92aa1bed22f..0dd763073622d2ec1d0234b46cf5cd4e42bfe607 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,4 +1,5 @@
 #!/usr/bin/env python3
+
 import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
@@ -21,7 +22,7 @@ from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
-__version__ = "18.3a2"
+__version__ = "18.3a3"
 DEFAULT_LINE_LENGTH = 88
 # types
 syms = pygram.python_symbols
@@ -74,7 +75,9 @@ class CannotSplit(Exception):
 @click.argument(
     'src',
     nargs=-1,
-    type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
+    type=click.Path(
+        exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
+    ),
 )
 @click.pass_context
 def main(
@@ -89,6 +92,8 @@ def main(
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
+        elif s == '-':
+            sources.append(Path('-'))
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
@@ -97,9 +102,14 @@ def main(
         p = sources[0]
         report = Report()
         try:
-            changed = format_file_in_place(
-                p, line_length=line_length, fast=fast, write_back=not check
-            )
+            if not p.is_file() and str(p) == '-':
+                changed = format_stdin_to_stdout(
+                    line_length=line_length, fast=fast, write_back=not check
+                )
+            else:
+                changed = format_file_in_place(
+                    p, line_length=line_length, fast=fast, write_back=not check
+                )
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
@@ -156,41 +166,59 @@ def format_file_in_place(
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
+    with tokenize.open(src) as src_buffer:
+        src_contents = src_buffer.read()
     try:
-        contents, encoding = format_file(src, line_length=line_length, fast=fast)
+        contents = format_file_contents(
+            src_contents, line_length=line_length, fast=fast
+        )
     except NothingChanged:
         return False
 
     if write_back:
-        with open(src, "w", encoding=encoding) as f:
+        with open(src, "w", encoding=src_buffer.encoding) as f:
             f.write(contents)
     return True
 
 
-def format_file(
-    src: Path, line_length: int, fast: bool
-) -> Tuple[FileContent, Encoding]:
+def format_stdin_to_stdout(
+    line_length: int, fast: bool, write_back: bool = False
+) -> bool:
+    """Format file on stdin and pipe output to stdout. Return True if changed."""
+    contents = sys.stdin.read()
+    try:
+        contents = format_file_contents(contents, line_length=line_length, fast=fast)
+        return True
+
+    except NothingChanged:
+        return False
+
+    finally:
+        if write_back:
+            sys.stdout.write(contents)
+
+
+def format_file_contents(
+    src_contents: str, line_length: int, fast: bool
+) -> FileContent:
     """Reformats a file and returns its contents and encoding."""
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
     if src_contents.strip() == '':
-        raise NothingChanged(src)
+        raise NothingChanged
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
-        raise NothingChanged(src)
+        raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
-    return dst_contents, src_buffer.encoding
+    return dst_contents
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
     """Reformats a string and returns new contents."""
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
-    comments: List[Line] = []
     lines = LineGenerator()
     elt = EmptyLineTracker()
     py36 = is_python36(src_node)
@@ -202,21 +230,8 @@ def format_str(src_contents: str, line_length: int) -> FileContent:
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
-        if not current_line.is_comment:
-            for comment in comments:
-                dst_contents += str(comment)
-            comments = []
-            for line in split_line(current_line, line_length=line_length, py36=py36):
-                dst_contents += str(line)
-        else:
-            comments.append(current_line)
-    if comments:
-        if elt.previous_defs:
-            # Separate postscriptum comments from the last module-level def.
-            dst_contents += str(empty_line)
-            dst_contents += str(empty_line)
-        for comment in comments:
-            dst_contents += str(comment)
+        for line in split_line(current_line, line_length=line_length, py36=py36):
+            dst_contents += str(line)
     return dst_contents
 
 
@@ -325,6 +340,7 @@ MATH_OPERATORS = {
     token.AMPER,
     token.PERCENT,
     token.CIRCUMFLEX,
+    token.TILDE,
     token.LEFTSHIFT,
     token.RIGHTSHIFT,
     token.DOUBLESTAR,
@@ -463,8 +479,7 @@ class Line:
         return (
             (first_leaf.type == token.NAME and first_leaf.value == 'def')
             or (
-                first_leaf.type == token.NAME
-                and first_leaf.value == 'async'
+                first_leaf.type == token.ASYNC
                 and second_leaf is not None
                 and second_leaf.type == token.NAME
                 and second_leaf.value == 'def'
@@ -635,10 +650,6 @@ class EmptyLineTracker:
         (two on module-level), as well as providing an extra empty line after flow
         control keywords to make them more prominent.
         """
-        if current_line.is_comment:
-            # Don't count standalone comments towards previous empty lines.
-            return 0, 0
-
         before, after = self._maybe_empty_lines(current_line)
         before -= self.previous_after
         self.previous_after = after
@@ -646,10 +657,14 @@ class EmptyLineTracker:
         return before, after
 
     def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
+        max_allowed = 1
+        if current_line.is_comment and current_line.depth == 0:
+            max_allowed = 2
         if current_line.leaves:
             # Consume the first leaf's extra newlines.
             first_leaf = current_line.leaves[0]
-            before = int('\n' in first_leaf.prefix)
+            before = first_leaf.prefix.count('\n')
+            before = min(before, max(before, max_allowed))
             first_leaf.prefix = ''
         else:
             before = 0
@@ -703,7 +718,6 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
     current_line: Line = Factory(Line)
-    standalone_comments: List[Leaf] = Factory(list)
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -723,8 +737,9 @@ class LineGenerator(Visitor[Line]):
 
     def visit_default(self, node: LN) -> Iterator[Line]:
         if isinstance(node, Leaf):
+            any_open_brackets = self.current_line.bracket_tracker.any_open_brackets()
             for comment in generate_comments(node):
-                if self.current_line.bracket_tracker.any_open_brackets():
+                if any_open_brackets:
                     # any comment within brackets is subject to splitting
                     self.current_line.append(comment)
                 elif comment.type == token.COMMENT:
@@ -733,33 +748,22 @@ class LineGenerator(Visitor[Line]):
                     yield from self.line()
 
                 else:
-                    # regular standalone comment, to be processed later (see
-                    # docstring in `generate_comments()`
-                    self.standalone_comments.append(comment)
-            normalize_prefix(node)
-            if node.type not in WHITESPACE:
-                for comment in self.standalone_comments:
+                    # regular standalone comment
                     yield from self.line()
 
                     self.current_line.append(comment)
                     yield from self.line()
 
-                self.standalone_comments = []
+            normalize_prefix(node, inside_brackets=any_open_brackets)
+            if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
-    def visit_suite(self, node: Node) -> Iterator[Line]:
-        """Body of a statement after a colon."""
-        children = iter(node.children)
-        # Process newline before indenting.  It might contain an inline
-        # comment that should go right after the colon.
-        newline = next(children)
-        yield from self.visit(newline)
+    def visit_INDENT(self, node: Node) -> Iterator[Line]:
         yield from self.line(+1)
+        yield from self.visit_default(node)
 
-        for child in children:
-            yield from self.visit(child)
-
+    def visit_DEDENT(self, node: Node) -> Iterator[Line]:
         yield from self.line(-1)
 
     def visit_stmt(self, node: Node, keywords: Set[str]) -> Iterator[Line]:
@@ -793,7 +797,7 @@ class LineGenerator(Visitor[Line]):
         for child in children:
             yield from self.visit(child)
 
-            if child.type == token.NAME and child.value == 'async':  # type: ignore
+            if child.type == token.ASYNC:
                 break
 
         internal_stmt = next(children)
@@ -831,7 +835,7 @@ BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.R
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
-ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
+ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
 def whitespace(leaf: Leaf) -> str:  # noqa C901
@@ -849,12 +853,18 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
+    if t == token.COLON and p.type not in {syms.subscript, syms.subscriptlist}:
+        return NO
+
     prev = leaf.prev_sibling
     if not prev:
         prevp = preceding_leaf(p)
         if not prevp or prevp.type in OPENING_BRACKETS:
             return NO
 
+        if t == token.COLON:
+            return SPACE if prevp.type == token.COMMA else NO
+
         if prevp.type == token.EQUAL:
             if prevp.parent and prevp.parent.type in {
                 syms.typedargslist,
@@ -876,10 +886,14 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
                 return NO
 
         elif prevp.type == token.COLON:
-            if prevp.parent and prevp.parent.type == syms.subscript:
+            if prevp.parent and prevp.parent.type in {syms.subscript, syms.sliceop}:
                 return NO
 
-        elif prevp.parent and prevp.parent.type in {syms.factor, syms.star_expr}:
+        elif (
+            prevp.parent
+            and prevp.parent.type in {syms.factor, syms.star_expr}
+            and prevp.type in MATH_OPERATORS
+        ):
             return NO
 
     elif prev.type in OPENING_BRACKETS:
@@ -983,7 +997,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
 
             return NO
 
-        elif prev.type == token.COLON:
+        else:
             return NO
 
     elif p.type == syms.atom:
@@ -1104,30 +1118,40 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
     are emitted with a fake STANDALONE_COMMENT token identifier.
     """
-    if not leaf.prefix:
+    p = leaf.prefix
+    if not p:
         return
 
-    if '#' not in leaf.prefix:
-        return
-
-    before_comment, content = leaf.prefix.split('#', 1)
-    content = content.rstrip()
-    if content and (content[0] not in {' ', '!', '#'}):
-        content = ' ' + content
-    is_standalone_comment = (
-        '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
-    )
-    if not is_standalone_comment:
-        # simple trailing comment
-        yield Leaf(token.COMMENT, value='#' + content)
+    if '#' not in p:
         return
 
-    for line in ('#' + content).split('\n'):
+    nlines = 0
+    for index, line in enumerate(p.split('\n')):
         line = line.lstrip()
+        if not line:
+            nlines += 1
         if not line.startswith('#'):
             continue
 
-        yield Leaf(STANDALONE_COMMENT, line)
+        if index == 0 and leaf.type != token.ENDMARKER:
+            comment_type = token.COMMENT  # simple trailing comment
+        else:
+            comment_type = STANDALONE_COMMENT
+        yield Leaf(comment_type, make_comment(line), prefix='\n' * nlines)
+
+        nlines = 0
+
+
+def make_comment(content: str) -> str:
+    content = content.rstrip()
+    if not content:
+        return '#'
+
+    if content[0] == '#':
+        content = content[1:]
+    if content and content[0] not in {' ', '!', '#'}:
+        content = ' ' + content
+    return '#' + content
 
 
 def split_line(
@@ -1210,7 +1234,7 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
                 current_leaves = body_leaves
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
-        normalize_prefix(body_leaves[0])
+        normalize_prefix(body_leaves[0], inside_brackets=True)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
@@ -1250,7 +1274,7 @@ def right_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
     head_leaves.reverse()
     # Since body is a new indent level, remove spurious leading whitespace.
     if body_leaves:
-        normalize_prefix(body_leaves[0])
+        normalize_prefix(body_leaves[0], inside_brackets=True)
     # Build the new lines.
     for result, leaves in (
         (head, head_leaves), (body, body_leaves), (tail, tail_leaves)
@@ -1314,7 +1338,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             trailing_comma_safe = trailing_comma_safe and py36
         leaf_priority = delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
-            normalize_prefix(current_line.leaves[0])
+            normalize_prefix(current_line.leaves[0], inside_brackets=True)
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
@@ -1325,7 +1349,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             and trailing_comma_safe
         ):
             current_line.append(Leaf(token.COMMA, ','))
-        normalize_prefix(current_line.leaves[0])
+        normalize_prefix(current_line.leaves[0], inside_brackets=True)
         yield current_line
 
 
@@ -1343,13 +1367,20 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
-def normalize_prefix(leaf: Leaf) -> None:
-    """Leave existing extra newlines for imports.  Remove everything else."""
-    if is_import(leaf):
-        spl = leaf.prefix.split('#', 1)
-        nl_count = spl[0].count('\n')
-        leaf.prefix = '\n' * nl_count
-        return
+def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
+    """Leave existing extra newlines if not `inside_brackets`.
+
+    Remove everything else.  Note: don't use backslashes for formatting or
+    you'll lose your voting rights.
+    """
+    if not inside_brackets:
+        spl = leaf.prefix.split('#')
+        if '\\' not in spl[0]:
+            nl_count = spl[-1].count('\n')
+            if len(spl) > 1:
+                nl_count -= 1
+            leaf.prefix = '\n' * nl_count
+            return
 
     leaf.prefix = ''