X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/e5452a6b676c161d01ae0ac6cbb5a7cc4c395745..fac4cf995c16441dc2e37bc002484d85329f06d6:/black.py?ds=sidebyside

diff --git a/black.py b/black.py
index c8c381c..dd7fe39 100644
--- a/black.py
+++ b/black.py
@@ -1,19 +1,20 @@
 import asyncio
-import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
+from datetime import datetime
 from enum import Enum, Flag
-from functools import partial, wraps
+from functools import lru_cache, partial, wraps
 import io
 import keyword
 import logging
 from multiprocessing import Manager
 import os
 from pathlib import Path
+import pickle
 import re
-import tokenize
 import signal
 import sys
+import tokenize
 from typing import (
     Any,
     Callable,
@@ -37,6 +38,7 @@ from typing import (
 from appdirs import user_cache_dir
 from attr import dataclass, Factory
 import click
+import toml
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
@@ -45,7 +47,7 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.5b1"
+__version__ = "18.6b1"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -57,6 +59,7 @@ CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 # types
 FileContent = str
 Encoding = str
+NewLine = str
 Depth = int
 NodeType = int
 LeafID = int
@@ -154,6 +157,40 @@ class FileMode(Flag):
         return mode
 
 
+def read_pyproject_toml(
+    ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
+) -> Optional[str]:
+    """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
+
+    Returns the path to a successfully found and read configuration file, None
+    otherwise.
+    """
+    assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
+    if not value:
+        root = find_project_root(ctx.params.get("src", ()))
+        path = root / "pyproject.toml"
+        if path.is_file():
+            value = str(path)
+        else:
+            return None
+
+    try:
+        pyproject_toml = toml.load(value)
+        config = pyproject_toml.get("tool", {}).get("black", {})
+    except (toml.TomlDecodeError, OSError) as e:
+        raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
+
+    if not config:
+        return None
+
+    if ctx.default_map is None:
+        ctx.default_map = {}
+    ctx.default_map.update(  # type: ignore  # bad types in .pyi
+        {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+    )
+    return value
+
+
 @click.command()
 @click.option(
     "-l",
@@ -255,6 +292,16 @@ class FileMode(Flag):
     type=click.Path(
         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
     ),
+    is_eager=True,
+)
+@click.option(
+    "--config",
+    type=click.Path(
+        exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
+    ),
+    is_eager=True,
+    callback=read_pyproject_toml,
+    help="Read configuration from PATH.",
 )
 @click.pass_context
 def main(
@@ -270,46 +317,48 @@ def main(
     verbose: bool,
     include: str,
     exclude: str,
-    src: List[str],
+    src: Tuple[str],
+    config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     mode = FileMode.from_configuration(
         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
     )
-    report = Report(check=check, quiet=quiet, verbose=verbose)
-    sources: List[Path] = []
+    if config and verbose:
+        out(f"Using configuration from {config}.", bold=False, fg="blue")
     try:
-        include_regex = re.compile(include)
+        include_regex = re_compile_maybe_verbose(include)
     except re.error:
         err(f"Invalid regular expression for include given: {include!r}")
         ctx.exit(2)
     try:
-        exclude_regex = re.compile(exclude)
+        exclude_regex = re_compile_maybe_verbose(exclude)
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
+    report = Report(check=check, quiet=quiet, verbose=verbose)
     root = find_project_root(src)
+    sources: Set[Path] = set()
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(
+            sources.update(
                 gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
             )
         elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
-            sources.append(p)
+            sources.add(p)
         else:
             err(f"invalid path: {s}")
     if len(sources) == 0:
         if verbose or not quiet:
             out("No paths given. Nothing to do 😴")
         ctx.exit(0)
-        return
 
-    elif len(sources) == 1:
+    if len(sources) == 1:
         reformat_one(
-            src=sources[0],
+            src=sources.pop(),
             line_length=line_length,
             fast=fast,
             write_back=write_back,
@@ -334,9 +383,10 @@ def main(
             )
         finally:
             shutdown(loop)
-        if verbose or not quiet:
-            out("All done! ✨ 🍰 ✨")
-            click.echo(str(report))
+    if verbose or not quiet:
+        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
+        out(f"All done! {bang}")
+        click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
 
@@ -384,7 +434,7 @@ def reformat_one(
 
 
 async def schedule_formatting(
-    sources: List[Path],
+    sources: Set[Path],
     line_length: int,
     fast: bool,
     write_back: WriteBack,
@@ -404,7 +454,7 @@ async def schedule_formatting(
     if write_back != WriteBack.DIFF:
         cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
-        for src in cached:
+        for src in sorted(cached):
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
@@ -468,8 +518,9 @@ def format_file_in_place(
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
 
+    then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
-        newline, encoding, src_contents = prepare_input(buf.read())
+        src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
@@ -481,8 +532,9 @@ def format_file_in_place(
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
-        src_name = f"{src}  (original)"
-        dst_name = f"{src}  (formatted)"
+        now = datetime.utcnow()
+        src_name = f"{src}\t{then} +0000"
+        dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
@@ -513,7 +565,8 @@ def format_stdin_to_stdout(
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    newline, encoding, src = prepare_input(sys.stdin.buffer.read())
+    then = datetime.utcnow()
+    src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -523,26 +576,17 @@ def format_stdin_to_stdout(
         return False
 
     finally:
+        f = io.TextIOWrapper(
+            sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
+        )
         if write_back == WriteBack.YES:
-            f = io.TextIOWrapper(
-                sys.stdout.buffer,
-                encoding=encoding,
-                newline=newline,
-                write_through=True,
-            )
             f.write(dst)
-            f.detach()
         elif write_back == WriteBack.DIFF:
-            src_name = "<stdin>  (original)"
-            dst_name = "<stdin>  (formatted)"
-            f = io.TextIOWrapper(
-                sys.stdout.buffer,
-                encoding=encoding,
-                newline=newline,
-                write_through=True,
-            )
+            now = datetime.utcnow()
+            src_name = f"STDIN\t{then} +0000"
+            dst_name = f"STDOUT\t{now} +0000"
             f.write(diff(src, dst, src_name, dst_name))
-            f.detach()
+        f.detach()
 
 
 def format_file_contents(
@@ -603,17 +647,21 @@ def format_str(
     return dst_contents
 
 
-def prepare_input(src: bytes) -> Tuple[str, str, str]:
-    """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
+    """Return a tuple of (decoded_contents, encoding, newline).
 
-    Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
-    universal newlines (i.e. only LF).
+    `newline` is either CRLF or LF but `decoded_contents` is decoded with
+    universal newlines (i.e. only contains LF).
     """
     srcbuf = io.BytesIO(src)
     encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    if not lines:
+        return "", encoding, "\n"
+
     newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
     srcbuf.seek(0)
-    return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+    with io.TextIOWrapper(srcbuf, encoding) as tiow:
+        return tiow.read(), encoding, newline
 
 
 GRAMMARS = [
@@ -626,7 +674,7 @@ GRAMMARS = [
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    if src_txt[-1] != "\n":
+    if src_txt[-1:] != "\n":
         src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
@@ -771,6 +819,7 @@ UNPACKING_PARENTS = {
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
+    syms.testlist_star_expr,
 }
 TEST_DESCENDANTS = {
     syms.test,
@@ -1093,6 +1142,13 @@ class Line:
 
         return False
 
+    def contains_multiline_strings(self) -> bool:
+        for leaf in self.leaves:
+            if is_multiline_string(leaf):
+                return True
+
+        return False
+
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
         if not (
@@ -1826,7 +1882,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
                 return NO
 
-        elif t == token.NAME or t == token.NUMBER:
+        elif t in {token.NAME, token.NUMBER, token.STRING}:
             return NO
 
     elif p.type == syms.import_from:
@@ -2212,32 +2268,50 @@ def right_hand_split(
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
-    bracket_split_succeeded_or_raise(head, body, tail)
     assert opening_bracket and closing_bracket
+    body.should_explode = should_explode(body, opening_bracket)
+    bracket_split_succeeded_or_raise(head, body, tail)
     if (
+        # the body shouldn't be exploded
+        not body.should_explode
         # the opening bracket is an optional paren
-        opening_bracket.type == token.LPAR
+        and opening_bracket.type == token.LPAR
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
-        # there are no standalone comments in the body
-        and not line.contains_standalone_comments(0)
-        # and it's not an import (optional parens are the only thing we can split
-        # on in this case; attempting a split without them is a waste of time)
+        # it's not an import (optional parens are the only thing we can split on
+        # in this case; attempting a split without them is a waste of time)
         and not line.is_import
+        # there are no standalone comments in the body
+        and not body.contains_standalone_comments(0)
+        # and we can actually remove the parens
+        and can_omit_invisible_parens(body, line_length)
     ):
         omit = {id(closing_bracket), *omit}
-        if can_omit_invisible_parens(body, line_length):
-            try:
-                yield from right_hand_split(line, line_length, py36=py36, omit=omit)
-                return
-            except CannotSplit:
-                pass
+        try:
+            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            return
+
+        except CannotSplit:
+            if not (
+                can_be_split(body)
+                or is_line_short_enough(body, line_length=line_length)
+            ):
+                raise CannotSplit(
+                    "Splitting failed, body is still too long and can't be split."
+                )
+
+            elif head.contains_multiline_strings() or tail.contains_multiline_strings():
+                raise CannotSplit(
+                    "The current optional pair of parentheses is bound to fail to "
+                    "satisfy the splitting algorithm because the head or the tail "
+                    "contains multiline strings which by definition never fit one "
+                    "line."
+                )
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
-    body.should_explode = should_explode(body, opening_bracket)
     for result in (head, body, tail):
         if result:
             yield result
@@ -2535,7 +2609,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
-    """If it's safe, make the parens in the atom `node` invisible, recusively."""
+    """If it's safe, make the parens in the atom `node` invisible, recursively."""
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -2867,7 +2941,7 @@ def gen_python_files_in_dir(
             normalized_path += "/"
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
-            report.path_ignored(child, f"matches --exclude={exclude.pattern}")
+            report.path_ignored(child, f"matches the --exclude regular expression")
             continue
 
         if child.is_dir():
@@ -2879,7 +2953,8 @@ def gen_python_files_in_dir(
                 yield child
 
 
-def find_project_root(srcs: List[str]) -> Path:
+@lru_cache()
+def find_project_root(srcs: Iterable[str]) -> Path:
     """Return a directory containing .git, .hg, or pyproject.toml.
 
     That directory can be one of the directories passed in `srcs` or their
@@ -3137,6 +3212,16 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
     return regex.sub(replacement, regex.sub(replacement, original))
 
 
+def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
+    """Compile a regular expression string in `regex`.
+
+    If it contains newlines, use verbose mode.
+    """
+    if "\n" in regex:
+        regex = "(?x)" + regex
+    return re.compile(regex)
+
+
 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
     """Like `reversed(enumerate(sequence))` if that were possible."""
     index = len(sequence) - 1
@@ -3182,6 +3267,42 @@ def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") ->
     )
 
 
+def can_be_split(line: Line) -> bool:
+    """Return False if the line cannot be split *for sure*.
+
+    This is not an exhaustive search but a cheap heuristic that we can use to
+    avoid some unfortunate formattings (mostly around wrapping unsplittable code
+    in unnecessary parentheses).
+    """
+    leaves = line.leaves
+    if len(leaves) < 2:
+        return False
+
+    if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
+        call_count = 0
+        dot_count = 0
+        next = leaves[-1]
+        for leaf in leaves[-2::-1]:
+            if leaf.type in OPENING_BRACKETS:
+                if next.type not in CLOSING_BRACKETS:
+                    return False
+
+                call_count += 1
+            elif leaf.type == token.DOT:
+                dot_count += 1
+            elif leaf.type == token.NAME:
+                if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
+                    return False
+
+            elif leaf.type not in CLOSING_BRACKETS:
+                return False
+
+            if dot_count > 1 and call_count > 1:
+                return False
+
+    return True
+
+
 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     """Does `line` have a shape safe to reformat without optional parens around it?
 
@@ -3272,12 +3393,7 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
 
 
 def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36)
-    return (
-        CACHE_DIR
-        / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
-    )
+    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
 
 
 def read_cache(line_length: int, mode: FileMode) -> Cache:
@@ -3304,26 +3420,24 @@ def get_cache_info(path: Path) -> CacheInfo:
     return stat.st_mtime, stat.st_size
 
 
-def filter_cached(
-    cache: Cache, sources: Iterable[Path]
-) -> Tuple[List[Path], List[Path]]:
-    """Split a list of paths into two.
+def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
+    """Split an iterable of paths in `sources` into two sets.
 
-    The first list contains paths of files that modified on disk or are not in the
-    cache. The other list contains paths to non-modified files.
+    The first contains paths of files that modified on disk or are not in the
+    cache. The other contains paths to non-modified files.
     """
-    todo, done = [], []
+    todo, done = set(), set()
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
-            todo.append(src)
+            todo.add(src)
         else:
-            done.append(src)
+            done.add(src)
     return todo, done
 
 
 def write_cache(
-    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
+    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
 ) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(line_length, mode)