]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix long trivial assignments being wrapped in unnecessary parentheses
[etc/vim.git] / black.py
index c8c381c2d30ff3fba658299c49de961542490d68..35af598e59032c62397bff766d06b006df41174c 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,7 +1,7 @@
 import asyncio
 import asyncio
-import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
+from datetime import datetime
 from enum import Enum, Flag
 from functools import partial, wraps
 import io
 from enum import Enum, Flag
 from functools import partial, wraps
 import io
@@ -10,10 +10,11 @@ import logging
 from multiprocessing import Manager
 import os
 from pathlib import Path
 from multiprocessing import Manager
 import os
 from pathlib import Path
+import pickle
 import re
 import re
-import tokenize
 import signal
 import sys
 import signal
 import sys
+import tokenize
 from typing import (
     Any,
     Callable,
 from typing import (
     Any,
     Callable,
@@ -57,6 +58,7 @@ CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 # types
 FileContent = str
 Encoding = str
 # types
 FileContent = str
 Encoding = str
+NewLine = str
 Depth = int
 NodeType = int
 LeafID = int
 Depth = int
 NodeType = int
 LeafID = int
@@ -278,7 +280,7 @@ def main(
         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
     )
     report = Report(check=check, quiet=quiet, verbose=verbose)
         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
     )
     report = Report(check=check, quiet=quiet, verbose=verbose)
-    sources: List[Path] = []
+    sources: Set[Path] = set()
     try:
         include_regex = re.compile(include)
     except re.error:
     try:
         include_regex = re.compile(include)
     except re.error:
@@ -293,12 +295,12 @@ def main(
     for s in src:
         p = Path(s)
         if p.is_dir():
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(
+            sources.update(
                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
             )
         elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
             )
         elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
-            sources.append(p)
+            sources.add(p)
         else:
             err(f"invalid path: {s}")
     if len(sources) == 0:
         else:
             err(f"invalid path: {s}")
     if len(sources) == 0:
@@ -309,7 +311,7 @@ def main(
 
     elif len(sources) == 1:
         reformat_one(
 
     elif len(sources) == 1:
         reformat_one(
-            src=sources[0],
+            src=sources.pop(),
             line_length=line_length,
             fast=fast,
             write_back=write_back,
             line_length=line_length,
             fast=fast,
             write_back=write_back,
@@ -334,9 +336,9 @@ def main(
             )
         finally:
             shutdown(loop)
             )
         finally:
             shutdown(loop)
-        if verbose or not quiet:
-            out("All done! ✨ 🍰 ✨")
-            click.echo(str(report))
+    if verbose or not quiet:
+        out("All done! ✨ 🍰 ✨")
+        click.echo(str(report))
     ctx.exit(report.return_code)
 
 
     ctx.exit(report.return_code)
 
 
@@ -384,7 +386,7 @@ def reformat_one(
 
 
 async def schedule_formatting(
 
 
 async def schedule_formatting(
-    sources: List[Path],
+    sources: Set[Path],
     line_length: int,
     fast: bool,
     write_back: WriteBack,
     line_length: int,
     fast: bool,
     write_back: WriteBack,
@@ -404,7 +406,7 @@ async def schedule_formatting(
     if write_back != WriteBack.DIFF:
         cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
     if write_back != WriteBack.DIFF:
         cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
-        for src in cached:
+        for src in sorted(cached):
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
@@ -468,8 +470,9 @@ def format_file_in_place(
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
 
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
 
+    then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
     with open(src, "rb") as buf:
-        newline, encoding, src_contents = prepare_input(buf.read())
+        src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
@@ -481,8 +484,9 @@ def format_file_in_place(
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
-        src_name = f"{src}  (original)"
-        dst_name = f"{src}  (formatted)"
+        now = datetime.utcnow()
+        src_name = f"{src}\t{then} +0000"
+        dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
@@ -513,7 +517,8 @@ def format_stdin_to_stdout(
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    newline, encoding, src = prepare_input(sys.stdin.buffer.read())
+    then = datetime.utcnow()
+    src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -523,26 +528,17 @@ def format_stdin_to_stdout(
         return False
 
     finally:
         return False
 
     finally:
+        f = io.TextIOWrapper(
+            sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
+        )
         if write_back == WriteBack.YES:
         if write_back == WriteBack.YES:
-            f = io.TextIOWrapper(
-                sys.stdout.buffer,
-                encoding=encoding,
-                newline=newline,
-                write_through=True,
-            )
             f.write(dst)
             f.write(dst)
-            f.detach()
         elif write_back == WriteBack.DIFF:
         elif write_back == WriteBack.DIFF:
-            src_name = "<stdin>  (original)"
-            dst_name = "<stdin>  (formatted)"
-            f = io.TextIOWrapper(
-                sys.stdout.buffer,
-                encoding=encoding,
-                newline=newline,
-                write_through=True,
-            )
+            now = datetime.utcnow()
+            src_name = f"STDIN\t{then} +0000"
+            dst_name = f"STDOUT\t{now} +0000"
             f.write(diff(src, dst, src_name, dst_name))
             f.write(diff(src, dst, src_name, dst_name))
-            f.detach()
+        f.detach()
 
 
 def format_file_contents(
 
 
 def format_file_contents(
@@ -603,17 +599,21 @@ def format_str(
     return dst_contents
 
 
     return dst_contents
 
 
-def prepare_input(src: bytes) -> Tuple[str, str, str]:
-    """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
+    """Return a tuple of (decoded_contents, encoding, newline).
 
 
-    Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
-    universal newlines (i.e. only LF).
+    `newline` is either CRLF or LF but `decoded_contents` is decoded with
+    universal newlines (i.e. only contains LF).
     """
     srcbuf = io.BytesIO(src)
     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
     """
     srcbuf = io.BytesIO(src)
     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    if not lines:
+        return "", encoding, "\n"
+
     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
     srcbuf.seek(0)
     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
     srcbuf.seek(0)
-    return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+    with io.TextIOWrapper(srcbuf, encoding) as tiow:
+        return tiow.read(), encoding, newline
 
 
 GRAMMARS = [
 
 
 GRAMMARS = [
@@ -626,7 +626,7 @@ GRAMMARS = [
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    if src_txt[-1] != "\n":
+    if src_txt[-1:] != "\n":
         src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
@@ -771,6 +771,7 @@ UNPACKING_PARENTS = {
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
+    syms.testlist_star_expr,
 }
 TEST_DESCENDANTS = {
     syms.test,
 }
 TEST_DESCENDANTS = {
     syms.test,
@@ -2212,11 +2213,14 @@ def right_hand_split(
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
-    bracket_split_succeeded_or_raise(head, body, tail)
     assert opening_bracket and closing_bracket
     assert opening_bracket and closing_bracket
+    body.should_explode = should_explode(body, opening_bracket)
+    bracket_split_succeeded_or_raise(head, body, tail)
     if (
     if (
+        # the body shouldn't be exploded
+        not body.should_explode
         # the opening bracket is an optional paren
         # the opening bracket is an optional paren
-        opening_bracket.type == token.LPAR
+        and opening_bracket.type == token.LPAR
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
@@ -2233,11 +2237,15 @@ def right_hand_split(
                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
                 return
             except CannotSplit:
                 yield from right_hand_split(line, line_length, py36=py36, omit=omit)
                 return
             except CannotSplit:
-                pass
+                if len(body.leaves) == 1 and not is_line_short_enough(
+                    body, line_length=line_length
+                ):
+                    raise CannotSplit(
+                        "Splitting failed, body is still too long and can't be split."
+                    )
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
-    body.should_explode = should_explode(body, opening_bracket)
     for result in (head, body, tail):
         if result:
             yield result
     for result in (head, body, tail):
         if result:
             yield result
@@ -3304,26 +3312,24 @@ def get_cache_info(path: Path) -> CacheInfo:
     return stat.st_mtime, stat.st_size
 
 
     return stat.st_mtime, stat.st_size
 
 
-def filter_cached(
-    cache: Cache, sources: Iterable[Path]
-) -> Tuple[List[Path], List[Path]]:
-    """Split a list of paths into two.
+def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
+    """Split an iterable of paths in `sources` into two sets.
 
 
-    The first list contains paths of files that modified on disk or are not in the
-    cache. The other list contains paths to non-modified files.
+    The first contains paths of files that modified on disk or are not in the
+    cache. The other contains paths to non-modified files.
     """
     """
-    todo, done = [], []
+    todo, done = set(), set()
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
-            todo.append(src)
+            todo.add(src)
         else:
         else:
-            done.append(src)
+            done.add(src)
     return todo, done
 
 
 def write_cache(
     return todo, done
 
 
 def write_cache(
-    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
+    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
 ) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(line_length, mode)
 ) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(line_length, mode)