]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Mention fix for #632 in the change log
[etc/vim.git] / black.py
index 57d5c8377a6a7357fdbb44cca369f6c3ba51b19e..31e629b90f9b84b3dbea1747a618b9ba04ac5620 100644 (file)
--- a/black.py
+++ b/black.py
@@ -2,7 +2,7 @@ import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from datetime import datetime
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from datetime import datetime
-from enum import Enum, Flag
+from enum import Enum
 from functools import lru_cache, partial, wraps
 import io
 import itertools
 from functools import lru_cache, partial, wraps
 import io
 import itertools
@@ -14,6 +14,7 @@ import pickle
 import re
 import signal
 import sys
 import re
 import signal
 import sys
+import tempfile
 import tokenize
 from typing import (
     Any,
 import tokenize
 from typing import (
     Any,
@@ -36,7 +37,7 @@ from typing import (
 )
 
 from appdirs import user_cache_dir
 )
 
 from appdirs import user_cache_dir
-from attr import dataclass, Factory
+from attr import dataclass, evolve, Factory
 import click
 import toml
 
 import click
 import toml
 
@@ -44,6 +45,7 @@ import toml
 from blib2to3.pytree import Node, Leaf, type_repr
 from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 from blib2to3.pytree import Node, Leaf, type_repr
 from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
+from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pgen2.parse import ParseError
 
 
 from blib2to3.pgen2.parse import ParseError
 
 
@@ -110,32 +112,82 @@ class Changed(Enum):
     YES = 2
 
 
     YES = 2
 
 
-class FileMode(Flag):
-    AUTO_DETECT = 0
-    PYTHON36 = 1
-    PYI = 2
-    NO_STRING_NORMALIZATION = 4
-    NO_NUMERIC_UNDERSCORE_NORMALIZATION = 8
+class TargetVersion(Enum):
+    PY27 = 2
+    PY33 = 3
+    PY34 = 4
+    PY35 = 5
+    PY36 = 6
+    PY37 = 7
+    PY38 = 8
+
+    def is_python2(self) -> bool:
+        return self is TargetVersion.PY27
+
+
+PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
+
+
+class Feature(Enum):
+    # All string literals are unicode
+    UNICODE_LITERALS = 1
+    F_STRINGS = 2
+    NUMERIC_UNDERSCORES = 3
+    TRAILING_COMMA = 4
+
+
+VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
+    TargetVersion.PY27: set(),
+    TargetVersion.PY33: {Feature.UNICODE_LITERALS},
+    TargetVersion.PY34: {Feature.UNICODE_LITERALS},
+    TargetVersion.PY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
+    TargetVersion.PY36: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA,
+    },
+    TargetVersion.PY37: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA,
+    },
+    TargetVersion.PY38: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA,
+    },
+}
 
 
-    @classmethod
-    def from_configuration(
-        cls,
-        *,
-        py36: bool,
-        pyi: bool,
-        skip_string_normalization: bool,
-        skip_numeric_underscore_normalization: bool,
-    ) -> "FileMode":
-        mode = cls.AUTO_DETECT
-        if py36:
-            mode |= cls.PYTHON36
-        if pyi:
-            mode |= cls.PYI
-        if skip_string_normalization:
-            mode |= cls.NO_STRING_NORMALIZATION
-        if skip_numeric_underscore_normalization:
-            mode |= cls.NO_NUMERIC_UNDERSCORE_NORMALIZATION
-        return mode
+
+@dataclass
+class FileMode:
+    target_versions: Set[TargetVersion] = Factory(set)
+    line_length: int = DEFAULT_LINE_LENGTH
+    string_normalization: bool = True
+    is_pyi: bool = False
+
+    def get_cache_key(self) -> str:
+        if self.target_versions:
+            version_str = ",".join(
+                str(version.value)
+                for version in sorted(self.target_versions, key=lambda v: v.value)
+            )
+        else:
+            version_str = "-"
+        parts = [
+            version_str,
+            str(self.line_length),
+            str(int(self.string_normalization)),
+            str(int(self.is_pyi)),
+        ]
+        return ".".join(parts)
+
+
+def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
+    return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
 def read_pyproject_toml(
 
 
 def read_pyproject_toml(
@@ -184,12 +236,14 @@ def read_pyproject_toml(
     show_default=True,
 )
 @click.option(
     show_default=True,
 )
 @click.option(
-    "--py36",
-    is_flag=True,
+    "-t",
+    "--target-version",
+    type=click.Choice([v.name.lower() for v in TargetVersion]),
+    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    multiple=True,
     help=(
     help=(
-        "Allow using Python 3.6-only syntax on all input files.  This will put "
-        "trailing commas in function signatures and calls also after *args and "
-        "**kwargs.  [default: per-file auto-detection]"
+        "Python versions that should be supported by Black's output. [default: "
+        "per-file auto-detection]"
     ),
 )
 @click.option(
     ),
 )
 @click.option(
@@ -206,12 +260,6 @@ def read_pyproject_toml(
     is_flag=True,
     help="Don't normalize string quotes or prefixes.",
 )
     is_flag=True,
     help="Don't normalize string quotes or prefixes.",
 )
-@click.option(
-    "-N",
-    "--skip-numeric-underscore-normalization",
-    is_flag=True,
-    help="Don't normalize underscores in numeric literals.",
-)
 @click.option(
     "--check",
     is_flag=True,
 @click.option(
     "--check",
     is_flag=True,
@@ -296,13 +344,12 @@ def read_pyproject_toml(
 def main(
     ctx: click.Context,
     line_length: int,
 def main(
     ctx: click.Context,
     line_length: int,
+    target_version: List[TargetVersion],
     check: bool,
     diff: bool,
     fast: bool,
     pyi: bool,
     check: bool,
     diff: bool,
     fast: bool,
     pyi: bool,
-    py36: bool,
     skip_string_normalization: bool,
     skip_string_normalization: bool,
-    skip_numeric_underscore_normalization: bool,
     quiet: bool,
     verbose: bool,
     include: str,
     quiet: bool,
     verbose: bool,
     include: str,
@@ -312,11 +359,16 @@ def main(
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
-    mode = FileMode.from_configuration(
-        py36=py36,
-        pyi=pyi,
-        skip_string_normalization=skip_string_normalization,
-        skip_numeric_underscore_normalization=skip_numeric_underscore_normalization,
+    if target_version:
+        versions = set(target_version)
+    else:
+        # We'll autodetect later.
+        versions = set()
+    mode = FileMode(
+        target_versions=versions,
+        line_length=line_length,
+        is_pyi=pyi,
+        string_normalization=not skip_string_normalization,
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
@@ -352,7 +404,6 @@ def main(
     if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
     if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
-            line_length=line_length,
             fast=fast,
             write_back=write_back,
             mode=mode,
             fast=fast,
             write_back=write_back,
             mode=mode,
@@ -365,7 +416,6 @@ def main(
             loop.run_until_complete(
                 schedule_formatting(
                     sources=sources,
             loop.run_until_complete(
                 schedule_formatting(
                     sources=sources,
-                    line_length=line_length,
                     fast=fast,
                     write_back=write_back,
                     mode=mode,
                     fast=fast,
                     write_back=write_back,
                     mode=mode,
@@ -384,12 +434,7 @@ def main(
 
 
 def reformat_one(
 
 
 def reformat_one(
-    src: Path,
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack,
-    mode: FileMode,
-    report: "Report",
+    src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
@@ -400,29 +445,23 @@ def reformat_one(
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
-            if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back, mode=mode
-            ):
+            if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length, mode)
+                cache = read_cache(mode)
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src,
-                line_length=line_length,
-                fast=fast,
-                write_back=write_back,
-                mode=mode,
+                src, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
                 write_back is WriteBack.CHECK and changed is Changed.NO
             ):
             ):
                 changed = Changed.YES
             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
                 write_back is WriteBack.CHECK and changed is Changed.NO
             ):
-                write_cache(cache, [src], line_length, mode)
+                write_cache(cache, [src], mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
@@ -430,7 +469,6 @@ def reformat_one(
 
 async def schedule_formatting(
     sources: Set[Path],
 
 async def schedule_formatting(
     sources: Set[Path],
-    line_length: int,
     fast: bool,
     write_back: WriteBack,
     mode: FileMode,
     fast: bool,
     write_back: WriteBack,
     mode: FileMode,
@@ -447,7 +485,7 @@ async def schedule_formatting(
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length, mode)
+        cache = read_cache(mode)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
@@ -464,14 +502,7 @@ async def schedule_formatting(
         lock = manager.Lock()
     tasks = {
         loop.run_in_executor(
         lock = manager.Lock()
     tasks = {
         loop.run_in_executor(
-            executor,
-            format_file_in_place,
-            src,
-            line_length,
-            fast,
-            write_back,
-            mode,
-            lock,
+            executor, format_file_in_place, src, fast, mode, write_back, lock
         ): src
         for src in sorted(sources)
     }
         ): src
         for src in sorted(sources)
     }
@@ -502,15 +533,14 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if sources_to_cache:
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if sources_to_cache:
-        write_cache(cache, sources_to_cache, line_length, mode)
+        write_cache(cache, sources_to_cache, mode)
 
 
 def format_file_in_place(
     src: Path,
 
 
 def format_file_in_place(
     src: Path,
-    line_length: int,
     fast: bool,
     fast: bool,
+    mode: FileMode,
     write_back: WriteBack = WriteBack.NO,
     write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
@@ -520,15 +550,13 @@ def format_file_in_place(
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
-        mode |= FileMode.PYI
+        mode = evolve(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
-        dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, mode=mode
-        )
+        dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
     except NothingChanged:
         return False
 
     except NothingChanged:
         return False
 
@@ -558,23 +586,19 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
 
 
 def format_stdin_to_stdout(
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
-    write a diff to stdout.
-    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
     then = datetime.utcnow()
     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
     :func:`format_file_contents`.
     """
     then = datetime.utcnow()
     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
+        dst = format_file_contents(src, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
         return True
 
     except NothingChanged:
@@ -595,11 +619,7 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
 
 
 def format_file_contents(
-    src_contents: str,
-    *,
-    line_length: int,
-    fast: bool,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    src_contents: str, *, fast: bool, mode: FileMode
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -610,38 +630,36 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
+    dst_contents = format_str(src_contents, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
+        assert_stable(src_contents, dst_contents, mode=mode)
     return dst_contents
 
 
     return dst_contents
 
 
-def format_str(
-    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
-) -> FileContent:
+def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     """Reformat a string and return new contents.
 
     `line_length` determines how many characters per line are allowed.
     """
     """Reformat a string and return new contents.
 
     `line_length` determines how many characters per line are allowed.
     """
-    src_node = lib2to3_parse(src_contents.lstrip())
+    src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
-    is_pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
-    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
+    if mode.target_versions:
+        versions = mode.target_versions
+    else:
+        versions = detect_target_versions(src_node)
     normalize_fmt_off(src_node)
     lines = LineGenerator(
     normalize_fmt_off(src_node)
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports,
-        is_pyi=is_pyi,
-        normalize_strings=normalize_strings,
-        allow_underscores=py36
-        and not bool(mode & FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION),
+        remove_u_prefix="unicode_literals" in future_imports
+        or supports_feature(versions, Feature.UNICODE_LITERALS),
+        is_pyi=mode.is_pyi,
+        normalize_strings=mode.string_normalization,
     )
     )
-    elt = EmptyLineTracker(is_pyi=is_pyi)
+    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -650,7 +668,11 @@ def format_str(
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
-        for line in split_line(current_line, line_length=line_length, py36=py36):
+        for line in split_line(
+            current_line,
+            line_length=mode.line_length,
+            supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+        ):
             dst_contents += str(line)
     return dst_contents
 
             dst_contents += str(line)
     return dst_contents
 
@@ -679,11 +701,25 @@ GRAMMARS = [
 ]
 
 
 ]
 
 
-def lib2to3_parse(src_txt: str) -> Node:
+def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
+    if not target_versions:
+        return GRAMMARS
+    elif all(not version.is_python2() for version in target_versions):
+        # Python 2-compatible code, so don't try Python 3 grammar.
+        return [
+            pygram.python_grammar_no_print_statement_no_exec_statement,
+            pygram.python_grammar_no_print_statement,
+        ]
+    else:
+        return [pygram.python_grammar]
+
+
+def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     if src_txt[-1:] != "\n":
         src_txt += "\n"
     """Given a string with source, return the lib2to3 Node."""
     if src_txt[-1:] != "\n":
         src_txt += "\n"
-    for grammar in GRAMMARS:
+
+    for grammar in get_grammars(set(target_versions)):
         drv = driver.Driver(grammar, pytree.convert)
         try:
             result = drv.parse_string(src_txt, True)
         drv = driver.Driver(grammar, pytree.convert)
         try:
             result = drv.parse_string(src_txt, True)
@@ -1025,9 +1061,7 @@ class Line:
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
-    # The LeafID keys of comments must remain ordered by the corresponding leaf's index
-    # in leaves
-    comments: Dict[LeafID, List[Leaf]] = Factory(dict)
+    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
@@ -1158,6 +1192,29 @@ class Line:
             if leaf.type == STANDALONE_COMMENT:
                 if leaf.bracket_depth <= depth_limit:
                     return True
             if leaf.type == STANDALONE_COMMENT:
                 if leaf.bracket_depth <= depth_limit:
                     return True
+        return False
+
+    def contains_inner_type_comments(self) -> bool:
+        ignored_ids = set()
+        try:
+            last_leaf = self.leaves[-1]
+            ignored_ids.add(id(last_leaf))
+            if last_leaf.type == token.COMMA:
+                # When trailing commas are inserted by Black for consistency, comments
+                # after the previous last element are not moved (they don't have to,
+                # rendering will still be correct).  So we ignore trailing commas.
+                last_leaf = self.leaves[-2]
+                ignored_ids.add(id(last_leaf))
+        except IndexError:
+            return False
+
+        for leaf_id, comments in self.comments.items():
+            if leaf_id in ignored_ids:
+                continue
+
+            for comment in comments:
+                if is_type_comment(comment):
+                    return True
 
         return False
 
 
         return False
 
@@ -1239,13 +1296,8 @@ class Line:
             comment.prefix = ""
             return False
 
             comment.prefix = ""
             return False
 
-        else:
-            leaf_id = id(self.leaves[-1])
-            if leaf_id not in self.comments:
-                self.comments[leaf_id] = [comment]
-            else:
-                self.comments[leaf_id].append(comment)
-            return True
+        self.comments.setdefault(id(self.leaves[-1]), []).append(comment)
+        return True
 
     def comments_after(self, leaf: Leaf) -> List[Leaf]:
         """Generate comments that should appear directly after `leaf`."""
 
     def comments_after(self, leaf: Leaf) -> List[Leaf]:
         """Generate comments that should appear directly after `leaf`."""
@@ -1253,17 +1305,11 @@ class Line:
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
-        # Remember, the LeafID keys of self.comments are ordered by the
-        # corresponding leaf's index in self.leaves
-        # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
-        # Otherwise, we insert it into self.comments, and it becomes the last entry.
-        # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
-        # is maintained
-        self.comments.setdefault(id(self.leaves[-2]), []).extend(
-            self.comments.get(id(self.leaves[-1]), [])
+        trailing_comma = self.leaves.pop()
+        trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
+        self.comments.setdefault(id(self.leaves[-1]), []).extend(
+            trailing_comma_comments
         )
         )
-        self.comments.pop(id(self.leaves[-1]), None)
-        self.leaves.pop()
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
@@ -1426,7 +1472,6 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
-    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1469,7 +1514,7 @@ class LineGenerator(Visitor[Line]):
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
-                normalize_numeric_literal(node, self.allow_underscores)
+                normalize_numeric_literal(node)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
@@ -1593,6 +1638,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
+        self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1605,7 +1651,7 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
-def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
+def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
     """Return whitespace prefix if needed for the given `leaf`.
 
     `complex_subscript` signals whether the given leaf is part of a subscription
     """Return whitespace prefix if needed for the given `leaf`.
 
     `complex_subscript` signals whether the given leaf is part of a subscription
@@ -2092,7 +2138,10 @@ def make_comment(content: str) -> str:
 
 
 def split_line(
 
 
 def split_line(
-    line: Line, line_length: int, inner: bool = False, py36: bool = False
+    line: Line,
+    line_length: int,
+    inner: bool = False,
+    supports_trailing_commas: bool = False,
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
@@ -2101,8 +2150,7 @@ def split_line(
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `py36` is True, splitting may generate syntax that is only compatible
-    with Python 3.6 and later.
+    If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
     """
     if line.is_comment:
         yield line
     """
     if line.is_comment:
         yield line
@@ -2110,16 +2158,8 @@ def split_line(
 
     line_str = str(line).strip("\n")
 
 
     line_str = str(line).strip("\n")
 
-    # we don't want to split special comments like type annotations
-    # https://github.com/python/typing/issues/186
-    has_special_comment = False
-    for leaf in line.leaves:
-        for comment in line.comments_after(leaf):
-            if leaf.type == token.COMMA and is_special_comment(comment):
-                has_special_comment = True
-
     if (
     if (
-        not has_special_comment
+        not line.contains_inner_type_comments()
         and not line.should_explode
         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
     ):
         and not line.should_explode
         and is_line_short_enough(line, line_length=line_length, line_str=line_str)
     ):
@@ -2131,9 +2171,13 @@ def split_line(
         split_funcs = [left_hand_split]
     else:
 
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(right_hand_split(line, line_length, py36, omit=omit))
+                lines = list(
+                    right_hand_split(
+                        line, line_length, supports_trailing_commas, omit=omit
+                    )
+                )
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
@@ -2141,7 +2185,7 @@ def split_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, py36)
+            yield from right_hand_split(line, supports_trailing_commas)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
@@ -2153,12 +2197,17 @@ def split_line(
         # split altogether.
         result: List[Line] = []
         try:
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, py36):
+            for l in split_func(line, supports_trailing_commas):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
-                    split_line(l, line_length=line_length, inner=True, py36=py36)
+                    split_line(
+                        l,
+                        line_length=line_length,
+                        inner=True,
+                        supports_trailing_commas=supports_trailing_commas,
+                    )
                 )
         except CannotSplit:
             continue
                 )
         except CannotSplit:
             continue
@@ -2171,7 +2220,9 @@ def split_line(
         yield line
 
 
         yield line
 
 
-def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def left_hand_split(
+    line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
@@ -2208,7 +2259,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 def right_hand_split(
 
 
 def right_hand_split(
-    line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
+    line: Line,
+    line_length: int,
+    supports_trailing_commas: bool = False,
+    omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
@@ -2266,7 +2320,12 @@ def right_hand_split(
     ):
         omit = {id(closing_bracket), *omit}
         try:
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            yield from right_hand_split(
+                line,
+                line_length,
+                supports_trailing_commas=supports_trailing_commas,
+                omit=omit,
+            )
             return
 
         except CannotSplit:
             return
 
         except CannotSplit:
@@ -2355,8 +2414,10 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """
 
     @wraps(split_func)
     """
 
     @wraps(split_func)
-    def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
-        for l in split_func(line, py36):
+    def split_wrapper(
+        line: Line, supports_trailing_commas: bool = False
+    ) -> Iterator[Line]:
+        for l in split_func(line, supports_trailing_commas):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
@@ -2364,11 +2425,13 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
 
 @dont_increase_indentation
 
 
 @dont_increase_indentation
-def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def delimiter_split(
+    line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
     """Split according to delimiters of the highest priority.
 
-    If `py36` is True, the split will add trailing commas also in function
-    signatures that contain `*` and `**`.
+    If `supports_trailing_commas` is True, the split will add trailing commas
+    also in function signatures that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
     """
     try:
         last_leaf = line.leaves[-1]
@@ -2410,7 +2473,7 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
         if leaf.bracket_depth == lowest_depth and is_vararg(
             leaf, within=VARARGS_PARENTS
         ):
         if leaf.bracket_depth == lowest_depth and is_vararg(
             leaf, within=VARARGS_PARENTS
         ):
-            trailing_comma_safe = trailing_comma_safe and py36
+            trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
@@ -2428,7 +2491,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 @dont_increase_indentation
 
 
 @dont_increase_indentation
-def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def standalone_comment_split(
+    line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
@@ -2470,14 +2535,12 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
     )
 
 
-def is_special_comment(leaf: Leaf) -> bool:
+def is_type_comment(leaf: Leaf) -> bool:
     """Return True if the given leaf is a special comment.
     Only returns true for type comments for now."""
     t = leaf.type
     v = leaf.value
     """Return True if the given leaf is a special comment.
     Only returns true for type comments for now."""
     t = leaf.type
     v = leaf.value
-    return bool(
-        (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
-    )
+    return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
 
 
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
 
 
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
@@ -2581,11 +2644,11 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
-def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+def normalize_numeric_literal(leaf: Leaf) -> None:
     """Normalizes numeric (float, int, and complex) literals.
 
     All letters used in the representation are normalized to lowercase (except
     """Normalizes numeric (float, int, and complex) literals.
 
     All letters used in the representation are normalized to lowercase (except
-    in Python 2 long literals), and long number literals are split using underscores.
+    in Python 2 long literals).
     """
     text = leaf.value.lower()
     if text.startswith(("0o", "0b")):
     """
     text = leaf.value.lower()
     if text.startswith(("0o", "0b")):
@@ -2603,8 +2666,7 @@ def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
-        before = format_float_or_int_string(before, allow_underscores)
-        after = format_int_string(after, allow_underscores)
+        before = format_float_or_int_string(before)
         text = f"{before}e{sign}{after}"
     elif text.endswith(("j", "l")):
         number = text[:-1]
         text = f"{before}e{sign}{after}"
     elif text.endswith(("j", "l")):
         number = text[:-1]
@@ -2612,50 +2674,19 @@ def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
         # Capitalize in "2L" because "l" looks too similar to "1".
         if suffix == "l":
             suffix = "L"
         # Capitalize in "2L" because "l" looks too similar to "1".
         if suffix == "l":
             suffix = "L"
-        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+        text = f"{format_float_or_int_string(number)}{suffix}"
     else:
     else:
-        text = format_float_or_int_string(text, allow_underscores)
+        text = format_float_or_int_string(text)
     leaf.value = text
 
 
     leaf.value = text
 
 
-def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+def format_float_or_int_string(text: str) -> str:
     """Formats a float string like "1.0"."""
     if "." not in text:
     """Formats a float string like "1.0"."""
     if "." not in text:
-        return format_int_string(text, allow_underscores)
-
-    before, after = text.split(".")
-    before = format_int_string(before, allow_underscores) if before else "0"
-    if after:
-        after = format_int_string(after, allow_underscores, count_from_end=False)
-    else:
-        after = "0"
-    return f"{before}.{after}"
-
-
-def format_int_string(
-    text: str, allow_underscores: bool, count_from_end: bool = True
-) -> str:
-    """Normalizes underscores in a string to e.g. 1_000_000.
-
-    Input must be a string of digits and optional underscores.
-    If count_from_end is False, we add underscores after groups of three digits
-    counting from the beginning instead of the end of the strings. This is used
-    for the fractional part of float literals.
-    """
-    if not allow_underscores:
-        return text
-
-    text = text.replace("_", "")
-    if len(text) <= 5:
-        # No underscores for numbers <= 5 digits long.
         return text
 
         return text
 
-    if count_from_end:
-        # Avoid removing leading zeros, which are important if we're formatting
-        # part of a number like "0.001".
-        return format(int("1" + text), "3_")[1:].lstrip("_")
-    else:
-        return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
+    before, after = text.split(".")
+    return f"{before or 0}.{after or 0}"
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
@@ -2987,23 +3018,24 @@ def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     return max_priority == COMMA_PRIORITY
 
 
     return max_priority == COMMA_PRIORITY
 
 
-def is_python36(node: Node) -> bool:
-    """Return True if the current file is using Python 3.6+ features.
+def get_features_used(node: Node) -> Set[Feature]:
+    """Return a set of (relatively) new Python features used in this file.
 
     Currently looking for:
     - f-strings;
     - underscores in numeric literals; and
     - trailing commas after * or ** in function signatures and calls.
     """
 
     Currently looking for:
     - f-strings;
     - underscores in numeric literals; and
     - trailing commas after * or ** in function signatures and calls.
     """
+    features: Set[Feature] = set()
     for n in node.pre_order():
         if n.type == token.STRING:
             value_head = n.value[:2]  # type: ignore
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
     for n in node.pre_order():
         if n.type == token.STRING:
             value_head = n.value[:2]  # type: ignore
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
-                return True
+                features.add(Feature.F_STRINGS)
 
         elif n.type == token.NUMBER:
             if "_" in n.value:  # type: ignore
 
         elif n.type == token.NUMBER:
             if "_" in n.value:  # type: ignore
-                return True
+                features.add(Feature.NUMERIC_UNDERSCORES)
 
         elif (
             n.type in {syms.typedargslist, syms.arglist}
 
         elif (
             n.type in {syms.typedargslist, syms.arglist}
@@ -3012,14 +3044,22 @@ def is_python36(node: Node) -> bool:
         ):
             for ch in n.children:
                 if ch.type in STARS:
         ):
             for ch in n.children:
                 if ch.type in STARS:
-                    return True
+                    features.add(Feature.TRAILING_COMMA)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            return True
+                            features.add(Feature.TRAILING_COMMA)
+
+    return features
 
 
-    return False
+
+def detect_target_versions(node: Node) -> Set[TargetVersion]:
+    """Detect the version to target based on the nodes used."""
+    features = get_features_used(node)
+    return {
+        version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
+    }
 
 
 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
 
 
 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
@@ -3293,7 +3333,16 @@ def assert_equivalent(src: str, dst: str) -> None:
 
             if isinstance(value, list):
                 for item in value:
 
             if isinstance(value, list):
                 for item in value:
-                    if isinstance(item, ast.AST):
+                    # Ignore nested tuples within del statements, because we may insert
+                    # parentheses and they change the AST.
+                    if (
+                        field == "targets"
+                        and isinstance(node, ast.Delete)
+                        and isinstance(item, ast.Tuple)
+                    ):
+                        for item in item.elts:
+                            yield from _v(item, depth + 2)
+                    elif isinstance(item, ast.AST):
                         yield from _v(item, depth + 2)
 
             elif isinstance(value, ast.AST):
                         yield from _v(item, depth + 2)
 
             elif isinstance(value, ast.AST):
@@ -3336,11 +3385,9 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
         ) from None
 
 
-def assert_stable(
-    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
-) -> None:
+def assert_stable(src: str, dst: str, mode: FileMode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, mode=mode)
+    newdst = format_str(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3389,8 +3436,12 @@ def cancel(tasks: Iterable[asyncio.Task]) -> None:
 def shutdown(loop: BaseEventLoop) -> None:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
 def shutdown(loop: BaseEventLoop) -> None:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
+        if sys.version_info[:2] >= (3, 7):
+            all_tasks = asyncio.all_tasks
+        else:
+            all_tasks = asyncio.Task.all_tasks
         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
-        to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
+        to_cancel = [task for task in all_tasks(loop) if not task.done()]
         if not to_cancel:
             return
 
         if not to_cancel:
             return
 
@@ -3597,16 +3648,16 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
+def get_cache_file(mode: FileMode) -> Path:
+    return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
 
 
-def read_cache(line_length: int, mode: FileMode) -> Cache:
+def read_cache(mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     if not cache_file.exists():
         return {}
 
     if not cache_file.exists():
         return {}
 
@@ -3641,17 +3692,15 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     return todo, done
 
 
     return todo, done
 
 
-def write_cache(
-    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
-) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
     """Update the cache file."""
     """Update the cache file."""
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     try:
     try:
-        if not CACHE_DIR.exists():
-            CACHE_DIR.mkdir(parents=True)
+        CACHE_DIR.mkdir(parents=True, exist_ok=True)
         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
-        with cache_file.open("wb") as fobj:
-            pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
+        with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
+            pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
+        os.replace(f.name, cache_file)
     except OSError:
         pass
 
     except OSError:
         pass