]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Restore ability to format code with legacy usage of `async` as a name
[etc/vim.git] / black.py
index b2b2db7e41192563f4eac9a37a3550584eeed437..74329d2389901323d2c743f9a55758c6107417f5 100644 (file)
--- a/black.py
+++ b/black.py
@@ -74,7 +74,9 @@ class CannotSplit(Exception):
 @click.argument(
     'src',
     nargs=-1,
-    type=click.Path(exists=True, file_okay=True, dir_okay=True, readable=True),
+    type=click.Path(
+        exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
+    ),
 )
 @click.pass_context
 def main(
@@ -89,6 +91,8 @@ def main(
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
+        elif s == '-':
+            sources.append(Path('-'))
         else:
             err(f'invalid path: {s}')
     if len(sources) == 0:
@@ -97,9 +101,12 @@ def main(
         p = sources[0]
         report = Report()
         try:
-            changed = format_file_in_place(
-                p, line_length=line_length, fast=fast, write_back=not check
-            )
+            if not p.is_file() and str(p) == '-':
+                changed = format_stdin_to_stdout(line_length=line_length, fast=fast)
+            else:
+                changed = format_file_in_place(
+                    p, line_length=line_length, fast=fast, write_back=not check
+                )
             report.done(p, changed)
         except Exception as exc:
             report.failed(p, str(exc))
@@ -156,34 +163,50 @@ def format_file_in_place(
     src: Path, line_length: int, fast: bool, write_back: bool = False
 ) -> bool:
     """Format the file and rewrite if changed. Return True if changed."""
+    with tokenize.open(src) as src_buffer:
+        src_contents = src_buffer.read()
     try:
-        contents, encoding = format_file(src, line_length=line_length, fast=fast)
+        contents = format_file_contents(
+            src_contents, line_length=line_length, fast=fast
+        )
     except NothingChanged:
         return False
 
     if write_back:
-        with open(src, "w", encoding=encoding) as f:
+        with open(src, "w", encoding=src_buffer.encoding) as f:
             f.write(contents)
     return True
 
 
-def format_file(
-    src: Path, line_length: int, fast: bool
-) -> Tuple[FileContent, Encoding]:
+def format_stdin_to_stdout(line_length: int, fast: bool) -> bool:
+    """Format file on stdin and pipe output to stdout. Return True if changed."""
+    contents = sys.stdin.read()
+    try:
+        contents = format_file_contents(contents, line_length=line_length, fast=fast)
+        return True
+
+    except NothingChanged:
+        return False
+
+    finally:
+        sys.stdout.write(contents)
+
+
+def format_file_contents(
+    src_contents: str, line_length: int, fast: bool
+) -> FileContent:
     """Reformats a file and returns its contents and encoding."""
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
     if src_contents.strip() == '':
-        raise NothingChanged(src)
+        raise NothingChanged
 
     dst_contents = format_str(src_contents, line_length=line_length)
     if src_contents == dst_contents:
-        raise NothingChanged(src)
+        raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
         assert_stable(src_contents, dst_contents, line_length=line_length)
-    return dst_contents, src_buffer.encoding
+    return dst_contents
 
 
 def format_str(src_contents: str, line_length: int) -> FileContent:
@@ -463,8 +486,7 @@ class Line:
         return (
             (first_leaf.type == token.NAME and first_leaf.value == 'def')
             or (
-                first_leaf.type == token.NAME
-                and first_leaf.value == 'async'
+                first_leaf.type == token.ASYNC
                 and second_leaf is not None
                 and second_leaf.type == token.NAME
                 and second_leaf.value == 'def'
@@ -793,7 +815,7 @@ class LineGenerator(Visitor[Line]):
         for child in children:
             yield from self.visit(child)
 
-            if child.type == token.NAME and child.value == 'async':  # type: ignore
+            if child.type == token.ASYNC:
                 break
 
         internal_stmt = next(children)