]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

List the Python extension for VS Code as an editor integration (#308)
[etc/vim.git] / black.py
index 6dda93b45c38f24024afad1d6a5b3251e559a677..d048162e3629cdabfc99a7d6ab60c1fe70a37b75 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,18 +1,20 @@
 import asyncio
 import asyncio
-import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
-from enum import Enum
+from datetime import datetime
+from enum import Enum, Flag
 from functools import partial, wraps
 from functools import partial, wraps
+import io
 import keyword
 import logging
 from multiprocessing import Manager
 import os
 from pathlib import Path
 import keyword
 import logging
 from multiprocessing import Manager
 import os
 from pathlib import Path
+import pickle
 import re
 import re
-import tokenize
 import signal
 import sys
 import signal
 import sys
+import tokenize
 from typing import (
     Any,
     Callable,
 from typing import (
     Any,
     Callable,
@@ -44,13 +46,19 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.4a6"
+__version__ = "18.6b1"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_LINE_LENGTH = 88
+DEFAULT_EXCLUDES = (
+    r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
+)
+DEFAULT_INCLUDES = r"\.pyi?$"
+CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+
 
 # types
 
 # types
-syms = pygram.python_symbols
 FileContent = str
 Encoding = str
 FileContent = str
 Encoding = str
+NewLine = str
 Depth = int
 NodeType = int
 LeafID = int
 Depth = int
 NodeType = int
 LeafID = int
@@ -65,6 +73,9 @@ Cache = Dict[Path, CacheInfo]
 out = partial(click.secho, bold=True, err=True)
 err = partial(click.secho, fg="red", err=True)
 
 out = partial(click.secho, bold=True, err=True)
 err = partial(click.secho, fg="red", err=True)
 
+pygram.initialize(CACHE_DIR)
+syms = pygram.python_symbols
+
 
 class NothingChanged(UserWarning):
     """Raised by :func:`format_file` when reformatted code is the same as source."""
 
 class NothingChanged(UserWarning):
     """Raised by :func:`format_file` when reformatted code is the same as source."""
@@ -111,6 +122,13 @@ class WriteBack(Enum):
     YES = 1
     DIFF = 2
 
     YES = 1
     DIFF = 2
 
+    @classmethod
+    def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack":
+        if check and not diff:
+            return cls.NO
+
+        return cls.DIFF if diff else cls.YES
+
 
 class Changed(Enum):
     NO = 0
 
 class Changed(Enum):
     NO = 0
@@ -118,6 +136,26 @@ class Changed(Enum):
     YES = 2
 
 
     YES = 2
 
 
+class FileMode(Flag):
+    AUTO_DETECT = 0
+    PYTHON36 = 1
+    PYI = 2
+    NO_STRING_NORMALIZATION = 4
+
+    @classmethod
+    def from_configuration(
+        cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
+    ) -> "FileMode":
+        mode = cls.AUTO_DETECT
+        if py36:
+            mode |= cls.PYTHON36
+        if pyi:
+            mode |= cls.PYI
+        if skip_string_normalization:
+            mode |= cls.NO_STRING_NORMALIZATION
+        return mode
+
+
 @click.command()
 @click.option(
     "-l",
 @click.command()
 @click.option(
     "-l",
@@ -127,6 +165,29 @@ class Changed(Enum):
     help="How many character per line to allow.",
     show_default=True,
 )
     help="How many character per line to allow.",
     show_default=True,
 )
+@click.option(
+    "--py36",
+    is_flag=True,
+    help=(
+        "Allow using Python 3.6-only syntax on all input files.  This will put "
+        "trailing commas in function signatures and calls also after *args and "
+        "**kwargs.  [default: per-file auto-detection]"
+    ),
+)
+@click.option(
+    "--pyi",
+    is_flag=True,
+    help=(
+        "Format all input files like typing stubs regardless of file extension "
+        "(useful when piping source on standard input)."
+    ),
+)
+@click.option(
+    "-S",
+    "--skip-string-normalization",
+    is_flag=True,
+    help="Don't normalize string quotes or prefixes.",
+)
 @click.option(
     "--check",
     is_flag=True,
 @click.option(
     "--check",
     is_flag=True,
@@ -146,6 +207,31 @@ class Changed(Enum):
     is_flag=True,
     help="If --fast given, skip temporary sanity checks. [default: --safe]",
 )
     is_flag=True,
     help="If --fast given, skip temporary sanity checks. [default: --safe]",
 )
+@click.option(
+    "--include",
+    type=str,
+    default=DEFAULT_INCLUDES,
+    help=(
+        "A regular expression that matches files and directories that should be "
+        "included on recursive searches.  An empty value means all files are "
+        "included regardless of the name.  Use forward slashes for directories on "
+        "all platforms (Windows, too).  Exclusions are calculated first, inclusions "
+        "later."
+    ),
+    show_default=True,
+)
+@click.option(
+    "--exclude",
+    type=str,
+    default=DEFAULT_EXCLUDES,
+    help=(
+        "A regular expression that matches files and directories that should be "
+        "excluded on recursive searches.  An empty value means no paths are excluded. "
+        "Use forward slashes for directories on all platforms (Windows, too).  "
+        "Exclusions are calculated first, inclusions later."
+    ),
+    show_default=True,
+)
 @click.option(
     "-q",
     "--quiet",
 @click.option(
     "-q",
     "--quiet",
@@ -155,6 +241,15 @@ class Changed(Enum):
         "silence those with 2>/dev/null."
     ),
 )
         "silence those with 2>/dev/null."
     ),
 )
+@click.option(
+    "-v",
+    "--verbose",
+    is_flag=True,
+    help=(
+        "Also emit messages to stderr about files that were not changed or were "
+        "ignored due to --exclude=."
+    ),
+)
 @click.version_option(version=__version__)
 @click.argument(
     "src",
 @click.version_option(version=__version__)
 @click.argument(
     "src",
@@ -170,92 +265,133 @@ def main(
     check: bool,
     diff: bool,
     fast: bool,
     check: bool,
     diff: bool,
     fast: bool,
+    pyi: bool,
+    py36: bool,
+    skip_string_normalization: bool,
     quiet: bool,
     quiet: bool,
+    verbose: bool,
+    include: str,
+    exclude: str,
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
-    sources: List[Path] = []
+    write_back = WriteBack.from_configuration(check=check, diff=diff)
+    mode = FileMode.from_configuration(
+        py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
+    )
+    report = Report(check=check, quiet=quiet, verbose=verbose)
+    sources: Set[Path] = set()
+    try:
+        include_regex = re.compile(include)
+    except re.error:
+        err(f"Invalid regular expression for include given: {include!r}")
+        ctx.exit(2)
+    try:
+        exclude_regex = re.compile(exclude)
+    except re.error:
+        err(f"Invalid regular expression for exclude given: {exclude!r}")
+        ctx.exit(2)
+    root = find_project_root(src)
     for s in src:
         p = Path(s)
         if p.is_dir():
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(gen_python_files_in_dir(p))
-        elif p.is_file():
+            sources.update(
+                gen_python_files_in_dir(p, root, include_regex, exclude_regex, report)
+            )
+        elif p.is_file() or s == "-":
             # if a file was explicitly given, we don't care about its extension
             # if a file was explicitly given, we don't care about its extension
-            sources.append(p)
-        elif s == "-":
-            sources.append(Path("-"))
+            sources.add(p)
         else:
             err(f"invalid path: {s}")
         else:
             err(f"invalid path: {s}")
-
-    if check and not diff:
-        write_back = WriteBack.NO
-    elif diff:
-        write_back = WriteBack.DIFF
-    else:
-        write_back = WriteBack.YES
-    report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
     if len(sources) == 0:
-        out("No paths given. Nothing to do 😴")
+        if verbose or not quiet:
+            out("No paths given. Nothing to do 😴")
         ctx.exit(0)
         return
 
     elif len(sources) == 1:
         ctx.exit(0)
         return
 
     elif len(sources) == 1:
-        reformat_one(sources[0], line_length, fast, write_back, report)
+        reformat_one(
+            src=sources.pop(),
+            line_length=line_length,
+            fast=fast,
+            write_back=write_back,
+            mode=mode,
+            report=report,
+        )
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
         try:
             loop.run_until_complete(
                 schedule_formatting(
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
         try:
             loop.run_until_complete(
                 schedule_formatting(
-                    sources, line_length, fast, write_back, report, loop, executor
+                    sources=sources,
+                    line_length=line_length,
+                    fast=fast,
+                    write_back=write_back,
+                    mode=mode,
+                    report=report,
+                    loop=loop,
+                    executor=executor,
                 )
             )
         finally:
             shutdown(loop)
                 )
             )
         finally:
             shutdown(loop)
-        if not quiet:
-            out("All done! ✨ 🍰 ✨")
-            click.echo(str(report))
+    if verbose or not quiet:
+        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
+        out(f"All done! {bang}")
+        click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
 
 def reformat_one(
     ctx.exit(report.return_code)
 
 
 def reformat_one(
-    src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
+    src: Path,
+    line_length: int,
+    fast: bool,
+    write_back: WriteBack,
+    mode: FileMode,
+    report: "Report",
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
-    `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
+    `write_back`, `fast` and `pyi` options are passed to
+    :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back
+                line_length=line_length, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length)
-                src = src.resolve()
-                if src in cache and cache[src] == get_cache_info(src):
+                cache = read_cache(line_length, mode)
+                res_src = src.resolve()
+                if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src, line_length=line_length, fast=fast, write_back=write_back
+                src,
+                line_length=line_length,
+                fast=fast,
+                write_back=write_back,
+                mode=mode,
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
-                write_cache(cache, [src], line_length)
+                write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
 
 
 async def schedule_formatting(
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
 
 
 async def schedule_formatting(
-    sources: List[Path],
+    sources: Set[Path],
     line_length: int,
     fast: bool,
     write_back: WriteBack,
     line_length: int,
     fast: bool,
     write_back: WriteBack,
+    mode: FileMode,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
@@ -264,14 +400,14 @@ async def schedule_formatting(
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
-    `line_length`, `write_back`, and `fast` options are passed to
+    `line_length`, `write_back`, `fast`, and `pyi` options are passed to
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length)
+        cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
         sources, cached = filter_cached(cache, sources)
-        for src in cached:
+        for src in sorted(cached):
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
             report.done(src, Changed.CACHED)
     cancelled = []
     formatted = []
@@ -284,7 +420,14 @@ async def schedule_formatting(
             lock = manager.Lock()
         tasks = {
             loop.run_in_executor(
             lock = manager.Lock()
         tasks = {
             loop.run_in_executor(
-                executor, format_file_in_place, src, line_length, fast, write_back, lock
+                executor,
+                format_file_in_place,
+                src,
+                line_length,
+                fast,
+                write_back,
+                mode,
+                lock,
             ): src
             for src in sorted(sources)
         }
             ): src
             for src in sorted(sources)
         }
@@ -309,7 +452,7 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length)
+        write_cache(cache, formatted, line_length, mode)
 
 
 def format_file_in_place(
 
 
 def format_file_in_place(
@@ -317,6 +460,7 @@ def format_file_in_place(
     line_length: int,
     fast: bool,
     write_back: WriteBack = WriteBack.NO,
     line_length: int,
     fast: bool,
     write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
@@ -324,28 +468,38 @@ def format_file_in_place(
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
-    is_pyi = src.suffix == ".pyi"
+    if src.suffix == ".pyi":
+        mode |= FileMode.PYI
 
 
-    with tokenize.open(src) as src_buffer:
-        src_contents = src_buffer.read()
+    then = datetime.utcfromtimestamp(src.stat().st_mtime)
+    with open(src, "rb") as buf:
+        src_contents, encoding, newline = decode_bytes(buf.read())
     try:
         dst_contents = format_file_contents(
     try:
         dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
+            src_contents, line_length=line_length, fast=fast, mode=mode
         )
     except NothingChanged:
         return False
 
     if write_back == write_back.YES:
         )
     except NothingChanged:
         return False
 
     if write_back == write_back.YES:
-        with open(src, "w", encoding=src_buffer.encoding) as f:
+        with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
             f.write(dst_contents)
     elif write_back == write_back.DIFF:
-        src_name = f"{src}  (original)"
-        dst_name = f"{src}  (formatted)"
+        now = datetime.utcnow()
+        src_name = f"{src}\t{then} +0000"
+        dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
         try:
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         if lock:
             lock.acquire()
         try:
-            sys.stdout.write(diff_contents)
+            f = io.TextIOWrapper(
+                sys.stdout.buffer,
+                encoding=encoding,
+                newline=newline,
+                write_through=True,
+            )
+            f.write(diff_contents)
+            f.detach()
         finally:
             if lock:
                 lock.release()
         finally:
             if lock:
                 lock.release()
@@ -353,33 +507,47 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
 
 
 def format_stdin_to_stdout(
-    line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
+    line_length: int,
+    fast: bool,
+    write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
-    `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
+    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    :func:`format_file_contents`.
     """
     """
-    src = sys.stdin.read()
+    then = datetime.utcnow()
+    src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast)
+        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
         return False
 
     finally:
         return True
 
     except NothingChanged:
         return False
 
     finally:
+        f = io.TextIOWrapper(
+            sys.stdout.buffer, encoding=encoding, newline=newline, write_through=True
+        )
         if write_back == WriteBack.YES:
         if write_back == WriteBack.YES:
-            sys.stdout.write(dst)
+            f.write(dst)
         elif write_back == WriteBack.DIFF:
         elif write_back == WriteBack.DIFF:
-            src_name = "<stdin>  (original)"
-            dst_name = "<stdin>  (formatted)"
-            sys.stdout.write(diff(src, dst, src_name, dst_name))
+            now = datetime.utcnow()
+            src_name = f"STDIN\t{then} +0000"
+            dst_name = f"STDOUT\t{now} +0000"
+            f.write(diff(src, dst, src_name, dst_name))
+        f.detach()
 
 
 def format_file_contents(
 
 
 def format_file_contents(
-    src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
+    src_contents: str,
+    *,
+    line_length: int,
+    fast: bool,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -390,20 +558,18 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
+    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(
-            src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
-        )
+        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
     return dst_contents
 
 
 def format_str(
     return dst_contents
 
 
 def format_str(
-    src_contents: str, line_length: int, *, is_pyi: bool = False
+    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
 ) -> FileContent:
     """Reformat a string and return new contents.
 
 ) -> FileContent:
     """Reformat a string and return new contents.
 
@@ -412,11 +578,15 @@ def format_str(
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
-    elt = EmptyLineTracker(is_pyi=is_pyi)
-    py36 = is_python36(src_node)
+    is_pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
+    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
     lines = LineGenerator(
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+        remove_u_prefix=py36 or "unicode_literals" in future_imports,
+        is_pyi=is_pyi,
+        normalize_strings=normalize_strings,
     )
     )
+    elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -430,6 +600,23 @@ def format_str(
     return dst_contents
 
 
     return dst_contents
 
 
+def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
+    """Return a tuple of (decoded_contents, encoding, newline).
+
+    `newline` is either CRLF or LF but `decoded_contents` is decoded with
+    universal newlines (i.e. only contains LF).
+    """
+    srcbuf = io.BytesIO(src)
+    encoding, lines = tokenize.detect_encoding(srcbuf.readline)
+    if not lines:
+        return "", encoding, "\n"
+
+    newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
+    srcbuf.seek(0)
+    with io.TextIOWrapper(srcbuf, encoding) as tiow:
+        return tiow.read(), encoding, newline
+
+
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
 GRAMMARS = [
     pygram.python_grammar_no_print_statement_no_exec_statement,
     pygram.python_grammar_no_print_statement,
@@ -440,9 +627,8 @@ GRAMMARS = [
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
 def lib2to3_parse(src_txt: str) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     grammar = pygram.python_grammar_no_print_statement
-    if src_txt[-1] != "\n":
-        nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
-        src_txt += nl
+    if src_txt[-1:] != "\n":
+        src_txt += "\n"
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
     for grammar in GRAMMARS:
         drv = driver.Driver(grammar, pytree.convert)
         try:
@@ -586,6 +772,7 @@ UNPACKING_PARENTS = {
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
+    syms.testlist_star_expr,
 }
 TEST_DESCENDANTS = {
     syms.test,
 }
 TEST_DESCENDANTS = {
     syms.test,
@@ -874,27 +1061,6 @@ class Line:
             and second_leaf.value == "def"
         )
 
             and second_leaf.value == "def"
         )
 
-    @property
-    def is_flow_control(self) -> bool:
-        """Is this line a flow control statement?
-
-        Those are `return`, `raise`, `break`, and `continue`.
-        """
-        return (
-            bool(self)
-            and self.leaves[0].type == token.NAME
-            and self.leaves[0].value in FLOW_CONTROL
-        )
-
-    @property
-    def is_yield(self) -> bool:
-        """Is this line a yield statement?"""
-        return (
-            bool(self)
-            and self.leaves[0].type == token.NAME
-            and self.leaves[0].value == "yield"
-        )
-
     @property
     def is_class_paren_empty(self) -> bool:
         """Is this a class with no base classes but using parentheses?
     @property
     def is_class_paren_empty(self) -> bool:
         """Is this a class with no base classes but using parentheses?
@@ -911,6 +1077,15 @@ class Line:
             and self.leaves[3].value == ")"
         )
 
             and self.leaves[3].value == ")"
         )
 
+    @property
+    def is_triple_quoted_string(self) -> bool:
+        """Is the line a triple quoted string?"""
+        return (
+            bool(self)
+            and self.leaves[0].type == token.STRING
+            and self.leaves[0].value.startswith(('"""', "'''"))
+        )
+
     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
         """If so, needs to be split before emitting."""
         for leaf in self.leaves:
     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
         """If so, needs to be split before emitting."""
         for leaf in self.leaves:
@@ -920,6 +1095,13 @@ class Line:
 
         return False
 
 
         return False
 
+    def contains_multiline_strings(self) -> bool:
+        for leaf in self.leaves:
+            if is_multiline_string(leaf):
+                return True
+
+        return False
+
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
         if not (
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
         if not (
@@ -1118,6 +1300,7 @@ class EmptyLineTracker:
     the prefix of the first leaf consists of optional newlines.  Those newlines
     are consumed by `maybe_empty_lines()` and included in the computation.
     """
     the prefix of the first leaf consists of optional newlines.  Those newlines
     are consumed by `maybe_empty_lines()` and included in the computation.
     """
+
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
@@ -1127,8 +1310,7 @@ class EmptyLineTracker:
         """Return the number of extra empty lines before and after the `current_line`.
 
         This is for separating `def`, `async def` and `class` with extra empty
         """Return the number of extra empty lines before and after the `current_line`.
 
         This is for separating `def`, `async def` and `class` with extra empty
-        lines (two on module-level), as well as providing an extra empty line
-        after flow control keywords to make them more prominent.
+        lines (two on module-level).
         """
         if isinstance(current_line, UnformattedLines):
             return 0, 0
         """
         if isinstance(current_line, UnformattedLines):
             return 0, 0
@@ -1169,6 +1351,11 @@ class EmptyLineTracker:
             if self.previous_line.is_decorator:
                 return 0, 0
 
             if self.previous_line.is_decorator:
                 return 0, 0
 
+            if self.previous_line.depth < current_line.depth and (
+                self.previous_line.is_class or self.previous_line.is_def
+            ):
+                return 0, 0
+
             if (
                 self.previous_line.is_comment
                 and self.previous_line.depth == current_line.depth
             if (
                 self.previous_line.is_comment
                 and self.previous_line.depth == current_line.depth
@@ -1200,6 +1387,13 @@ class EmptyLineTracker:
         ):
             return (before or 1), 0
 
         ):
             return (before or 1), 0
 
+        if (
+            self.previous_line
+            and self.previous_line.is_class
+            and current_line.is_triple_quoted_string
+        ):
+            return before, 1
+
         return before, 0
 
 
         return before, 0
 
 
@@ -1210,7 +1404,9 @@ class LineGenerator(Visitor[Line]):
     Note: destroys the tree it's visiting by mutating prefixes of its leaves
     in ways that will no longer stringify to valid Python code on the tree.
     """
     Note: destroys the tree it's visiting by mutating prefixes of its leaves
     in ways that will no longer stringify to valid Python code on the tree.
     """
+
     is_pyi: bool = False
     is_pyi: bool = False
+    normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
@@ -1279,7 +1475,7 @@ class LineGenerator(Visitor[Line]):
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
-                if node.type == token.STRING:
+                if self.normalize_strings and node.type == token.STRING:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
@@ -1371,32 +1567,6 @@ class LineGenerator(Visitor[Line]):
             yield from self.line()
             yield from self.visit(child)
 
             yield from self.line()
             yield from self.visit(child)
 
-    def visit_import_from(self, node: Node) -> Iterator[Line]:
-        """Visit import_from and maybe put invisible parentheses.
-
-        This is separate from `visit_stmt` because import statements don't
-        support arbitrary atoms and thus handling of parentheses is custom.
-        """
-        check_lpar = False
-        for index, child in enumerate(node.children):
-            if check_lpar:
-                if child.type == token.LPAR:
-                    # make parentheses invisible
-                    child.value = ""  # type: ignore
-                    node.children[-1].value = ""  # type: ignore
-                else:
-                    # insert invisible parentheses
-                    node.insert_child(index, Leaf(token.LPAR, ""))
-                    node.append_child(Leaf(token.RPAR, ""))
-                break
-
-            check_lpar = (
-                child.type == token.NAME and child.value == "import"  # type: ignore
-            )
-
-        for child in node.children:
-            yield from self.visit(child)
-
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
         """Remove a semicolon and put the other statement on a separate line."""
         yield from self.line()
     def visit_SEMI(self, leaf: Leaf) -> Iterator[Line]:
         """Remove a semicolon and put the other statement on a separate line."""
         yield from self.line()
@@ -1443,6 +1613,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_classdef = partial(v, keywords={"class"}, parens=Ø)
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
+        self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1816,7 +1987,7 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
     return 0
 
 
-def generate_comments(leaf: Leaf) -> Iterator[Leaf]:
+def generate_comments(leaf: LN) -> Iterator[Leaf]:
     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
     """Clean the prefix of the `leaf` and generate comments from it, if any.
 
     Comments in lib2to3 are shoved into the whitespace prefix.  This happens
@@ -2050,32 +2221,50 @@ def right_hand_split(
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
             result.append(leaf, preformatted=True)
             for comment_after in line.comments_after(leaf):
                 result.append(comment_after, preformatted=True)
-    bracket_split_succeeded_or_raise(head, body, tail)
     assert opening_bracket and closing_bracket
     assert opening_bracket and closing_bracket
+    body.should_explode = should_explode(body, opening_bracket)
+    bracket_split_succeeded_or_raise(head, body, tail)
     if (
     if (
+        # the body shouldn't be exploded
+        not body.should_explode
         # the opening bracket is an optional paren
         # the opening bracket is an optional paren
-        opening_bracket.type == token.LPAR
+        and opening_bracket.type == token.LPAR
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
         and not opening_bracket.value
         # the closing bracket is an optional paren
         and closing_bracket.type == token.RPAR
         and not closing_bracket.value
-        # there are no standalone comments in the body
-        and not line.contains_standalone_comments(0)
-        # and it's not an import (optional parens are the only thing we can split
-        # on in this case; attempting a split without them is a waste of time)
+        # it's not an import (optional parens are the only thing we can split on
+        # in this case; attempting a split without them is a waste of time)
         and not line.is_import
         and not line.is_import
+        # there are no standalone comments in the body
+        and not body.contains_standalone_comments(0)
+        # and we can actually remove the parens
+        and can_omit_invisible_parens(body, line_length)
     ):
         omit = {id(closing_bracket), *omit}
     ):
         omit = {id(closing_bracket), *omit}
-        if can_omit_invisible_parens(body, line_length):
-            try:
-                yield from right_hand_split(line, line_length, py36=py36, omit=omit)
-                return
-            except CannotSplit:
-                pass
+        try:
+            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            return
+
+        except CannotSplit:
+            if not (
+                can_be_split(body)
+                or is_line_short_enough(body, line_length=line_length)
+            ):
+                raise CannotSplit(
+                    "Splitting failed, body is still too long and can't be split."
+                )
+
+            elif head.contains_multiline_strings() or tail.contains_multiline_strings():
+                raise CannotSplit(
+                    "The current optional pair of parentheses is bound to fail to "
+                    "satisfy the splitting algorithm because the head or the tail "
+                    "contains multiline strings which by definition never fit one "
+                    "line."
+                )
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
-    body.should_explode = should_explode(body, opening_bracket)
     for result in (head, body, tail):
         if result:
             yield result
     for result in (head, body, tail):
         if result:
             yield result
@@ -2333,8 +2522,13 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
     Standardizes on visible parentheses for single-element tuples, and keeps
     existing visible parentheses for other tuples and generator expressions.
     """
+    try:
+        list(generate_comments(node))
+    except FormatOff:
+        return  # This `node` has a prefix with `# fmt: off`, don't mess with parens.
+
     check_lpar = False
     check_lpar = False
-    for child in list(node.children):
+    for index, child in enumerate(list(node.children)):
         if check_lpar:
             if child.type == syms.atom:
                 maybe_make_parens_invisible_in_atom(child)
         if check_lpar:
             if child.type == syms.atom:
                 maybe_make_parens_invisible_in_atom(child)
@@ -2342,8 +2536,21 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
                 rpar = Leaf(token.RPAR, ")")
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
                 rpar = Leaf(token.RPAR, ")")
-                index = child.remove() or 0
+                child.remove()
                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
                 node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+            elif node.type == syms.import_from:
+                # "import from" nodes store parentheses directly as part of
+                # the statement
+                if child.type == token.LPAR:
+                    # make parentheses invisible
+                    child.value = ""  # type: ignore
+                    node.children[-1].value = ""  # type: ignore
+                elif child.type != token.STAR:
+                    # insert invisible parentheses
+                    node.insert_child(index, Leaf(token.LPAR, ""))
+                    node.append_child(Leaf(token.RPAR, ""))
+                break
+
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 # wrap child in invisible parentheses
                 lpar = Leaf(token.LPAR, "")
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 # wrap child in invisible parentheses
                 lpar = Leaf(token.LPAR, "")
@@ -2355,7 +2562,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
 
 
 def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
-    """If it's safe, make the parens in the atom `node` invisible, recusively."""
+    """If it's safe, make the parens in the atom `node` invisible, recursively."""
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -2602,6 +2809,10 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
         if length > line_length:
             break
 
         if length > line_length:
             break
 
+        has_inline_comment = leaf_length > len(leaf.value) + len(leaf.prefix)
+        if leaf.type == STANDALONE_COMMENT or has_inline_comment:
+            break
+
         optional_brackets.discard(id(leaf))
         if opening_bracket:
             if leaf is opening_bracket:
         optional_brackets.discard(id(leaf))
         if opening_bracket:
             if leaf is opening_bracket:
@@ -2664,40 +2875,73 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
     return imports
 
 
-PYTHON_EXTENSIONS = {".py", ".pyi"}
-BLACKLISTED_DIRECTORIES = {
-    "build",
-    "buck-out",
-    "dist",
-    "_build",
-    ".git",
-    ".hg",
-    ".mypy_cache",
-    ".tox",
-    ".venv",
-}
-
+def gen_python_files_in_dir(
+    path: Path,
+    root: Path,
+    include: Pattern[str],
+    exclude: Pattern[str],
+    report: "Report",
+) -> Iterator[Path]:
+    """Generate all files under `path` whose paths are not excluded by the
+    `exclude` regex, but are included by the `include` regex.
 
 
-def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
-    """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
-    and have one of the PYTHON_EXTENSIONS.
+    `report` is where output about exclusions goes.
     """
     """
+    assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
     for child in path.iterdir():
     for child in path.iterdir():
+        normalized_path = "/" + child.resolve().relative_to(root).as_posix()
         if child.is_dir():
         if child.is_dir():
-            if child.name in BLACKLISTED_DIRECTORIES:
-                continue
+            normalized_path += "/"
+        exclude_match = exclude.search(normalized_path)
+        if exclude_match and exclude_match.group(0):
+            report.path_ignored(child, f"matches --exclude={exclude.pattern}")
+            continue
+
+        if child.is_dir():
+            yield from gen_python_files_in_dir(child, root, include, exclude, report)
+
+        elif child.is_file():
+            include_match = include.search(normalized_path)
+            if include_match:
+                yield child
+
+
+def find_project_root(srcs: List[str]) -> Path:
+    """Return a directory containing .git, .hg, or pyproject.toml.
+
+    That directory can be one of the directories passed in `srcs` or their
+    common parent.
+
+    If no directory in the tree contains a marker that would specify it's the
+    project root, the root of the file system is returned.
+    """
+    if not srcs:
+        return Path("/").resolve()
+
+    common_base = min(Path(src).resolve() for src in srcs)
+    if common_base.is_dir():
+        # Append a fake file so `parents` below returns `common_base_dir`, too.
+        common_base /= "fake-file"
+    for directory in common_base.parents:
+        if (directory / ".git").is_dir():
+            return directory
 
 
-            yield from gen_python_files_in_dir(child)
+        if (directory / ".hg").is_dir():
+            return directory
 
 
-        elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
-            yield child
+        if (directory / "pyproject.toml").is_file():
+            return directory
+
+    return directory
 
 
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
 
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
+
     check: bool = False
     quiet: bool = False
     check: bool = False
     quiet: bool = False
+    verbose: bool = False
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
     change_count: int = 0
     same_count: int = 0
     failure_count: int = 0
@@ -2706,11 +2950,11 @@ class Report:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
             reformatted = "would reformat" if self.check else "reformatted"
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
             reformatted = "would reformat" if self.check else "reformatted"
-            if not self.quiet:
+            if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
         else:
                 out(f"{reformatted} {src}")
             self.change_count += 1
         else:
-            if not self.quiet:
+            if self.verbose:
                 if changed is Changed.NO:
                     msg = f"{src} already well formatted, good job."
                 else:
                 if changed is Changed.NO:
                     msg = f"{src} already well formatted, good job."
                 else:
@@ -2723,6 +2967,10 @@ class Report:
         err(f"error: cannot format {src}: {message}")
         self.failure_count += 1
 
         err(f"error: cannot format {src}: {message}")
         self.failure_count += 1
 
+    def path_ignored(self, path: Path, message: str) -> None:
+        if self.verbose:
+            out(f"{path} ignored: {message}", bold=False)
+
     @property
     def return_code(self) -> int:
         """Return the exit code that the app should use.
     @property
     def return_code(self) -> int:
         """Return the exit code that the app should use.
@@ -2835,9 +3083,11 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
         ) from None
 
 
-def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
+def assert_stable(
+    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
+) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
+    newdst = format_str(dst, line_length=line_length, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -2940,9 +3190,6 @@ def enumerate_with_length(
 
         comment: Optional[Leaf]
         for comment in line.comments_after(leaf, index):
 
         comment: Optional[Leaf]
         for comment in line.comments_after(leaf, index):
-            if "\n" in comment.prefix:
-                return  # Oops, standalone comment!
-
             length += len(comment.value)
 
         yield index, leaf, length
             length += len(comment.value)
 
         yield index, leaf, length
@@ -2962,6 +3209,42 @@ def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") ->
     )
 
 
     )
 
 
+def can_be_split(line: Line) -> bool:
+    """Return False if the line cannot be split *for sure*.
+
+    This is not an exhaustive search but a cheap heuristic that we can use to
+    avoid some unfortunate formattings (mostly around wrapping unsplittable code
+    in unnecessary parentheses).
+    """
+    leaves = line.leaves
+    if len(leaves) < 2:
+        return False
+
+    if leaves[0].type == token.STRING and leaves[1].type == token.DOT:
+        call_count = 0
+        dot_count = 0
+        next = leaves[-1]
+        for leaf in leaves[-2::-1]:
+            if leaf.type in OPENING_BRACKETS:
+                if next.type not in CLOSING_BRACKETS:
+                    return False
+
+                call_count += 1
+            elif leaf.type == token.DOT:
+                dot_count += 1
+            elif leaf.type == token.NAME:
+                if not (next.type == token.DOT or next.type in OPENING_BRACKETS):
+                    return False
+
+            elif leaf.type not in CLOSING_BRACKETS:
+                return False
+
+            if dot_count > 1 and call_count > 1:
+                return False
+
+    return True
+
+
 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     """Does `line` have a shape safe to reformat without optional parens around it?
 
 def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     """Does `line` have a shape safe to reformat without optional parens around it?
 
@@ -3051,19 +3334,21 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-CACHE_DIR = Path(user_cache_dir("black", version=__version__))
-
-
-def get_cache_file(line_length: int) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.pickle"
+def get_cache_file(line_length: int, mode: FileMode) -> Path:
+    pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36)
+    return (
+        CACHE_DIR
+        / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
+    )
 
 
 
 
-def read_cache(line_length: int) -> Cache:
+def read_cache(line_length: int, mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length)
+    cache_file = get_cache_file(line_length, mode)
     if not cache_file.exists():
         return {}
 
     if not cache_file.exists():
         return {}
 
@@ -3082,27 +3367,27 @@ def get_cache_info(path: Path) -> CacheInfo:
     return stat.st_mtime, stat.st_size
 
 
     return stat.st_mtime, stat.st_size
 
 
-def filter_cached(
-    cache: Cache, sources: Iterable[Path]
-) -> Tuple[List[Path], List[Path]]:
-    """Split a list of paths into two.
+def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
+    """Split an iterable of paths in `sources` into two sets.
 
 
-    The first list contains paths of files that modified on disk or are not in the
-    cache. The other list contains paths to non-modified files.
+    The first contains paths of files that modified on disk or are not in the
+    cache. The other contains paths to non-modified files.
     """
     """
-    todo, done = [], []
+    todo, done = set(), set()
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
     for src in sources:
         src = src.resolve()
         if cache.get(src) != get_cache_info(src):
-            todo.append(src)
+            todo.add(src)
         else:
         else:
-            done.append(src)
+            done.add(src)
     return todo, done
 
 
     return todo, done
 
 
-def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
+def write_cache(
+    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
+) -> None:
     """Update the cache file."""
     """Update the cache file."""
-    cache_file = get_cache_file(line_length)
+    cache_file = get_cache_file(line_length, mode)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)