]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Always show summary of reformatting
[etc/vim.git] / black.py
index 547751b789700fde7d7e1d88223b6a19806659e2..19a023cff52db2e92d5cc0950125e20c96d24dff 100644 (file)
--- a/black.py
+++ b/black.py
@@ -4,6 +4,7 @@ from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from enum import Enum, Flag
 from functools import partial, wraps
 from concurrent.futures import Executor, ProcessPoolExecutor
 from enum import Enum, Flag
 from functools import partial, wraps
+import io
 import keyword
 import logging
 from multiprocessing import Manager
 import keyword
 import logging
 from multiprocessing import Manager
@@ -44,8 +45,12 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.5b0"
+__version__ = "18.5b1"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_LINE_LENGTH = 88
+DEFAULT_EXCLUDES = (
+    r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
+)
+DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 
 
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 
 
@@ -115,6 +120,13 @@ class WriteBack(Enum):
     YES = 1
     DIFF = 2
 
     YES = 1
     DIFF = 2
 
+    @classmethod
+    def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
+        if check and not diff:
+            return cls.NO
+
+        return cls.DIFF if diff else cls.YES
+
 
 class Changed(Enum):
     NO = 0
 
 class Changed(Enum):
     NO = 0
@@ -126,6 +138,20 @@ class FileMode(Flag):
     AUTO_DETECT = 0
     PYTHON36 = 1
     PYI = 2
     AUTO_DETECT = 0
     PYTHON36 = 1
     PYI = 2
+    NO_STRING_NORMALIZATION = 4
+
+    @classmethod
+    def from_configuration(
+        cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
+    ) -> "FileMode":
+        mode = cls.AUTO_DETECT
+        if py36:
+            mode |= cls.PYTHON36
+        if pyi:
+            mode |= cls.PYI
+        if skip_string_normalization:
+            mode |= cls.NO_STRING_NORMALIZATION
+        return mode
 
 
 @click.command()
 
 
 @click.command()
@@ -137,6 +163,29 @@ class FileMode(Flag):
     help="How many character per line to allow.",
     show_default=True,
 )
     help="How many character per line to allow.",
     show_default=True,
 )
+@click.option(
+    "--py36",
+    is_flag=True,
+    help=(
+        "Allow using Python 3.6-only syntax on all input files.  This will put "
+        "trailing commas in function signatures and calls also after *args and "
+        "**kwargs.  [default: per-file auto-detection]"
+    ),
+)
+@click.option(
+    "--pyi",
+    is_flag=True,
+    help=(
+        "Format all input files like typing stubs regardless of file extension "
+        "(useful when piping source on standard input)."
+    ),
+)
+@click.option(
+    "-S",
+    "--skip-string-normalization",
+    is_flag=True,
+    help="Don't normalize string quotes or prefixes.",
+)
 @click.option(
     "--check",
     is_flag=True,
 @click.option(
     "--check",
     is_flag=True,
@@ -157,29 +206,46 @@ class FileMode(Flag):
     help="If --fast given, skip temporary sanity checks. [default: --safe]",
 )
 @click.option(
     help="If --fast given, skip temporary sanity checks. [default: --safe]",
 )
 @click.option(
-    "-q",
-    "--quiet",
-    is_flag=True,
+    "--include",
+    type=str,
+    default=DEFAULT_INCLUDES,
     help=(
     help=(
-        "Don't emit non-error messages to stderr. Errors are still emitted, "
-        "silence those with 2>/dev/null."
+        "A regular expression that matches files and directories that should be "
+        "included on recursive searches.  An empty value means all files are "
+        "included regardless of the name.  Use forward slashes for directories on "
+        "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
+        "later."
     ),
     ),
+    show_default=True,
 )
 @click.option(
 )
 @click.option(
-    "--pyi",
+    "--exclude",
+    type=str,
+    default=DEFAULT_EXCLUDES,
+    help=(
+        "A regular expression that matches files and directories that should be "
+        "excluded on recursive searches.  An empty value means no paths are excluded. "
+        "Use forward slashes for directories on all platforms (Windows, too).  "
+        "Exclusions are calculated first, inclusions later."
+    ),
+    show_default=True,
+)
+@click.option(
+    "-q",
+    "--quiet",
     is_flag=True,
     help=(
     is_flag=True,
     help=(
-        "Consider all input files typing stubs regardless of file extension "
-        "(useful when piping source on standard input)."
+        "Don't emit non-error messages to stderr. Errors are still emitted, "
+        "silence those with 2>/dev/null."
     ),
 )
 @click.option(
     ),
 )
 @click.option(
-    "--py36",
+    "-v",
+    "--verbose",
     is_flag=True,
     help=(
     is_flag=True,
     help=(
-        "Allow using Python 3.6-only syntax on all input files.  This will put "
-        "trailing commas in function signatures and calls also after *args and "
-        "**kwargs.  [default: per-file auto-detection]"
+        "Also emit messages to stderr about files that were not changed or were "
+        "ignored due to --exclude=."
     ),
 )
 @click.version_option(version=__version__)
     ),
 )
 @click.version_option(version=__version__)
@@ -199,43 +265,51 @@ def main(
     fast: bool,
     pyi: bool,
     py36: bool,
     fast: bool,
     pyi: bool,
     py36: bool,
+    skip_string_normalization: bool,
     quiet: bool,
     quiet: bool,
+    verbose: bool,
+    include: str,
+    exclude: str,
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
-    sources: List[Path] = []
+    write_back = WriteBack.from_configuration(check=check, diff=diff)
+    mode = FileMode.from_configuration(
+        py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
+    )
+    report = Report(check=check, quiet=quiet, verbose=verbose)
+    sources: Set[Path] = set()
+    try:
+        include_regex = re.compile(include)
+    except re.error:
+        err(f"Invalid regular expression for include given: {include!r}")
+        ctx.exit(2)
+    try:
+        exclude_regex = re.compile(exclude)
+    except re.error:
+        err(f"Invalid regular expression for exclude given: {exclude!r}")
+        ctx.exit(2)
+    root = find_project_root(src)
     for s in src:
         p = Path(s)
         if p.is_dir():
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(gen_python_files_in_dir(p))
-        elif p.is_file():
+            sources.update(
+                gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
+            )
+        elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
             # if a file was explicitly given, we don't care about its extension
-            sources.append(p)
-        elif s == "-":
-            sources.append(Path("-"))
+            sources.add(p)
         else:
             err(f"invalid path: {s}")
         else:
             err(f"invalid path: {s}")
-
-    if check and not diff:
-        write_back = WriteBack.NO
-    elif diff:
-        write_back = WriteBack.DIFF
-    else:
-        write_back = WriteBack.YES
-    mode = FileMode.AUTO_DETECT
-    if py36:
-        mode |= FileMode.PYTHON36
-    if pyi:
-        mode |= FileMode.PYI
-    report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
     if len(sources) == 0:
-        out("No paths given. Nothing to do 😴")
+        if verbose or not quiet:
+            out("No paths given. Nothing to do 😴")
         ctx.exit(0)
         return
 
     elif len(sources) == 1:
         reformat_one(
         ctx.exit(0)
         return
 
     elif len(sources) == 1:
         reformat_one(
-            src=sources[0],
+            src=sources.pop(),
             line_length=line_length,
             fast=fast,
             write_back=write_back,
             line_length=line_length,
             fast=fast,
             write_back=write_back,
@@ -260,9 +334,9 @@ def main(
             )
         finally:
             shutdown(loop)
             )
         finally:
             shutdown(loop)
-        if not quiet:
-            out("All done! ✨ 🍰 ✨")
-            click.echo(str(report))
+    if verbose or not quiet:
+        out("All done! ✨ 🍰 ✨")
+        click.echo(str(report))
     ctx.exit(report.return_code)
 
 
     ctx.exit(report.return_code)
 
 
@@ -291,8 +365,8 @@ def reformat_one(
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
                 cache = read_cache(line_length, mode)
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
                 cache = read_cache(line_length, mode)
-                src = src.resolve()
-                if src in cache and cache[src] == get_cache_info(src):
+                res_src = src.resolve()
+                if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 src,
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 src,
@@ -310,7 +384,7 @@ def reformat_one(
 
 
 async def schedule_formatting(
 
 
 async def schedule_formatting(
-    sources: List[Path],
+    sources: Set[Path],
     line_length: int,
     fast: bool,
     write_back: WriteBack,
     line_length: int,
     fast: bool,
     write_back: WriteBack,
@@ -330,7 +404,7 @@ async def schedule_formatting(
     if write_back != WriteBack.DIFF:
         cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
     if write_back != WriteBack.DIFF:
         cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
-        for src in cached:
+        for src in sorted(cached):
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
@@ -393,8 +467,9 @@ def format_file_in_place(
     """
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
     """
     if src.suffix == ".pyi":
         mode |= FileMode.PYI
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
+
+    with open(src, "rb") as buf:
+        newline, encoding, src_contents = prepare_input(buf.read())
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
     try:
         dst_contents = format_file_contents(
             src_contents, line_length=line_length, fast=fast, mode=mode
@@ -403,7 +478,7 @@ def format_file_in_place(
         return False
 
     if write_back == write_back.YES:
         return False
 
     if write_back == write_back.YES:
-        with open(src, "w", encoding=src_buffer.encoding) as f:
+        with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
         src_name = f"{src}  (original)"
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
         src_name = f"{src}  (original)"
@@ -412,7 +487,14 @@ def format_file_in_place(
         if lock:
             lock.acquire()
         try:
         if lock:
             lock.acquire()
         try:
-            sys.stdout.write(diff_contents)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff_contents)
+            f.detach()
         finally:
             if lock:
                 lock.release()
         finally:
             if lock:
                 lock.release()
@@ -431,7 +513,7 @@ def format_stdin_to_stdout(
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
     `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
     :func:`format_file_contents`.
     """
-    src = sys.stdin.read()
+    newline, encoding, src = prepare_input(sys.stdin.buffer.read())
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
     dst = src
     try:
         dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -442,11 +524,25 @@ def format_stdin_to_stdout(
 
     finally:
         if write_back == WriteBack.YES:
 
     finally:
         if write_back == WriteBack.YES:
-            sys.stdout.write(dst)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(dst)
+            f.detach()
         elif write_back == WriteBack.DIFF:
             src_name = "<stdin>  (original)"
             dst_name = "<stdin>  (formatted)"
         elif write_back == WriteBack.DIFF:
             src_name = "<stdin>  (original)"
             dst_name = "<stdin>  (formatted)"
-            sys.stdout.write(diff(src, dst, src_name, dst_name))
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff(src, dst, src_name, dst_name))
+            f.detach()
 
 
 def format_file_contents(
 
 
 def format_file_contents(
@@ -487,8 +583,11 @@ def format_str(
     future_imports = get_future_imports(src_node)
     is_pyi = bool(mode & FileMode.PYI)
     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
     future_imports = get_future_imports(src_node)
     is_pyi = bool(mode & FileMode.PYI)
     py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
+    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
     lines = LineGenerator(
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+        remove_u_prefix=py36 or "unicode_literals" in future_imports,
+        is_pyi=is_pyi,
+        normalize_strings=normalize_strings,
     )
     elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
     )
     elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
@@ -504,6 +603,19 @@ def format_str(
     return dst_contents
 
 
     return dst_contents
 
 
+def prepare_input(src: bytes) -> Tuple[str, str, str]:
+    """Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
+
+    Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
+    universal newlines (i.e. only LF).
+    """
+    srcbuf = io.BytesIO(src)
+    encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
+    srcbuf.seek(0)
+    return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
+
+
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
@@ -515,8 +627,7 @@ def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
     if src_txt[-1] != "\n":
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
     if src_txt[-1] != "\n":
-        nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
-        src_txt += nl
+        src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
@@ -1286,6 +1397,7 @@ class LineGenerator(Visitor[Line]):
     """
 
     is_pyi: bool = False
     """
 
     is_pyi: bool = False
+    normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
@@ -1354,7 +1466,7 @@ class LineGenerator(Visitor[Line]):
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
-                if node.type == token.STRING:
+                if self.normalize_strings and node.type == token.STRING:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
@@ -2736,33 +2848,64 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
     return imports
 
 
-PYTHON_EXTENSIONS = {".py", ".pyi"}
-BLACKLISTED_DIRECTORIES = {
-    "build",
-    "buck-out",
-    "dist",
-    "_build",
-    ".git",
-    ".hg",
-    ".mypy_cache",
-    ".tox",
-    ".venv",
-}
-
+def gen_python_files_in_dir(
+    path: Path,
+    root: Path,
+    include: Pattern[str],
+    exclude: Pattern[str],
+    report: "Report",
+) -> Iterator[Path]:
+    """Generate all files under `path` whose paths are not excluded by the
+    `exclude` regex, but are included by the `include` regex.
 
 
-def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
-    """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
-    and have one of the PYTHON_EXTENSIONS.
+    `report` is where output about exclusions goes.
     """
     """
+    assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
     for child in path.iterdir():
     for child in path.iterdir():
+        normalized_path = "/" + child.resolve().relative_to(root).as_posix()
         if child.is_dir():
         if child.is_dir():
-            if child.name in BLACKLISTED_DIRECTORIES:
-                continue
+            normalized_path += "/"
+        exclude_match = exclude.search(normalized_path)
+        if exclude_match and exclude_match.group(0):
+            report.path_ignored(child, f"matches --exclude={exclude.pattern}")
+            continue
+
+        if child.is_dir():
+            yield from gen_python_files_in_dir(child, root, include, exclude, report)
 
 
-            yield from gen_python_files_in_dir(child)
+        elif child.is_file():
+            include_match = include.search(normalized_path)
+            if include_match:
+                yield child
 
 
-        elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
-            yield child
+
+def find_project_root(srcs: List[str]) -> Path:
+    """Return a directory containing .git, .hg, or pyproject.toml.
+
+    That directory can be one of the directories passed in `srcs` or their
+    common parent.
+
+    If no directory in the tree contains a marker that would specify it's the
+    project root, the root of the file system is returned.
+    """
+    if not srcs:
+        return Path("/").resolve()
+
+    common_base = min(Path(src).resolve() for src in srcs)
+    if common_base.is_dir():
+        # Append a fake file so `parents` below returns `common_base_dir`, too.
+        common_base /= "fake-file"
+    for directory in common_base.parents:
+        if (directory / ".git").is_dir():
+            return directory
+
+        if (directory / ".hg").is_dir():
+            return directory
+
+        if (directory / "pyproject.toml").is_file():
+            return directory
+
+    return directory
 
 
 @dataclass
 
 
 @dataclass
@@ -2771,6 +2914,7 @@ class Report:
 
     check: bool = False
     quiet: bool = False
 
     check: bool = False
     quiet: bool = False
+    verbose: bool = False
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
@@ -2779,11 +2923,11 @@ class Report:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
             reformatted = "would reformat" if self.check else "reformatted"
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
             reformatted = "would reformat" if self.check else "reformatted"
-            if not self.quiet:
+            if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
         else:
                 out(f"{reformatted} {src}")
             self.change_count += 1
         else:
-            if not self.quiet:
+            if self.verbose:
                 if changed is Changed.NO:
                     msg = f"{src} already well formatted, good job."
                 else:
                 if changed is Changed.NO:
                     msg = f"{src} already well formatted, good job."
                 else:
@@ -2796,6 +2940,10 @@ class Report:
         err(f"error: cannot format {src}: {message}")
         self.failure_count += 1
 
         err(f"error: cannot format {src}: {message}")
         self.failure_count += 1
 
+    def path_ignored(self, path: Path, message: str) -> None:
+        if self.verbose:
+            out(f"{path} ignored: {message}", bold=False)
+
     @property
     def return_code(self) -> int:
         """Return the exit code that the app should use.
     @property
     def return_code(self) -> int:
         """Return the exit code that the app should use.
@@ -3156,26 +3304,24 @@ def get_cache_info(path: Path) -> CacheInfo:
     return stat.st_mtime, stat.st_size
 
 
     return stat.st_mtime, stat.st_size
 
 
-def filter_cached(
-    cache: Cache, sources: Iterable[Path]
-) -> Tuple[List[Path], List[Path]]:
-    """Split a list of paths into two.
+def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
+    """Split an iterable of paths in `sources` into two sets.
 
 
-    The first list contains paths of files that modified on disk or are not in the
-    cache. The other list contains paths to non-modified files.
+    The first contains paths of files that modified on disk or are not in the
+    cache. The other contains paths to non-modified files.
     """
     """
-    todo, done = [], []
+    todo, done = set(), set()
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
-            todo.append(src)
+            todo.add(src)
         else:
         else:
-            done.append(src)
+            done.add(src)
     return todo, done
 
 
 def write_cache(
     return todo, done
 
 
 def write_cache(
-    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
+    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
 ) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(line_length, mode)
 ) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(line_length, mode)