]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Put cursor in last line if old position is invalid (#641)
[etc/vim.git] / black.py
index 9ac16c02c2980c5e9aa48764f2f8c372ebb0c393..2850ae1a19cb0038ce5acd4bbb86837db171cda0 100644 (file)
--- a/black.py
+++ b/black.py
@@ -2,18 +2,19 @@ import asyncio
 from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from datetime import datetime
-from enum import Enum, Flag
+from enum import Enum
 from functools import lru_cache, partial, wraps
 import io
-import keyword
+import itertools
 import logging
-from multiprocessing import Manager
+from multiprocessing import Manager, freeze_support
 import os
 from pathlib import Path
 import pickle
 import re
 import signal
 import sys
+import tempfile
 import tokenize
 from typing import (
     Any,
@@ -36,7 +37,7 @@ from typing import (
 )
 
 from appdirs import user_cache_dir
-from attr import dataclass, Factory
+from attr import dataclass, evolve, Factory
 import click
 import toml
 
@@ -44,13 +45,14 @@ import toml
 from blib2to3.pytree import Node, Leaf, type_repr
 from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
+from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pgen2.parse import ParseError
 
 
 __version__ = "18.9b0"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
-    r"/(\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
+    r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
 )
 DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
@@ -110,32 +112,84 @@ class Changed(Enum):
     YES = 2
 
 
-class FileMode(Flag):
-    AUTO_DETECT = 0
-    PYTHON36 = 1
-    PYI = 2
-    NO_STRING_NORMALIZATION = 4
-    NO_NUMERIC_UNDERSCORE_NORMALIZATION = 8
+class TargetVersion(Enum):
+    PYPY35 = 1
+    CPY27 = 2
+    CPY33 = 3
+    CPY34 = 4
+    CPY35 = 5
+    CPY36 = 6
+    CPY37 = 7
+    CPY38 = 8
+
+    def is_python2(self) -> bool:
+        return self is TargetVersion.CPY27
+
+
+PY36_VERSIONS = {TargetVersion.CPY36, TargetVersion.CPY37, TargetVersion.CPY38}
+
+
+class Feature(Enum):
+    # All string literals are unicode
+    UNICODE_LITERALS = 1
+    F_STRINGS = 2
+    NUMERIC_UNDERSCORES = 3
+    TRAILING_COMMA = 4
+
+
+VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
+    TargetVersion.CPY27: set(),
+    TargetVersion.PYPY35: {Feature.UNICODE_LITERALS, Feature.F_STRINGS},
+    TargetVersion.CPY33: {Feature.UNICODE_LITERALS},
+    TargetVersion.CPY34: {Feature.UNICODE_LITERALS},
+    TargetVersion.CPY35: {Feature.UNICODE_LITERALS, Feature.TRAILING_COMMA},
+    TargetVersion.CPY36: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA,
+    },
+    TargetVersion.CPY37: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA,
+    },
+    TargetVersion.CPY38: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA,
+    },
+}
 
-    @classmethod
-    def from_configuration(
-        cls,
-        *,
-        py36: bool,
-        pyi: bool,
-        skip_string_normalization: bool,
-        skip_numeric_underscore_normalization: bool,
-    ) -> "FileMode":
-        mode = cls.AUTO_DETECT
-        if py36:
-            mode |= cls.PYTHON36
-        if pyi:
-            mode |= cls.PYI
-        if skip_string_normalization:
-            mode |= cls.NO_STRING_NORMALIZATION
-        if skip_numeric_underscore_normalization:
-            mode |= cls.NO_NUMERIC_UNDERSCORE_NORMALIZATION
-        return mode
+
+@dataclass
+class FileMode:
+    target_versions: Set[TargetVersion] = Factory(set)
+    line_length: int = DEFAULT_LINE_LENGTH
+    string_normalization: bool = True
+    is_pyi: bool = False
+
+    def get_cache_key(self) -> str:
+        if self.target_versions:
+            version_str = ",".join(
+                str(version.value)
+                for version in sorted(self.target_versions, key=lambda v: v.value)
+            )
+        else:
+            version_str = "-"
+        parts = [
+            version_str,
+            str(self.line_length),
+            str(int(self.string_normalization)),
+            str(int(self.is_pyi)),
+        ]
+        return ".".join(parts)
+
+
+def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
+    return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
 def read_pyproject_toml(
@@ -159,7 +213,9 @@ def read_pyproject_toml(
         pyproject_toml = toml.load(value)
         config = pyproject_toml.get("tool", {}).get("black", {})
     except (toml.TomlDecodeError, OSError) as e:
-        raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
+        raise click.FileError(
+            filename=value, hint=f"Error reading configuration file: {e}"
+        )
 
     if not config:
         return None
@@ -182,12 +238,14 @@ def read_pyproject_toml(
     show_default=True,
 )
 @click.option(
-    "--py36",
-    is_flag=True,
+    "-t",
+    "--target-version",
+    type=click.Choice([v.name.lower() for v in TargetVersion]),
+    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    multiple=True,
     help=(
-        "Allow using Python 3.6-only syntax on all input files.  This will put "
-        "trailing commas in function signatures and calls also after *args and "
-        "**kwargs.  [default: per-file auto-detection]"
+        "Python versions that should be supported by Black's output. [default: "
+        "per-file auto-detection]"
     ),
 )
 @click.option(
@@ -204,12 +262,6 @@ def read_pyproject_toml(
     is_flag=True,
     help="Don't normalize string quotes or prefixes.",
 )
-@click.option(
-    "-N",
-    "--skip-numeric-underscore-normalization",
-    is_flag=True,
-    help="Don't normalize underscores in numeric literals.",
-)
 @click.option(
     "--check",
     is_flag=True,
@@ -294,13 +346,12 @@ def read_pyproject_toml(
 def main(
     ctx: click.Context,
     line_length: int,
+    target_version: List[TargetVersion],
     check: bool,
     diff: bool,
     fast: bool,
     pyi: bool,
-    py36: bool,
     skip_string_normalization: bool,
-    skip_numeric_underscore_normalization: bool,
     quiet: bool,
     verbose: bool,
     include: str,
@@ -310,11 +361,16 @@ def main(
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
-    mode = FileMode.from_configuration(
-        py36=py36,
-        pyi=pyi,
-        skip_string_normalization=skip_string_normalization,
-        skip_numeric_underscore_normalization=skip_numeric_underscore_normalization,
+    if target_version:
+        versions = set(target_version)
+    else:
+        # We'll autodetect later.
+        versions = set()
+    mode = FileMode(
+        target_versions=versions,
+        line_length=line_length,
+        is_pyi=pyi,
+        string_normalization=not skip_string_normalization,
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
@@ -350,7 +406,6 @@ def main(
     if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
-            line_length=line_length,
             fast=fast,
             write_back=write_back,
             mode=mode,
@@ -363,7 +418,6 @@ def main(
             loop.run_until_complete(
                 schedule_formatting(
                     sources=sources,
-                    line_length=line_length,
                     fast=fast,
                     write_back=write_back,
                     mode=mode,
@@ -382,12 +436,7 @@ def main(
 
 
 def reformat_one(
-    src: Path,
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack,
-    mode: FileMode,
-    report: "Report",
+    src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
@@ -398,29 +447,23 @@ def reformat_one(
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
-            if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back, mode=mode
-            ):
+            if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length, mode)
+                cache = read_cache(mode)
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src,
-                line_length=line_length,
-                fast=fast,
-                write_back=write_back,
-                mode=mode,
+                src, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
                 write_back is WriteBack.CHECK and changed is Changed.NO
             ):
-                write_cache(cache, [src], line_length, mode)
+                write_cache(cache, [src], mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
@@ -428,7 +471,6 @@ def reformat_one(
 
 async def schedule_formatting(
     sources: Set[Path],
-    line_length: int,
     fast: bool,
     write_back: WriteBack,
     mode: FileMode,
@@ -445,7 +487,7 @@ async def schedule_formatting(
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length, mode)
+        cache = read_cache(mode)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
@@ -462,14 +504,7 @@ async def schedule_formatting(
         lock = manager.Lock()
     tasks = {
         loop.run_in_executor(
-            executor,
-            format_file_in_place,
-            src,
-            line_length,
-            fast,
-            write_back,
-            mode,
-            lock,
+            executor, format_file_in_place, src, fast, mode, write_back, lock
         ): src
         for src in sorted(sources)
     }
@@ -500,15 +535,14 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if sources_to_cache:
-        write_cache(cache, sources_to_cache, line_length, mode)
+        write_cache(cache, sources_to_cache, mode)
 
 
 def format_file_in_place(
     src: Path,
-    line_length: int,
     fast: bool,
+    mode: FileMode,
     write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
@@ -518,15 +552,13 @@ def format_file_in_place(
     `line_length` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
-        mode |= FileMode.PYI
+        mode = evolve(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
-        dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, mode=mode
-        )
+        dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
     except NothingChanged:
         return False
 
@@ -556,23 +588,19 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
-    write a diff to stdout.
-    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
     then = datetime.utcnow()
     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
+        dst = format_file_contents(src, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
@@ -593,11 +621,7 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
-    src_contents: str,
-    *,
-    line_length: int,
-    fast: bool,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    src_contents: str, *, fast: bool, mode: FileMode
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
@@ -608,38 +632,36 @@ def format_file_contents(
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
+    dst_contents = format_str(src_contents, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
+        assert_stable(src_contents, dst_contents, mode=mode)
     return dst_contents
 
 
-def format_str(
-    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
-) -> FileContent:
+def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     """Reformat a string and return new contents.
 
     `line_length` determines how many characters per line are allowed.
     """
-    src_node = lib2to3_parse(src_contents.lstrip())
+    src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = ""
     future_imports = get_future_imports(src_node)
-    is_pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
-    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
+    if mode.target_versions:
+        versions = mode.target_versions
+    else:
+        versions = detect_target_versions(src_node)
     normalize_fmt_off(src_node)
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports,
-        is_pyi=is_pyi,
-        normalize_strings=normalize_strings,
-        allow_underscores=py36
-        and not bool(mode & FileMode.NO_NUMERIC_UNDERSCORE_NORMALIZATION),
+        remove_u_prefix="unicode_literals" in future_imports
+        or supports_feature(versions, Feature.UNICODE_LITERALS),
+        is_pyi=mode.is_pyi,
+        normalize_strings=mode.string_normalization,
     )
-    elt = EmptyLineTracker(is_pyi=is_pyi)
+    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
     for current_line in lines.visit(src_node):
@@ -648,7 +670,11 @@ def format_str(
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
             dst_contents += str(empty_line)
-        for line in split_line(current_line, line_length=line_length, py36=py36):
+        for line in split_line(
+            current_line,
+            line_length=mode.line_length,
+            supports_trailing_commas=supports_feature(versions, Feature.TRAILING_COMMA),
+        ):
             dst_contents += str(line)
     return dst_contents
 
@@ -677,11 +703,25 @@ GRAMMARS = [
 ]
 
 
-def lib2to3_parse(src_txt: str) -> Node:
+def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
+    if not target_versions:
+        return GRAMMARS
+    elif all(not version.is_python2() for version in target_versions):
+        # Python 2-compatible code, so don't try Python 3 grammar.
+        return [
+            pygram.python_grammar_no_print_statement_no_exec_statement,
+            pygram.python_grammar_no_print_statement,
+        ]
+    else:
+        return [pygram.python_grammar]
+
+
+def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     if src_txt[-1:] != "\n":
         src_txt += "\n"
-    for grammar in GRAMMARS:
+
+    for grammar in get_grammars(set(target_versions)):
         drv = driver.Driver(grammar, pytree.convert)
         try:
             result = drv.parse_string(src_txt, True)
@@ -774,9 +814,7 @@ class DebugVisitor(Visitor[T]):
         list(v.visit(code))
 
 
-KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-FLOW_CONTROL = {"return", "raise", "break", "continue"}
 STATEMENT = {
     syms.if_stmt,
     syms.while_stmt,
@@ -1025,7 +1063,9 @@ class Line:
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
-    comments: List[Tuple[Index, Leaf]] = Factory(list)
+    # The LeafID keys of comments must remain ordered by the corresponding leaf's index
+    # in leaves
+    comments: Dict[LeafID, List[Leaf]] = Factory(dict)
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
@@ -1232,43 +1272,35 @@ class Line:
         if comment.type != token.COMMENT:
             return False
 
-        after = len(self.leaves) - 1
-        if after == -1:
+        if not self.leaves:
             comment.type = STANDALONE_COMMENT
             comment.prefix = ""
             return False
 
         else:
-            self.comments.append((after, comment))
-            return True
-
-    def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
-        """Generate comments that should appear directly after `leaf`.
-
-        Provide a non-negative leaf `_index` to speed up the function.
-        """
-        if not self.comments:
-            return
-
-        if _index == -1:
-            for _index, _leaf in enumerate(self.leaves):
-                if leaf is _leaf:
-                    break
-
+            leaf_id = id(self.leaves[-1])
+            if leaf_id not in self.comments:
+                self.comments[leaf_id] = [comment]
             else:
-                return
+                self.comments[leaf_id].append(comment)
+            return True
 
-        for index, comment_after in self.comments:
-            if _index == index:
-                yield comment_after
+    def comments_after(self, leaf: Leaf) -> List[Leaf]:
+        """Generate comments that should appear directly after `leaf`."""
+        return self.comments.get(id(leaf), [])
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
-        comma_index = len(self.leaves) - 1
-        for i in range(len(self.comments)):
-            comment_index, comment = self.comments[i]
-            if comment_index == comma_index:
-                self.comments[i] = (comma_index - 1, comment)
+        # Remember, the LeafID keys of self.comments are ordered by the
+        # corresponding leaf's index in self.leaves
+        # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
+        # Otherwise, we insert it into self.comments, and it becomes the last entry.
+        # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
+        # is maintained
+        self.comments.setdefault(id(self.leaves[-2]), []).extend(
+            self.comments.get(id(self.leaves[-1]), [])
+        )
+        self.comments.pop(id(self.leaves[-1]), None)
         self.leaves.pop()
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
@@ -1300,7 +1332,7 @@ class Line:
         res = f"{first.prefix}{indent}{first.value}"
         for leaf in leaves:
             res += str(leaf)
-        for _, comment in self.comments:
+        for comment in itertools.chain.from_iterable(self.comments.values()):
             res += str(comment)
         return res + "\n"
 
@@ -1432,7 +1464,6 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
-    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1475,7 +1506,7 @@ class LineGenerator(Visitor[Line]):
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
-                normalize_numeric_literal(node, self.allow_underscores)
+                normalize_numeric_literal(node)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
@@ -1599,6 +1630,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
+        self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1611,7 +1643,7 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
-def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
+def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
     """Return whitespace prefix if needed for the given `leaf`.
 
     `complex_subscript` signals whether the given leaf is part of a subscription
@@ -1892,7 +1924,7 @@ def container_of(leaf: Leaf) -> LN:
     return container
 
 
-def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
+def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
     """Return the priority of the `leaf` delimiter, given a line break after it.
 
     The delimiter priorities returned here are from those delimiters that would
@@ -1906,7 +1938,7 @@ def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
-def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
+def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> int:
     """Return the priority of the `leaf` delimiter, given a line break before it.
 
     The delimiter priorities returned here are from those delimiters that would
@@ -2098,7 +2130,10 @@ def make_comment(content: str) -> str:
 
 
 def split_line(
-    line: Line, line_length: int, inner: bool = False, py36: bool = False
+    line: Line,
+    line_length: int,
+    inner: bool = False,
+    supports_trailing_commas: bool = False,
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
@@ -2107,16 +2142,26 @@ def split_line(
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `py36` is True, splitting may generate syntax that is only compatible
-    with Python 3.6 and later.
+    If `supports_trailing_commas` is True, splitting may use the TRAILING_COMMA feature.
     """
     if line.is_comment:
         yield line
         return
 
     line_str = str(line).strip("\n")
-    if not line.should_explode and is_line_short_enough(
-        line, line_length=line_length, line_str=line_str
+
+    # we don't want to split special comments like type annotations
+    # https://github.com/python/typing/issues/186
+    has_special_comment = False
+    for leaf in line.leaves:
+        for comment in line.comments_after(leaf):
+            if leaf.type == token.COMMA and is_special_comment(comment):
+                has_special_comment = True
+
+    if (
+        not has_special_comment
+        and not line.should_explode
+        and is_line_short_enough(line, line_length=line_length, line_str=line_str)
     ):
         yield line
         return
@@ -2126,9 +2171,13 @@ def split_line(
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, supports_trailing_commas: bool = False) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(right_hand_split(line, line_length, py36, omit=omit))
+                lines = list(
+                    right_hand_split(
+                        line, line_length, supports_trailing_commas, omit=omit
+                    )
+                )
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
@@ -2136,7 +2185,7 @@ def split_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, py36)
+            yield from right_hand_split(line, supports_trailing_commas)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
@@ -2148,12 +2197,17 @@ def split_line(
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, py36):
+            for l in split_func(line, supports_trailing_commas):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
-                    split_line(l, line_length=line_length, inner=True, py36=py36)
+                    split_line(
+                        l,
+                        line_length=line_length,
+                        inner=True,
+                        supports_trailing_commas=supports_trailing_commas,
+                    )
                 )
         except CannotSplit:
             continue
@@ -2166,7 +2220,9 @@ def split_line(
         yield line
 
 
-def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def left_hand_split(
+    line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
@@ -2203,7 +2259,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 def right_hand_split(
-    line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
+    line: Line,
+    line_length: int,
+    supports_trailing_commas: bool = False,
+    omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
@@ -2261,7 +2320,12 @@ def right_hand_split(
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            yield from right_hand_split(
+                line,
+                line_length,
+                supports_trailing_commas=supports_trailing_commas,
+                omit=omit,
+            )
             return
 
         except CannotSplit:
@@ -2350,8 +2414,10 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """
 
     @wraps(split_func)
-    def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
-        for l in split_func(line, py36):
+    def split_wrapper(
+        line: Line, supports_trailing_commas: bool = False
+    ) -> Iterator[Line]:
+        for l in split_func(line, supports_trailing_commas):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
@@ -2359,11 +2425,13 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
 
 @dont_increase_indentation
-def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def delimiter_split(
+    line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
-    If `py36` is True, the split will add trailing commas also in function
-    signatures that contain `*` and `**`.
+    If `supports_trailing_commas` is True, the split will add trailing commas
+    also in function signatures that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
@@ -2395,17 +2463,17 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
         if leaf.bracket_depth == lowest_depth and is_vararg(
             leaf, within=VARARGS_PARENTS
         ):
-            trailing_comma_safe = trailing_comma_safe and py36
+            trailing_comma_safe = trailing_comma_safe and supports_trailing_commas
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
@@ -2423,7 +2491,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 @dont_increase_indentation
-def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def standalone_comment_split(
+    line: Line, supports_trailing_commas: bool = False
+) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
@@ -2441,10 +2511,10 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
     if current_line:
@@ -2465,6 +2535,16 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
+def is_special_comment(leaf: Leaf) -> bool:
+    """Return True if the given leaf is a special comment.
+    Only returns true for type comments for now."""
+    t = leaf.type
+    v = leaf.value
+    return bool(
+        (t == token.COMMENT or t == STANDALONE_COMMENT) and (v.startswith("# type:"))
+    )
+
+
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
     """Leave existing extra newlines if not `inside_brackets`. Remove everything
     else.
@@ -2566,11 +2646,11 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
-def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+def normalize_numeric_literal(leaf: Leaf) -> None:
     """Normalizes numeric (float, int, and complex) literals.
 
     All letters used in the representation are normalized to lowercase (except
-    in Python 2 long literals), and long number literals are split using underscores.
+    in Python 2 long literals).
     """
     text = leaf.value.lower()
     if text.startswith(("0o", "0b")):
@@ -2588,8 +2668,7 @@ def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
-        before = format_float_or_int_string(before, allow_underscores)
-        after = format_int_string(after, allow_underscores)
+        before = format_float_or_int_string(before)
         text = f"{before}e{sign}{after}"
     elif text.endswith(("j", "l")):
         number = text[:-1]
@@ -2597,50 +2676,19 @@ def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
         # Capitalize in "2L" because "l" looks too similar to "1".
         if suffix == "l":
             suffix = "L"
-        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+        text = f"{format_float_or_int_string(number)}{suffix}"
     else:
-        text = format_float_or_int_string(text, allow_underscores)
+        text = format_float_or_int_string(text)
     leaf.value = text
 
 
-def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+def format_float_or_int_string(text: str) -> str:
     """Formats a float string like "1.0"."""
     if "." not in text:
-        return format_int_string(text, allow_underscores)
-
-    before, after = text.split(".")
-    before = format_int_string(before, allow_underscores) if before else "0"
-    if after:
-        after = format_int_string(after, allow_underscores, count_from_end=False)
-    else:
-        after = "0"
-    return f"{before}.{after}"
-
-
-def format_int_string(
-    text: str, allow_underscores: bool, count_from_end: bool = True
-) -> str:
-    """Normalizes underscores in a string to e.g. 1_000_000.
-
-    Input must be a string of digits and optional underscores.
-    If count_from_end is False, we add underscores after groups of three digits
-    counting from the beginning instead of the end of the strings. This is used
-    for the fractional part of float literals.
-    """
-    if not allow_underscores:
         return text
 
-    text = text.replace("_", "")
-    if len(text) <= 5:
-        # No underscores for numbers <= 5 digits long.
-        return text
-
-    if count_from_end:
-        # Avoid removing leading zeros, which are important if we're formatting
-        # part of a number like "0.001".
-        return format(int("1" + text), "3_")[1:].lstrip("_")
-    else:
-        return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
+    before, after = text.split(".")
+    return f"{before or 0}.{after or 0}"
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
@@ -2954,6 +3002,7 @@ def ensure_visible(leaf: Leaf) -> None:
 
 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
+
     if not (
         opening_bracket.parent
         and opening_bracket.parent.type in {syms.atom, syms.import_from}
@@ -2971,23 +3020,24 @@ def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     return max_priority == COMMA_PRIORITY
 
 
-def is_python36(node: Node) -> bool:
-    """Return True if the current file is using Python 3.6+ features.
+def get_features_used(node: Node) -> Set[Feature]:
+    """Return a set of (relatively) new Python features used in this file.
 
     Currently looking for:
     - f-strings;
     - underscores in numeric literals; and
     - trailing commas after * or ** in function signatures and calls.
     """
+    features: Set[Feature] = set()
     for n in node.pre_order():
         if n.type == token.STRING:
             value_head = n.value[:2]  # type: ignore
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
-                return True
+                features.add(Feature.F_STRINGS)
 
         elif n.type == token.NUMBER:
             if "_" in n.value:  # type: ignore
-                return True
+                features.add(Feature.NUMERIC_UNDERSCORES)
 
         elif (
             n.type in {syms.typedargslist, syms.arglist}
@@ -2996,14 +3046,22 @@ def is_python36(node: Node) -> bool:
         ):
             for ch in n.children:
                 if ch.type in STARS:
-                    return True
+                    features.add(Feature.TRAILING_COMMA)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            return True
+                            features.add(Feature.TRAILING_COMMA)
 
-    return False
+    return features
+
+
+def detect_target_versions(node: Node) -> Set[TargetVersion]:
+    """Detect the version to target based on the nodes used."""
+    features = get_features_used(node)
+    return {
+        version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
+    }
 
 
 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
@@ -3277,7 +3335,16 @@ def assert_equivalent(src: str, dst: str) -> None:
 
             if isinstance(value, list):
                 for item in value:
-                    if isinstance(item, ast.AST):
+                    # Ignore nested tuples within del statements, because we may insert
+                    # parentheses and they change the AST.
+                    if (
+                        field == "targets"
+                        and isinstance(node, ast.Delete)
+                        and isinstance(item, ast.Tuple)
+                    ):
+                        for item in item.elts:
+                            yield from _v(item, depth + 2)
+                    elif isinstance(item, ast.AST):
                         yield from _v(item, depth + 2)
 
             elif isinstance(value, ast.AST):
@@ -3320,11 +3387,9 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
-def assert_stable(
-    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
-) -> None:
+def assert_stable(src: str, dst: str, mode: FileMode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, mode=mode)
+    newdst = format_str(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3436,7 +3501,7 @@ def enumerate_with_length(
             return  # Multiline strings, we can't continue.
 
         comment: Optional[Leaf]
-        for comment in line.comments_after(leaf, index):
+        for comment in line.comments_after(leaf):
             length += len(comment.value)
 
         yield index, leaf, length
@@ -3581,16 +3646,16 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
-def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
+def get_cache_file(mode: FileMode) -> Path:
+    return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
-def read_cache(line_length: int, mode: FileMode) -> Cache:
+def read_cache(mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     if not cache_file.exists():
         return {}
 
@@ -3625,17 +3690,15 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     return todo, done
 
 
-def write_cache(
-    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
-) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
     """Update the cache file."""
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     try:
-        if not CACHE_DIR.exists():
-            CACHE_DIR.mkdir(parents=True)
+        CACHE_DIR.mkdir(parents=True, exist_ok=True)
         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
-        with cache_file.open("wb") as fobj:
-            pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
+        with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
+            pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
+        os.replace(f.name, cache_file)
     except OSError:
         pass
 
@@ -3663,6 +3726,7 @@ def patch_click() -> None:
 
 
 def patched_main() -> None:
+    freeze_support()
     patch_click()
     main()