]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix unstable formatting involving unwrapping multiple parentheses (#836) (#961)
[etc/vim.git] / black.py
index 8d53194680309ce664af0e0a77105c31c2343824..f178aceb4dc9820d8b3562a29b2c2295de1341af 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,20 +1,23 @@
+import ast
 import asyncio
 import asyncio
-from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
 from concurrent.futures import Executor, ProcessPoolExecutor
+from contextlib import contextmanager
 from datetime import datetime
 from datetime import datetime
-from enum import Enum, Flag
+from enum import Enum
 from functools import lru_cache, partial, wraps
 import io
 from functools import lru_cache, partial, wraps
 import io
-import keyword
+import itertools
 import logging
 import logging
-from multiprocessing import Manager
+from multiprocessing import Manager, freeze_support
 import os
 from pathlib import Path
 import pickle
 import re
 import signal
 import sys
 import os
 from pathlib import Path
 import pickle
 import re
 import signal
 import sys
+import tempfile
 import tokenize
 import tokenize
+import traceback
 from typing import (
     Any,
     Callable,
 from typing import (
     Any,
     Callable,
@@ -36,24 +39,30 @@ from typing import (
 )
 
 from appdirs import user_cache_dir
 )
 
 from appdirs import user_cache_dir
-from attr import dataclass, Factory
+from attr import dataclass, evolve, Factory
 import click
 import toml
 import click
 import toml
+from typed_ast import ast3, ast27
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
 from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
 from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
+from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pgen2.parse import ParseError
 
 from blib2to3.pgen2.parse import ParseError
 
+from _version import get_versions
+
+v = get_versions()
+__version__ = v.get("closest-tag", v["version"])
+__git_version__ = v.get("full-revisionid")
 
 
-__version__ = "18.6b4"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
-    r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
+    r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
 )
 DEFAULT_INCLUDES = r"\.pyi?$"
 )
 DEFAULT_INCLUDES = r"\.pyi?$"
-CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+CACHE_DIR = Path(user_cache_dir("black", version=__git_version__))
 
 
 # types
 
 
 # types
@@ -66,7 +75,7 @@ LeafID = int
 Priority = int
 Index = int
 LN = Union[Leaf, Node]
 Priority = int
 Index = int
 LN = Union[Leaf, Node]
-SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
@@ -79,15 +88,15 @@ syms = pygram.python_symbols
 
 
 class NothingChanged(UserWarning):
 
 
 class NothingChanged(UserWarning):
-    """Raised by :func:`format_file` when reformatted code is the same as source."""
+    """Raised when reformatted code is the same as source."""
 
 
 class CannotSplit(Exception):
 
 
 class CannotSplit(Exception):
-    """A readable split that fits the allotted line length is impossible.
+    """A readable split that fits the allotted line length is impossible."""
 
 
-    Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
-    :func:`delimiter_split`.
-    """
+
+class InvalidInput(ValueError):
+    """Raised when input source code fails all parse attempts."""
 
 
 class WriteBack(Enum):
 
 
 class WriteBack(Enum):
@@ -110,24 +119,101 @@ class Changed(Enum):
     YES = 2
 
 
     YES = 2
 
 
-class FileMode(Flag):
-    AUTO_DETECT = 0
-    PYTHON36 = 1
-    PYI = 2
-    NO_STRING_NORMALIZATION = 4
+class TargetVersion(Enum):
+    PY27 = 2
+    PY33 = 3
+    PY34 = 4
+    PY35 = 5
+    PY36 = 6
+    PY37 = 7
+    PY38 = 8
+
+    def is_python2(self) -> bool:
+        return self is TargetVersion.PY27
+
+
+PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
+
+
+class Feature(Enum):
+    # All string literals are unicode
+    UNICODE_LITERALS = 1
+    F_STRINGS = 2
+    NUMERIC_UNDERSCORES = 3
+    TRAILING_COMMA_IN_CALL = 4
+    TRAILING_COMMA_IN_DEF = 5
+    # The following two feature-flags are mutually exclusive, and exactly one should be
+    # set for every version of python.
+    ASYNC_IDENTIFIERS = 6
+    ASYNC_KEYWORDS = 7
+    ASSIGNMENT_EXPRESSIONS = 8
+    POS_ONLY_ARGUMENTS = 9
+
+
+VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
+    TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY35: {
+        Feature.UNICODE_LITERALS,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.ASYNC_IDENTIFIERS,
+    },
+    TargetVersion.PY36: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_IDENTIFIERS,
+    },
+    TargetVersion.PY37: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
+    },
+    TargetVersion.PY38: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
+        Feature.ASSIGNMENT_EXPRESSIONS,
+        Feature.POS_ONLY_ARGUMENTS,
+    },
+}
+
 
 
-    @classmethod
-    def from_configuration(
-        cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
-    ) -> "FileMode":
-        mode = cls.AUTO_DETECT
-        if py36:
-            mode |= cls.PYTHON36
-        if pyi:
-            mode |= cls.PYI
-        if skip_string_normalization:
-            mode |= cls.NO_STRING_NORMALIZATION
-        return mode
+@dataclass
+class FileMode:
+    target_versions: Set[TargetVersion] = Factory(set)
+    line_length: int = DEFAULT_LINE_LENGTH
+    string_normalization: bool = True
+    is_pyi: bool = False
+
+    def get_cache_key(self) -> str:
+        if self.target_versions:
+            version_str = ",".join(
+                str(version.value)
+                for version in sorted(self.target_versions, key=lambda v: v.value)
+            )
+        else:
+            version_str = "-"
+        parts = [
+            version_str,
+            str(self.line_length),
+            str(int(self.string_normalization)),
+            str(int(self.is_pyi)),
+        ]
+        return ".".join(parts)
+
+
+def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
+    return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
 def read_pyproject_toml(
 
 
 def read_pyproject_toml(
@@ -151,7 +237,9 @@ def read_pyproject_toml(
         pyproject_toml = toml.load(value)
         config = pyproject_toml.get("tool", {}).get("black", {})
     except (toml.TomlDecodeError, OSError) as e:
         pyproject_toml = toml.load(value)
         config = pyproject_toml.get("tool", {}).get("black", {})
     except (toml.TomlDecodeError, OSError) as e:
-        raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
+        raise click.FileError(
+            filename=value, hint=f"Error reading configuration file: {e}"
+        )
 
     if not config:
         return None
 
     if not config:
         return None
@@ -165,6 +253,7 @@ def read_pyproject_toml(
 
 
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 
 
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
     "-l",
     "--line-length",
 @click.option(
     "-l",
     "--line-length",
@@ -173,13 +262,25 @@ def read_pyproject_toml(
     help="How many characters per line to allow.",
     show_default=True,
 )
     help="How many characters per line to allow.",
     show_default=True,
 )
+@click.option(
+    "-t",
+    "--target-version",
+    type=click.Choice([v.name.lower() for v in TargetVersion]),
+    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    multiple=True,
+    help=(
+        "Python versions that should be supported by Black's output. [default: "
+        "per-file auto-detection]"
+    ),
+)
 @click.option(
     "--py36",
     is_flag=True,
     help=(
         "Allow using Python 3.6-only syntax on all input files.  This will put "
         "trailing commas in function signatures and calls also after *args and "
 @click.option(
     "--py36",
     is_flag=True,
     help=(
         "Allow using Python 3.6-only syntax on all input files.  This will put "
         "trailing commas in function signatures and calls also after *args and "
-        "**kwargs.  [default: per-file auto-detection]"
+        "**kwargs. Deprecated; use --target-version instead. "
+        "[default: per-file auto-detection]"
     ),
 )
 @click.option(
     ),
 )
 @click.option(
@@ -245,7 +346,7 @@ def read_pyproject_toml(
     "--quiet",
     is_flag=True,
     help=(
     "--quiet",
     is_flag=True,
     help=(
-        "Don't emit non-error messages to stderr. Errors are still emitted, "
+        "Don't emit non-error messages to stderr. Errors are still emitted; "
         "silence those with 2>/dev/null."
     ),
 )
         "silence those with 2>/dev/null."
     ),
 )
@@ -279,7 +380,9 @@ def read_pyproject_toml(
 @click.pass_context
 def main(
     ctx: click.Context,
 @click.pass_context
 def main(
     ctx: click.Context,
+    code: Optional[str],
     line_length: int,
     line_length: int,
+    target_version: List[TargetVersion],
     check: bool,
     diff: bool,
     fast: bool,
     check: bool,
     diff: bool,
     fast: bool,
@@ -295,11 +398,32 @@ def main(
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
-    mode = FileMode.from_configuration(
-        py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
+    if target_version:
+        if py36:
+            err(f"Cannot use both --target-version and --py36")
+            ctx.exit(2)
+        else:
+            versions = set(target_version)
+    elif py36:
+        err(
+            "--py36 is deprecated and will be removed in a future version. "
+            "Use --target-version py36 instead."
+        )
+        versions = PY36_VERSIONS
+    else:
+        # We'll autodetect later.
+        versions = set()
+    mode = FileMode(
+        target_versions=versions,
+        line_length=line_length,
+        is_pyi=pyi,
+        string_normalization=not skip_string_normalization,
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
+    if code is not None:
+        print(format_str(code, mode=mode))
+        ctx.exit(0)
     try:
         include_regex = re_compile_maybe_verbose(include)
     except re.error:
     try:
         include_regex = re_compile_maybe_verbose(include)
     except re.error:
@@ -332,102 +456,105 @@ def main(
     if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
     if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
-            line_length=line_length,
             fast=fast,
             write_back=write_back,
             mode=mode,
             report=report,
         )
     else:
             fast=fast,
             write_back=write_back,
             mode=mode,
             report=report,
         )
     else:
-        loop = asyncio.get_event_loop()
-        executor = ProcessPoolExecutor(max_workers=os.cpu_count())
-        try:
-            loop.run_until_complete(
-                schedule_formatting(
-                    sources=sources,
-                    line_length=line_length,
-                    fast=fast,
-                    write_back=write_back,
-                    mode=mode,
-                    report=report,
-                    loop=loop,
-                    executor=executor,
-                )
-            )
-        finally:
-            shutdown(loop)
+        reformat_many(
+            sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+        )
+
     if verbose or not quiet:
     if verbose or not quiet:
-        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
-        out(f"All done! {bang}")
+        out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
         click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
 
 def reformat_one(
         click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
 
 def reformat_one(
-    src: Path,
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack,
-    mode: FileMode,
-    report: "Report",
+    src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
-    If `quiet` is True, non-error messages are not output. `line_length`,
-    `write_back`, `fast` and `pyi` options are passed to
+    `fast`, `write_back`, and `mode` options are passed to
     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
-            if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back, mode=mode
-            ):
+            if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length, mode)
+                cache = read_cache(mode)
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src,
-                line_length=line_length,
-                fast=fast,
-                write_back=write_back,
-                mode=mode,
+                src, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
                 write_back is WriteBack.CHECK and changed is Changed.NO
             ):
             ):
                 changed = Changed.YES
             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
                 write_back is WriteBack.CHECK and changed is Changed.NO
             ):
-                write_cache(cache, [src], line_length, mode)
+                write_cache(cache, [src], mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
 
 
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
 
 
+def reformat_many(
+    sources: Set[Path],
+    fast: bool,
+    write_back: WriteBack,
+    mode: FileMode,
+    report: "Report",
+) -> None:
+    """Reformat multiple files using a ProcessPoolExecutor."""
+    loop = asyncio.get_event_loop()
+    worker_count = os.cpu_count()
+    if sys.platform == "win32":
+        # Work around https://bugs.python.org/issue26903
+        worker_count = min(worker_count, 61)
+    executor = ProcessPoolExecutor(max_workers=worker_count)
+    try:
+        loop.run_until_complete(
+            schedule_formatting(
+                sources=sources,
+                fast=fast,
+                write_back=write_back,
+                mode=mode,
+                report=report,
+                loop=loop,
+                executor=executor,
+            )
+        )
+    finally:
+        shutdown(loop)
+        executor.shutdown()
+
+
 async def schedule_formatting(
     sources: Set[Path],
 async def schedule_formatting(
     sources: Set[Path],
-    line_length: int,
     fast: bool,
     write_back: WriteBack,
     mode: FileMode,
     report: "Report",
     fast: bool,
     write_back: WriteBack,
     mode: FileMode,
     report: "Report",
-    loop: BaseEventLoop,
+    loop: asyncio.AbstractEventLoop,
     executor: Executor,
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
     executor: Executor,
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
-    `line_length`, `write_back`, `fast`, and `pyi` options are passed to
+    `write_back`, `fast`, and `mode` options are passed to
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length, mode)
+        cache = read_cache(mode)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
@@ -443,19 +570,14 @@ async def schedule_formatting(
         manager = Manager()
         lock = manager.Lock()
     tasks = {
         manager = Manager()
         lock = manager.Lock()
     tasks = {
-        loop.run_in_executor(
-            executor,
-            format_file_in_place,
-            src,
-            line_length,
-            fast,
-            write_back,
-            mode,
-            lock,
+        asyncio.ensure_future(
+            loop.run_in_executor(
+                executor, format_file_in_place, src, fast, mode, write_back, lock
+            )
         ): src
         for src in sorted(sources)
     }
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable[asyncio.Task] = tasks.keys()
+    pending: Iterable[asyncio.Future] = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -482,33 +604,30 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if sources_to_cache:
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if sources_to_cache:
-        write_cache(cache, sources_to_cache, line_length, mode)
+        write_cache(cache, sources_to_cache, mode)
 
 
 def format_file_in_place(
     src: Path,
 
 
 def format_file_in_place(
     src: Path,
-    line_length: int,
     fast: bool,
     fast: bool,
+    mode: FileMode,
     write_back: WriteBack = WriteBack.NO,
     write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
     code to the file.
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
     code to the file.
-    `line_length` and `fast` options are passed to :func:`format_file_contents`.
+    `mode` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
     """
     if src.suffix == ".pyi":
-        mode |= FileMode.PYI
+        mode = evolve(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
-        dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, mode=mode
-        )
+        dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
     except NothingChanged:
         return False
 
     except NothingChanged:
         return False
 
@@ -520,9 +639,8 @@ def format_file_in_place(
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
-        if lock:
-            lock.acquire()
-        try:
+
+        with lock or nullcontext():
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
@@ -531,30 +649,24 @@ def format_file_in_place(
             )
             f.write(diff_contents)
             f.detach()
             )
             f.write(diff_contents)
             f.detach()
-        finally:
-            if lock:
-                lock.release()
+
     return True
 
 
 def format_stdin_to_stdout(
     return True
 
 
 def format_stdin_to_stdout(
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
-    write a diff to stdout.
-    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
     then = datetime.utcnow()
     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
     :func:`format_file_contents`.
     """
     then = datetime.utcnow()
     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
+        dst = format_file_contents(src, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
         return True
 
     except NothingChanged:
@@ -575,63 +687,66 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
 
 
 def format_file_contents(
-    src_contents: str,
-    *,
-    line_length: int,
-    fast: bool,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    src_contents: str, *, fast: bool, mode: FileMode
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
-    `line_length` is passed to :func:`format_str`.
+    `mode` is passed to :func:`format_str`.
     """
     if src_contents.strip() == "":
         raise NothingChanged
 
     """
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
+    dst_contents = format_str(src_contents, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
+        assert_stable(src_contents, dst_contents, mode=mode)
     return dst_contents
 
 
     return dst_contents
 
 
-def format_str(
-    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
-) -> FileContent:
+def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     """Reformat a string and return new contents.
 
     """Reformat a string and return new contents.
 
-    `line_length` determines how many characters per line are allowed.
+    `mode` determines formatting options, such as how many characters per line are
+    allowed.
     """
     """
-    src_node = lib2to3_parse(src_contents)
-    dst_contents = ""
+    src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
+    dst_contents = []
     future_imports = get_future_imports(src_node)
     future_imports = get_future_imports(src_node)
-    is_pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
-    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
+    if mode.target_versions:
+        versions = mode.target_versions
+    else:
+        versions = detect_target_versions(src_node)
     normalize_fmt_off(src_node)
     lines = LineGenerator(
     normalize_fmt_off(src_node)
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports,
-        is_pyi=is_pyi,
-        normalize_strings=normalize_strings,
-        allow_underscores=py36,
+        remove_u_prefix="unicode_literals" in future_imports
+        or supports_feature(versions, Feature.UNICODE_LITERALS),
+        is_pyi=mode.is_pyi,
+        normalize_strings=mode.string_normalization,
     )
     )
-    elt = EmptyLineTracker(is_pyi=is_pyi)
+    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
     empty_line = Line()
     after = 0
+    split_line_features = {
+        feature
+        for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
+        if supports_feature(versions, feature)
+    }
     for current_line in lines.visit(src_node):
         for _ in range(after):
     for current_line in lines.visit(src_node):
         for _ in range(after):
-            dst_contents += str(empty_line)
+            dst_contents.append(str(empty_line))
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
-            dst_contents += str(empty_line)
-        for line in split_line(current_line, line_length=line_length, py36=py36):
-            dst_contents += str(line)
-    return dst_contents
+            dst_contents.append(str(empty_line))
+        for line in split_line(
+            current_line, line_length=mode.line_length, features=split_line_features
+        ):
+            dst_contents.append(str(line))
+    return "".join(dst_contents)
 
 
 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
 
 
 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
@@ -651,19 +766,50 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
         return tiow.read(), encoding, newline
 
 
         return tiow.read(), encoding, newline
 
 
-GRAMMARS = [
-    pygram.python_grammar_no_print_statement_no_exec_statement,
-    pygram.python_grammar_no_print_statement,
-    pygram.python_grammar,
-]
+def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
+    if not target_versions:
+        # No target_version specified, so try all grammars.
+        return [
+            # Python 3.7+
+            pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
+            # Python 3.0-3.6
+            pygram.python_grammar_no_print_statement_no_exec_statement,
+            # Python 2.7 with future print_function import
+            pygram.python_grammar_no_print_statement,
+            # Python 2.7
+            pygram.python_grammar,
+        ]
+    elif all(version.is_python2() for version in target_versions):
+        # Python 2-only code, so try Python 2 grammars.
+        return [
+            # Python 2.7 with future print_function import
+            pygram.python_grammar_no_print_statement,
+            # Python 2.7
+            pygram.python_grammar,
+        ]
+    else:
+        # Python 3-compatible code, so only try Python 3 grammar.
+        grammars = []
+        # If we have to parse both, try to parse async as a keyword first
+        if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
+            # Python 3.7+
+            grammars.append(
+                pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
+            )
+        if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
+            # Python 3.0-3.6
+            grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
+        # At least one of the above branches must have been taken, because every Python
+        # version has exactly one of the two 'ASYNC_*' flags
+        return grammars
 
 
 
 
-def lib2to3_parse(src_txt: str) -> Node:
+def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
     """Given a string with source, return the lib2to3 Node."""
     """Given a string with source, return the lib2to3 Node."""
-    grammar = pygram.python_grammar_no_print_statement
     if src_txt[-1:] != "\n":
         src_txt += "\n"
     if src_txt[-1:] != "\n":
         src_txt += "\n"
-    for grammar in GRAMMARS:
+
+    for grammar in get_grammars(set(target_versions)):
         drv = driver.Driver(grammar, pytree.convert)
         try:
             result = drv.parse_string(src_txt, True)
         drv = driver.Driver(grammar, pytree.convert)
         try:
             result = drv.parse_string(src_txt, True)
@@ -676,7 +822,7 @@ def lib2to3_parse(src_txt: str) -> Node:
                 faulty_line = lines[lineno - 1]
             except IndexError:
                 faulty_line = "<line number missing in source>"
                 faulty_line = lines[lineno - 1]
             except IndexError:
                 faulty_line = "<line number missing in source>"
-            exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
+            exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
     else:
         raise exc from None
 
     else:
         raise exc from None
 
@@ -756,9 +902,7 @@ class DebugVisitor(Visitor[T]):
         list(v.visit(code))
 
 
         list(v.visit(code))
 
 
-KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-FLOW_CONTROL = {"return", "raise", "break", "continue"}
 STATEMENT = {
     syms.if_stmt,
     syms.while_stmt,
 STATEMENT = {
     syms.if_stmt,
     syms.while_stmt,
@@ -797,6 +941,7 @@ MATH_OPERATORS = {
     token.DOUBLESTAR,
 }
 STARS = {token.STAR, token.DOUBLESTAR}
     token.DOUBLESTAR,
 }
 STARS = {token.STAR, token.DOUBLESTAR}
+VARARGS_SPECIALS = STARS | {token.SLASH}
 VARARGS_PARENTS = {
     syms.arglist,
     syms.argument,  # double star in arglist
 VARARGS_PARENTS = {
     syms.arglist,
     syms.argument,  # double star in arglist
@@ -924,7 +1069,7 @@ class BracketTracker:
         """Return True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
         """Return True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
-    def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
+    def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
         """Return the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_split_*_delimiter()` return.
         """Return the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_split_*_delimiter()` return.
@@ -932,7 +1077,7 @@ class BracketTracker:
         """
         return max(v for k, v in self.delimiters.items() if k not in exclude)
 
         """
         return max(v for k, v in self.delimiters.items() if k not in exclude)
 
-    def delimiter_count_with_priority(self, priority: int = 0) -> int:
+    def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
         """Return the number of delimiters with the given `priority`.
 
         If no `priority` is passed, defaults to max priority on the line.
         """Return the number of delimiters with the given `priority`.
 
         If no `priority` is passed, defaults to max priority on the line.
@@ -1007,7 +1152,7 @@ class Line:
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
-    comments: List[Tuple[Index, Leaf]] = Factory(list)
+    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
@@ -1138,6 +1283,32 @@ class Line:
             if leaf.type == STANDALONE_COMMENT:
                 if leaf.bracket_depth <= depth_limit:
                     return True
             if leaf.type == STANDALONE_COMMENT:
                 if leaf.bracket_depth <= depth_limit:
                     return True
+        return False
+
+    def contains_inner_type_comments(self) -> bool:
+        ignored_ids = set()
+        try:
+            last_leaf = self.leaves[-1]
+            ignored_ids.add(id(last_leaf))
+            if last_leaf.type == token.COMMA or (
+                last_leaf.type == token.RPAR and not last_leaf.value
+            ):
+                # When trailing commas or optional parens are inserted by Black for
+                # consistency, comments after the previous last element are not moved
+                # (they don't have to, rendering will still be correct).  So we ignore
+                # trailing commas and invisible.
+                last_leaf = self.leaves[-2]
+                ignored_ids.add(id(last_leaf))
+        except IndexError:
+            return False
+
+        for leaf_id, comments in self.comments.items():
+            if leaf_id in ignored_ids:
+                continue
+
+            for comment in comments:
+                if is_type_comment(comment):
+                    return True
 
         return False
 
 
         return False
 
@@ -1192,7 +1363,10 @@ class Line:
             bracket_depth = leaf.bracket_depth
             if bracket_depth == depth and leaf.type == token.COMMA:
                 commas += 1
             bracket_depth = leaf.bracket_depth
             if bracket_depth == depth and leaf.type == token.COMMA:
                 commas += 1
-                if leaf.parent and leaf.parent.type == syms.arglist:
+                if leaf.parent and leaf.parent.type in {
+                    syms.arglist,
+                    syms.typedargslist,
+                }:
                     commas += 1
                     break
 
                     commas += 1
                     break
 
@@ -1214,44 +1388,41 @@ class Line:
         if comment.type != token.COMMENT:
             return False
 
         if comment.type != token.COMMENT:
             return False
 
-        after = len(self.leaves) - 1
-        if after == -1:
+        if not self.leaves:
             comment.type = STANDALONE_COMMENT
             comment.prefix = ""
             return False
 
             comment.type = STANDALONE_COMMENT
             comment.prefix = ""
             return False
 
-        else:
-            self.comments.append((after, comment))
-            return True
-
-    def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
-        """Generate comments that should appear directly after `leaf`.
-
-        Provide a non-negative leaf `_index` to speed up the function.
-        """
-        if not self.comments:
-            return
-
-        if _index == -1:
-            for _index, _leaf in enumerate(self.leaves):
-                if leaf is _leaf:
-                    break
-
-            else:
-                return
+        last_leaf = self.leaves[-1]
+        if (
+            last_leaf.type == token.RPAR
+            and not last_leaf.value
+            and last_leaf.parent
+            and len(list(last_leaf.parent.leaves())) <= 3
+            and not is_type_comment(comment)
+        ):
+            # Comments on an optional parens wrapping a single leaf should belong to
+            # the wrapped node except if it's a type comment. Pinning the comment like
+            # this avoids unstable formatting caused by comment migration.
+            if len(self.leaves) < 2:
+                comment.type = STANDALONE_COMMENT
+                comment.prefix = ""
+                return False
+            last_leaf = self.leaves[-2]
+        self.comments.setdefault(id(last_leaf), []).append(comment)
+        return True
 
 
-        for index, comment_after in self.comments:
-            if _index == index:
-                yield comment_after
+    def comments_after(self, leaf: Leaf) -> List[Leaf]:
+        """Generate comments that should appear directly after `leaf`."""
+        return self.comments.get(id(leaf), [])
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
-        comma_index = len(self.leaves) - 1
-        for i in range(len(self.comments)):
-            comment_index, comment = self.comments[i]
-            if comment_index == comma_index:
-                self.comments[i] = (comma_index - 1, comment)
-        self.leaves.pop()
+        trailing_comma = self.leaves.pop()
+        trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
+        self.comments.setdefault(id(self.leaves[-1]), []).extend(
+            trailing_comma_comments
+        )
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
@@ -1282,7 +1453,7 @@ class Line:
         res = f"{first.prefix}{indent}{first.value}"
         for leaf in leaves:
             res += str(leaf)
         res = f"{first.prefix}{indent}{first.value}"
         for leaf in leaves:
             res += str(leaf)
-        for _, comment in self.comments:
+        for comment in itertools.chain.from_iterable(self.comments.values()):
             res += str(comment)
         return res + "\n"
 
             res += str(comment)
         return res + "\n"
 
@@ -1313,7 +1484,13 @@ class EmptyLineTracker:
         lines (two on module-level).
         """
         before, after = self._maybe_empty_lines(current_line)
         lines (two on module-level).
         """
         before, after = self._maybe_empty_lines(current_line)
-        before -= self.previous_after
+        before = (
+            # Black should not insert empty lines at the beginning
+            # of the file
+            0
+            if self.previous_line is None
+            else before - self.previous_after
+        )
         self.previous_after = after
         self.previous_line = current_line
         return before, after
         self.previous_after = after
         self.previous_line = current_line
         return before, after
@@ -1414,7 +1591,6 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
-    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1457,11 +1633,44 @@ class LineGenerator(Visitor[Line]):
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
-                normalize_numeric_literal(node, self.allow_underscores)
+                normalize_numeric_literal(node)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
+    def visit_atom(self, node: Node) -> Iterator[Line]:
+        # Always make parentheses invisible around a single node, because it should
+        # not be needed (except in the case of yield, where removing the parentheses
+        # produces a SyntaxError).
+        if (
+            len(node.children) == 3
+            and isinstance(node.children[0], Leaf)
+            and node.children[0].type == token.LPAR
+            and isinstance(node.children[2], Leaf)
+            and node.children[2].type == token.RPAR
+            and isinstance(node.children[1], Leaf)
+            and not (
+                node.children[1].type == token.NAME
+                and node.children[1].value == "yield"
+            )
+        ):
+            node.children[0].value = ""
+            node.children[2].value = ""
+        yield from super().visit_default(node)
+
+    def visit_factor(self, node: Node) -> Iterator[Line]:
+        """Force parentheses between a unary op and a binary power:
+
+        -2 ** 8 -> -(2 ** 8)
+        """
+        child = node.children[1]
+        if child.type == syms.power and len(child.children) == 3:
+            lpar = Leaf(token.LPAR, "(")
+            rpar = Leaf(token.RPAR, ")")
+            index = child.remove() or 0
+            node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+        yield from self.visit_default(node)
+
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
@@ -1581,6 +1790,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
+        self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1593,7 +1803,7 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
-def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
+def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
     """Return whitespace prefix if needed for the given `leaf`.
 
     `complex_subscript` signals whether the given leaf is part of a subscription
     """Return whitespace prefix if needed for the given `leaf`.
 
     `complex_subscript` signals whether the given leaf is part of a subscription
@@ -1650,7 +1860,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
                     # that, too.
                     return prevp.prefix
 
                     # that, too.
                     return prevp.prefix
 
-        elif prevp.type in STARS:
+        elif prevp.type in VARARGS_SPECIALS:
             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
                 return NO
 
             if is_vararg(prevp, within=VARARGS_PARENTS | UNPACKING_PARENTS):
                 return NO
 
@@ -1740,7 +1950,7 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
             if not prevp or prevp.type == token.LPAR:
                 return NO
 
             if not prevp or prevp.type == token.LPAR:
                 return NO
 
-        elif prev.type in {token.EQUAL} | STARS:
+        elif prev.type in {token.EQUAL} | VARARGS_SPECIALS:
             return NO
 
     elif p.type == syms.decorator:
             return NO
 
     elif p.type == syms.decorator:
@@ -1874,7 +2084,7 @@ def container_of(leaf: Leaf) -> LN:
     return container
 
 
     return container
 
 
-def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
+def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
     """Return the priority of the `leaf` delimiter, given a line break after it.
 
     The delimiter priorities returned here are from those delimiters that would
     """Return the priority of the `leaf` delimiter, given a line break after it.
 
     The delimiter priorities returned here are from those delimiters that would
@@ -1888,7 +2098,7 @@ def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
     return 0
 
 
-def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
+def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
     """Return the priority of the `leaf` delimiter, given a line break before it.
 
     The delimiter priorities returned here are from those delimiters that would
     """Return the priority of the `leaf` delimiter, given a line break before it.
 
     The delimiter priorities returned here are from those delimiters that would
@@ -2013,6 +2223,16 @@ def generate_comments(leaf: LN) -> Iterator[Leaf]:
 
 @dataclass
 class ProtoComment:
 
 @dataclass
 class ProtoComment:
+    """Describes a piece of syntax that is a comment.
+
+    It's not a :class:`blib2to3.pytree.Leaf` so that:
+
+    * it can be cached (`Leaf` objects should not be reused more than once as
+      they store their lineno, column, prefix, and parent information);
+    * `newlines` and `consumed` fields are kept separate from the `value`. This
+      simplifies handling of special marker comments like ``# fmt: off/on``.
+    """
+
     type: int  # token.COMMENT or STANDALONE_COMMENT
     value: str  # content of the comment
     newlines: int  # how many newlines before the comment
     type: int  # token.COMMENT or STANDALONE_COMMENT
     value: str  # content of the comment
     newlines: int  # how many newlines before the comment
@@ -2021,21 +2241,28 @@ class ProtoComment:
 
 @lru_cache(maxsize=4096)
 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
 
 @lru_cache(maxsize=4096)
 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
+    """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
     result: List[ProtoComment] = []
     if not prefix or "#" not in prefix:
         return result
 
     consumed = 0
     nlines = 0
     result: List[ProtoComment] = []
     if not prefix or "#" not in prefix:
         return result
 
     consumed = 0
     nlines = 0
+    ignored_lines = 0
     for index, line in enumerate(prefix.split("\n")):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
             nlines += 1
         if not line.startswith("#"):
     for index, line in enumerate(prefix.split("\n")):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
             nlines += 1
         if not line.startswith("#"):
+            # Escaped newlines outside of a comment are not really newlines at
+            # all. We treat a single-line comment following an escaped newline
+            # as a simple trailing comment.
+            if line.endswith("\\"):
+                ignored_lines += 1
             continue
 
             continue
 
-        if index == 0 and not is_endmarker:
+        if index == ignored_lines and not is_endmarker:
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
@@ -2052,8 +2279,8 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
 def make_comment(content: str) -> str:
     """Return a consistently formatted comment from the given `content` string.
 
 def make_comment(content: str) -> str:
     """Return a consistently formatted comment from the given `content` string.
 
-    All comments (except for "##", "#!", "#:") should have a single space between
-    the hash sign and the content.
+    All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
+    space between the hash sign and the content.
 
     If `content` didn't start with a hash sign, one is provided.
     """
 
     If `content` didn't start with a hash sign, one is provided.
     """
@@ -2063,13 +2290,16 @@ def make_comment(content: str) -> str:
 
     if content[0] == "#":
         content = content[1:]
 
     if content[0] == "#":
         content = content[1:]
-    if content and content[0] not in " !:#":
+    if content and content[0] not in " !:#'%":
         content = " " + content
     return "#" + content
 
 
 def split_line(
         content = " " + content
     return "#" + content
 
 
 def split_line(
-    line: Line, line_length: int, inner: bool = False, py36: bool = False
+    line: Line,
+    line_length: int,
+    inner: bool = False,
+    features: Collection[Feature] = (),
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
@@ -2078,16 +2308,18 @@ def split_line(
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `py36` is True, splitting may generate syntax that is only compatible
-    with Python 3.6 and later.
+    `features` are syntactical features that may be used in the output.
     """
     if line.is_comment:
         yield line
         return
 
     line_str = str(line).strip("\n")
     """
     if line.is_comment:
         yield line
         return
 
     line_str = str(line).strip("\n")
-    if not line.should_explode and is_line_short_enough(
-        line, line_length=line_length, line_str=line_str
+
+    if (
+        not line.contains_inner_type_comments()
+        and not line.should_explode
+        and is_line_short_enough(line, line_length=line_length, line_str=line_str)
     ):
         yield line
         return
     ):
         yield line
         return
@@ -2097,9 +2329,9 @@ def split_line(
         split_funcs = [left_hand_split]
     else:
 
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(right_hand_split(line, line_length, py36, omit=omit))
+                lines = list(right_hand_split(line, line_length, features, omit=omit))
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
@@ -2107,7 +2339,7 @@ def split_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, py36)
+            yield from right_hand_split(line, line_length, features=features)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
@@ -2119,14 +2351,16 @@ def split_line(
         # split altogether.
         result: List[Line] = []
         try:
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, py36):
+            for l in split_func(line, features):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
-                    split_line(l, line_length=line_length, inner=True, py36=py36)
+                    split_line(
+                        l, line_length=line_length, inner=True, features=features
+                    )
                 )
                 )
-        except CannotSplit as cs:
+        except CannotSplit:
             continue
 
         else:
             continue
 
         else:
@@ -2137,16 +2371,13 @@ def split_line(
         yield line
 
 
         yield line
 
 
-def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
     Prefer RHS otherwise.  This is why this function is not symmetrical with
     :func:`right_hand_split` which also handles optional parentheses.
     """
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
     Prefer RHS otherwise.  This is why this function is not symmetrical with
     :func:`right_hand_split` which also handles optional parentheses.
     """
-    head = Line(depth=line.depth)
-    body = Line(depth=line.depth + 1, inside_brackets=True)
-    tail = Line(depth=line.depth)
     tail_leaves: List[Leaf] = []
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     tail_leaves: List[Leaf] = []
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
@@ -2164,15 +2395,12 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
             if leaf.type in OPENING_BRACKETS:
                 matching_bracket = leaf
                 current_leaves = body_leaves
             if leaf.type in OPENING_BRACKETS:
                 matching_bracket = leaf
                 current_leaves = body_leaves
-    # Since body is a new indent level, remove spurious leading whitespace.
-    if body_leaves:
-        normalize_prefix(body_leaves[0], inside_brackets=True)
-    # Build the new lines.
-    for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
-        for leaf in leaves:
-            result.append(leaf, preformatted=True)
-            for comment_after in line.comments_after(leaf):
-                result.append(comment_after, preformatted=True)
+    if not matching_bracket:
+        raise CannotSplit("No brackets found")
+
+    head = bracket_split_build_line(head_leaves, line, matching_bracket)
+    body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
+    tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
     bracket_split_succeeded_or_raise(head, body, tail)
     for result in (head, body, tail):
         if result:
     bracket_split_succeeded_or_raise(head, body, tail)
     for result in (head, body, tail):
         if result:
@@ -2180,7 +2408,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 def right_hand_split(
 
 
 def right_hand_split(
-    line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
+    line: Line,
+    line_length: int,
+    features: Collection[Feature] = (),
+    omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
@@ -2190,9 +2421,6 @@ def right_hand_split(
 
     Note: running this function modifies `bracket_depth` on the leaves of `line`.
     """
 
     Note: running this function modifies `bracket_depth` on the leaves of `line`.
     """
-    head = Line(depth=line.depth)
-    body = Line(depth=line.depth + 1, inside_brackets=True)
-    tail = Line(depth=line.depth)
     tail_leaves: List[Leaf] = []
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     tail_leaves: List[Leaf] = []
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
@@ -2209,25 +2437,18 @@ def right_hand_split(
                 opening_bracket = leaf.opening_bracket
                 closing_bracket = leaf
                 current_leaves = body_leaves
                 opening_bracket = leaf.opening_bracket
                 closing_bracket = leaf
                 current_leaves = body_leaves
-    tail_leaves.reverse()
-    body_leaves.reverse()
-    head_leaves.reverse()
-    # Since body is a new indent level, remove spurious leading whitespace.
-    if body_leaves:
-        normalize_prefix(body_leaves[0], inside_brackets=True)
-    if not head_leaves:
-        # No `head` means the split failed. Either `tail` has all content or
+    if not (opening_bracket and closing_bracket and head_leaves):
+        # If there is no opening or closing_bracket that means the split failed and
+        # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
         # the matching `opening_bracket` wasn't available on `line` anymore.
         raise CannotSplit("No brackets found")
 
         # the matching `opening_bracket` wasn't available on `line` anymore.
         raise CannotSplit("No brackets found")
 
-    # Build the new lines.
-    for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
-        for leaf in leaves:
-            result.append(leaf, preformatted=True)
-            for comment_after in line.comments_after(leaf):
-                result.append(comment_after, preformatted=True)
-    assert opening_bracket and closing_bracket
-    body.should_explode = should_explode(body, opening_bracket)
+    tail_leaves.reverse()
+    body_leaves.reverse()
+    head_leaves.reverse()
+    head = bracket_split_build_line(head_leaves, line, opening_bracket)
+    body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
+    tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
     bracket_split_succeeded_or_raise(head, body, tail)
     if (
         # the body shouldn't be exploded
     bracket_split_succeeded_or_raise(head, body, tail)
     if (
         # the body shouldn't be exploded
@@ -2248,7 +2469,7 @@ def right_hand_split(
     ):
         omit = {id(closing_bracket), *omit}
         try:
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
         except CannotSplit:
             return
 
         except CannotSplit:
@@ -2301,6 +2522,46 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
             )
 
 
             )
 
 
+def bracket_split_build_line(
+    leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
+) -> Line:
+    """Return a new line with given `leaves` and respective comments from `original`.
+
+    If `is_body` is True, the result line is one-indented inside brackets and as such
+    has its first leaf's prefix normalized and a trailing comma added when expected.
+    """
+    result = Line(depth=original.depth)
+    if is_body:
+        result.inside_brackets = True
+        result.depth += 1
+        if leaves:
+            # Since body is a new indent level, remove spurious leading whitespace.
+            normalize_prefix(leaves[0], inside_brackets=True)
+            # Ensure a trailing comma for imports and standalone function arguments, but
+            # be careful not to add one after any comments.
+            no_commas = original.is_def and not any(
+                l.type == token.COMMA for l in leaves
+            )
+
+            if original.is_import or no_commas:
+                for i in range(len(leaves) - 1, -1, -1):
+                    if leaves[i].type == STANDALONE_COMMENT:
+                        continue
+                    elif leaves[i].type == token.COMMA:
+                        break
+                    else:
+                        leaves.insert(i + 1, Leaf(token.COMMA, ","))
+                        break
+    # Populate the line
+    for leaf in leaves:
+        result.append(leaf, preformatted=True)
+        for comment_after in original.comments_after(leaf):
+            result.append(comment_after, preformatted=True)
+    if is_body:
+        result.should_explode = should_explode(result, opening_bracket)
+    return result
+
+
 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """Normalize prefix of the first leaf in every line returned by `split_func`.
 
 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """Normalize prefix of the first leaf in every line returned by `split_func`.
 
@@ -2308,8 +2569,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """
 
     @wraps(split_func)
     """
 
     @wraps(split_func)
-    def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
-        for l in split_func(line, py36):
+    def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
+        for l in split_func(line, features):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
@@ -2317,11 +2578,11 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
 
 @dont_increase_indentation
 
 
 @dont_increase_indentation
-def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
     """Split according to delimiters of the highest priority.
 
-    If `py36` is True, the split will add trailing commas also in function
-    signatures that contain `*` and `**`.
+    If the appropriate Features are given, the split will add trailing commas
+    also in function signatures and calls that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
     """
     try:
         last_leaf = line.leaves[-1]
@@ -2347,23 +2608,29 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
         nonlocal current_line
         try:
             current_line.append_safe(leaf, preformatted=True)
         nonlocal current_line
         try:
             current_line.append_safe(leaf, preformatted=True)
-        except ValueError as ve:
+        except ValueError:
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
-        if leaf.bracket_depth == lowest_depth and is_vararg(
-            leaf, within=VARARGS_PARENTS
-        ):
-            trailing_comma_safe = trailing_comma_safe and py36
+        if leaf.bracket_depth == lowest_depth:
+            if is_vararg(leaf, within={syms.typedargslist}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
+                )
+            elif is_vararg(leaf, within={syms.arglist, syms.argument}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
+                )
+
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
@@ -2381,7 +2648,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 @dont_increase_indentation
 
 
 @dont_increase_indentation
-def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def standalone_comment_split(
+    line: Line, features: Collection[Feature] = ()
+) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
@@ -2393,16 +2662,16 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
         nonlocal current_line
         try:
             current_line.append_safe(leaf, preformatted=True)
         nonlocal current_line
         try:
             current_line.append_safe(leaf, preformatted=True)
-        except ValueError as ve:
+        except ValueError:
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
     if current_line:
             yield from append_to_line(comment_after)
 
     if current_line:
@@ -2423,6 +2692,14 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
     )
 
 
+def is_type_comment(leaf: Leaf) -> bool:
+    """Return True if the given leaf is a special comment.
+    Only returns true for type comments for now."""
+    t = leaf.type
+    v = leaf.value
+    return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
+
+
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
     """Leave existing extra newlines if not `inside_brackets`. Remove everything
     else.
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
     """Leave existing extra newlines if not `inside_brackets`. Remove everything
     else.
@@ -2505,7 +2782,15 @@ def normalize_string_quotes(leaf: Leaf) -> None:
         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
     if "f" in prefix.casefold():
         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
     if "f" in prefix.casefold():
-        matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
+        matches = re.findall(
+            r"""
+            (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
+                ([^{].*?)  # contents of the brackets except if begins with {{
+            \}(?:[^}]|$)  # A } followed by end of the string or a non-}
+            """,
+            new_body,
+            re.VERBOSE,
+        )
         for m in matches:
             if "\\" in str(m):
                 # Do not introduce backslashes in interpolated expressions
         for m in matches:
             if "\\" in str(m):
                 # Do not introduce backslashes in interpolated expressions
@@ -2524,16 +2809,20 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
-def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
+def normalize_numeric_literal(leaf: Leaf) -> None:
     """Normalizes numeric (float, int, and complex) literals.
 
     All letters used in the representation are normalized to lowercase (except
     """Normalizes numeric (float, int, and complex) literals.
 
     All letters used in the representation are normalized to lowercase (except
-    in Python 2 long literals), and long number literals are split using underscores.
+    in Python 2 long literals).
     """
     text = leaf.value.lower()
     """
     text = leaf.value.lower()
-    if text.startswith(("0o", "0x", "0b")):
-        # Leave octal, hex, and binary literals alone.
+    if text.startswith(("0o", "0b")):
+        # Leave octal and binary literals alone.
         pass
         pass
+    elif text.startswith("0x"):
+        # Change hex literals to upper case.
+        before, after = text[:2], text[2:]
+        text = f"{before}{after.upper()}"
     elif "e" in text:
         before, after = text.split("e")
         sign = ""
     elif "e" in text:
         before, after = text.split("e")
         sign = ""
@@ -2542,8 +2831,7 @@ def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
-        before = format_float_or_int_string(before, allow_underscores)
-        after = format_int_string(after, allow_underscores)
+        before = format_float_or_int_string(before)
         text = f"{before}e{sign}{after}"
     elif text.endswith(("j", "l")):
         number = text[:-1]
         text = f"{before}e{sign}{after}"
     elif text.endswith(("j", "l")):
         number = text[:-1]
@@ -2551,56 +2839,25 @@ def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
         # Capitalize in "2L" because "l" looks too similar to "1".
         if suffix == "l":
             suffix = "L"
         # Capitalize in "2L" because "l" looks too similar to "1".
         if suffix == "l":
             suffix = "L"
-        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+        text = f"{format_float_or_int_string(number)}{suffix}"
     else:
     else:
-        text = format_float_or_int_string(text, allow_underscores)
+        text = format_float_or_int_string(text)
     leaf.value = text
 
 
     leaf.value = text
 
 
-def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+def format_float_or_int_string(text: str) -> str:
     """Formats a float string like "1.0"."""
     if "." not in text:
     """Formats a float string like "1.0"."""
     if "." not in text:
-        return format_int_string(text, allow_underscores)
-
-    before, after = text.split(".")
-    before = format_int_string(before, allow_underscores) if before else "0"
-    if after:
-        after = format_int_string(after, allow_underscores, count_from_end=False)
-    else:
-        after = "0"
-    return f"{before}.{after}"
-
-
-def format_int_string(
-    text: str, allow_underscores: bool, count_from_end: bool = True
-) -> str:
-    """Normalizes underscores in a string to e.g. 1_000_000.
-
-    Input must be a string of digits and optional underscores.
-    If count_from_end is False, we add underscores after groups of three digits
-    counting from the beginning instead of the end of the strings. This is used
-    for the fractional part of float literals.
-    """
-    if not allow_underscores:
         return text
 
         return text
 
-    text = text.replace("_", "")
-    if len(text) <= 6:
-        # No underscores for numbers <= 6 digits long.
-        return text
-
-    if count_from_end:
-        # Avoid removing leading zeros, which are important if we're formatting
-        # part of a number like "0.001".
-        return format(int("1" + text), "3_")[1:].lstrip("_")
-    else:
-        return "_".join(text[i : i + 3] for i in range(0, len(text), 3))
+    before, after = text.split(".")
+    return f"{before or 0}.{after or 0}"
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.
 
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.
 
-    `parens_after` is a set of string leaf values immeditely after which parens
+    `parens_after` is a set of string leaf values immediately after which parens
     should be put.
 
     Standardizes on visible parentheses for single-element tuples, and keeps
     should be put.
 
     Standardizes on visible parentheses for single-element tuples, and keeps
@@ -2613,9 +2870,27 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
     check_lpar = False
     for index, child in enumerate(list(node.children)):
 
     check_lpar = False
     for index, child in enumerate(list(node.children)):
+        # Add parentheses around long tuple unpacking in assignments.
+        if (
+            index == 0
+            and isinstance(child, Node)
+            and child.type == syms.testlist_star_expr
+        ):
+            check_lpar = True
+
         if check_lpar:
         if check_lpar:
+            if is_walrus_assignment(child):
+                continue
             if child.type == syms.atom:
             if child.type == syms.atom:
-                if maybe_make_parens_invisible_in_atom(child):
+                # Determines if the underlying atom should be surrounded with
+                # invisible params - also makes parens invisible recursively
+                # within the atom and removes repeated invisible parens within
+                # the atom
+                should_surround_with_parens = maybe_make_parens_invisible_in_atom(
+                    child, parent=node
+                )
+
+                if should_surround_with_parens:
                     lpar = Leaf(token.LPAR, "")
                     rpar = Leaf(token.RPAR, "")
                     index = child.remove() or 0
                     lpar = Leaf(token.LPAR, "")
                     rpar = Leaf(token.RPAR, "")
                     index = child.remove() or 0
@@ -2644,7 +2919,11 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 lpar = Leaf(token.LPAR, "")
                 rpar = Leaf(token.RPAR, "")
                 index = child.remove() or 0
                 lpar = Leaf(token.LPAR, "")
                 rpar = Leaf(token.RPAR, "")
                 index = child.remove() or 0
-                node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                prefix = child.prefix
+                child.prefix = ""
+                new_child = Node(syms.atom, [lpar, child, rpar])
+                new_child.prefix = prefix
+                node.insert_child(index, new_child)
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
@@ -2726,8 +3005,10 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
         container = container.next_sibling
 
 
         container = container.next_sibling
 
 
-def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
+def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     """If it's safe, make the parens in the atom `node` invisible, recursively.
     """If it's safe, make the parens in the atom `node` invisible, recursively.
+    Additionally, remove repeated, adjacent invisible parens from the atom `node`
+    as they are redundant.
 
     Returns whether the node should itself be wrapped in invisible parentheses.
 
 
     Returns whether the node should itself be wrapped in invisible parentheses.
 
@@ -2736,7 +3017,7 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
-        or is_yield(node)
+        or (is_yield(node) and parent.type != syms.expr_stmt)
         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
     ):
         return False
         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
     ):
         return False
@@ -2744,16 +3025,40 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
     first = node.children[0]
     last = node.children[-1]
     if first.type == token.LPAR and last.type == token.RPAR:
     first = node.children[0]
     last = node.children[-1]
     if first.type == token.LPAR and last.type == token.RPAR:
+        middle = node.children[1]
         # make parentheses invisible
         first.value = ""  # type: ignore
         last.value = ""  # type: ignore
         # make parentheses invisible
         first.value = ""  # type: ignore
         last.value = ""  # type: ignore
-        if len(node.children) > 1:
-            maybe_make_parens_invisible_in_atom(node.children[1])
+        maybe_make_parens_invisible_in_atom(middle, parent=parent)
+
+        if is_atom_with_invisible_parens(middle):
+            # Strip the invisible parens from `middle` by replacing
+            # it with the child in-between the invisible parens
+            middle.replace(middle.children[1])
+
         return False
 
     return True
 
 
         return False
 
     return True
 
 
+def is_atom_with_invisible_parens(node: LN) -> bool:
+    """Given a `LN`, determines whether it's an atom `node` with invisible
+    parens. Useful in dedupe-ing and normalizing parens.
+    """
+    if isinstance(node, Leaf) or node.type != syms.atom:
+        return False
+
+    first, last = node.children[0], node.children[-1]
+    return (
+        isinstance(first, Leaf)
+        and first.type == token.LPAR
+        and first.value == ""
+        and isinstance(last, Leaf)
+        and last.type == token.RPAR
+        and last.value == ""
+    )
+
+
 def is_empty_tuple(node: LN) -> bool:
     """Return True if `node` holds an empty tuple."""
     return (
 def is_empty_tuple(node: LN) -> bool:
     """Return True if `node` holds an empty tuple."""
     return (
@@ -2764,18 +3069,24 @@ def is_empty_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
+    """Returns `wrapped` if `node` is of the shape ( wrapped ).
+
+    Parenthesis can be optional. Returns None otherwise"""
+    if len(node.children) != 3:
+        return None
+    lpar, wrapped, rpar = node.children
+    if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
+        return None
+
+    return wrapped
+
+
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
-        if len(node.children) != 3:
-            return False
-
-        lpar, gexp, rpar = node.children
-        if not (
-            lpar.type == token.LPAR
-            and gexp.type == syms.testlist_gexp
-            and rpar.type == token.RPAR
-        ):
+        gexp = unwrap_singleton_parenthesis(node)
+        if gexp is None or gexp.type != syms.testlist_gexp:
             return False
 
         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
             return False
 
         return len(gexp.children) == 2 and gexp.children[1].type == token.COMMA
@@ -2787,6 +3098,12 @@ def is_one_tuple(node: LN) -> bool:
     )
 
 
     )
 
 
+def is_walrus_assignment(node: LN) -> bool:
+    """Return True iff `node` is of the shape ( test := test )"""
+    inner = unwrap_singleton_parenthesis(node)
+    return inner is not None and inner.type == syms.namedexpr_test
+
+
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
@@ -2816,7 +3133,7 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool:
     extended iterable unpacking (PEP 3132) and additional unpacking
     generalizations (PEP 448).
     """
     extended iterable unpacking (PEP 3132) and additional unpacking
     generalizations (PEP 448).
     """
-    if leaf.type not in STARS or not leaf.parent:
+    if leaf.type not in VARARGS_SPECIALS or not leaf.parent:
         return False
 
     p = leaf.parent
         return False
 
     p = leaf.parent
@@ -2866,7 +3183,7 @@ def is_stub_body(node: LN) -> bool:
     )
 
 
     )
 
 
-def max_delimiter_priority_in_atom(node: LN) -> int:
+def max_delimiter_priority_in_atom(node: LN) -> Priority:
     """Return maximum delimiter priority inside `node`.
 
     This is specific to atoms with contents contained in a pair of parentheses.
     """Return maximum delimiter priority inside `node`.
 
     This is specific to atoms with contents contained in a pair of parentheses.
@@ -2898,7 +3215,7 @@ def ensure_visible(leaf: Leaf) -> None:
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
-    :func:`normalize_invible_parens` and :func:`visit_import_from`).
+    :func:`normalize_invisible_parens` and :func:`visit_import_from`).
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
@@ -2908,6 +3225,7 @@ def ensure_visible(leaf: Leaf) -> None:
 
 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
 
 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
+
     if not (
         opening_bracket.parent
         and opening_bracket.parent.type in {syms.atom, syms.import_from}
     if not (
         opening_bracket.parent
         and opening_bracket.parent.type in {syms.atom, syms.import_from}
@@ -2925,39 +3243,61 @@ def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     return max_priority == COMMA_PRIORITY
 
 
     return max_priority == COMMA_PRIORITY
 
 
-def is_python36(node: Node) -> bool:
-    """Return True if the current file is using Python 3.6+ features.
+def get_features_used(node: Node) -> Set[Feature]:
+    """Return a set of (relatively) new Python features used in this file.
 
     Currently looking for:
     - f-strings;
 
     Currently looking for:
     - f-strings;
-    - underscores in numeric literals; and
-    - trailing commas after * or ** in function signatures and calls.
+    - underscores in numeric literals;
+    - trailing commas after * or ** in function signatures and calls;
+    - positional only arguments in function signatures and lambdas;
     """
     """
+    features: Set[Feature] = set()
     for n in node.pre_order():
         if n.type == token.STRING:
             value_head = n.value[:2]  # type: ignore
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
     for n in node.pre_order():
         if n.type == token.STRING:
             value_head = n.value[:2]  # type: ignore
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
-                return True
+                features.add(Feature.F_STRINGS)
 
         elif n.type == token.NUMBER:
             if "_" in n.value:  # type: ignore
 
         elif n.type == token.NUMBER:
             if "_" in n.value:  # type: ignore
-                return True
+                features.add(Feature.NUMERIC_UNDERSCORES)
+
+        elif n.type == token.SLASH:
+            if n.parent and n.parent.type in {syms.typedargslist, syms.arglist}:
+                features.add(Feature.POS_ONLY_ARGUMENTS)
+
+        elif n.type == token.COLONEQUAL:
+            features.add(Feature.ASSIGNMENT_EXPRESSIONS)
 
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
             and n.children[-1].type == token.COMMA
         ):
 
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
             and n.children[-1].type == token.COMMA
         ):
+            if n.type == syms.typedargslist:
+                feature = Feature.TRAILING_COMMA_IN_DEF
+            else:
+                feature = Feature.TRAILING_COMMA_IN_CALL
+
             for ch in n.children:
                 if ch.type in STARS:
             for ch in n.children:
                 if ch.type in STARS:
-                    return True
+                    features.add(feature)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            return True
+                            features.add(feature)
 
 
-    return False
+    return features
+
+
+def detect_target_versions(node: Node) -> Set[TargetVersion]:
+    """Detect the version to target based on the nodes used."""
+    features = get_features_used(node)
+    return {
+        version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
+    }
 
 
 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
 
 
 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
@@ -3027,7 +3367,7 @@ def get_future_imports(node: Node) -> Set[str]:
             elif child.type == syms.import_as_names:
                 yield from get_imports_from_children(child.children)
             else:
             elif child.type == syms.import_as_names:
                 yield from get_imports_from_children(child.children)
             else:
-                assert False, "Invalid syntax parsing imports"
+                raise AssertionError("Invalid syntax parsing imports")
 
     for child in node.children:
         if child.type != syms.simple_stmt:
 
     for child in node.children:
         if child.type != syms.simple_stmt:
@@ -3211,17 +3551,58 @@ class Report:
         return ", ".join(report) + "."
 
 
         return ", ".join(report) + "."
 
 
+def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]:
+    filename = "<unknown>"
+    if sys.version_info >= (3, 8):
+        # TODO: support Python 4+ ;)
+        for minor_version in range(sys.version_info[1], 4, -1):
+            try:
+                return ast.parse(src, filename, feature_version=(3, minor_version))
+            except SyntaxError:
+                continue
+    else:
+        for feature_version in (7, 6):
+            try:
+                return ast3.parse(src, filename, feature_version=feature_version)
+            except SyntaxError:
+                continue
+
+    return ast27.parse(src)
+
+
+def _fixup_ast_constants(
+    node: Union[ast.AST, ast3.AST, ast27.AST]
+) -> Union[ast.AST, ast3.AST, ast27.AST]:
+    """Map ast nodes deprecated in 3.8 to Constant."""
+    # casts are required until this is released:
+    # https://github.com/python/typeshed/pull/3142
+    if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
+        return cast(ast.AST, ast.Constant(value=node.s))
+    elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
+        return cast(ast.AST, ast.Constant(value=node.n))
+    elif isinstance(node, (ast.NameConstant, ast3.NameConstant)):
+        return cast(ast.AST, ast.Constant(value=node.value))
+    return node
+
+
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
-    import ast
-    import traceback
-
-    def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
+    def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
         """Simple visitor generating strings to compare ASTs by content."""
         """Simple visitor generating strings to compare ASTs by content."""
+
+        node = _fixup_ast_constants(node)
+
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
+            # TypeIgnore has only one field 'lineno' which breaks this comparison
+            type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
+            if sys.version_info >= (3, 8):
+                type_ignore_classes += (ast.TypeIgnore,)
+            if isinstance(node, type_ignore_classes):
+                break
+
             try:
                 value = getattr(node, field)
             except AttributeError:
             try:
                 value = getattr(node, field)
             except AttributeError:
@@ -3231,10 +3612,19 @@ def assert_equivalent(src: str, dst: str) -> None:
 
             if isinstance(value, list):
                 for item in value:
 
             if isinstance(value, list):
                 for item in value:
-                    if isinstance(item, ast.AST):
+                    # Ignore nested tuples within del statements, because we may insert
+                    # parentheses and they change the AST.
+                    if (
+                        field == "targets"
+                        and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete))
+                        and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple))
+                    ):
+                        for item in item.elts:
+                            yield from _v(item, depth + 2)
+                    elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
                         yield from _v(item, depth + 2)
 
-            elif isinstance(value, ast.AST):
+            elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)):
                 yield from _v(value, depth + 2)
 
             else:
                 yield from _v(value, depth + 2)
 
             else:
@@ -3243,22 +3633,20 @@ def assert_equivalent(src: str, dst: str) -> None:
         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
     try:
         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
     try:
-        src_ast = ast.parse(src)
+        src_ast = parse_ast(src)
     except Exception as exc:
     except Exception as exc:
-        major, minor = sys.version_info[:2]
         raise AssertionError(
         raise AssertionError(
-            f"cannot use --safe with this file; failed to parse source file "
-            f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
-            f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+            f"cannot use --safe with this file; failed to parse source file.  "
+            f"AST error message: {exc}"
         )
 
     try:
         )
 
     try:
-        dst_ast = ast.parse(dst)
+        dst_ast = parse_ast(dst)
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This invalid output might be helpful: {log}"
         ) from None
 
             f"This invalid output might be helpful: {log}"
         ) from None
 
@@ -3269,16 +3657,14 @@ def assert_equivalent(src: str, dst: str) -> None:
         raise AssertionError(
             f"INTERNAL ERROR: Black produced code that is not equivalent to "
             f"the source.  "
         raise AssertionError(
             f"INTERNAL ERROR: Black produced code that is not equivalent to "
             f"the source.  "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
 
             f"This diff might be helpful: {log}"
         ) from None
 
 
-def assert_stable(
-    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
-) -> None:
+def assert_stable(src: str, dst: str, mode: FileMode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, mode=mode)
+    newdst = format_str(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3287,15 +3673,13 @@ def assert_stable(
         raise AssertionError(
             f"INTERNAL ERROR: Black produced different code on the second pass "
             f"of the formatter.  "
         raise AssertionError(
             f"INTERNAL ERROR: Black produced different code on the second pass "
             f"of the formatter.  "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
 
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
             f"This diff might be helpful: {log}"
         ) from None
 
 
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
-    import tempfile
-
     with tempfile.NamedTemporaryFile(
         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
     ) as f:
     with tempfile.NamedTemporaryFile(
         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
     ) as f:
@@ -3306,6 +3690,13 @@ def dump_to_file(*output: str) -> str:
     return f.name
 
 
     return f.name
 
 
+@contextmanager
+def nullcontext() -> Iterator[None]:
+    """Return context manager that does nothing.
+    Similar to `nullcontext` from python 3.7"""
+    yield
+
+
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
@@ -3324,11 +3715,15 @@ def cancel(tasks: Iterable[asyncio.Task]) -> None:
         task.cancel()
 
 
         task.cancel()
 
 
-def shutdown(loop: BaseEventLoop) -> None:
+def shutdown(loop: asyncio.AbstractEventLoop) -> None:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
+        if sys.version_info[:2] >= (3, 7):
+            all_tasks = asyncio.all_tasks
+        else:
+            all_tasks = asyncio.Task.all_tasks
         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
-        to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
+        to_cancel = [task for task in all_tasks(loop) if not task.done()]
         if not to_cancel:
             return
 
         if not to_cancel:
             return
 
@@ -3389,8 +3784,7 @@ def enumerate_with_length(
         if "\n" in leaf.value:
             return  # Multiline strings, we can't continue.
 
         if "\n" in leaf.value:
             return  # Multiline strings, we can't continue.
 
-        comment: Optional[Leaf]
-        for comment in line.comments_after(leaf, index):
+        for comment in line.comments_after(leaf):
             length += len(comment.value)
 
         yield index, leaf, length
             length += len(comment.value)
 
         yield index, leaf, length
@@ -3535,16 +3929,16 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
+def get_cache_file(mode: FileMode) -> Path:
+    return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
 
 
-def read_cache(line_length: int, mode: FileMode) -> Cache:
+def read_cache(mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     if not cache_file.exists():
         return {}
 
     if not cache_file.exists():
         return {}
 
@@ -3579,17 +3973,15 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     return todo, done
 
 
     return todo, done
 
 
-def write_cache(
-    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
-) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
     """Update the cache file."""
     """Update the cache file."""
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     try:
     try:
-        if not CACHE_DIR.exists():
-            CACHE_DIR.mkdir(parents=True)
+        CACHE_DIR.mkdir(parents=True, exist_ok=True)
         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
-        with cache_file.open("wb") as fobj:
-            pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
+        with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
+            pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
+        os.replace(f.name, cache_file)
     except OSError:
         pass
 
     except OSError:
         pass
 
@@ -3616,6 +4008,11 @@ def patch_click() -> None:
             module._verify_python3_env = lambda: None
 
 
             module._verify_python3_env = lambda: None
 
 
-if __name__ == "__main__":
+def patched_main() -> None:
+    freeze_support()
     patch_click()
     main()
     patch_click()
     main()
+
+
+if __name__ == "__main__":
+    patched_main()