]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix string normalization eating all backslashes above 3
[etc/vim.git] / black.py
index 19a023cff52db2e92d5cc0950125e20c96d24dff..9d9bada604951bef5bc9afca90033a31411c260e 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,19 +1,20 @@
 import asyncio
-import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
+from datetime import datetime
 from enum import Enum, Flag
-from functools import partial, wraps
+from functools import lru_cache, partial, wraps
 import io
 import keyword
 import logging
 from multiprocessing import Manager
 import os
 from pathlib import Path
+import pickle
 import re
-import tokenize
 import signal
 import sys
+import tokenize
 from typing import (
     Any,
     Callable,
@@ -37,6 +38,7 @@ from typing import (
 from appdirs import user_cache_dir
 from attr import dataclass, Factory
 import click
+import toml
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
@@ -45,7 +47,7 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.5b1"
+__version__ = "18.6b2"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -57,6 +59,7 @@ CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 # types
 FileContent = str
 Encoding = str
+NewLine = str
 Depth = int
 NodeType = int
 LeafID = int
@@ -154,7 +157,41 @@ class FileMode(Flag):
         return mode
 
 
-@click.command()
+def read_pyproject_toml(
+    ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
+) -> Optional[str]:
+    """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
+
+    Returns the path to a successfully found and read configuration file, None
+    otherwise.
+    """
+    assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
+    if not value:
+        root = find_project_root(ctx.params.get("src", ()))
+        path = root / "pyproject.toml"
+        if path.is_file():
+            value = str(path)
+        else:
+            return None
+
+    try:
+        pyproject_toml = toml.load(value)
+        config = pyproject_toml.get("tool", {}).get("black", {})
+    except (toml.TomlDecodeError, OSError) as e:
+        raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
+
+    if not config:
+        return None
+
+    if ctx.default_map is None:
+        ctx.default_map = {}
+    ctx.default_map.update(  # type: ignore  # bad types in .pyi
+        {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+    )
+    return value
+
+
+@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option(
     "-l",
     "--line-length",
@@ -255,6 +292,16 @@ class FileMode(Flag):
     type=click.Path(
         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
     ),
+    is_eager=True,
+)
+@click.option(
+    "--config",
+    type=click.Path(
+        exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
+    ),
+    is_eager=True,
+    callback=read_pyproject_toml,
+    help="Read configuration from PATH.",
 )
 @click.pass_context
 def main(
@@ -270,26 +317,29 @@ def main(
     verbose: bool,
     include: str,
     exclude: str,
-    src: List[str],
+    src: Tuple[str],
+    config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     mode = FileMode.from_configuration(
         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
     )
-    report = Report(check=check, quiet=quiet, verbose=verbose)
-    sources: Set[Path] = set()
+    if config and verbose:
+        out(f"Using configuration from {config}.", bold=False, fg="blue")
     try:
-        include_regex = re.compile(include)
+        include_regex = re_compile_maybe_verbose(include)
     except re.error:
         err(f"Invalid regular expression for include given: {include!r}")
         ctx.exit(2)
     try:
-        exclude_regex = re.compile(exclude)
+        exclude_regex = re_compile_maybe_verbose(exclude)
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
+    report = Report(check=check, quiet=quiet, verbose=verbose)
     root = find_project_root(src)
+    sources: Set[Path] = set()
     for s in src:
         p = Path(s)
         if p.is_dir():
@@ -305,9 +355,8 @@ def main(
         if verbose or not quiet:
             out("No paths given. Nothing to do 😴")
         ctx.exit(0)
-        return
 
-    elif len(sources) == 1:
+    if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
             line_length=line_length,
@@ -335,8 +384,9 @@ def main(
         finally:
             shutdown(loop)
     if verbose or not quiet:
-        out("All done! ✨ 🍰 ✨")
-        click.echo(str(report))
+        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
+        out(f"All done! {bang}")
+        click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
 
@@ -468,8 +518,9 @@ def format_file_in_place(
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
 
+    then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
-        newline, encoding, src_contents = prepare_input(buf.read())
+        src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
@@ -481,8 +532,9 @@ def format_file_in_place(
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
-        src_name = f"{src}  (original)"
-        dst_name = f"{src}  (formatted)"
+        now = datetime.utcnow()
+        src_name = f"{src}\t{then} +0000"
+        dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
@@ -513,7 +565,8 @@ def format_stdin_to_stdout(
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    newline, encoding, src = prepare_input(sys.stdin.buffer.read())
+    then = datetime.utcnow()
+    src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -523,26 +576,17 @@ def format_stdin_to_stdout(
         return False
 
     finally:
+        f = io.TextIOWrapper(
+            sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
+        )
         if write_back == WriteBack.YES:
-            f = io.TextIOWrapper(
-                sys.stdout.buffer,
-                encoding=encoding,
-                newline=newline,
-                write_through=True,
-            )
             f.write(dst)
-            f.detach()
         elif write_back == WriteBack.DIFF:
-            src_name = "<stdin>  (original)"
-            dst_name = "<stdin>  (formatted)"
-            f = io.TextIOWrapper(
-                sys.stdout.buffer,
-                encoding=encoding,
-                newline=newline,
-                write_through=True,
-            )
+            now = datetime.utcnow()
+            src_name = f"STDIN\t{then} +0000"
+            dst_name = f"STDOUT\t{now} +0000"
             f.write(diff(src, dst, src_name, dst_name))
-            f.detach()
+        f.detach()
 
 
 def format_file_contents(
@@ -603,17 +647,21 @@ def format_str(
     return dst_contents
 
 
-def prepare_input(src: bytes) -> Tuple[str, str, str]:
-    """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
+    """Return a tuple of (decoded_contents, encoding, newline).
 
-    Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
-    universal newlines (i.e. only LF).
+    `newline` is either CRLF or LF but `decoded_contents` is decoded with
+    universal newlines (i.e. only contains LF).
     """
     srcbuf = io.BytesIO(src)
     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    if not lines:
+        return "", encoding, "\n"
+
     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
     srcbuf.seek(0)
-    return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+    with io.TextIOWrapper(srcbuf, encoding) as tiow:
+        return tiow.read(), encoding, newline
 
 
 GRAMMARS = [
@@ -626,7 +674,7 @@ GRAMMARS = [
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    if src_txt[-1] != "\n":
+    if src_txt[-1:] != "\n":
         src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
@@ -771,6 +819,7 @@ UNPACKING_PARENTS = {
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
+    syms.testlist_star_expr,
 }
 TEST_DESCENDANTS = {
     syms.test,
@@ -1093,6 +1142,13 @@ class Line:
 
         return False
 
+    def contains_multiline_strings(self) -> bool:
+        for leaf in self.leaves:
+            if is_multiline_string(leaf):
+                return True
+
+        return False
+
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
         if not (
@@ -1174,6 +1230,9 @@ class Line:
 
         Provide a non-negative leaf `_index` to speed up the function.
         """
+        if not self.comments:
+            return
+
         if _index == -1:
             for _index, _leaf in enumerate(self.leaves):
                 if leaf is _leaf:
@@ -1197,18 +1256,18 @@ class Line:
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
-        open_lsqb = (
-            leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
-        )
+        open_lsqb = self.bracket_tracker.get_open_lsqb()
         if open_lsqb is None:
             return False
 
         subscript_start = open_lsqb.next_sibling
-        if (
-            isinstance(subscript_start, Node)
-            and subscript_start.type == syms.subscriptlist
-        ):
-            subscript_start = child_towards(subscript_start, leaf)
+
+        if isinstance(subscript_start, Node):
+            if subscript_start.type == syms.listmaker:
+                return False
+
+            if subscript_start.type == syms.subscriptlist:
+                subscript_start = child_towards(subscript_start, leaf)
         return subscript_start is not None and any(
             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
         )
@@ -1826,7 +1885,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
                 return NO
 
-        elif t == token.NAME or t == token.NUMBER:
+        elif t in {token.NAME, token.NUMBER, token.STRING}:
             return NO
 
     elif p.type == syms.import_from:
@@ -2212,32 +2271,50 @@ def right_hand_split(
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
-    bracket_split_succeeded_or_raise(head, body, tail)
     assert opening_bracket and closing_bracket
+    body.should_explode = should_explode(body, opening_bracket)
+    bracket_split_succeeded_or_raise(head, body, tail)
     if (
+        # the body shouldn't be exploded
+        not body.should_explode
         # the opening bracket is an optional paren
-        opening_bracket.type == token.LPAR
+        and opening_bracket.type == token.LPAR
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
-        # there are no standalone comments in the body
-        and not line.contains_standalone_comments(0)
-        # and it's not an import (optional parens are the only thing we can split
-        # on in this case; attempting a split without them is a waste of time)
+        # it's not an import (optional parens are the only thing we can split on
+        # in this case; attempting a split without them is a waste of time)
         and not line.is_import
+        # there are no standalone comments in the body
+        and not body.contains_standalone_comments(0)
+        # and we can actually remove the parens
+        and can_omit_invisible_parens(body, line_length)
     ):
         omit = {id(closing_bracket), *omit}
-        if can_omit_invisible_parens(body, line_length):
-            try:
-                yield from right_hand_split(line, line_length, py36=py36, omit=omit)
-                return
-            except CannotSplit:
-                pass
+        try:
+            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            return
+
+        except CannotSplit:
+            if not (
+                can_be_split(body)
+                or is_line_short_enough(body, line_length=line_length)
+            ):
+                raise CannotSplit(
+                    "Splitting failed, body is still too long and can't be split."
+                )
+
+            elif head.contains_multiline_strings() or tail.contains_multiline_strings():
+                raise CannotSplit(
+                    "The current optional pair of parentheses is bound to fail to "
+                    "satisfy the splitting algorithm because the head or the tail "
+                    "contains multiline strings which by definition never fit one "
+                    "line."
+                )
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
-    body.should_explode = should_explode(body, opening_bracket)
     for result in (head, body, tail):
         if result:
             yield result
@@ -2452,8 +2529,8 @@ def normalize_string_quotes(leaf: Leaf) -> None:
 
     prefix = leaf.value[:first_quote_pos]
     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
-    escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
-    escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
+    escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
+    escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
     if "r" in prefix.casefold():
         if unescaped_new_quote.search(body):
@@ -2464,15 +2541,21 @@ def normalize_string_quotes(leaf: Leaf) -> None:
         # Do not introduce or remove backslashes in raw strings
         new_body = body
     else:
-        # remove unnecessary quotes
+        # remove unnecessary escapes
         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
         if body != new_body:
-            # Consider the string without unnecessary quotes as the original
+            # Consider the string without unnecessary escapes as the original
             body = new_body
             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
-    if new_quote == '"""' and new_body[-1] == '"':
+    if "f" in prefix.casefold():
+        matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
+        for m in matches:
+            if "\\" in str(m):
+                # Do not introduce backslashes in interpolated expressions
+                return
+    if new_quote == '"""' and new_body[-1:] == '"':
         # edge case:
         new_body = new_body[:-1] + '\\"'
     orig_escape_count = body.count("\\")
@@ -2535,7 +2618,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
-    """If it's safe, make the parens in the atom `node` invisible, recusively."""
+    """If it's safe, make the parens in the atom `node` invisible, recursively."""
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -2867,7 +2950,7 @@ def gen_python_files_in_dir(
             normalized_path += "/"
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
-            report.path_ignored(child, f"matches --exclude={exclude.pattern}")
+            report.path_ignored(child, f"matches the --exclude regular expression")
             continue
 
         if child.is_dir():
@@ -2879,7 +2962,8 @@ def gen_python_files_in_dir(
                 yield child
 
 
-def find_project_root(srcs: List[str]) -> Path:
+@lru_cache()
+def find_project_root(srcs: Iterable[str]) -> Path:
     """Return a directory containing .git, .hg, or pyproject.toml.
 
     That directory can be one of the directories passed in `srcs` or their
@@ -3137,6 +3221,16 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
     return regex.sub(replacement, regex.sub(replacement, original))
 
 
+def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
+    """Compile a regular expression string in `regex`.
+
+    If it contains newlines, use verbose mode.
+    """
+    if "\n" in regex:
+        regex = "(?x)" + regex
+    return re.compile(regex)
+
+
 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
     """Like `reversed(enumerate(sequence))` if that were possible."""
     index = len(sequence) - 1
@@ -3182,6 +3276,42 @@ def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") ->
     )
 
 
+def can_be_split(line: Line) -> bool:
+    """Return False if the line cannot be split *for sure*.
+
+    This is not an exhaustive search but a cheap heuristic that we can use to
+    avoid some unfortunate formattings (mostly around wrapping unsplittable code
+    in unnecessary parentheses).
+    """
+    leaves = line.leaves
+    if len(leaves) < 2:
+        return False
+
+    if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
+        call_count = 0
+        dot_count = 0
+        next = leaves[-1]
+        for leaf in leaves[-2::-1]:
+            if leaf.type in OPENING_BRACKETS:
+                if next.type not in CLOSING_BRACKETS:
+                    return False
+
+                call_count += 1
+            elif leaf.type == token.DOT:
+                dot_count += 1
+            elif leaf.type == token.NAME:
+                if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
+                    return False
+
+            elif leaf.type not in CLOSING_BRACKETS:
+                return False
+
+            if dot_count > 1 and call_count > 1:
+                return False
+
+    return True
+
+
 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     """Does `line` have a shape safe to reformat without optional parens around it?
 
@@ -3272,12 +3402,7 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
 
 
 def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36)
-    return (
-        CACHE_DIR
-        / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
-    )
+    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
 
 
 def read_cache(line_length: int, mode: FileMode) -> Cache: