]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Cache generated comments
[etc/vim.git] / black.py
index cc5075c6666ffca0cc637b652a0550f02aa5a6e1..f51b10ef7d0a282e60a8064f2d871c43739cd90d 100644 (file)
--- a/black.py
+++ b/black.py
@@ -3,7 +3,7 @@ from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from datetime import datetime
 from enum import Enum, Flag
-from functools import partial, wraps
+from functools import lru_cache, partial, wraps
 import io
 import keyword
 import logging
@@ -38,6 +38,7 @@ from typing import (
 from appdirs import user_cache_dir
 from attr import dataclass, Factory
 import click
+import toml
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
@@ -46,7 +47,7 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.6b1"
+__version__ = "18.6b2"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
     r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
@@ -156,7 +157,41 @@ class FileMode(Flag):
         return mode
 
 
-@click.command()
+def read_pyproject_toml(
+    ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
+) -> Optional[str]:
+    """Inject Black configuration from "pyproject.toml" into defaults in `ctx`.
+
+    Returns the path to a successfully found and read configuration file, None
+    otherwise.
+    """
+    assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
+    if not value:
+        root = find_project_root(ctx.params.get("src", ()))
+        path = root / "pyproject.toml"
+        if path.is_file():
+            value = str(path)
+        else:
+            return None
+
+    try:
+        pyproject_toml = toml.load(value)
+        config = pyproject_toml.get("tool", {}).get("black", {})
+    except (toml.TomlDecodeError, OSError) as e:
+        raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
+
+    if not config:
+        return None
+
+    if ctx.default_map is None:
+        ctx.default_map = {}
+    ctx.default_map.update(  # type: ignore  # bad types in .pyi
+        {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+    )
+    return value
+
+
+@click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option(
     "-l",
     "--line-length",
@@ -257,6 +292,16 @@ class FileMode(Flag):
     type=click.Path(
         exists=True, file_okay=True, dir_okay=True, readable=True, allow_dash=True
     ),
+    is_eager=True,
+)
+@click.option(
+    "--config",
+    type=click.Path(
+        exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False
+    ),
+    is_eager=True,
+    callback=read_pyproject_toml,
+    help="Read configuration from PATH.",
 )
 @click.pass_context
 def main(
@@ -272,26 +317,29 @@ def main(
     verbose: bool,
     include: str,
     exclude: str,
-    src: List[str],
+    src: Tuple[str],
+    config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     mode = FileMode.from_configuration(
         py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
     )
-    report = Report(check=check, quiet=quiet, verbose=verbose)
-    sources: Set[Path] = set()
+    if config and verbose:
+        out(f"Using configuration from {config}.", bold=False, fg="blue")
     try:
-        include_regex = re.compile(include)
+        include_regex = re_compile_maybe_verbose(include)
     except re.error:
         err(f"Invalid regular expression for include given: {include!r}")
         ctx.exit(2)
     try:
-        exclude_regex = re.compile(exclude)
+        exclude_regex = re_compile_maybe_verbose(exclude)
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
+    report = Report(check=check, quiet=quiet, verbose=verbose)
     root = find_project_root(src)
+    sources: Set[Path] = set()
     for s in src:
         p = Path(s)
         if p.is_dir():
@@ -307,9 +355,8 @@ def main(
         if verbose or not quiet:
             out("No paths given. Nothing to do 😴")
         ctx.exit(0)
-        return
 
-    elif len(sources) == 1:
+    if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
             line_length=line_length,
@@ -1183,6 +1230,9 @@ class Line:
 
         Provide a non-negative leaf `_index` to speed up the function.
         """
+        if not self.comments:
+            return
+
         if _index == -1:
             for _index, _leaf in enumerate(self.leaves):
                 if leaf is _leaf:
@@ -1206,18 +1256,18 @@ class Line:
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
-        open_lsqb = (
-            leaf if leaf.type == token.LSQB else self.bracket_tracker.get_open_lsqb()
-        )
+        open_lsqb = self.bracket_tracker.get_open_lsqb()
         if open_lsqb is None:
             return False
 
         subscript_start = open_lsqb.next_sibling
-        if (
-            isinstance(subscript_start, Node)
-            and subscript_start.type == syms.subscriptlist
-        ):
-            subscript_start = child_towards(subscript_start, leaf)
+
+        if isinstance(subscript_start, Node):
+            if subscript_start.type == syms.listmaker:
+                return False
+
+            if subscript_start.type == syms.subscriptlist:
+                subscript_start = child_towards(subscript_start, leaf)
         return subscript_start is not None and any(
             n.type in TEST_DESCENDANTS for n in subscript_start.pre_order()
         )
@@ -1340,44 +1390,8 @@ class EmptyLineTracker:
                 before = 0 if depth else 1
             else:
                 before = 1 if depth else 2
-        is_decorator = current_line.is_decorator
-        if is_decorator or current_line.is_def or current_line.is_class:
-            if not is_decorator:
-                self.previous_defs.append(depth)
-            if self.previous_line is None:
-                # Don't insert empty lines before the first line in the file.
-                return 0, 0
-
-            if self.previous_line.is_decorator:
-                return 0, 0
-
-            if self.previous_line.depth < current_line.depth and (
-                self.previous_line.is_class or self.previous_line.is_def
-            ):
-                return 0, 0
-
-            if (
-                self.previous_line.is_comment
-                and self.previous_line.depth == current_line.depth
-                and before == 0
-            ):
-                return 0, 0
-
-            if self.is_pyi:
-                if self.previous_line.depth > current_line.depth:
-                    newlines = 1
-                elif current_line.is_class or self.previous_line.is_class:
-                    if current_line.is_stub_class and self.previous_line.is_stub_class:
-                        newlines = 0
-                    else:
-                        newlines = 1
-                else:
-                    newlines = 0
-            else:
-                newlines = 2
-            if current_line.depth and newlines:
-                newlines -= 1
-            return newlines, 0
+        if current_line.is_decorator or current_line.is_def or current_line.is_class:
+            return self._maybe_empty_lines_for_class_or_def(current_line, before)
 
         if (
             self.previous_line
@@ -1396,6 +1410,50 @@ class EmptyLineTracker:
 
         return before, 0
 
+    def _maybe_empty_lines_for_class_or_def(
+        self, current_line: Line, before: int
+    ) -> Tuple[int, int]:
+        if not current_line.is_decorator:
+            self.previous_defs.append(current_line.depth)
+        if self.previous_line is None:
+            # Don't insert empty lines before the first line in the file.
+            return 0, 0
+
+        if self.previous_line.is_decorator:
+            return 0, 0
+
+        if self.previous_line.depth < current_line.depth and (
+            self.previous_line.is_class or self.previous_line.is_def
+        ):
+            return 0, 0
+
+        if (
+            self.previous_line.is_comment
+            and self.previous_line.depth == current_line.depth
+            and before == 0
+        ):
+            return 0, 0
+
+        if self.is_pyi:
+            if self.previous_line.depth > current_line.depth:
+                newlines = 1
+            elif current_line.is_class or self.previous_line.is_class:
+                if current_line.is_stub_class and self.previous_line.is_stub_class:
+                    # No blank line between classes with an emty body
+                    newlines = 0
+                else:
+                    newlines = 1
+            elif current_line.is_def and not self.previous_line.is_def:
+                # Blank line between a block of functions and a block of non-functions
+                newlines = 1
+            else:
+                newlines = 0
+        else:
+            newlines = 2
+        if current_line.depth and newlines:
+            newlines -= 1
+        return newlines, 0
+
 
 @dataclass
 class LineGenerator(Visitor[Line]):
@@ -1835,7 +1893,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
             elif prevp.type == token.EQUAL and prevp_parent.type == syms.argument:
                 return NO
 
-        elif t == token.NAME or t == token.NUMBER:
+        elif t in {token.NAME, token.NUMBER, token.STRING}:
             return NO
 
     elif p.type == syms.import_from:
@@ -1987,6 +2045,10 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
+FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
+
+
 def generate_comments(leaf: LN) -> Iterator[Leaf]:
     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
@@ -2006,16 +2068,37 @@ def generate_comments(leaf: LN) -> Iterator[Leaf]:
     Inline comments are emitted as regular token.COMMENT leaves.  Standalone
     are emitted with a fake STANDALONE_COMMENT token identifier.
     """
-    p = leaf.prefix
-    if not p:
-        return
+    for pc in list_comments(leaf.prefix, is_endmarker=leaf.type == token.ENDMARKER):
+        yield Leaf(pc.type, pc.value, prefix="\n" * pc.newlines)
+        if pc.value in FMT_ON:
+            raise FormatOn(pc.consumed)
 
-    if "#" not in p:
-        return
+        if pc.value in FMT_OFF:
+            if pc.type == STANDALONE_COMMENT:
+                raise FormatOff(pc.consumed)
+
+            prev = preceding_leaf(leaf)
+            if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
+                raise FormatOff(pc.consumed)
+
+
+@dataclass
+class ProtoComment:
+    type: int  # token.COMMENT or STANDALONE_COMMENT
+    value: str  # content of the comment
+    newlines: int  # how many newlines before the comment
+    consumed: int  # how many characters of the original leaf's prefix did we consume
+
+
+@lru_cache(maxsize=4096)
+def list_comments(prefix: str, is_endmarker: bool) -> List[ProtoComment]:
+    result: List[ProtoComment] = []
+    if not prefix or "#" not in prefix:
+        return result
 
     consumed = 0
     nlines = 0
-    for index, line in enumerate(p.split("\n")):
+    for index, line in enumerate(prefix.split("\n")):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
@@ -2023,25 +2106,18 @@ def generate_comments(leaf: LN) -> Iterator[Leaf]:
         if not line.startswith("#"):
             continue
 
-        if index == 0 and leaf.type != token.ENDMARKER:
+        if index == 0 and not is_endmarker:
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
         comment = make_comment(line)
-        yield Leaf(comment_type, comment, prefix="\n" * nlines)
-
-        if comment in {"# fmt: on", "# yapf: enable"}:
-            raise FormatOn(consumed)
-
-        if comment in {"# fmt: off", "# yapf: disable"}:
-            if comment_type == STANDALONE_COMMENT:
-                raise FormatOff(consumed)
-
-            prev = preceding_leaf(leaf)
-            if not prev or prev.type in WHITESPACE:  # standalone comment in disguise
-                raise FormatOff(consumed)
-
+        result.append(
+            ProtoComment(
+                type=comment_type, value=comment, newlines=nlines, consumed=consumed
+            )
+        )
         nlines = 0
+    return result
 
 
 def make_comment(content: str) -> str:
@@ -2258,7 +2334,7 @@ def right_hand_split(
             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
                 raise CannotSplit(
                     "The current optional pair of parentheses is bound to fail to "
-                    "satisfy the splitting algorithm becase the head or the tail "
+                    "satisfy the splitting algorithm because the head or the tail "
                     "contains multiline strings which by definition never fit one "
                     "line."
                 )
@@ -2479,8 +2555,8 @@ def normalize_string_quotes(leaf: Leaf) -> None:
 
     prefix = leaf.value[:first_quote_pos]
     unescaped_new_quote = re.compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
-    escaped_new_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{new_quote}")
-    escaped_orig_quote = re.compile(rf"([^\\]|^)\\(\\\\)*{orig_quote}")
+    escaped_new_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
+    escaped_orig_quote = re.compile(rf"([^\\]|^)\\((?:\\\\)*){orig_quote}")
     body = leaf.value[first_quote_pos + len(orig_quote) : -len(orig_quote)]
     if "r" in prefix.casefold():
         if unescaped_new_quote.search(body):
@@ -2491,15 +2567,21 @@ def normalize_string_quotes(leaf: Leaf) -> None:
         # Do not introduce or remove backslashes in raw strings
         new_body = body
     else:
-        # remove unnecessary quotes
+        # remove unnecessary escapes
         new_body = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", body)
         if body != new_body:
-            # Consider the string without unnecessary quotes as the original
+            # Consider the string without unnecessary escapes as the original
             body = new_body
             leaf.value = f"{prefix}{orig_quote}{body}{orig_quote}"
         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
-    if new_quote == '"""' and new_body[-1] == '"':
+    if "f" in prefix.casefold():
+        matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
+        for m in matches:
+            if "\\" in str(m):
+                # Do not introduce backslashes in interpolated expressions
+                return
+    if new_quote == '"""' and new_body[-1:] == '"':
         # edge case:
         new_body = new_body[:-1] + '\\"'
     orig_escape_count = body.count("\\")
@@ -2522,10 +2604,10 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
-    try:
-        list(generate_comments(node))
-    except FormatOff:
-        return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
+    for pc in list_comments(node.prefix, is_endmarker=False):
+        if pc.value in FMT_OFF:
+            # This `node` has a prefix with `# fmt: off`, don't mess with parens.
+            return
 
     check_lpar = False
     for index, child in enumerate(list(node.children)):
@@ -2562,7 +2644,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
-    """If it's safe, make the parens in the atom `node` invisible, recusively."""
+    """If it's safe, make the parens in the atom `node` invisible, recursively."""
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -2885,16 +2967,29 @@ def gen_python_files_in_dir(
     """Generate all files under `path` whose paths are not excluded by the
     `exclude` regex, but are included by the `include` regex.
 
+    Symbolic links pointing outside of the root directory are ignored.
+
     `report` is where output about exclusions goes.
     """
     assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
     for child in path.iterdir():
-        normalized_path = "/" + child.resolve().relative_to(root).as_posix()
+        try:
+            normalized_path = "/" + child.resolve().relative_to(root).as_posix()
+        except ValueError:
+            if child.is_symlink():
+                report.path_ignored(
+                    child,
+                    "is a symbolic link that points outside of the root directory",
+                )
+                continue
+
+            raise
+
         if child.is_dir():
             normalized_path += "/"
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
-            report.path_ignored(child, f"matches --exclude={exclude.pattern}")
+            report.path_ignored(child, f"matches the --exclude regular expression")
             continue
 
         if child.is_dir():
@@ -2906,7 +3001,8 @@ def gen_python_files_in_dir(
                 yield child
 
 
-def find_project_root(srcs: List[str]) -> Path:
+@lru_cache()
+def find_project_root(srcs: Iterable[str]) -> Path:
     """Return a directory containing .git, .hg, or pyproject.toml.
 
     That directory can be one of the directories passed in `srcs` or their
@@ -3164,6 +3260,16 @@ def sub_twice(regex: Pattern[str], replacement: str, original: str) -> str:
     return regex.sub(replacement, regex.sub(replacement, original))
 
 
+def re_compile_maybe_verbose(regex: str) -> Pattern[str]:
+    """Compile a regular expression string in `regex`.
+
+    If it contains newlines, use verbose mode.
+    """
+    if "\n" in regex:
+        regex = "(?x)" + regex
+    return re.compile(regex)
+
+
 def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]:
     """Like `reversed(enumerate(sequence))` if that were possible."""
     index = len(sequence) - 1
@@ -3335,12 +3441,7 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
 
 
 def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36)
-    return (
-        CACHE_DIR
-        / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
-    )
+    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
 
 
 def read_cache(line_length: int, mode: FileMode) -> Cache: