]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

[trivial] Simplify stdin handling
[etc/vim.git] / black.py
index 7dc6ef86c1d4b12a4a46bc2c1cd6fe080a12af01..0dce397e768fee2973537f157b4903c3d858390d 100644 (file)
--- a/black.py
+++ b/black.py
@@ -2,7 +2,7 @@ import asyncio
 import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
-from enum import Enum
+from enum import Enum, Flag
 from functools import partial, wraps
 import keyword
 import logging
@@ -44,8 +44,12 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.5b0"
+__version__ = "18.5b1"
 DEFAULT_LINE_LENGTH = 88
+DEFAULT_EXCLUDES = (
+    r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
+)
+DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 
 
@@ -122,6 +126,13 @@ class Changed(Enum):
     YES = 2
 
 
+class FileMode(Flag):
+    AUTO_DETECT = 0
+    PYTHON36 = 1
+    PYI = 2
+    NO_STRING_NORMALIZATION = 4
+
+
 @click.command()
 @click.option(
     "-l",
@@ -131,6 +142,29 @@ class Changed(Enum):
     help="How many character per line to allow.",
     show_default=True,
 )
+@click.option(
+    "--py36",
+    is_flag=True,
+    help=(
+        "Allow using Python 3.6-only syntax on all input files.  This will put "
+        "trailing commas in function signatures and calls also after *args and "
+        "**kwargs.  [default: per-file auto-detection]"
+    ),
+)
+@click.option(
+    "--pyi",
+    is_flag=True,
+    help=(
+        "Format all input files like typing stubs regardless of file extension "
+        "(useful when piping source on standard input)."
+    ),
+)
+@click.option(
+    "-S",
+    "--skip-string-normalization",
+    is_flag=True,
+    help="Don't normalize string quotes or prefixes.",
+)
 @click.option(
     "--check",
     is_flag=True,
@@ -151,29 +185,37 @@ class Changed(Enum):
     help="If --fast given, skip temporary sanity checks. [default: --safe]",
 )
 @click.option(
-    "-q",
-    "--quiet",
-    is_flag=True,
+    "--include",
+    type=str,
+    default=DEFAULT_INCLUDES,
     help=(
-        "Don't emit non-error messages to stderr. Errors are still emitted, "
-        "silence those with 2>/dev/null."
+        "A regular expression that matches files and directories that should be "
+        "included on recursive searches.  An empty value means all files are "
+        "included regardless of the name.  Use forward slashes for directories on "
+        "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
+        "later."
     ),
+    show_default=True,
 )
 @click.option(
-    "--pyi",
-    is_flag=True,
+    "--exclude",
+    type=str,
+    default=DEFAULT_EXCLUDES,
     help=(
-        "Consider all input files typing stubs regardless of file extension "
-        "(useful when piping source on standard input)."
+        "A regular expression that matches files and directories that should be "
+        "excluded on recursive searches.  An empty value means no paths are excluded. "
+        "Use forward slashes for directories on all platforms (Windows, too).  "
+        "Exclusions are calculated first, inclusions later."
     ),
+    show_default=True,
 )
 @click.option(
-    "--py36",
+    "-q",
+    "--quiet",
     is_flag=True,
     help=(
-        "Allow using Python 3.6-only syntax on all input files.  This will put "
-        "trailing commas in function signatures and calls also after *args and "
-        "**kwargs.  [default: per-file auto-detection]"
+        "Don't emit non-error messages to stderr. Errors are still emitted, "
+        "silence those with 2>/dev/null."
     ),
 )
 @click.version_option(version=__version__)
@@ -193,20 +235,34 @@ def main(
     fast: bool,
     pyi: bool,
     py36: bool,
+    skip_string_normalization: bool,
     quiet: bool,
+    include: str,
+    exclude: str,
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
     sources: List[Path] = []
+    try:
+        include_regex = re.compile(include)
+    except re.error:
+        err(f"Invalid regular expression for include given: {include!r}")
+        ctx.exit(2)
+    try:
+        exclude_regex = re.compile(exclude)
+    except re.error:
+        err(f"Invalid regular expression for exclude given: {exclude!r}")
+        ctx.exit(2)
+    root = find_project_root(src)
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(gen_python_files_in_dir(p))
-        elif p.is_file():
+            sources.extend(
+                gen_python_files_in_dir(p, root, include_regex, exclude_regex)
+            )
+        elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
-        elif s == "-":
-            sources.append(Path("-"))
         else:
             err(f"invalid path: {s}")
 
@@ -216,6 +272,13 @@ def main(
         write_back = WriteBack.DIFF
     else:
         write_back = WriteBack.YES
+    mode = FileMode.AUTO_DETECT
+    if py36:
+        mode |= FileMode.PYTHON36
+    if pyi:
+        mode |= FileMode.PYI
+    if skip_string_normalization:
+        mode |= FileMode.NO_STRING_NORMALIZATION
     report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
         out("No paths given. Nothing to do 😴")
@@ -227,9 +290,8 @@ def main(
             src=sources[0],
             line_length=line_length,
             fast=fast,
-            pyi=pyi,
-            py36=py36,
             write_back=write_back,
+            mode=mode,
             report=report,
         )
     else:
@@ -241,9 +303,8 @@ def main(
                     sources=sources,
                     line_length=line_length,
                     fast=fast,
-                    pyi=pyi,
-                    py36=py36,
                     write_back=write_back,
+                    mode=mode,
                     report=report,
                     loop=loop,
                     executor=executor,
@@ -261,9 +322,8 @@ def reformat_one(
     src: Path,
     line_length: int,
     fast: bool,
-    pyi: bool,
-    py36: bool,
     write_back: WriteBack,
+    mode: FileMode,
     report: "Report",
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
@@ -276,31 +336,26 @@ def reformat_one(
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
-                line_length=line_length,
-                fast=fast,
-                is_pyi=pyi,
-                force_py36=py36,
-                write_back=write_back,
+                line_length=line_length, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length, pyi, py36)
-                src = src.resolve()
-                if src in cache and cache[src] == get_cache_info(src):
+                cache = read_cache(line_length, mode)
+                res_src = src.resolve()
+                if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 src,
                 line_length=line_length,
                 fast=fast,
-                force_pyi=pyi,
-                force_py36=py36,
                 write_back=write_back,
+                mode=mode,
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
-                write_cache(cache, [src], line_length, pyi, py36)
+                write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
@@ -310,9 +365,8 @@ async def schedule_formatting(
     sources: List[Path],
     line_length: int,
     fast: bool,
-    pyi: bool,
-    py36: bool,
     write_back: WriteBack,
+    mode: FileMode,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
@@ -326,7 +380,7 @@ async def schedule_formatting(
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length, pyi, py36)
+        cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
         for src in cached:
             report.done(src, Changed.CACHED)
@@ -346,9 +400,8 @@ async def schedule_formatting(
                 src,
                 line_length,
                 fast,
-                pyi,
-                py36,
                 write_back,
+                mode,
                 lock,
             ): src
             for src in sorted(sources)
@@ -374,16 +427,15 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length, pyi, py36)
+        write_cache(cache, formatted, line_length, mode)
 
 
 def format_file_in_place(
     src: Path,
     line_length: int,
     fast: bool,
-    force_pyi: bool = False,
-    force_py36: bool = False,
     write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
@@ -391,17 +443,13 @@ def format_file_in_place(
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
-    is_pyi = force_pyi or src.suffix == ".pyi"
-
+    if src.suffix == ".pyi":
+        mode |= FileMode.PYI
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
-            src_contents,
-            line_length=line_length,
-            fast=fast,
-            is_pyi=is_pyi,
-            force_py36=force_py36,
+            src_contents, line_length=line_length, fast=fast, mode=mode
         )
     except NothingChanged:
         return False
@@ -426,9 +474,8 @@ def format_file_in_place(
 def format_stdin_to_stdout(
     line_length: int,
     fast: bool,
-    is_pyi: bool = False,
-    force_py36: bool = False,
     write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
@@ -439,13 +486,7 @@ def format_stdin_to_stdout(
     src = sys.stdin.read()
     dst = src
     try:
-        dst = format_file_contents(
-            src,
-            line_length=line_length,
-            fast=fast,
-            is_pyi=is_pyi,
-            force_py36=force_py36,
-        )
+        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
@@ -465,8 +506,7 @@ def format_file_contents(
     *,
     line_length: int,
     fast: bool,
-    is_pyi: bool = False,
-    force_py36: bool = False,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -477,30 +517,18 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(
-        src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
-    )
+    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(
-            src_contents,
-            dst_contents,
-            line_length=line_length,
-            is_pyi=is_pyi,
-            force_py36=force_py36,
-        )
+        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
     return dst_contents
 
 
 def format_str(
-    src_contents: str,
-    line_length: int,
-    *,
-    is_pyi: bool = False,
-    force_py36: bool = False,
+    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
 ) -> FileContent:
     """Reformat a string and return new contents.
 
@@ -509,11 +537,15 @@ def format_str(
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
-    elt = EmptyLineTracker(is_pyi=is_pyi)
-    py36 = force_py36 or is_python36(src_node)
+    is_pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
+    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+        remove_u_prefix=py36 or "unicode_literals" in future_imports,
+        is_pyi=is_pyi,
+        normalize_strings=normalize_strings,
     )
+    elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -989,14 +1021,11 @@ class Line:
 
     @property
     def is_triple_quoted_string(self) -> bool:
-        """Is the line a triple quoted docstring?"""
+        """Is the line a triple quoted string?"""
         return (
             bool(self)
             and self.leaves[0].type == token.STRING
-            and (
-                self.leaves[0].value.startswith('"""')
-                or self.leaves[0].value.startswith("'''")
-            )
+            and self.leaves[0].value.startswith(('"""', "'''"))
         )
 
     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
@@ -1257,9 +1286,8 @@ class EmptyLineTracker:
             if self.previous_line.is_decorator:
                 return 0, 0
 
-            if (
-                self.previous_line.is_class
-                and self.previous_line.depth != current_line.depth
+            if self.previous_line.depth < current_line.depth and (
+                self.previous_line.is_class or self.previous_line.is_def
             ):
                 return 0, 0
 
@@ -1313,6 +1341,7 @@ class LineGenerator(Visitor[Line]):
     """
 
     is_pyi: bool = False
+    normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
@@ -1381,7 +1410,7 @@ class LineGenerator(Visitor[Line]):
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
-                if node.type == token.STRING:
+                if self.normalize_strings and node.type == token.STRING:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
@@ -2763,33 +2792,57 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
-PYTHON_EXTENSIONS = {".py", ".pyi"}
-BLACKLISTED_DIRECTORIES = {
-    "build",
-    "buck-out",
-    "dist",
-    "_build",
-    ".git",
-    ".hg",
-    ".mypy_cache",
-    ".tox",
-    ".venv",
-}
-
-
-def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
-    """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
-    and have one of the PYTHON_EXTENSIONS.
+def gen_python_files_in_dir(
+    path: Path, root: Path, include: Pattern[str], exclude: Pattern[str]
+) -> Iterator[Path]:
+    """Generate all files under `path` whose paths are not excluded by the
+    `exclude` regex, but are included by the `include` regex.
     """
+    assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
     for child in path.iterdir():
+        normalized_path = child.resolve().relative_to(root).as_posix()
         if child.is_dir():
-            if child.name in BLACKLISTED_DIRECTORIES:
-                continue
+            normalized_path += "/"
+        exclude_match = exclude.search(normalized_path)
+        if exclude_match and exclude_match.group(0):
+            continue
+
+        if child.is_dir():
+            yield from gen_python_files_in_dir(child, root, include, exclude)
+
+        elif child.is_file():
+            include_match = include.search(normalized_path)
+            if include_match:
+                yield child
 
-            yield from gen_python_files_in_dir(child)
 
-        elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
-            yield child
+def find_project_root(srcs: List[str]) -> Path:
+    """Return a directory containing .git, .hg, or pyproject.toml.
+
+    That directory can be one of the directories passed in `srcs` or their
+    common parent.
+
+    If no directory in the tree contains a marker that would specify it's the
+    project root, the root of the file system is returned.
+    """
+    if not srcs:
+        return Path("/").resolve()
+
+    common_base = min(Path(src).resolve() for src in srcs)
+    if common_base.is_dir():
+        # Append a fake file so `parents` below returns `common_base_dir`, too.
+        common_base /= "fake-file"
+    for directory in common_base.parents:
+        if (directory / ".git").is_dir():
+            return directory
+
+        if (directory / ".hg").is_dir():
+            return directory
+
+        if (directory / "pyproject.toml").is_file():
+            return directory
+
+    return directory
 
 
 @dataclass
@@ -2936,12 +2989,10 @@ def assert_equivalent(src: str, dst: str) -> None:
 
 
 def assert_stable(
-    src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
+    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
 ) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(
-        dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
-    )
+    newdst = format_str(dst, line_length=line_length, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3152,19 +3203,21 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
-def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
+def get_cache_file(line_length: int, mode: FileMode) -> Path:
+    pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36)
     return (
         CACHE_DIR
         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
     )
 
 
-def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
+def read_cache(line_length: int, mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length, pyi, py36)
+    cache_file = get_cache_file(line_length, mode)
     if not cache_file.exists():
         return {}
 
@@ -3202,14 +3255,10 @@ def filter_cached(
 
 
 def write_cache(
-    cache: Cache,
-    sources: List[Path],
-    line_length: int,
-    pyi: bool = False,
-    py36: bool = False,
+    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
 ) -> None:
     """Update the cache file."""
-    cache_file = get_cache_file(line_length, pyi, py36)
+    cache_file = get_cache_file(line_length, mode)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)