]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Sort default excludes, include the leading slash
[etc/vim.git] / black.py
index e1a71e844939d9b1ad671c76440421fb439d8f90..61b884ad262ccf3e08a0fc380f5cd9607839515c 100644 (file)
--- a/black.py
+++ b/black.py
@@ -2,7 +2,7 @@ import asyncio
 import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
-from enum import Enum
+from enum import Enum, Flag
 from functools import partial, wraps
 import keyword
 import logging
 from functools import partial, wraps
 import keyword
 import logging
@@ -44,8 +44,12 @@ from blib2to3.pgen2 import driver, token
 from blib2to3.pgen2.parse import ParseError
 
 
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.5b0"
+__version__ = "18.5b1"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_LINE_LENGTH = 88
+DEFAULT_EXCLUDES = (
+    r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
+)
+DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 
 
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
 
 
@@ -122,6 +126,13 @@ class Changed(Enum):
     YES = 2
 
 
     YES = 2
 
 
+class FileMode(Flag):
+    AUTO_DETECT = 0
+    PYTHON36 = 1
+    PYI = 2
+    NO_STRING_NORMALIZATION = 4
+
+
 @click.command()
 @click.option(
     "-l",
 @click.command()
 @click.option(
     "-l",
@@ -159,6 +170,51 @@ class Changed(Enum):
         "silence those with 2>/dev/null."
     ),
 )
         "silence those with 2>/dev/null."
     ),
 )
+@click.option(
+    "--pyi",
+    is_flag=True,
+    help=(
+        "Consider all input files typing stubs regardless of file extension "
+        "(useful when piping source on standard input)."
+    ),
+)
+@click.option(
+    "--py36",
+    is_flag=True,
+    help=(
+        "Allow using Python 3.6-only syntax on all input files.  This will put "
+        "trailing commas in function signatures and calls also after *args and "
+        "**kwargs.  [default: per-file auto-detection]"
+    ),
+)
+@click.option(
+    "-S",
+    "--skip-string-normalization",
+    is_flag=True,
+    help="Don't normalize string quotes or prefixes.",
+)
+@click.option(
+    "--include",
+    type=str,
+    default=DEFAULT_INCLUDES,
+    help=(
+        "A regular expression that matches files and directories that should be "
+        "included on recursive searches. On Windows, use forward slashes for "
+        "directories."
+    ),
+    show_default=True,
+)
+@click.option(
+    "--exclude",
+    type=str,
+    default=DEFAULT_EXCLUDES,
+    help=(
+        "A regular expression that matches files and directories that should be "
+        "excluded on recursive searches. On Windows, use forward slashes for "
+        "directories."
+    ),
+    show_default=True,
+)
 @click.version_option(version=__version__)
 @click.argument(
     "src",
 @click.version_option(version=__version__)
 @click.argument(
     "src",
@@ -174,15 +230,30 @@ def main(
     check: bool,
     diff: bool,
     fast: bool,
     check: bool,
     diff: bool,
     fast: bool,
+    pyi: bool,
+    py36: bool,
+    skip_string_normalization: bool,
     quiet: bool,
     quiet: bool,
+    include: str,
+    exclude: str,
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
     sources: List[Path] = []
     src: List[str],
 ) -> None:
     """The uncompromising code formatter."""
     sources: List[Path] = []
+    try:
+        include_regex = re.compile(include)
+    except re.error:
+        err(f"Invalid regular expression for include given: {include!r}")
+        ctx.exit(2)
+    try:
+        exclude_regex = re.compile(exclude)
+    except re.error:
+        err(f"Invalid regular expression for exclude given: {exclude!r}")
+        ctx.exit(2)
     for s in src:
         p = Path(s)
         if p.is_dir():
     for s in src:
         p = Path(s)
         if p.is_dir():
-            sources.extend(gen_python_files_in_dir(p))
+            sources.extend(gen_python_files_in_dir(p, include_regex, exclude_regex))
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
         elif p.is_file():
             # if a file was explicitly given, we don't care about its extension
             sources.append(p)
@@ -197,6 +268,13 @@ def main(
         write_back = WriteBack.DIFF
     else:
         write_back = WriteBack.YES
         write_back = WriteBack.DIFF
     else:
         write_back = WriteBack.YES
+    mode = FileMode.AUTO_DETECT
+    if py36:
+        mode |= FileMode.PYTHON36
+    if pyi:
+        mode |= FileMode.PYI
+    if skip_string_normalization:
+        mode |= FileMode.NO_STRING_NORMALIZATION
     report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
         out("No paths given. Nothing to do 😴")
     report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
         out("No paths given. Nothing to do 😴")
@@ -204,14 +282,28 @@ def main(
         return
 
     elif len(sources) == 1:
         return
 
     elif len(sources) == 1:
-        reformat_one(sources[0], line_length, fast, write_back, report)
+        reformat_one(
+            src=sources[0],
+            line_length=line_length,
+            fast=fast,
+            write_back=write_back,
+            mode=mode,
+            report=report,
+        )
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
         try:
             loop.run_until_complete(
                 schedule_formatting(
     else:
         loop = asyncio.get_event_loop()
         executor = ProcessPoolExecutor(max_workers=os.cpu_count())
         try:
             loop.run_until_complete(
                 schedule_formatting(
-                    sources, line_length, fast, write_back, report, loop, executor
+                    sources=sources,
+                    line_length=line_length,
+                    fast=fast,
+                    write_back=write_back,
+                    mode=mode,
+                    report=report,
+                    loop=loop,
+                    executor=executor,
                 )
             )
         finally:
                 )
             )
         finally:
@@ -223,33 +315,43 @@ def main(
 
 
 def reformat_one(
 
 
 def reformat_one(
-    src: Path, line_length: int, fast: bool, write_back: WriteBack, report: "Report"
+    src: Path,
+    line_length: int,
+    fast: bool,
+    write_back: WriteBack,
+    mode: FileMode,
+    report: "Report",
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
     If `quiet` is True, non-error messages are not output. `line_length`,
-    `write_back`, and `fast` options are passed to :func:`format_file_in_place`.
+    `write_back`, `fast` and `pyi` options are passed to
+    :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back
+                line_length=line_length, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length)
+                cache = read_cache(line_length, mode)
                 src = src.resolve()
                 if src in cache and cache[src] == get_cache_info(src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 src = src.resolve()
                 if src in cache and cache[src] == get_cache_info(src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src, line_length=line_length, fast=fast, write_back=write_back
+                src,
+                line_length=line_length,
+                fast=fast,
+                write_back=write_back,
+                mode=mode,
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
-                write_cache(cache, [src], line_length)
+                write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
@@ -260,6 +362,7 @@ async def schedule_formatting(
     line_length: int,
     fast: bool,
     write_back: WriteBack,
     line_length: int,
     fast: bool,
     write_back: WriteBack,
+    mode: FileMode,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
@@ -268,12 +371,12 @@ async def schedule_formatting(
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
-    `line_length`, `write_back`, and `fast` options are passed to
+    `line_length`, `write_back`, `fast`, and `pyi` options are passed to
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length)
+        cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
         for src in cached:
             report.done(src, Changed.CACHED)
         sources, cached = filter_cached(cache, sources)
         for src in cached:
             report.done(src, Changed.CACHED)
@@ -288,7 +391,14 @@ async def schedule_formatting(
             lock = manager.Lock()
         tasks = {
             loop.run_in_executor(
             lock = manager.Lock()
         tasks = {
             loop.run_in_executor(
-                executor, format_file_in_place, src, line_length, fast, write_back, lock
+                executor,
+                format_file_in_place,
+                src,
+                line_length,
+                fast,
+                write_back,
+                mode,
+                lock,
             ): src
             for src in sorted(sources)
         }
             ): src
             for src in sorted(sources)
         }
@@ -313,7 +423,7 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length)
+        write_cache(cache, formatted, line_length, mode)
 
 
 def format_file_in_place(
 
 
 def format_file_in_place(
@@ -321,6 +431,7 @@ def format_file_in_place(
     line_length: int,
     fast: bool,
     write_back: WriteBack = WriteBack.NO,
     line_length: int,
     fast: bool,
     write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
@@ -328,13 +439,13 @@ def format_file_in_place(
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
-    is_pyi = src.suffix == ".pyi"
-
+    if src.suffix == ".pyi":
+        mode |= FileMode.PYI
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, is_pyi=is_pyi
+            src_contents, line_length=line_length, fast=fast, mode=mode
         )
     except NothingChanged:
         return False
         )
     except NothingChanged:
         return False
@@ -357,17 +468,21 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
 
 
 def format_stdin_to_stdout(
-    line_length: int, fast: bool, write_back: WriteBack = WriteBack.NO
+    line_length: int,
+    fast: bool,
+    write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is True, write reformatted code back to stdout.
-    `line_length` and `fast` arguments are passed to :func:`format_file_contents`.
+    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    :func:`format_file_contents`.
     """
     src = sys.stdin.read()
     dst = src
     try:
     """
     src = sys.stdin.read()
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast)
+        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
         return True
 
     except NothingChanged:
@@ -383,7 +498,11 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
 
 
 def format_file_contents(
-    src_contents: str, *, line_length: int, fast: bool, is_pyi: bool = False
+    src_contents: str,
+    *,
+    line_length: int,
+    fast: bool,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -394,20 +513,18 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, is_pyi=is_pyi)
+    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(
-            src_contents, dst_contents, line_length=line_length, is_pyi=is_pyi
-        )
+        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
     return dst_contents
 
 
 def format_str(
     return dst_contents
 
 
 def format_str(
-    src_contents: str, line_length: int, *, is_pyi: bool = False
+    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
 ) -> FileContent:
     """Reformat a string and return new contents.
 
 ) -> FileContent:
     """Reformat a string and return new contents.
 
@@ -416,11 +533,15 @@ def format_str(
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
-    elt = EmptyLineTracker(is_pyi=is_pyi)
-    py36 = is_python36(src_node)
+    is_pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
+    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
     lines = LineGenerator(
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
+        remove_u_prefix=py36 or "unicode_literals" in future_imports,
+        is_pyi=is_pyi,
+        normalize_strings=normalize_strings,
     )
     )
+    elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -878,27 +999,6 @@ class Line:
             and second_leaf.value == "def"
         )
 
             and second_leaf.value == "def"
         )
 
-    @property
-    def is_flow_control(self) -> bool:
-        """Is this line a flow control statement?
-
-        Those are `return`, `raise`, `break`, and `continue`.
-        """
-        return (
-            bool(self)
-            and self.leaves[0].type == token.NAME
-            and self.leaves[0].value in FLOW_CONTROL
-        )
-
-    @property
-    def is_yield(self) -> bool:
-        """Is this line a yield statement?"""
-        return (
-            bool(self)
-            and self.leaves[0].type == token.NAME
-            and self.leaves[0].value == "yield"
-        )
-
     @property
     def is_class_paren_empty(self) -> bool:
         """Is this a class with no base classes but using parentheses?
     @property
     def is_class_paren_empty(self) -> bool:
         """Is this a class with no base classes but using parentheses?
@@ -915,6 +1015,15 @@ class Line:
             and self.leaves[3].value == ")"
         )
 
             and self.leaves[3].value == ")"
         )
 
+    @property
+    def is_triple_quoted_string(self) -> bool:
+        """Is the line a triple quoted string?"""
+        return (
+            bool(self)
+            and self.leaves[0].type == token.STRING
+            and self.leaves[0].value.startswith(('"""', "'''"))
+        )
+
     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
         """If so, needs to be split before emitting."""
         for leaf in self.leaves:
     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
         """If so, needs to be split before emitting."""
         for leaf in self.leaves:
@@ -1122,6 +1231,7 @@ class EmptyLineTracker:
     the prefix of the first leaf consists of optional newlines.  Those newlines
     are consumed by `maybe_empty_lines()` and included in the computation.
     """
     the prefix of the first leaf consists of optional newlines.  Those newlines
     are consumed by `maybe_empty_lines()` and included in the computation.
     """
+
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
@@ -1131,8 +1241,7 @@ class EmptyLineTracker:
         """Return the number of extra empty lines before and after the `current_line`.
 
         This is for separating `def`, `async def` and `class` with extra empty
         """Return the number of extra empty lines before and after the `current_line`.
 
         This is for separating `def`, `async def` and `class` with extra empty
-        lines (two on module-level), as well as providing an extra empty line
-        after flow control keywords to make them more prominent.
+        lines (two on module-level).
         """
         if isinstance(current_line, UnformattedLines):
             return 0, 0
         """
         if isinstance(current_line, UnformattedLines):
             return 0, 0
@@ -1173,6 +1282,11 @@ class EmptyLineTracker:
             if self.previous_line.is_decorator:
                 return 0, 0
 
             if self.previous_line.is_decorator:
                 return 0, 0
 
+            if self.previous_line.depth < current_line.depth and (
+                self.previous_line.is_class or self.previous_line.is_def
+            ):
+                return 0, 0
+
             if (
                 self.previous_line.is_comment
                 and self.previous_line.depth == current_line.depth
             if (
                 self.previous_line.is_comment
                 and self.previous_line.depth == current_line.depth
@@ -1204,6 +1318,13 @@ class EmptyLineTracker:
         ):
             return (before or 1), 0
 
         ):
             return (before or 1), 0
 
+        if (
+            self.previous_line
+            and self.previous_line.is_class
+            and current_line.is_triple_quoted_string
+        ):
+            return before, 1
+
         return before, 0
 
 
         return before, 0
 
 
@@ -1214,7 +1335,9 @@ class LineGenerator(Visitor[Line]):
     Note: destroys the tree it's visiting by mutating prefixes of its leaves
     in ways that will no longer stringify to valid Python code on the tree.
     """
     Note: destroys the tree it's visiting by mutating prefixes of its leaves
     in ways that will no longer stringify to valid Python code on the tree.
     """
+
     is_pyi: bool = False
     is_pyi: bool = False
+    normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
 
@@ -1283,7 +1406,7 @@ class LineGenerator(Visitor[Line]):
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
 
             else:
                 normalize_prefix(node, inside_brackets=any_open_brackets)
-                if node.type == token.STRING:
+                if self.normalize_strings and node.type == token.STRING:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
                     normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                     normalize_string_quotes(node)
                 if node.type not in WHITESPACE:
@@ -2665,38 +2788,41 @@ def get_future_imports(node: Node) -> Set[str]:
     return imports
 
 
     return imports
 
 
-PYTHON_EXTENSIONS = {".py", ".pyi"}
-BLACKLISTED_DIRECTORIES = {
-    "build",
-    "buck-out",
-    "dist",
-    "_build",
-    ".git",
-    ".hg",
-    ".mypy_cache",
-    ".tox",
-    ".venv",
-}
-
-
-def gen_python_files_in_dir(path: Path) -> Iterator[Path]:
-    """Generate all files under `path` which aren't under BLACKLISTED_DIRECTORIES
-    and have one of the PYTHON_EXTENSIONS.
+def gen_python_files_in_dir(
+    path: Path, include: Pattern[str], exclude: Pattern[str]
+) -> Iterator[Path]:
+    """Generate all files under `path` whose paths are not excluded by the
+    `exclude` regex, but are included by the `include` regex.
     """
     """
+
     for child in path.iterdir():
     for child in path.iterdir():
+        searchable_path = str(child.as_posix())
+        if Path(child.parts[0]).is_dir():
+            searchable_path = "/" + searchable_path
         if child.is_dir():
         if child.is_dir():
-            if child.name in BLACKLISTED_DIRECTORIES:
+            searchable_path = searchable_path + "/"
+            exclude_match = exclude.search(searchable_path)
+            if exclude_match and len(exclude_match.group()) > 0:
                 continue
 
                 continue
 
-            yield from gen_python_files_in_dir(child)
+            yield from gen_python_files_in_dir(child, include, exclude)
 
 
-        elif child.is_file() and child.suffix in PYTHON_EXTENSIONS:
-            yield child
+        else:
+            include_match = include.search(searchable_path)
+            exclude_match = exclude.search(searchable_path)
+            if (
+                child.is_file()
+                and include_match
+                and len(include_match.group()) > 0
+                and (not exclude_match or len(exclude_match.group()) == 0)
+            ):
+                yield child
 
 
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
 
 @dataclass
 class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
+
     check: bool = False
     quiet: bool = False
     change_count: int = 0
     check: bool = False
     quiet: bool = False
     change_count: int = 0
@@ -2836,9 +2962,11 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
         ) from None
 
 
-def assert_stable(src: str, dst: str, line_length: int, is_pyi: bool = False) -> None:
+def assert_stable(
+    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
+) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, is_pyi=is_pyi)
+    newdst = format_str(dst, line_length=line_length, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3049,16 +3177,21 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-def get_cache_file(line_length: int) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.pickle"
+def get_cache_file(line_length: int, mode: FileMode) -> Path:
+    pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36)
+    return (
+        CACHE_DIR
+        / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
+    )
 
 
 
 
-def read_cache(line_length: int) -> Cache:
+def read_cache(line_length: int, mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length)
+    cache_file = get_cache_file(line_length, mode)
     if not cache_file.exists():
         return {}
 
     if not cache_file.exists():
         return {}
 
@@ -3095,9 +3228,11 @@ def filter_cached(
     return todo, done
 
 
     return todo, done
 
 
-def write_cache(cache: Cache, sources: List[Path], line_length: int) -> None:
+def write_cache(
+    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
+) -> None:
     """Update the cache file."""
     """Update the cache file."""
-    cache_file = get_cache_file(line_length)
+    cache_file = get_cache_file(line_length, mode)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)