]> git.madduck.net Git - etc/vim.git/commitdiff

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Refactor --pyi and --py36 into FileMode
authorŁukasz Langa <lukasz@langa.pl>
Tue, 29 May 2018 08:53:54 +0000 (01:53 -0700)
committerŁukasz Langa <lukasz@langa.pl>
Tue, 29 May 2018 08:53:54 +0000 (01:53 -0700)
black.py
tests/test_black.py

index dc67991f2c94d3f1444a56179409eb8c65cfbd9a..547751b789700fde7d7e1d88223b6a19806659e2 100644 (file)
--- a/black.py
+++ b/black.py
@@ -2,7 +2,7 @@ import asyncio
 import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 import pickle
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
-from enum import Enum
+from enum import Enum, Flag
 from functools import partial, wraps
 import keyword
 import logging
 from functools import partial, wraps
 import keyword
 import logging
@@ -122,6 +122,12 @@ class Changed(Enum):
     YES = 2
 
 
     YES = 2
 
 
+class FileMode(Flag):
+    AUTO_DETECT = 0
+    PYTHON36 = 1
+    PYI = 2
+
+
 @click.command()
 @click.option(
     "-l",
 @click.command()
 @click.option(
     "-l",
@@ -216,6 +222,11 @@ def main(
         write_back = WriteBack.DIFF
     else:
         write_back = WriteBack.YES
         write_back = WriteBack.DIFF
     else:
         write_back = WriteBack.YES
+    mode = FileMode.AUTO_DETECT
+    if py36:
+        mode |= FileMode.PYTHON36
+    if pyi:
+        mode |= FileMode.PYI
     report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
         out("No paths given. Nothing to do 😴")
     report = Report(check=check, quiet=quiet)
     if len(sources) == 0:
         out("No paths given. Nothing to do 😴")
@@ -227,9 +238,8 @@ def main(
             src=sources[0],
             line_length=line_length,
             fast=fast,
             src=sources[0],
             line_length=line_length,
             fast=fast,
-            pyi=pyi,
-            py36=py36,
             write_back=write_back,
             write_back=write_back,
+            mode=mode,
             report=report,
         )
     else:
             report=report,
         )
     else:
@@ -241,9 +251,8 @@ def main(
                     sources=sources,
                     line_length=line_length,
                     fast=fast,
                     sources=sources,
                     line_length=line_length,
                     fast=fast,
-                    pyi=pyi,
-                    py36=py36,
                     write_back=write_back,
                     write_back=write_back,
+                    mode=mode,
                     report=report,
                     loop=loop,
                     executor=executor,
                     report=report,
                     loop=loop,
                     executor=executor,
@@ -261,9 +270,8 @@ def reformat_one(
     src: Path,
     line_length: int,
     fast: bool,
     src: Path,
     line_length: int,
     fast: bool,
-    pyi: bool,
-    py36: bool,
     write_back: WriteBack,
     write_back: WriteBack,
+    mode: FileMode,
     report: "Report",
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
     report: "Report",
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
@@ -276,17 +284,13 @@ def reformat_one(
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
             if format_stdin_to_stdout(
-                line_length=line_length,
-                fast=fast,
-                is_pyi=pyi,
-                force_py36=py36,
-                write_back=write_back,
+                line_length=line_length, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
             ):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length, pyi, py36)
+                cache = read_cache(line_length, mode)
                 src = src.resolve()
                 if src in cache and cache[src] == get_cache_info(src):
                     changed = Changed.CACHED
                 src = src.resolve()
                 if src in cache and cache[src] == get_cache_info(src):
                     changed = Changed.CACHED
@@ -294,13 +298,12 @@ def reformat_one(
                 src,
                 line_length=line_length,
                 fast=fast,
                 src,
                 line_length=line_length,
                 fast=fast,
-                force_pyi=pyi,
-                force_py36=py36,
                 write_back=write_back,
                 write_back=write_back,
+                mode=mode,
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
             ):
                 changed = Changed.YES
             if write_back == WriteBack.YES and changed is not Changed.NO:
-                write_cache(cache, [src], line_length, pyi, py36)
+                write_cache(cache, [src], line_length, mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
@@ -310,9 +313,8 @@ async def schedule_formatting(
     sources: List[Path],
     line_length: int,
     fast: bool,
     sources: List[Path],
     line_length: int,
     fast: bool,
-    pyi: bool,
-    py36: bool,
     write_back: WriteBack,
     write_back: WriteBack,
+    mode: FileMode,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
     report: "Report",
     loop: BaseEventLoop,
     executor: Executor,
@@ -326,7 +328,7 @@ async def schedule_formatting(
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length, pyi, py36)
+        cache = read_cache(line_length, mode)
         sources, cached = filter_cached(cache, sources)
         for src in cached:
             report.done(src, Changed.CACHED)
         sources, cached = filter_cached(cache, sources)
         for src in cached:
             report.done(src, Changed.CACHED)
@@ -346,9 +348,8 @@ async def schedule_formatting(
                 src,
                 line_length,
                 fast,
                 src,
                 line_length,
                 fast,
-                pyi,
-                py36,
                 write_back,
                 write_back,
+                mode,
                 lock,
             ): src
             for src in sorted(sources)
                 lock,
             ): src
             for src in sorted(sources)
@@ -374,16 +375,15 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if write_back == WriteBack.YES and formatted:
-        write_cache(cache, formatted, line_length, pyi, py36)
+        write_cache(cache, formatted, line_length, mode)
 
 
 def format_file_in_place(
     src: Path,
     line_length: int,
     fast: bool,
 
 
 def format_file_in_place(
     src: Path,
     line_length: int,
     fast: bool,
-    force_pyi: bool = False,
-    force_py36: bool = False,
     write_back: WriteBack = WriteBack.NO,
     write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
@@ -391,17 +391,13 @@ def format_file_in_place(
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     If `write_back` is True, write reformatted code back to stdout.
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
-    is_pyi = force_pyi or src.suffix == ".pyi"
-
+    if src.suffix == ".pyi":
+        mode |= FileMode.PYI
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
     with tokenize.open(src) as src_buffer:
         src_contents = src_buffer.read()
     try:
         dst_contents = format_file_contents(
-            src_contents,
-            line_length=line_length,
-            fast=fast,
-            is_pyi=is_pyi,
-            force_py36=force_py36,
+            src_contents, line_length=line_length, fast=fast, mode=mode
         )
     except NothingChanged:
         return False
         )
     except NothingChanged:
         return False
@@ -426,9 +422,8 @@ def format_file_in_place(
 def format_stdin_to_stdout(
     line_length: int,
     fast: bool,
 def format_stdin_to_stdout(
     line_length: int,
     fast: bool,
-    is_pyi: bool = False,
-    force_py36: bool = False,
     write_back: WriteBack = WriteBack.NO,
     write_back: WriteBack = WriteBack.NO,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
@@ -439,13 +434,7 @@ def format_stdin_to_stdout(
     src = sys.stdin.read()
     dst = src
     try:
     src = sys.stdin.read()
     dst = src
     try:
-        dst = format_file_contents(
-            src,
-            line_length=line_length,
-            fast=fast,
-            is_pyi=is_pyi,
-            force_py36=force_py36,
-        )
+        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
         return True
 
     except NothingChanged:
@@ -465,8 +454,7 @@ def format_file_contents(
     *,
     line_length: int,
     fast: bool,
     *,
     line_length: int,
     fast: bool,
-    is_pyi: bool = False,
-    force_py36: bool = False,
+    mode: FileMode = FileMode.AUTO_DETECT,
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -477,30 +465,18 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(
-        src_contents, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
-    )
+    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(
-            src_contents,
-            dst_contents,
-            line_length=line_length,
-            is_pyi=is_pyi,
-            force_py36=force_py36,
-        )
+        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
     return dst_contents
 
 
 def format_str(
     return dst_contents
 
 
 def format_str(
-    src_contents: str,
-    line_length: int,
-    *,
-    is_pyi: bool = False,
-    force_py36: bool = False,
+    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
 ) -> FileContent:
     """Reformat a string and return new contents.
 
 ) -> FileContent:
     """Reformat a string and return new contents.
 
@@ -509,11 +485,12 @@ def format_str(
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
     src_node = lib2to3_parse(src_contents)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
-    elt = EmptyLineTracker(is_pyi=is_pyi)
-    py36 = force_py36 or is_python36(src_node)
+    is_pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
     lines = LineGenerator(
         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
     )
     lines = LineGenerator(
         remove_u_prefix=py36 or "unicode_literals" in future_imports, is_pyi=is_pyi
     )
+    elt = EmptyLineTracker(is_pyi=is_pyi)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -2932,12 +2909,10 @@ def assert_equivalent(src: str, dst: str) -> None:
 
 
 def assert_stable(
 
 
 def assert_stable(
-    src: str, dst: str, line_length: int, is_pyi: bool = False, force_py36: bool = False
+    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
 ) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
 ) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(
-        dst, line_length=line_length, is_pyi=is_pyi, force_py36=force_py36
-    )
+    newdst = format_str(dst, line_length=line_length, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3148,19 +3123,21 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-def get_cache_file(line_length: int, pyi: bool = False, py36: bool = False) -> Path:
+def get_cache_file(line_length: int, mode: FileMode) -> Path:
+    pyi = bool(mode & FileMode.PYI)
+    py36 = bool(mode & FileMode.PYTHON36)
     return (
         CACHE_DIR
         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
     )
 
 
     return (
         CACHE_DIR
         / f"cache.{line_length}{'.pyi' if pyi else ''}{'.py36' if py36 else ''}.pickle"
     )
 
 
-def read_cache(line_length: int, pyi: bool = False, py36: bool = False) -> Cache:
+def read_cache(line_length: int, mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length, pyi, py36)
+    cache_file = get_cache_file(line_length, mode)
     if not cache_file.exists():
         return {}
 
     if not cache_file.exists():
         return {}
 
@@ -3198,14 +3175,10 @@ def filter_cached(
 
 
 def write_cache(
 
 
 def write_cache(
-    cache: Cache,
-    sources: List[Path],
-    line_length: int,
-    pyi: bool = False,
-    py36: bool = False,
+    cache: Cache, sources: List[Path], line_length: int, mode: FileMode
 ) -> None:
     """Update the cache file."""
 ) -> None:
     """Update the cache file."""
-    cache_file = get_cache_file(line_length, pyi, py36)
+    cache_file = get_cache_file(line_length, mode)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)
     try:
         if not CACHE_DIR.exists():
             CACHE_DIR.mkdir(parents=True)
index f5114a7c0c3c54d1b5259b5d349514fd5bceee17..595d6cd35c4bd2df0eba2f25fc984f3b7080df63 100644 (file)
@@ -342,10 +342,11 @@ class BlackTestCase(unittest.TestCase):
 
     @patch("black.dump_to_file", dump_to_stderr)
     def test_stub(self) -> None:
 
     @patch("black.dump_to_file", dump_to_stderr)
     def test_stub(self) -> None:
+        mode = black.FileMode.PYI
         source, expected = read_data("stub.pyi")
         source, expected = read_data("stub.pyi")
-        actual = fs(source, is_pyi=True)
+        actual = fs(source, mode=mode)
         self.assertFormatEqual(expected, actual)
         self.assertFormatEqual(expected, actual)
-        black.assert_stable(source, actual, line_length=ll, is_pyi=True)
+        black.assert_stable(source, actual, line_length=ll, mode=mode)
 
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fmtonoff(self) -> None:
 
     @patch("black.dump_to_file", dump_to_stderr)
     def test_fmtonoff(self) -> None:
@@ -566,25 +567,27 @@ class BlackTestCase(unittest.TestCase):
         self.assertEqual("".join(err_lines), "")
 
     def test_cache_broken_file(self) -> None:
         self.assertEqual("".join(err_lines), "")
 
     def test_cache_broken_file(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace:
         with cache_dir() as workspace:
-            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
+            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
             with cache_file.open("w") as fobj:
                 fobj.write("this is not a pickle")
             with cache_file.open("w") as fobj:
                 fobj.write("this is not a pickle")
-            self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
+            self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
             result = CliRunner().invoke(black.main, [str(src)])
             self.assertEqual(result.exit_code, 0)
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
             result = CliRunner().invoke(black.main, [str(src)])
             self.assertEqual(result.exit_code, 0)
-            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
             self.assertIn(src, cache)
 
     def test_cache_single_file_already_cached(self) -> None:
             self.assertIn(src, cache)
 
     def test_cache_single_file_already_cached(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace:
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
         with cache_dir() as workspace:
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
-            black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
+            black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
             result = CliRunner().invoke(black.main, [str(src)])
             self.assertEqual(result.exit_code, 0)
             with src.open("r") as fobj:
             result = CliRunner().invoke(black.main, [str(src)])
             self.assertEqual(result.exit_code, 0)
             with src.open("r") as fobj:
@@ -592,6 +595,7 @@ class BlackTestCase(unittest.TestCase):
 
     @event_loop(close=False)
     def test_cache_multiple_files(self) -> None:
 
     @event_loop(close=False)
     def test_cache_multiple_files(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace, patch(
             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
         ):
         with cache_dir() as workspace, patch(
             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
         ):
@@ -601,44 +605,48 @@ class BlackTestCase(unittest.TestCase):
             two = (workspace / "two.py").resolve()
             with two.open("w") as fobj:
                 fobj.write("print('hello')")
             two = (workspace / "two.py").resolve()
             with two.open("w") as fobj:
                 fobj.write("print('hello')")
-            black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH)
+            black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
             result = CliRunner().invoke(black.main, [str(workspace)])
             self.assertEqual(result.exit_code, 0)
             with one.open("r") as fobj:
                 self.assertEqual(fobj.read(), "print('hello')")
             with two.open("r") as fobj:
                 self.assertEqual(fobj.read(), 'print("hello")\n')
             result = CliRunner().invoke(black.main, [str(workspace)])
             self.assertEqual(result.exit_code, 0)
             with one.open("r") as fobj:
                 self.assertEqual(fobj.read(), "print('hello')")
             with two.open("r") as fobj:
                 self.assertEqual(fobj.read(), 'print("hello")\n')
-            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
             self.assertIn(one, cache)
             self.assertIn(two, cache)
 
     def test_no_cache_when_writeback_diff(self) -> None:
             self.assertIn(one, cache)
             self.assertIn(two, cache)
 
     def test_no_cache_when_writeback_diff(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace:
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
             result = CliRunner().invoke(black.main, [str(src), "--diff"])
             self.assertEqual(result.exit_code, 0)
         with cache_dir() as workspace:
             src = (workspace / "test.py").resolve()
             with src.open("w") as fobj:
                 fobj.write("print('hello')")
             result = CliRunner().invoke(black.main, [str(src), "--diff"])
             self.assertEqual(result.exit_code, 0)
-            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
+            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
             self.assertFalse(cache_file.exists())
 
     def test_no_cache_when_stdin(self) -> None:
             self.assertFalse(cache_file.exists())
 
     def test_no_cache_when_stdin(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir():
             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
             self.assertEqual(result.exit_code, 0)
         with cache_dir():
             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
             self.assertEqual(result.exit_code, 0)
-            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH)
+            cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
             self.assertFalse(cache_file.exists())
 
     def test_read_cache_no_cachefile(self) -> None:
             self.assertFalse(cache_file.exists())
 
     def test_read_cache_no_cachefile(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir():
         with cache_dir():
-            self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH), {})
+            self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
 
     def test_write_cache_read_cache(self) -> None:
 
     def test_write_cache_read_cache(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace:
             src = (workspace / "test.py").resolve()
             src.touch()
         with cache_dir() as workspace:
             src = (workspace / "test.py").resolve()
             src.touch()
-            black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH)
-            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
             self.assertIn(src, cache)
             self.assertEqual(cache[src], black.get_cache_info(src))
 
             self.assertIn(src, cache)
             self.assertEqual(cache[src], black.get_cache_info(src))
 
@@ -659,13 +667,15 @@ class BlackTestCase(unittest.TestCase):
             self.assertEqual(done, [cached])
 
     def test_write_cache_creates_directory_if_needed(self) -> None:
             self.assertEqual(done, [cached])
 
     def test_write_cache_creates_directory_if_needed(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir(exists=False) as workspace:
             self.assertFalse(workspace.exists())
         with cache_dir(exists=False) as workspace:
             self.assertFalse(workspace.exists())
-            black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
+            black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
             self.assertTrue(workspace.exists())
 
     @event_loop(close=False)
     def test_failed_formatting_does_not_get_cached(self) -> None:
             self.assertTrue(workspace.exists())
 
     @event_loop(close=False)
     def test_failed_formatting_does_not_get_cached(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace, patch(
             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
         ):
         with cache_dir() as workspace, patch(
             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
         ):
@@ -677,14 +687,15 @@ class BlackTestCase(unittest.TestCase):
                 fobj.write('print("hello")\n')
             result = CliRunner().invoke(black.main, [str(workspace)])
             self.assertEqual(result.exit_code, 123)
                 fobj.write('print("hello")\n')
             result = CliRunner().invoke(black.main, [str(workspace)])
             self.assertEqual(result.exit_code, 123)
-            cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
             self.assertNotIn(failing, cache)
             self.assertIn(clean, cache)
 
     def test_write_cache_write_fail(self) -> None:
             self.assertNotIn(failing, cache)
             self.assertIn(clean, cache)
 
     def test_write_cache_write_fail(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir(), patch.object(Path, "open") as mock:
             mock.side_effect = OSError
         with cache_dir(), patch.object(Path, "open") as mock:
             mock.side_effect = OSError
-            black.write_cache({}, [], black.DEFAULT_LINE_LENGTH)
+            black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
 
     @event_loop(close=False)
     def test_check_diff_use_together(self) -> None:
 
     @event_loop(close=False)
     def test_check_diff_use_together(self) -> None:
@@ -719,16 +730,19 @@ class BlackTestCase(unittest.TestCase):
             self.assertEqual(result.exit_code, 0)
 
     def test_read_cache_line_lengths(self) -> None:
             self.assertEqual(result.exit_code, 0)
 
     def test_read_cache_line_lengths(self) -> None:
+        mode = black.FileMode.AUTO_DETECT
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
             path.touch()
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
             path.touch()
-            black.write_cache({}, [path], 1)
-            one = black.read_cache(1)
+            black.write_cache({}, [path], 1, mode)
+            one = black.read_cache(1, mode)
             self.assertIn(path, one)
             self.assertIn(path, one)
-            two = black.read_cache(2)
+            two = black.read_cache(2, mode)
             self.assertNotIn(path, two)
 
     def test_single_file_force_pyi(self) -> None:
             self.assertNotIn(path, two)
 
     def test_single_file_force_pyi(self) -> None:
+        reg_mode = black.FileMode.AUTO_DETECT
+        pyi_mode = black.FileMode.PYI
         contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
         contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
@@ -739,14 +753,16 @@ class BlackTestCase(unittest.TestCase):
             with open(path, "r") as fh:
                 actual = fh.read()
             # verify cache with --pyi is separate
             with open(path, "r") as fh:
                 actual = fh.read()
             # verify cache with --pyi is separate
-            pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi=True)
+            pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
             self.assertIn(path, pyi_cache)
             self.assertIn(path, pyi_cache)
-            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
     @event_loop(close=False)
     def test_multi_file_force_pyi(self) -> None:
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
     @event_loop(close=False)
     def test_multi_file_force_pyi(self) -> None:
+        reg_mode = black.FileMode.AUTO_DETECT
+        pyi_mode = black.FileMode.PYI
         contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
             paths = [
         contents, expected = read_data("force_pyi")
         with cache_dir() as workspace:
             paths = [
@@ -763,8 +779,8 @@ class BlackTestCase(unittest.TestCase):
                     actual = fh.read()
                 self.assertEqual(actual, expected)
             # verify cache with --pyi is separate
                     actual = fh.read()
                 self.assertEqual(actual, expected)
             # verify cache with --pyi is separate
-            pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi=True)
-            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
+            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
             for path in paths:
                 self.assertIn(path, pyi_cache)
                 self.assertNotIn(path, normal_cache)
             for path in paths:
                 self.assertIn(path, pyi_cache)
                 self.assertNotIn(path, normal_cache)
@@ -777,6 +793,8 @@ class BlackTestCase(unittest.TestCase):
         self.assertFormatEqual(actual, expected)
 
     def test_single_file_force_py36(self) -> None:
         self.assertFormatEqual(actual, expected)
 
     def test_single_file_force_py36(self) -> None:
+        reg_mode = black.FileMode.AUTO_DETECT
+        py36_mode = black.FileMode.PYTHON36
         source, expected = read_data("force_py36")
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
         source, expected = read_data("force_py36")
         with cache_dir() as workspace:
             path = (workspace / "file.py").resolve()
@@ -787,14 +805,16 @@ class BlackTestCase(unittest.TestCase):
             with open(path, "r") as fh:
                 actual = fh.read()
             # verify cache with --py36 is separate
             with open(path, "r") as fh:
                 actual = fh.read()
             # verify cache with --py36 is separate
-            py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36=True)
+            py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
             self.assertIn(path, py36_cache)
             self.assertIn(path, py36_cache)
-            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
     @event_loop(close=False)
     def test_multi_file_force_py36(self) -> None:
             self.assertNotIn(path, normal_cache)
         self.assertEqual(actual, expected)
 
     @event_loop(close=False)
     def test_multi_file_force_py36(self) -> None:
+        reg_mode = black.FileMode.AUTO_DETECT
+        py36_mode = black.FileMode.PYTHON36
         source, expected = read_data("force_py36")
         with cache_dir() as workspace:
             paths = [
         source, expected = read_data("force_py36")
         with cache_dir() as workspace:
             paths = [
@@ -813,8 +833,8 @@ class BlackTestCase(unittest.TestCase):
                     actual = fh.read()
                 self.assertEqual(actual, expected)
             # verify cache with --py36 is separate
                     actual = fh.read()
                 self.assertEqual(actual, expected)
             # verify cache with --py36 is separate
-            pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36=True)
-            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH)
+            pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
+            normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
             for path in paths:
                 self.assertIn(path, pyi_cache)
                 self.assertNotIn(path, normal_cache)
             for path in paths:
                 self.assertIn(path, pyi_cache)
                 self.assertNotIn(path, normal_cache)