]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add piping from stdin to stdout with a - (#25)
[etc/vim.git] / black.py
index 89155f60103a32e2f89cd86ac1e6c92aa1bed22f..6bfef500d6d44ea79d8d1eca2030127c7433339d 100644 (file)
--- a/black.py
+++ b/black.py
@@ -74,7 +74,9 @@ class CannotSplit(Exception):
 @click.argument(
     'src',
     nargs=-1,
 @click.argument(
     'src',
     nargs=-1,
-    type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
+    type=click.Path(
+        exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
+    ),
 )
 @click.pass_context
 def main(
 )
 @click.pass_context
 def main(
@@ -89,6 +91,8 @@ def main(
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
+        elif s == '-':
+            sources.append(Path('-'))
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
@@ -97,9 +101,12 @@ def main(
         p = sources[0]
         report = Report()
         try:
         p = sources[0]
         report = Report()
         try:
-            changed = format_file_in_place(
-                p, line_length=line_length, fast=fast, write_back=not check
-            )
+            if not p.is_file() and str(p) == '-':
+                changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
+            else:
+                changed = format_file_in_place(
+                    p, line_length=line_length, fast=fast, write_back=not check
+                )
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
@@ -156,34 +163,50 @@ def format_file_in_place(
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
+    with tokenize.open(src) as src_buffer:
+        src_contents = src_buffer.read()
     try:
     try:
-        contents, encoding = format_file(src, line_length=line_length, fast=fast)
+        contents = format_file_contents(
+            src_contents, line_length=line_length, fast=fast
+        )
     except NothingChanged:
         return False
 
     if write_back:
     except NothingChanged:
         return False
 
     if write_back:
-        with open(src, "w", encoding=encoding) as f:
+        with open(src, "w", encoding=src_buffer.encoding) as f:
             f.write(contents)
     return True
 
 
             f.write(contents)
     return True
 
 
-def format_file(
-    src: Path, line_length: int, fast: bool
-) -> Tuple[FileContent, Encoding]:
+def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
+    """Format file on stdin and pipe output to stdout. Return True if changed."""
+    contents = sys.stdin.read()
+    try:
+        contents = format_file_contents(contents, line_length=line_length, fast=fast)
+        return True
+
+    except NothingChanged:
+        return False
+
+    finally:
+        sys.stdout.write(contents)
+
+
+def format_file_contents(
+    src_contents: str, line_length: int, fast: bool
+) -> FileContent:
     """Reformats a file and returns its contents and encoding."""
     """Reformats a file and returns its contents and encoding."""
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
     if src_contents.strip() == '':
     if src_contents.strip() == '':
-        raise NothingChanged(src)
+        raise NothingChanged
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
-        raise NothingChanged(src)
+        raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
-    return dst_contents, src_buffer.encoding
+    return dst_contents
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
@@ -831,7 +854,7 @@ BRACKET = {token.LPAR: token.RPAR, token.LSQB: token.RSQB, token.LBRACE: token.R
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 OPENING_BRACKETS = set(BRACKET.keys())
 CLOSING_BRACKETS = set(BRACKET.values())
 BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
-ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, token.COLON, STANDALONE_COMMENT}
+ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
 def whitespace(leaf: Leaf) -> str:  # noqa C901
 
 
 def whitespace(leaf: Leaf) -> str:  # noqa C901
@@ -849,12 +872,18 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
         return DOUBLESPACE
 
     assert p is not None, f"INTERNAL ERROR: hand-made leaf without parent: {leaf!r}"
+    if t == token.COLON and p.type != syms.subscript:
+        return NO
+
     prev = leaf.prev_sibling
     if not prev:
         prevp = preceding_leaf(p)
         if not prevp or prevp.type in OPENING_BRACKETS:
             return NO
 
     prev = leaf.prev_sibling
     if not prev:
         prevp = preceding_leaf(p)
         if not prevp or prevp.type in OPENING_BRACKETS:
             return NO
 
+        if t == token.COLON:
+            return SPACE if prevp.type == token.COMMA else NO
+
         if prevp.type == token.EQUAL:
             if prevp.parent and prevp.parent.type in {
                 syms.typedargslist,
         if prevp.type == token.EQUAL:
             if prevp.parent and prevp.parent.type in {
                 syms.typedargslist,
@@ -983,7 +1012,7 @@ def whitespace(leaf: Leaf) -> str:  # noqa C901
 
             return NO
 
 
             return NO
 
-        elif prev.type == token.COLON:
+        else:
             return NO
 
     elif p.type == syms.atom:
             return NO
 
     elif p.type == syms.atom:
@@ -1115,7 +1144,7 @@ def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
     if content and (content[0] not in {' ', '!', '#'}):
         content = ' ' + content
     is_standalone_comment = (
     if content and (content[0] not in {' ', '!', '#'}):
         content = ' ' + content
     is_standalone_comment = (
-        '\n' in before_comment or '\n' in content or leaf.type == token.DEDENT
+        '\n' in before_comment or '\n' in content or leaf.type == token.ENDMARKER
     )
     if not is_standalone_comment:
         # simple trailing comment
     )
     if not is_standalone_comment:
         # simple trailing comment