]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix unnecessary parentheses when a line contains multiline strings
[etc/vim.git] / black.py
index da00525ca945e6443aa61ffe82a1f93ad0061e71..551d3c1dbdaf9c82f38a48729b78aed92c6c8079 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,18 +1,20 @@
 import asyncio
-import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
+from datetime import datetime
 from enum import Enum, Flag
 from functools import partial, wraps
+import io
 import keyword
 import logging
 from multiprocessing import Manager
 import os
 from pathlib import Path
+import pickle
 import re
-import tokenize
 import signal
 import sys
+import tokenize
 from typing import (
     Any,
     Callable,
@@ -56,6 +58,7 @@ CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 # types
 FileContent = str
 Encoding = str
+NewLine = str
 Depth = int
 NodeType = int
 LeafID = int
@@ -119,6 +122,13 @@ class WriteBack(Enum):
     YES = 1
     DIFF = 2
 
+    @classmethod
+    def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
+        if check and not diff:
+            return cls.NO
+
+        return cls.DIFF if diff else cls.YES
+
 
 class Changed(Enum):
     NO = 0
@@ -132,6 +142,19 @@ class FileMode(Flag):
     PYI = 2
     NO_STRING_NORMALIZATION = 4
 
+    @classmethod
+    def from_configuration(
+        cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
+    ) -> "FileMode":
+        mode = cls.AUTO_DETECT
+        if py36:
+            mode |= cls.PYTHON36
+        if pyi:
+            mode |= cls.PYI
+        if skip_string_normalization:
+            mode |= cls.NO_STRING_NORMALIZATION
+        return mode
+
 
 @click.command()
 @click.option(
@@ -218,6 +241,15 @@ class FileMode(Flag):
         "silence those with 2>/dev/null."
     ),
 )
+@click.option(
+    "-v",
+    "--verbose",
+    is_flag=True,
+    help=(
+        "Also emit messages to stderr about files that were not changed or were "
+        "ignored due to --exclude=."
+    ),
+)
 @click.version_option(version=__version__)
 @click.argument(
     "src",
@@ -237,12 +269,18 @@ def main(
     py36: bool,
     skip_string_normalization: bool,
     quiet: bool,
+    verbose: bool,
     include: str,
     exclude: str,
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
-    sources: List[Path] = []
+    write_back = WriteBack.from_configuration(check=check, diff=diff)
+    mode = FileMode.from_configuration(
+        py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
+    )
+    report = Report(check=check, quiet=quiet, verbose=verbose)
+    sources: Set[Path] = set()
     try:
         include_regex = re.compile(include)
     except re.error:
@@ -253,40 +291,27 @@ def main(
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
+    root = find_project_root(src)
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(gen_python_files_in_dir(p, include_regex, exclude_regex))
-        elif p.is_file():
+            sources.update(
+                gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
+            )
+        elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
-            sources.append(p)
-        elif s == "-":
-            sources.append(Path("-"))
+            sources.add(p)
         else:
             err(f"invalid path: {s}")
-
-    if check and not diff:
-        write_back = WriteBack.NO
-    elif diff:
-        write_back = WriteBack.DIFF
-    else:
-        write_back = WriteBack.YES
-    mode = FileMode.AUTO_DETECT
-    if py36:
-        mode |= FileMode.PYTHON36
-    if pyi:
-        mode |= FileMode.PYI
-    if skip_string_normalization:
-        mode |= FileMode.NO_STRING_NORMALIZATION
-    report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
-        out("No paths given. Nothing to do 😴")
+        if verbose or not quiet:
+            out("No paths given. Nothing to do 😴")
         ctx.exit(0)
         return
 
     elif len(sources) == 1:
         reformat_one(
-            src=sources[0],
+            src=sources.pop(),
             line_length=line_length,
             fast=fast,
             write_back=write_back,
@@ -311,9 +336,9 @@ def main(
             )
         finally:
             shutdown(loop)
-        if not quiet:
-            out("All done! ✨ 🍰 ✨")
-            click.echo(str(report))
+    if verbose or not quiet:
+        out("All done! ✨ 🍰 ✨")
+        click.echo(str(report))
     ctx.exit(report.return_code)
 
 
@@ -361,7 +386,7 @@ def reformat_one(
 
 
 async def schedule_formatting(
-    sources: List[Path],
+    sources: Set[Path],
     line_length: int,
     fast: bool,
     write_back: WriteBack,
@@ -381,7 +406,7 @@ async def schedule_formatting(
     if write_back != WriteBack.DIFF:
         cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
-        for src in cached:
+        for src in sorted(cached):
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
@@ -444,8 +469,10 @@ def format_file_in_place(
     """
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
+
+    then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    with open(src, "rb") as buf:
+        src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
@@ -454,16 +481,24 @@ def format_file_in_place(
         return False
 
     if write_back == write_back.YES:
-        with open(src, "w", encoding=src_buffer.encoding) as f:
+        with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
-        src_name = f"{src}  (original)"
-        dst_name = f"{src}  (formatted)"
+        now = datetime.utcnow()
+        src_name = f"{src}\t{then} +0000"
+        dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
         try:
-            sys.stdout.write(diff_contents)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff_contents)
+            f.detach()
         finally:
             if lock:
                 lock.release()
@@ -482,7 +517,8 @@ def format_stdin_to_stdout(
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    src = sys.stdin.read()
+    then = datetime.utcnow()
+    src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -492,12 +528,17 @@ def format_stdin_to_stdout(
         return False
 
     finally:
+        f = io.TextIOWrapper(
+            sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
+        )
         if write_back == WriteBack.YES:
-            sys.stdout.write(dst)
+            f.write(dst)
         elif write_back == WriteBack.DIFF:
-            src_name = "<stdin>  (original)"
-            dst_name = "<stdin>  (formatted)"
-            sys.stdout.write(diff(src, dst, src_name, dst_name))
+            now = datetime.utcnow()
+            src_name = f"STDIN\t{then} +0000"
+            dst_name = f"STDOUT\t{now} +0000"
+            f.write(diff(src, dst, src_name, dst_name))
+        f.detach()
 
 
 def format_file_contents(
@@ -558,6 +599,23 @@ def format_str(
     return dst_contents
 
 
+def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
+    """Return a tuple of (decoded_contents, encoding, newline).
+
+    `newline` is either CRLF or LF but `decoded_contents` is decoded with
+    universal newlines (i.e. only contains LF).
+    """
+    srcbuf = io.BytesIO(src)
+    encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    if not lines:
+        return "", encoding, "\n"
+
+    newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
+    srcbuf.seek(0)
+    with io.TextIOWrapper(srcbuf, encoding) as tiow:
+        return tiow.read(), encoding, newline
+
+
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
@@ -568,9 +626,8 @@ GRAMMARS = [
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    if src_txt[-1] != "\n":
-        nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
-        src_txt += nl
+    if src_txt[-1:] != "\n":
+        src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
@@ -714,6 +771,7 @@ UNPACKING_PARENTS = {
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
+    syms.testlist_star_expr,
 }
 TEST_DESCENDANTS = {
     syms.test,
@@ -1036,6 +1094,13 @@ class Line:
 
         return False
 
+    def contains_multiline_strings(self) -> bool:
+        for leaf in self.leaves:
+            if is_multiline_string(leaf):
+                return True
+
+        return False
+
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
         if not (
@@ -2155,32 +2220,50 @@ def right_hand_split(
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
-    bracket_split_succeeded_or_raise(head, body, tail)
     assert opening_bracket and closing_bracket
+    body.should_explode = should_explode(body, opening_bracket)
+    bracket_split_succeeded_or_raise(head, body, tail)
     if (
+        # the body shouldn't be exploded
+        not body.should_explode
         # the opening bracket is an optional paren
-        opening_bracket.type == token.LPAR
+        and opening_bracket.type == token.LPAR
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
-        # there are no standalone comments in the body
-        and not line.contains_standalone_comments(0)
-        # and it's not an import (optional parens are the only thing we can split
-        # on in this case; attempting a split without them is a waste of time)
+        # it's not an import (optional parens are the only thing we can split on
+        # in this case; attempting a split without them is a waste of time)
         and not line.is_import
+        # there are no standalone comments in the body
+        and not body.contains_standalone_comments(0)
+        # and we can actually remove the parens
+        and can_omit_invisible_parens(body, line_length)
     ):
         omit = {id(closing_bracket), *omit}
-        if can_omit_invisible_parens(body, line_length):
-            try:
-                yield from right_hand_split(line, line_length, py36=py36, omit=omit)
-                return
-            except CannotSplit:
-                pass
+        try:
+            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            return
+
+        except CannotSplit:
+            if not (
+                can_be_split(body)
+                or is_line_short_enough(body, line_length=line_length)
+            ):
+                raise CannotSplit(
+                    "Splitting failed, body is still too long and can't be split."
+                )
+
+            elif head.contains_multiline_strings() or tail.contains_multiline_strings():
+                raise CannotSplit(
+                    "The current optional pair of parentheses is bound to fail to "
+                    "satisfy the splitting algorithm becase the head or the tail "
+                    "contains multiline strings which by definition never fit one "
+                    "line."
+                )
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
-    body.should_explode = should_explode(body, opening_bracket)
     for result in (head, body, tail):
         if result:
             yield result
@@ -2792,21 +2875,29 @@ def get_future_imports(node: Node) -> Set[str]:
 
 
 def gen_python_files_in_dir(
-    path: Path, include: Pattern[str], exclude: Pattern[str]
+    path: Path,
+    root: Path,
+    include: Pattern[str],
+    exclude: Pattern[str],
+    report: "Report",
 ) -> Iterator[Path]:
     """Generate all files under `path` whose paths are not excluded by the
     `exclude` regex, but are included by the `include` regex.
+
+    `report` is where output about exclusions goes.
     """
+    assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
     for child in path.iterdir():
-        normalized_path = child.resolve().as_posix()
+        normalized_path = "/" + child.resolve().relative_to(root).as_posix()
         if child.is_dir():
             normalized_path += "/"
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
+            report.path_ignored(child, f"matches --exclude={exclude.pattern}")
             continue
 
         if child.is_dir():
-            yield from gen_python_files_in_dir(child, include, exclude)
+            yield from gen_python_files_in_dir(child, root, include, exclude, report)
 
         elif child.is_file():
             include_match = include.search(normalized_path)
@@ -2814,12 +2905,42 @@ def gen_python_files_in_dir(
                 yield child
 
 
+def find_project_root(srcs: List[str]) -> Path:
+    """Return a directory containing .git, .hg, or pyproject.toml.
+
+    That directory can be one of the directories passed in `srcs` or their
+    common parent.
+
+    If no directory in the tree contains a marker that would specify it's the
+    project root, the root of the file system is returned.
+    """
+    if not srcs:
+        return Path("/").resolve()
+
+    common_base = min(Path(src).resolve() for src in srcs)
+    if common_base.is_dir():
+        # Append a fake file so `parents` below returns `common_base_dir`, too.
+        common_base /= "fake-file"
+    for directory in common_base.parents:
+        if (directory / ".git").is_dir():
+            return directory
+
+        if (directory / ".hg").is_dir():
+            return directory
+
+        if (directory / "pyproject.toml").is_file():
+            return directory
+
+    return directory
+
+
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
     check: bool = False
     quiet: bool = False
+    verbose: bool = False
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
@@ -2828,11 +2949,11 @@ class Report:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
             reformatted = "would reformat" if self.check else "reformatted"
-            if not self.quiet:
+            if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
         else:
-            if not self.quiet:
+            if self.verbose:
                 if changed is Changed.NO:
                     msg = f"{src} already well formatted, good job."
                 else:
@@ -2845,6 +2966,10 @@ class Report:
         err(f"error: cannot format {src}: {message}")
         self.failure_count += 1
 
+    def path_ignored(self, path: Path, message: str) -> None:
+        if self.verbose:
+            out(f"{path} ignored: {message}", bold=False)
+
     @property
     def return_code(self) -> int:
         """Return the exit code that the app should use.
@@ -3083,6 +3208,42 @@ def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") ->
     )
 
 
+def can_be_split(line: Line) -> bool:
+    """Return False if the line cannot be split *for sure*.
+
+    This is not an exhaustive search but a cheap heuristic that we can use to
+    avoid some unfortunate formattings (mostly around wrapping unsplittable code
+    in unnecessary parentheses).
+    """
+    leaves = line.leaves
+    if len(leaves) < 2:
+        return False
+
+    if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
+        call_count = 0
+        dot_count = 0
+        next = leaves[-1]
+        for leaf in leaves[-2::-1]:
+            if leaf.type in OPENING_BRACKETS:
+                if next.type not in CLOSING_BRACKETS:
+                    return False
+
+                call_count += 1
+            elif leaf.type == token.DOT:
+                dot_count += 1
+            elif leaf.type == token.NAME:
+                if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
+                    return False
+
+            elif leaf.type not in CLOSING_BRACKETS:
+                return False
+
+            if dot_count > 1 and call_count > 1:
+                return False
+
+    return True
+
+
 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     """Does `line` have a shape safe to reformat without optional parens around it?
 
@@ -3205,26 +3366,24 @@ def get_cache_info(path: Path) -> CacheInfo:
     return stat.st_mtime, stat.st_size
 
 
-def filter_cached(
-    cache: Cache, sources: Iterable[Path]
-) -> Tuple[List[Path], List[Path]]:
-    """Split a list of paths into two.
+def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
+    """Split an iterable of paths in `sources` into two sets.
 
-    The first list contains paths of files that modified on disk or are not in the
-    cache. The other list contains paths to non-modified files.
+    The first contains paths of files that modified on disk or are not in the
+    cache. The other contains paths to non-modified files.
     """
-    todo, done = [], []
+    todo, done = set(), set()
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
-            todo.append(src)
+            todo.add(src)
         else:
-            done.append(src)
+            done.add(src)
     return todo, done
 
 
 def write_cache(
-    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
+    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
 ) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(line_length, mode)