]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Hello github.com/psf!
[etc/vim.git] / black.py
index 9b363d5041f2528c063fc70aadc14937f37bb781..180163c74c138090214b3da6e32399864a3fc7be 100644 (file)
--- a/black.py
+++ b/black.py
@@ -1,20 +1,22 @@
 import asyncio
-from asyncio.base_events import BaseEventLoop
 from concurrent.futures import Executor, ProcessPoolExecutor
+from contextlib import contextmanager
 from datetime import datetime
-from enum import Enum, Flag
+from enum import Enum
 from functools import lru_cache, partial, wraps
 import io
-import keyword
+import itertools
 import logging
-from multiprocessing import Manager
+from multiprocessing import Manager, freeze_support
 import os
 from pathlib import Path
 import pickle
 import re
 import signal
 import sys
+import tempfile
 import tokenize
+import traceback
 from typing import (
     Any,
     Callable,
@@ -36,21 +38,23 @@ from typing import (
 )
 
 from appdirs import user_cache_dir
-from attr import dataclass, Factory
+from attr import dataclass, evolve, Factory
 import click
 import toml
+from typed_ast import ast3, ast27
 
 # lib2to3 fork
 from blib2to3.pytree import Node, Leaf, type_repr
 from blib2to3 import pygram, pytree
 from blib2to3.pgen2 import driver, token
+from blib2to3.pgen2.grammar import Grammar
 from blib2to3.pgen2.parse import ParseError
 
 
-__version__ = "18.6b4"
+__version__ = "19.3b0"
 DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = (
-    r"/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|_build|buck-out|build|dist)/"
+    r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|_build|buck-out|build|dist)/"
 )
 DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
@@ -66,7 +70,7 @@ LeafID = int
 Priority = int
 Index = int
 LN = Union[Leaf, Node]
-SplitFunc = Callable[["Line", bool], Iterator["Line"]]
+SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
@@ -79,15 +83,15 @@ syms = pygram.python_symbols
 
 
 class NothingChanged(UserWarning):
-    """Raised by :func:`format_file` when reformatted code is the same as source."""
+    """Raised when reformatted code is the same as source."""
 
 
 class CannotSplit(Exception):
-    """A readable split that fits the allotted line length is impossible.
+    """A readable split that fits the allotted line length is impossible."""
 
-    Raised by :func:`left_hand_split`, :func:`right_hand_split`, and
-    :func:`delimiter_split`.
-    """
+
+class InvalidInput(ValueError):
+    """Raised when input source code fails all parse attempts."""
 
 
 class WriteBack(Enum):
@@ -110,24 +114,97 @@ class Changed(Enum):
     YES = 2
 
 
-class FileMode(Flag):
-    AUTO_DETECT = 0
-    PYTHON36 = 1
-    PYI = 2
-    NO_STRING_NORMALIZATION = 4
+class TargetVersion(Enum):
+    PY27 = 2
+    PY33 = 3
+    PY34 = 4
+    PY35 = 5
+    PY36 = 6
+    PY37 = 7
+    PY38 = 8
+
+    def is_python2(self) -> bool:
+        return self is TargetVersion.PY27
+
+
+PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
+
+
+class Feature(Enum):
+    # All string literals are unicode
+    UNICODE_LITERALS = 1
+    F_STRINGS = 2
+    NUMERIC_UNDERSCORES = 3
+    TRAILING_COMMA_IN_CALL = 4
+    TRAILING_COMMA_IN_DEF = 5
+    # The following two feature-flags are mutually exclusive, and exactly one should be
+    # set for every version of python.
+    ASYNC_IDENTIFIERS = 6
+    ASYNC_KEYWORDS = 7
+
+
+VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
+    TargetVersion.PY27: {Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY33: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY34: {Feature.UNICODE_LITERALS, Feature.ASYNC_IDENTIFIERS},
+    TargetVersion.PY35: {
+        Feature.UNICODE_LITERALS,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.ASYNC_IDENTIFIERS,
+    },
+    TargetVersion.PY36: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_IDENTIFIERS,
+    },
+    TargetVersion.PY37: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
+    },
+    TargetVersion.PY38: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
+    },
+}
 
-    @classmethod
-    def from_configuration(
-        cls, *, py36: bool, pyi: bool, skip_string_normalization: bool
-    ) -> "FileMode":
-        mode = cls.AUTO_DETECT
-        if py36:
-            mode |= cls.PYTHON36
-        if pyi:
-            mode |= cls.PYI
-        if skip_string_normalization:
-            mode |= cls.NO_STRING_NORMALIZATION
-        return mode
+
+@dataclass
+class FileMode:
+    target_versions: Set[TargetVersion] = Factory(set)
+    line_length: int = DEFAULT_LINE_LENGTH
+    string_normalization: bool = True
+    is_pyi: bool = False
+
+    def get_cache_key(self) -> str:
+        if self.target_versions:
+            version_str = ",".join(
+                str(version.value)
+                for version in sorted(self.target_versions, key=lambda v: v.value)
+            )
+        else:
+            version_str = "-"
+        parts = [
+            version_str,
+            str(self.line_length),
+            str(int(self.string_normalization)),
+            str(int(self.is_pyi)),
+        ]
+        return ".".join(parts)
+
+
+def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
+    return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
 def read_pyproject_toml(
@@ -151,7 +228,9 @@ def read_pyproject_toml(
         pyproject_toml = toml.load(value)
         config = pyproject_toml.get("tool", {}).get("black", {})
     except (toml.TomlDecodeError, OSError) as e:
-        raise click.BadOptionUsage(f"Error reading configuration file: {e}", ctx)
+        raise click.FileError(
+            filename=value, hint=f"Error reading configuration file: {e}"
+        )
 
     if not config:
         return None
@@ -165,21 +244,34 @@ def read_pyproject_toml(
 
 
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
+@click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
     "-l",
     "--line-length",
     type=int,
     default=DEFAULT_LINE_LENGTH,
-    help="How many character per line to allow.",
+    help="How many characters per line to allow.",
     show_default=True,
 )
+@click.option(
+    "-t",
+    "--target-version",
+    type=click.Choice([v.name.lower() for v in TargetVersion]),
+    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    multiple=True,
+    help=(
+        "Python versions that should be supported by Black's output. [default: "
+        "per-file auto-detection]"
+    ),
+)
 @click.option(
     "--py36",
     is_flag=True,
     help=(
         "Allow using Python 3.6-only syntax on all input files.  This will put "
         "trailing commas in function signatures and calls also after *args and "
-        "**kwargs.  [default: per-file auto-detection]"
+        "**kwargs. Deprecated; use --target-version instead. "
+        "[default: per-file auto-detection]"
     ),
 )
 @click.option(
@@ -245,7 +337,7 @@ def read_pyproject_toml(
     "--quiet",
     is_flag=True,
     help=(
-        "Don't emit non-error messages to stderr. Errors are still emitted, "
+        "Don't emit non-error messages to stderr. Errors are still emitted; "
         "silence those with 2>/dev/null."
     ),
 )
@@ -279,7 +371,9 @@ def read_pyproject_toml(
 @click.pass_context
 def main(
     ctx: click.Context,
+    code: Optional[str],
     line_length: int,
+    target_version: List[TargetVersion],
     check: bool,
     diff: bool,
     fast: bool,
@@ -295,11 +389,32 @@ def main(
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
-    mode = FileMode.from_configuration(
-        py36=py36, pyi=pyi, skip_string_normalization=skip_string_normalization
+    if target_version:
+        if py36:
+            err(f"Cannot use both --target-version and --py36")
+            ctx.exit(2)
+        else:
+            versions = set(target_version)
+    elif py36:
+        err(
+            "--py36 is deprecated and will be removed in a future version. "
+            "Use --target-version py36 instead."
+        )
+        versions = PY36_VERSIONS
+    else:
+        # We'll autodetect later.
+        versions = set()
+    mode = FileMode(
+        target_versions=versions,
+        line_length=line_length,
+        is_pyi=pyi,
+        string_normalization=not skip_string_normalization,
     )
     if config and verbose:
         out(f"Using configuration from {config}.", bold=False, fg="blue")
+    if code is not None:
+        print(format_str(code, mode=mode))
+        ctx.exit(0)
     try:
         include_regex = re_compile_maybe_verbose(include)
     except re.error:
@@ -332,102 +447,105 @@ def main(
     if len(sources) == 1:
         reformat_one(
             src=sources.pop(),
-            line_length=line_length,
             fast=fast,
             write_back=write_back,
             mode=mode,
             report=report,
         )
     else:
-        loop = asyncio.get_event_loop()
-        executor = ProcessPoolExecutor(max_workers=os.cpu_count())
-        try:
-            loop.run_until_complete(
-                schedule_formatting(
-                    sources=sources,
-                    line_length=line_length,
-                    fast=fast,
-                    write_back=write_back,
-                    mode=mode,
-                    report=report,
-                    loop=loop,
-                    executor=executor,
-                )
-            )
-        finally:
-            shutdown(loop)
+        reformat_many(
+            sources=sources, fast=fast, write_back=write_back, mode=mode, report=report
+        )
+
     if verbose or not quiet:
-        bang = "💥 💔 💥" if report.return_code else "✨ 🍰 ✨"
-        out(f"All done! {bang}")
+        out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨")
         click.secho(str(report), err=True)
     ctx.exit(report.return_code)
 
 
 def reformat_one(
-    src: Path,
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack,
-    mode: FileMode,
-    report: "Report",
+    src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
-    If `quiet` is True, non-error messages are not output. `line_length`,
-    `write_back`, `fast` and `pyi` options are passed to
+    `fast`, `write_back`, and `mode` options are passed to
     :func:`format_file_in_place` or :func:`format_stdin_to_stdout`.
     """
     try:
         changed = Changed.NO
         if not src.is_file() and str(src) == "-":
-            if format_stdin_to_stdout(
-                line_length=line_length, fast=fast, write_back=write_back, mode=mode
-            ):
+            if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
             cache: Cache = {}
             if write_back != WriteBack.DIFF:
-                cache = read_cache(line_length, mode)
+                cache = read_cache(mode)
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
-                src,
-                line_length=line_length,
-                fast=fast,
-                write_back=write_back,
-                mode=mode,
+                src, fast=fast, write_back=write_back, mode=mode
             ):
                 changed = Changed.YES
             if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
                 write_back is WriteBack.CHECK and changed is Changed.NO
             ):
-                write_cache(cache, [src], line_length, mode)
+                write_cache(cache, [src], mode)
         report.done(src, changed)
     except Exception as exc:
         report.failed(src, str(exc))
 
 
+def reformat_many(
+    sources: Set[Path],
+    fast: bool,
+    write_back: WriteBack,
+    mode: FileMode,
+    report: "Report",
+) -> None:
+    """Reformat multiple files using a ProcessPoolExecutor."""
+    loop = asyncio.get_event_loop()
+    worker_count = os.cpu_count()
+    if sys.platform == "win32":
+        # Work around https://bugs.python.org/issue26903
+        worker_count = min(worker_count, 61)
+    executor = ProcessPoolExecutor(max_workers=worker_count)
+    try:
+        loop.run_until_complete(
+            schedule_formatting(
+                sources=sources,
+                fast=fast,
+                write_back=write_back,
+                mode=mode,
+                report=report,
+                loop=loop,
+                executor=executor,
+            )
+        )
+    finally:
+        shutdown(loop)
+        executor.shutdown()
+
+
 async def schedule_formatting(
     sources: Set[Path],
-    line_length: int,
     fast: bool,
     write_back: WriteBack,
     mode: FileMode,
     report: "Report",
-    loop: BaseEventLoop,
+    loop: asyncio.AbstractEventLoop,
     executor: Executor,
 ) -> None:
     """Run formatting of `sources` in parallel using the provided `executor`.
 
     (Use ProcessPoolExecutors for actual parallelism.)
 
-    `line_length`, `write_back`, `fast`, and `pyi` options are passed to
+    `write_back`, `fast`, and `mode` options are passed to
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
     if write_back != WriteBack.DIFF:
-        cache = read_cache(line_length, mode)
+        cache = read_cache(mode)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
             report.done(src, Changed.CACHED)
@@ -443,19 +561,14 @@ async def schedule_formatting(
         manager = Manager()
         lock = manager.Lock()
     tasks = {
-        loop.run_in_executor(
-            executor,
-            format_file_in_place,
-            src,
-            line_length,
-            fast,
-            write_back,
-            mode,
-            lock,
+        asyncio.ensure_future(
+            loop.run_in_executor(
+                executor, format_file_in_place, src, fast, mode, write_back, lock
+            )
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable[asyncio.Task] = tasks.keys()
+    pending: Iterable[asyncio.Future] = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -482,33 +595,30 @@ async def schedule_formatting(
     if cancelled:
         await asyncio.gather(*cancelled, loop=loop, return_exceptions=True)
     if sources_to_cache:
-        write_cache(cache, sources_to_cache, line_length, mode)
+        write_cache(cache, sources_to_cache, mode)
 
 
 def format_file_in_place(
     src: Path,
-    line_length: int,
     fast: bool,
+    mode: FileMode,
     write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     """Format file under `src` path. Return True if changed.
 
     If `write_back` is DIFF, write a diff to stdout. If it is YES, write reformatted
     code to the file.
-    `line_length` and `fast` options are passed to :func:`format_file_contents`.
+    `mode` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
-        mode |= FileMode.PYI
+        mode = evolve(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
         src_contents, encoding, newline = decode_bytes(buf.read())
     try:
-        dst_contents = format_file_contents(
-            src_contents, line_length=line_length, fast=fast, mode=mode
-        )
+        dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
     except NothingChanged:
         return False
 
@@ -520,9 +630,8 @@ def format_file_in_place(
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
-        if lock:
-            lock.acquire()
-        try:
+
+        with lock or nullcontext():
             f = io.TextIOWrapper(
                 sys.stdout.buffer,
                 encoding=encoding,
@@ -531,30 +640,24 @@ def format_file_in_place(
             )
             f.write(diff_contents)
             f.detach()
-        finally:
-            if lock:
-                lock.release()
+
     return True
 
 
 def format_stdin_to_stdout(
-    line_length: int,
-    fast: bool,
-    write_back: WriteBack = WriteBack.NO,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
     If `write_back` is YES, write reformatted code back to stdout. If it is DIFF,
-    write a diff to stdout.
-    `line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
+    write a diff to stdout. The `mode` argument is passed to
     :func:`format_file_contents`.
     """
     then = datetime.utcnow()
     src, encoding, newline = decode_bytes(sys.stdin.buffer.read())
     dst = src
     try:
-        dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
+        dst = format_file_contents(src, fast=fast, mode=mode)
         return True
 
     except NothingChanged:
@@ -575,63 +678,66 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(
-    src_contents: str,
-    *,
-    line_length: int,
-    fast: bool,
-    mode: FileMode = FileMode.AUTO_DETECT,
+    src_contents: str, *, fast: bool, mode: FileMode
 ) -> FileContent:
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
-    `line_length` is passed to :func:`format_str`.
+    `mode` is passed to :func:`format_str`.
     """
     if src_contents.strip() == "":
         raise NothingChanged
 
-    dst_contents = format_str(src_contents, line_length=line_length, mode=mode)
+    dst_contents = format_str(src_contents, mode=mode)
     if src_contents == dst_contents:
         raise NothingChanged
 
     if not fast:
         assert_equivalent(src_contents, dst_contents)
-        assert_stable(src_contents, dst_contents, line_length=line_length, mode=mode)
+        assert_stable(src_contents, dst_contents, mode=mode)
     return dst_contents
 
 
-def format_str(
-    src_contents: str, line_length: int, *, mode: FileMode = FileMode.AUTO_DETECT
-) -> FileContent:
+def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
     """Reformat a string and return new contents.
 
-    `line_length` determines how many characters per line are allowed.
+    `mode` determines formatting options, such as how many characters per line are
+    allowed.
     """
-    src_node = lib2to3_parse(src_contents)
-    dst_contents = ""
+    src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
+    dst_contents = []
     future_imports = get_future_imports(src_node)
-    is_pyi = bool(mode & FileMode.PYI)
-    py36 = bool(mode & FileMode.PYTHON36) or is_python36(src_node)
-    normalize_strings = not bool(mode & FileMode.NO_STRING_NORMALIZATION)
+    if mode.target_versions:
+        versions = mode.target_versions
+    else:
+        versions = detect_target_versions(src_node)
     normalize_fmt_off(src_node)
     lines = LineGenerator(
-        remove_u_prefix=py36 or "unicode_literals" in future_imports,
-        is_pyi=is_pyi,
-        normalize_strings=normalize_strings,
-        allow_underscores=py36,
+        remove_u_prefix="unicode_literals" in future_imports
+        or supports_feature(versions, Feature.UNICODE_LITERALS),
+        is_pyi=mode.is_pyi,
+        normalize_strings=mode.string_normalization,
     )
-    elt = EmptyLineTracker(is_pyi=is_pyi)
+    elt = EmptyLineTracker(is_pyi=mode.is_pyi)
     empty_line = Line()
     after = 0
+    split_line_features = {
+        feature
+        for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
+        if supports_feature(versions, feature)
+    }
     for current_line in lines.visit(src_node):
         for _ in range(after):
-            dst_contents += str(empty_line)
+            dst_contents.append(str(empty_line))
         before, after = elt.maybe_empty_lines(current_line)
         for _ in range(before):
-            dst_contents += str(empty_line)
-        for line in split_line(current_line, line_length=line_length, py36=py36):
-            dst_contents += str(line)
-    return dst_contents
+            dst_contents.append(str(empty_line))
+        for line in split_line(
+            current_line, line_length=mode.line_length, features=split_line_features
+        ):
+            dst_contents.append(str(line))
+    return "".join(dst_contents)
 
 
 def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
@@ -651,19 +757,50 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
         return tiow.read(), encoding, newline
 
 
-GRAMMARS = [
-    pygram.python_grammar_no_print_statement_no_exec_statement,
-    pygram.python_grammar_no_print_statement,
-    pygram.python_grammar,
-]
+def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
+    if not target_versions:
+        # No target_version specified, so try all grammars.
+        return [
+            # Python 3.7+
+            pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords,
+            # Python 3.0-3.6
+            pygram.python_grammar_no_print_statement_no_exec_statement,
+            # Python 2.7 with future print_function import
+            pygram.python_grammar_no_print_statement,
+            # Python 2.7
+            pygram.python_grammar,
+        ]
+    elif all(version.is_python2() for version in target_versions):
+        # Python 2-only code, so try Python 2 grammars.
+        return [
+            # Python 2.7 with future print_function import
+            pygram.python_grammar_no_print_statement,
+            # Python 2.7
+            pygram.python_grammar,
+        ]
+    else:
+        # Python 3-compatible code, so only try Python 3 grammar.
+        grammars = []
+        # If we have to parse both, try to parse async as a keyword first
+        if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS):
+            # Python 3.7+
+            grammars.append(
+                pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords  # noqa: B950
+            )
+        if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS):
+            # Python 3.0-3.6
+            grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement)
+        # At least one of the above branches must have been taken, because every Python
+        # version has exactly one of the two 'ASYNC_*' flags
+        return grammars
 
 
-def lib2to3_parse(src_txt: str) -> Node:
+def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
     """Given a string with source, return the lib2to3 Node."""
-    grammar = pygram.python_grammar_no_print_statement
     if src_txt[-1:] != "\n":
         src_txt += "\n"
-    for grammar in GRAMMARS:
+
+    for grammar in get_grammars(set(target_versions)):
         drv = driver.Driver(grammar, pytree.convert)
         try:
             result = drv.parse_string(src_txt, True)
@@ -676,7 +813,7 @@ def lib2to3_parse(src_txt: str) -> Node:
                 faulty_line = lines[lineno - 1]
             except IndexError:
                 faulty_line = "<line number missing in source>"
-            exc = ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}")
+            exc = InvalidInput(f"Cannot parse: {lineno}:{column}: {faulty_line}")
     else:
         raise exc from None
 
@@ -756,9 +893,7 @@ class DebugVisitor(Visitor[T]):
         list(v.visit(code))
 
 
-KEYWORDS = set(keyword.kwlist)
 WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-FLOW_CONTROL = {"return", "raise", "break", "continue"}
 STATEMENT = {
     syms.if_stmt,
     syms.while_stmt,
@@ -877,8 +1012,8 @@ class BracketTracker:
     bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
     delimiters: Dict[LeafID, Priority] = Factory(dict)
     previous: Optional[Leaf] = None
-    _for_loop_variable: int = 0
-    _lambda_arguments: int = 0
+    _for_loop_depths: List[int] = Factory(list)
+    _lambda_argument_depths: List[int] = Factory(list)
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -924,7 +1059,7 @@ class BracketTracker:
         """Return True if there is an yet unmatched open bracket on the line."""
         return bool(self.bracket_match)
 
-    def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> int:
+    def max_delimiter_priority(self, exclude: Iterable[LeafID] = ()) -> Priority:
         """Return the highest priority of a delimiter found on the line.
 
         Values are consistent with what `is_split_*_delimiter()` return.
@@ -932,7 +1067,7 @@ class BracketTracker:
         """
         return max(v for k, v in self.delimiters.items() if k not in exclude)
 
-    def delimiter_count_with_priority(self, priority: int = 0) -> int:
+    def delimiter_count_with_priority(self, priority: Priority = 0) -> int:
         """Return the number of delimiters with the given `priority`.
 
         If no `priority` is passed, defaults to max priority on the line.
@@ -951,16 +1086,21 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "for":
             self.depth += 1
-            self._for_loop_variable += 1
+            self._for_loop_depths.append(self.depth)
             return True
 
         return False
 
     def maybe_decrement_after_for_loop_variable(self, leaf: Leaf) -> bool:
         """See `maybe_increment_for_loop_variable` above for explanation."""
-        if self._for_loop_variable and leaf.type == token.NAME and leaf.value == "in":
+        if (
+            self._for_loop_depths
+            and self._for_loop_depths[-1] == self.depth
+            and leaf.type == token.NAME
+            and leaf.value == "in"
+        ):
             self.depth -= 1
-            self._for_loop_variable -= 1
+            self._for_loop_depths.pop()
             return True
 
         return False
@@ -973,16 +1113,20 @@ class BracketTracker:
         """
         if leaf.type == token.NAME and leaf.value == "lambda":
             self.depth += 1
-            self._lambda_arguments += 1
+            self._lambda_argument_depths.append(self.depth)
             return True
 
         return False
 
     def maybe_decrement_after_lambda_arguments(self, leaf: Leaf) -> bool:
         """See `maybe_increment_lambda_arguments` above for explanation."""
-        if self._lambda_arguments and leaf.type == token.COLON:
+        if (
+            self._lambda_argument_depths
+            and self._lambda_argument_depths[-1] == self.depth
+            and leaf.type == token.COLON
+        ):
             self.depth -= 1
-            self._lambda_arguments -= 1
+            self._lambda_argument_depths.pop()
             return True
 
         return False
@@ -998,7 +1142,7 @@ class Line:
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
-    comments: List[Tuple[Index, Leaf]] = Factory(list)
+    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
@@ -1129,6 +1273,32 @@ class Line:
             if leaf.type == STANDALONE_COMMENT:
                 if leaf.bracket_depth <= depth_limit:
                     return True
+        return False
+
+    def contains_inner_type_comments(self) -> bool:
+        ignored_ids = set()
+        try:
+            last_leaf = self.leaves[-1]
+            ignored_ids.add(id(last_leaf))
+            if last_leaf.type == token.COMMA or (
+                last_leaf.type == token.RPAR and not last_leaf.value
+            ):
+                # When trailing commas or optional parens are inserted by Black for
+                # consistency, comments after the previous last element are not moved
+                # (they don't have to, rendering will still be correct).  So we ignore
+                # trailing commas and invisible.
+                last_leaf = self.leaves[-2]
+                ignored_ids.add(id(last_leaf))
+        except IndexError:
+            return False
+
+        for leaf_id, comments in self.comments.items():
+            if leaf_id in ignored_ids:
+                continue
+
+            for comment in comments:
+                if is_type_comment(comment):
+                    return True
 
         return False
 
@@ -1164,7 +1334,7 @@ class Line:
             self.remove_trailing_comma()
             return True
 
-        # Otheriwsse, if the trailing one is the only one, we might mistakenly
+        # Otherwise, if the trailing one is the only one, we might mistakenly
         # change a tuple into a different type by removing the comma.
         depth = closing.bracket_depth + 1
         commas = 0
@@ -1183,7 +1353,10 @@ class Line:
             bracket_depth = leaf.bracket_depth
             if bracket_depth == depth and leaf.type == token.COMMA:
                 commas += 1
-                if leaf.parent and leaf.parent.type == syms.arglist:
+                if leaf.parent and leaf.parent.type in {
+                    syms.arglist,
+                    syms.typedargslist,
+                }:
                     commas += 1
                     break
 
@@ -1205,44 +1378,41 @@ class Line:
         if comment.type != token.COMMENT:
             return False
 
-        after = len(self.leaves) - 1
-        if after == -1:
+        if not self.leaves:
             comment.type = STANDALONE_COMMENT
             comment.prefix = ""
             return False
 
-        else:
-            self.comments.append((after, comment))
-            return True
-
-    def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
-        """Generate comments that should appear directly after `leaf`.
-
-        Provide a non-negative leaf `_index` to speed up the function.
-        """
-        if not self.comments:
-            return
-
-        if _index == -1:
-            for _index, _leaf in enumerate(self.leaves):
-                if leaf is _leaf:
-                    break
-
-            else:
-                return
+        last_leaf = self.leaves[-1]
+        if (
+            last_leaf.type == token.RPAR
+            and not last_leaf.value
+            and last_leaf.parent
+            and len(list(last_leaf.parent.leaves())) <= 3
+            and not is_type_comment(comment)
+        ):
+            # Comments on an optional parens wrapping a single leaf should belong to
+            # the wrapped node except if it's a type comment. Pinning the comment like
+            # this avoids unstable formatting caused by comment migration.
+            if len(self.leaves) < 2:
+                comment.type = STANDALONE_COMMENT
+                comment.prefix = ""
+                return False
+            last_leaf = self.leaves[-2]
+        self.comments.setdefault(id(last_leaf), []).append(comment)
+        return True
 
-        for index, comment_after in self.comments:
-            if _index == index:
-                yield comment_after
+    def comments_after(self, leaf: Leaf) -> List[Leaf]:
+        """Generate comments that should appear directly after `leaf`."""
+        return self.comments.get(id(leaf), [])
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
-        comma_index = len(self.leaves) - 1
-        for i in range(len(self.comments)):
-            comment_index, comment = self.comments[i]
-            if comment_index == comma_index:
-                self.comments[i] = (comma_index - 1, comment)
-        self.leaves.pop()
+        trailing_comma = self.leaves.pop()
+        trailing_comma_comments = self.comments.pop(id(trailing_comma), [])
+        self.comments.setdefault(id(self.leaves[-1]), []).extend(
+            trailing_comma_comments
+        )
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         """Return True iff `leaf` is part of a slice with non-trivial exprs."""
@@ -1273,7 +1443,7 @@ class Line:
         res = f"{first.prefix}{indent}{first.value}"
         for leaf in leaves:
             res += str(leaf)
-        for _, comment in self.comments:
+        for comment in itertools.chain.from_iterable(self.comments.values()):
             res += str(comment)
         return res + "\n"
 
@@ -1377,7 +1547,7 @@ class EmptyLineTracker:
                 newlines = 1
             elif current_line.is_class or self.previous_line.is_class:
                 if current_line.is_stub_class and self.previous_line.is_stub_class:
-                    # No blank line between classes with an emty body
+                    # No blank line between classes with an empty body
                     newlines = 0
                 else:
                     newlines = 1
@@ -1405,7 +1575,6 @@ class LineGenerator(Visitor[Line]):
     normalize_strings: bool = True
     current_line: Line = Factory(Line)
     remove_u_prefix: bool = False
-    allow_underscores: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1448,11 +1617,44 @@ class LineGenerator(Visitor[Line]):
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
-                normalize_numeric_literal(node, self.allow_underscores)
+                normalize_numeric_literal(node)
             if node.type not in WHITESPACE:
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
+    def visit_atom(self, node: Node) -> Iterator[Line]:
+        # Always make parentheses invisible around a single node, because it should
+        # not be needed (except in the case of yield, where removing the parentheses
+        # produces a SyntaxError).
+        if (
+            len(node.children) == 3
+            and isinstance(node.children[0], Leaf)
+            and node.children[0].type == token.LPAR
+            and isinstance(node.children[2], Leaf)
+            and node.children[2].type == token.RPAR
+            and isinstance(node.children[1], Leaf)
+            and not (
+                node.children[1].type == token.NAME
+                and node.children[1].value == "yield"
+            )
+        ):
+            node.children[0].value = ""
+            node.children[2].value = ""
+        yield from super().visit_default(node)
+
+    def visit_factor(self, node: Node) -> Iterator[Line]:
+        """Force parentheses between a unary op and a binary power:
+
+        -2 ** 8 -> -(2 ** 8)
+        """
+        child = node.children[1]
+        if child.type == syms.power and len(child.children) == 3:
+            lpar = Leaf(token.LPAR, "(")
+            rpar = Leaf(token.RPAR, ")")
+            index = child.remove() or 0
+            node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+        yield from self.visit_default(node)
+
     def visit_INDENT(self, node: Node) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
@@ -1572,6 +1774,7 @@ class LineGenerator(Visitor[Line]):
         self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS)
         self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"})
         self.visit_import_from = partial(v, keywords=Ø, parens={"import"})
+        self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"})
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
@@ -1584,7 +1787,7 @@ BRACKETS = OPENING_BRACKETS | CLOSING_BRACKETS
 ALWAYS_NO_SPACE = CLOSING_BRACKETS | {token.COMMA, STANDALONE_COMMENT}
 
 
-def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa C901
+def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
     """Return whitespace prefix if needed for the given `leaf`.
 
     `complex_subscript` signals whether the given leaf is part of a subscription
@@ -1865,7 +2068,7 @@ def container_of(leaf: Leaf) -> LN:
     return container
 
 
-def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
+def is_split_after_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
     """Return the priority of the `leaf` delimiter, given a line break after it.
 
     The delimiter priorities returned here are from those delimiters that would
@@ -1879,8 +2082,8 @@ def is_split_after_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     return 0
 
 
-def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
-    """Return the priority of the `leaf` delimiter, given a line before after it.
+def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Priority:
+    """Return the priority of the `leaf` delimiter, given a line break before it.
 
     The delimiter priorities returned here are from those delimiters that would
     cause a line break before themselves.
@@ -1917,15 +2120,20 @@ def is_split_before_delimiter(leaf: Leaf, previous: Leaf = None) -> int:
     ):
         return STRING_PRIORITY
 
-    if leaf.type != token.NAME:
+    if leaf.type not in {token.NAME, token.ASYNC}:
         return 0
 
     if (
         leaf.value == "for"
         and leaf.parent
         and leaf.parent.type in {syms.comp_for, syms.old_comp_for}
+        or leaf.type == token.ASYNC
     ):
-        return COMPREHENSION_PRIORITY
+        if (
+            not isinstance(leaf.prev_sibling, Leaf)
+            or leaf.prev_sibling.value != "async"
+        ):
+            return COMPREHENSION_PRIORITY
 
     if (
         leaf.value == "if"
@@ -1999,6 +2207,16 @@ def generate_comments(leaf: LN) -> Iterator[Leaf]:
 
 @dataclass
 class ProtoComment:
+    """Describes a piece of syntax that is a comment.
+
+    It's not a :class:`blib2to3.pytree.Leaf` so that:
+
+    * it can be cached (`Leaf` objects should not be reused more than once as
+      they store their lineno, column, prefix, and parent information);
+    * `newlines` and `consumed` fields are kept separate from the `value`. This
+      simplifies handling of special marker comments like ``# fmt: off/on``.
+    """
+
     type: int  # token.COMMENT or STANDALONE_COMMENT
     value: str  # content of the comment
     newlines: int  # how many newlines before the comment
@@ -2007,21 +2225,28 @@ class ProtoComment:
 
 @lru_cache(maxsize=4096)
 def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
+    """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`."""
     result: List[ProtoComment] = []
     if not prefix or "#" not in prefix:
         return result
 
     consumed = 0
     nlines = 0
+    ignored_lines = 0
     for index, line in enumerate(prefix.split("\n")):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
             nlines += 1
         if not line.startswith("#"):
+            # Escaped newlines outside of a comment are not really newlines at
+            # all. We treat a single-line comment following an escaped newline
+            # as a simple trailing comment.
+            if line.endswith("\\"):
+                ignored_lines += 1
             continue
 
-        if index == 0 and not is_endmarker:
+        if index == ignored_lines and not is_endmarker:
             comment_type = token.COMMENT  # simple trailing comment
         else:
             comment_type = STANDALONE_COMMENT
@@ -2038,8 +2263,8 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
 def make_comment(content: str) -> str:
     """Return a consistently formatted comment from the given `content` string.
 
-    All comments (except for "##", "#!", "#:") should have a single space between
-    the hash sign and the content.
+    All comments (except for "##", "#!", "#:", '#'", "#%%") should have a single
+    space between the hash sign and the content.
 
     If `content` didn't start with a hash sign, one is provided.
     """
@@ -2049,13 +2274,16 @@ def make_comment(content: str) -> str:
 
     if content[0] == "#":
         content = content[1:]
-    if content and content[0] not in " !:#":
+    if content and content[0] not in " !:#'%":
         content = " " + content
     return "#" + content
 
 
 def split_line(
-    line: Line, line_length: int, inner: bool = False, py36: bool = False
+    line: Line,
+    line_length: int,
+    inner: bool = False,
+    features: Collection[Feature] = (),
 ) -> Iterator[Line]:
     """Split a `line` into potentially many lines.
 
@@ -2064,16 +2292,18 @@ def split_line(
     current `line`, possibly transitively. This means we can fallback to splitting
     by delimiters if the LHS/RHS don't yield any results.
 
-    If `py36` is True, splitting may generate syntax that is only compatible
-    with Python 3.6 and later.
+    `features` are syntactical features that may be used in the output.
     """
     if line.is_comment:
         yield line
         return
 
     line_str = str(line).strip("\n")
-    if not line.should_explode and is_line_short_enough(
-        line, line_length=line_length, line_str=line_str
+
+    if (
+        not line.contains_inner_type_comments()
+        and not line.should_explode
+        and is_line_short_enough(line, line_length=line_length, line_str=line_str)
     ):
         yield line
         return
@@ -2083,9 +2313,9 @@ def split_line(
         split_funcs = [left_hand_split]
     else:
 
-        def rhs(line: Line, py36: bool = False) -> Iterator[Line]:
+        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
             for omit in generate_trailers_to_omit(line, line_length):
-                lines = list(right_hand_split(line, line_length, py36, omit=omit))
+                lines = list(right_hand_split(line, line_length, features, omit=omit))
                 if is_line_short_enough(lines[0], line_length=line_length):
                     yield from lines
                     return
@@ -2093,7 +2323,7 @@ def split_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line.
-            yield from right_hand_split(line, py36)
+            yield from right_hand_split(line, line_length, features=features)
 
         if line.inside_brackets:
             split_funcs = [delimiter_split, standalone_comment_split, rhs]
@@ -2105,14 +2335,16 @@ def split_line(
         # split altogether.
         result: List[Line] = []
         try:
-            for l in split_func(line, py36):
+            for l in split_func(line, features):
                 if str(l).strip("\n") == line_str:
                     raise CannotSplit("Split function returned an unchanged result")
 
                 result.extend(
-                    split_line(l, line_length=line_length, inner=True, py36=py36)
+                    split_line(
+                        l, line_length=line_length, inner=True, features=features
+                    )
                 )
-        except CannotSplit as cs:
+        except CannotSplit:
             continue
 
         else:
@@ -2123,16 +2355,13 @@ def split_line(
         yield line
 
 
-def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split line into many lines, starting with the first matching bracket pair.
 
     Note: this usually looks weird, only use this for function definitions.
     Prefer RHS otherwise.  This is why this function is not symmetrical with
     :func:`right_hand_split` which also handles optional parentheses.
     """
-    head = Line(depth=line.depth)
-    body = Line(depth=line.depth + 1, inside_brackets=True)
-    tail = Line(depth=line.depth)
     tail_leaves: List[Leaf] = []
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
@@ -2150,15 +2379,12 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
             if leaf.type in OPENING_BRACKETS:
                 matching_bracket = leaf
                 current_leaves = body_leaves
-    # Since body is a new indent level, remove spurious leading whitespace.
-    if body_leaves:
-        normalize_prefix(body_leaves[0], inside_brackets=True)
-    # Build the new lines.
-    for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
-        for leaf in leaves:
-            result.append(leaf, preformatted=True)
-            for comment_after in line.comments_after(leaf):
-                result.append(comment_after, preformatted=True)
+    if not matching_bracket:
+        raise CannotSplit("No brackets found")
+
+    head = bracket_split_build_line(head_leaves, line, matching_bracket)
+    body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True)
+    tail = bracket_split_build_line(tail_leaves, line, matching_bracket)
     bracket_split_succeeded_or_raise(head, body, tail)
     for result in (head, body, tail):
         if result:
@@ -2166,7 +2392,10 @@ def left_hand_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 def right_hand_split(
-    line: Line, line_length: int, py36: bool = False, omit: Collection[LeafID] = ()
+    line: Line,
+    line_length: int,
+    features: Collection[Feature] = (),
+    omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
     """Split line into many lines, starting with the last matching bracket pair.
 
@@ -2176,9 +2405,6 @@ def right_hand_split(
 
     Note: running this function modifies `bracket_depth` on the leaves of `line`.
     """
-    head = Line(depth=line.depth)
-    body = Line(depth=line.depth + 1, inside_brackets=True)
-    tail = Line(depth=line.depth)
     tail_leaves: List[Leaf] = []
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
@@ -2195,25 +2421,18 @@ def right_hand_split(
                 opening_bracket = leaf.opening_bracket
                 closing_bracket = leaf
                 current_leaves = body_leaves
-    tail_leaves.reverse()
-    body_leaves.reverse()
-    head_leaves.reverse()
-    # Since body is a new indent level, remove spurious leading whitespace.
-    if body_leaves:
-        normalize_prefix(body_leaves[0], inside_brackets=True)
-    if not head_leaves:
-        # No `head` means the split failed. Either `tail` has all content or
+    if not (opening_bracket and closing_bracket and head_leaves):
+        # If there is no opening or closing_bracket that means the split failed and
+        # all content is in the tail.  Otherwise, if `head_leaves` are empty, it means
         # the matching `opening_bracket` wasn't available on `line` anymore.
         raise CannotSplit("No brackets found")
 
-    # Build the new lines.
-    for result, leaves in (head, head_leaves), (body, body_leaves), (tail, tail_leaves):
-        for leaf in leaves:
-            result.append(leaf, preformatted=True)
-            for comment_after in line.comments_after(leaf):
-                result.append(comment_after, preformatted=True)
-    assert opening_bracket and closing_bracket
-    body.should_explode = should_explode(body, opening_bracket)
+    tail_leaves.reverse()
+    body_leaves.reverse()
+    head_leaves.reverse()
+    head = bracket_split_build_line(head_leaves, line, opening_bracket)
+    body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True)
+    tail = bracket_split_build_line(tail_leaves, line, opening_bracket)
     bracket_split_succeeded_or_raise(head, body, tail)
     if (
         # the body shouldn't be exploded
@@ -2234,7 +2453,7 @@ def right_hand_split(
     ):
         omit = {id(closing_bracket), *omit}
         try:
-            yield from right_hand_split(line, line_length, py36=py36, omit=omit)
+            yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
         except CannotSplit:
@@ -2287,6 +2506,46 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
             )
 
 
+def bracket_split_build_line(
+    leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False
+) -> Line:
+    """Return a new line with given `leaves` and respective comments from `original`.
+
+    If `is_body` is True, the result line is one-indented inside brackets and as such
+    has its first leaf's prefix normalized and a trailing comma added when expected.
+    """
+    result = Line(depth=original.depth)
+    if is_body:
+        result.inside_brackets = True
+        result.depth += 1
+        if leaves:
+            # Since body is a new indent level, remove spurious leading whitespace.
+            normalize_prefix(leaves[0], inside_brackets=True)
+            # Ensure a trailing comma for imports and standalone function arguments, but
+            # be careful not to add one after any comments.
+            no_commas = original.is_def and not any(
+                l.type == token.COMMA for l in leaves
+            )
+
+            if original.is_import or no_commas:
+                for i in range(len(leaves) - 1, -1, -1):
+                    if leaves[i].type == STANDALONE_COMMENT:
+                        continue
+                    elif leaves[i].type == token.COMMA:
+                        break
+                    else:
+                        leaves.insert(i + 1, Leaf(token.COMMA, ","))
+                        break
+    # Populate the line
+    for leaf in leaves:
+        result.append(leaf, preformatted=True)
+        for comment_after in original.comments_after(leaf):
+            result.append(comment_after, preformatted=True)
+    if is_body:
+        result.should_explode = should_explode(result, opening_bracket)
+    return result
+
+
 def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """Normalize prefix of the first leaf in every line returned by `split_func`.
 
@@ -2294,8 +2553,8 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
     """
 
     @wraps(split_func)
-    def split_wrapper(line: Line, py36: bool = False) -> Iterator[Line]:
-        for l in split_func(line, py36):
+    def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
+        for l in split_func(line, features):
             normalize_prefix(l.leaves[0], inside_brackets=True)
             yield l
 
@@ -2303,11 +2562,11 @@ def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc:
 
 
 @dont_increase_indentation
-def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]:
     """Split according to delimiters of the highest priority.
 
-    If `py36` is True, the split will add trailing commas also in function
-    signatures that contain `*` and `**`.
+    If the appropriate Features are given, the split will add trailing commas
+    also in function signatures and calls that contain `*` and `**`.
     """
     try:
         last_leaf = line.leaves[-1]
@@ -2333,23 +2592,29 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
         nonlocal current_line
         try:
             current_line.append_safe(leaf, preformatted=True)
-        except ValueError as ve:
+        except ValueError:
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
-        if leaf.bracket_depth == lowest_depth and is_vararg(
-            leaf, within=VARARGS_PARENTS
-        ):
-            trailing_comma_safe = trailing_comma_safe and py36
+        if leaf.bracket_depth == lowest_depth:
+            if is_vararg(leaf, within={syms.typedargslist}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_DEF in features
+                )
+            elif is_vararg(leaf, within={syms.arglist, syms.argument}):
+                trailing_comma_safe = (
+                    trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features
+                )
+
         leaf_priority = bt.delimiters.get(id(leaf))
         if leaf_priority == delimiter_priority:
             yield current_line
@@ -2367,7 +2632,9 @@ def delimiter_split(line: Line, py36: bool = False) -> Iterator[Line]:
 
 
 @dont_increase_indentation
-def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
+def standalone_comment_split(
+    line: Line, features: Collection[Feature] = ()
+) -> Iterator[Line]:
     """Split standalone comments from the rest of the line."""
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
@@ -2379,16 +2646,16 @@ def standalone_comment_split(line: Line, py36: bool = False) -> Iterator[Line]:
         nonlocal current_line
         try:
             current_line.append_safe(leaf, preformatted=True)
-        except ValueError as ve:
+        except ValueError:
             yield current_line
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
     if current_line:
@@ -2409,6 +2676,14 @@ def is_import(leaf: Leaf) -> bool:
     )
 
 
+def is_type_comment(leaf: Leaf) -> bool:
+    """Return True if the given leaf is a special comment.
+    Only returns true for type comments for now."""
+    t = leaf.type
+    v = leaf.value
+    return t in {token.COMMENT, t == STANDALONE_COMMENT} and v.startswith("# type:")
+
+
 def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None:
     """Leave existing extra newlines if not `inside_brackets`. Remove everything
     else.
@@ -2491,7 +2766,15 @@ def normalize_string_quotes(leaf: Leaf) -> None:
         new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
         new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
     if "f" in prefix.casefold():
-        matches = re.findall(r"[^{]\{(.*?)\}[^}]", new_body)
+        matches = re.findall(
+            r"""
+            (?:[^{]|^)\{  # start of the string or a non-{ followed by a single {
+                ([^{].*?)  # contents of the brackets except if begins with {{
+            \}(?:[^}]|$)  # A } followed by end of the string or a non-}
+            """,
+            new_body,
+            re.VERBOSE,
+        )
         for m in matches:
             if "\\" in str(m):
                 # Do not introduce backslashes in interpolated expressions
@@ -2510,65 +2793,55 @@ def normalize_string_quotes(leaf: Leaf) -> None:
     leaf.value = f"{prefix}{new_quote}{new_body}{new_quote}"
 
 
-def normalize_numeric_literal(leaf: Leaf, allow_underscores: bool) -> None:
-    """Normalizes numeric (float, int, and complex) literals."""
-    # We want all letters (e in exponents, j in complex literals, a-f
-    # in hex literals) to be lowercase.
+def normalize_numeric_literal(leaf: Leaf) -> None:
+    """Normalizes numeric (float, int, and complex) literals.
+
+    All letters used in the representation are normalized to lowercase (except
+    in Python 2 long literals).
+    """
     text = leaf.value.lower()
-    if text.startswith(("0o", "0x", "0b")):
-        # Leave octal, hex, and binary literals alone for now.
+    if text.startswith(("0o", "0b")):
+        # Leave octal and binary literals alone.
         pass
+    elif text.startswith("0x"):
+        # Change hex literals to upper case.
+        before, after = text[:2], text[2:]
+        text = f"{before}{after.upper()}"
     elif "e" in text:
         before, after = text.split("e")
+        sign = ""
         if after.startswith("-"):
             after = after[1:]
             sign = "-"
         elif after.startswith("+"):
             after = after[1:]
-            sign = ""
-        else:
-            sign = ""
-        before = format_float_or_int_string(before, allow_underscores)
-        after = format_int_string(after, allow_underscores)
+        before = format_float_or_int_string(before)
         text = f"{before}e{sign}{after}"
-    # Complex numbers and Python 2 longs
-    elif "j" in text or "l" in text:
+    elif text.endswith(("j", "l")):
         number = text[:-1]
         suffix = text[-1]
-        text = f"{format_float_or_int_string(number, allow_underscores)}{suffix}"
+        # Capitalize in "2L" because "l" looks too similar to "1".
+        if suffix == "l":
+            suffix = "L"
+        text = f"{format_float_or_int_string(number)}{suffix}"
     else:
-        text = format_float_or_int_string(text, allow_underscores)
+        text = format_float_or_int_string(text)
     leaf.value = text
 
 
-def format_float_or_int_string(text: str, allow_underscores: bool) -> str:
+def format_float_or_int_string(text: str) -> str:
     """Formats a float string like "1.0"."""
     if "." not in text:
-        return format_int_string(text, allow_underscores)
-    before, after = text.split(".")
-    before = format_int_string(before, allow_underscores) if before else "0"
-    after = format_int_string(after, allow_underscores) if after else "0"
-    return f"{before}.{after}"
-
-
-def format_int_string(text: str, allow_underscores: bool) -> str:
-    """Normalizes underscores in a string to e.g. 1_000_000.
-
-    Input must be a string consisting only of digits and underscores.
-    """
-    if not allow_underscores:
-        return text
-    text = text.replace("_", "")
-    if len(text) <= 6:
-        # No underscores for numbers <= 6 digits long.
         return text
-    return format(int(text), "3_")
+
+    before, after = text.split(".")
+    return f"{before or 0}.{after or 0}"
 
 
 def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
     """Make existing optional parentheses invisible or create new ones.
 
-    `parens_after` is a set of string leaf values immeditely after which parens
+    `parens_after` is a set of string leaf values immediately after which parens
     should be put.
 
     Standardizes on visible parentheses for single-element tuples, and keeps
@@ -2581,9 +2854,21 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
     check_lpar = False
     for index, child in enumerate(list(node.children)):
+        # Add parentheses around long tuple unpacking in assignments.
+        if (
+            index == 0
+            and isinstance(child, Node)
+            and child.type == syms.testlist_star_expr
+        ):
+            check_lpar = True
+
         if check_lpar:
             if child.type == syms.atom:
-                maybe_make_parens_invisible_in_atom(child)
+                if maybe_make_parens_invisible_in_atom(child, parent=node):
+                    lpar = Leaf(token.LPAR, "")
+                    rpar = Leaf(token.RPAR, "")
+                    index = child.remove() or 0
+                    node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
             elif is_one_tuple(child):
                 # wrap child in visible parentheses
                 lpar = Leaf(token.LPAR, "(")
@@ -2608,7 +2893,11 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 lpar = Leaf(token.LPAR, "")
                 rpar = Leaf(token.RPAR, "")
                 index = child.remove() or 0
-                node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                prefix = child.prefix
+                child.prefix = ""
+                new_child = Node(syms.atom, [lpar, child, rpar])
+                new_child.prefix = prefix
+                node.insert_child(index, new_child)
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
@@ -2690,13 +2979,17 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
         container = container.next_sibling
 
 
-def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
-    """If it's safe, make the parens in the atom `node` invisible, recursively."""
+def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
+    """If it's safe, make the parens in the atom `node` invisible, recursively.
+
+    Returns whether the node should itself be wrapped in invisible parentheses.
+
+    """
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
         or is_one_tuple(node)
-        or is_yield(node)
+        or (is_yield(node) and parent.type != syms.expr_stmt)
         or max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY
     ):
         return False
@@ -2708,10 +3001,10 @@ def maybe_make_parens_invisible_in_atom(node: LN) -> bool:
         first.value = ""  # type: ignore
         last.value = ""  # type: ignore
         if len(node.children) > 1:
-            maybe_make_parens_invisible_in_atom(node.children[1])
-        return True
+            maybe_make_parens_invisible_in_atom(node.children[1], parent=parent)
+        return False
 
-    return False
+    return True
 
 
 def is_empty_tuple(node: LN) -> bool:
@@ -2826,7 +3119,7 @@ def is_stub_body(node: LN) -> bool:
     )
 
 
-def max_delimiter_priority_in_atom(node: LN) -> int:
+def max_delimiter_priority_in_atom(node: LN) -> Priority:
     """Return maximum delimiter priority inside `node`.
 
     This is specific to atoms with contents contained in a pair of parentheses.
@@ -2858,7 +3151,7 @@ def ensure_visible(leaf: Leaf) -> None:
     """Make sure parentheses are visible.
 
     They could be invisible as part of some statements (see
-    :func:`normalize_invible_parens` and :func:`visit_import_from`).
+    :func:`normalize_invisible_parens` and :func:`visit_import_from`).
     """
     if leaf.type == token.LPAR:
         leaf.value = "("
@@ -2868,6 +3161,7 @@ def ensure_visible(leaf: Leaf) -> None:
 
 def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` immediately be split with `delimiter_split()` after RHS?"""
+
     if not (
         opening_bracket.parent
         and opening_bracket.parent.type in {syms.atom, syms.import_from}
@@ -2885,34 +3179,53 @@ def should_explode(line: Line, opening_bracket: Leaf) -> bool:
     return max_priority == COMMA_PRIORITY
 
 
-def is_python36(node: Node) -> bool:
-    """Return True if the current file is using Python 3.6+ features.
+def get_features_used(node: Node) -> Set[Feature]:
+    """Return a set of (relatively) new Python features used in this file.
 
     Currently looking for:
-    - f-strings; and
+    - f-strings;
+    - underscores in numeric literals; and
     - trailing commas after * or ** in function signatures and calls.
     """
+    features: Set[Feature] = set()
     for n in node.pre_order():
         if n.type == token.STRING:
             value_head = n.value[:2]  # type: ignore
             if value_head in {'f"', 'F"', "f'", "F'", "rf", "fr", "RF", "FR"}:
-                return True
+                features.add(Feature.F_STRINGS)
+
+        elif n.type == token.NUMBER:
+            if "_" in n.value:  # type: ignore
+                features.add(Feature.NUMERIC_UNDERSCORES)
 
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
             and n.children[-1].type == token.COMMA
         ):
+            if n.type == syms.typedargslist:
+                feature = Feature.TRAILING_COMMA_IN_DEF
+            else:
+                feature = Feature.TRAILING_COMMA_IN_CALL
+
             for ch in n.children:
                 if ch.type in STARS:
-                    return True
+                    features.add(feature)
 
                 if ch.type == syms.argument:
                     for argch in ch.children:
                         if argch.type in STARS:
-                            return True
+                            features.add(feature)
 
-    return False
+    return features
+
+
+def detect_target_versions(node: Node) -> Set[TargetVersion]:
+    """Detect the version to target based on the nodes used."""
+    features = get_features_used(node)
+    return {
+        version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
+    }
 
 
 def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]:
@@ -2931,7 +3244,6 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     length = 4 * line.depth
     opening_bracket = None
     closing_bracket = None
-    optional_brackets: Set[LeafID] = set()
     inner_brackets: Set[LeafID] = set()
     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
         length += leaf_length
@@ -2942,17 +3254,12 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
         if leaf.type == STANDALONE_COMMENT or has_inline_comment:
             break
 
-        optional_brackets.discard(id(leaf))
         if opening_bracket:
             if leaf is opening_bracket:
                 opening_bracket = None
             elif leaf.type in CLOSING_BRACKETS:
                 inner_brackets.add(id(leaf))
         elif leaf.type in CLOSING_BRACKETS:
-            if not leaf.value:
-                optional_brackets.add(id(opening_bracket))
-                continue
-
             if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS:
                 # Empty brackets would fail a split so treat them as "inner"
                 # brackets (e.g. only add them to the `omit` set if another
@@ -2960,13 +3267,15 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
                 inner_brackets.add(id(leaf))
                 continue
 
-            opening_bracket = leaf.opening_bracket
             if closing_bracket:
                 omit.add(id(closing_bracket))
                 omit.update(inner_brackets)
                 inner_brackets.clear()
                 yield omit
-            closing_bracket = leaf
+
+            if leaf.value:
+                opening_bracket = leaf.opening_bracket
+                closing_bracket = leaf
 
 
 def get_future_imports(node: Node) -> Set[str]:
@@ -2986,7 +3295,7 @@ def get_future_imports(node: Node) -> Set[str]:
             elif child.type == syms.import_as_names:
                 yield from get_imports_from_children(child.children)
             else:
-                assert False, "Invalid syntax parsing imports"
+                raise AssertionError("Invalid syntax parsing imports")
 
     for child in node.children:
         if child.type != syms.simple_stmt:
@@ -3131,7 +3440,7 @@ class Report:
         - otherwise return 0.
         """
         # According to http://tldp.org/LDP/abs/html/exitcodes.html starting with
-        # 126 we have special returncodes reserved by the shell.
+        # 126 we have special return codes reserved by the shell.
         if self.failure_count:
             return 123
 
@@ -3170,17 +3479,32 @@ class Report:
         return ", ".join(report) + "."
 
 
+def parse_ast(src: str) -> Union[ast3.AST, ast27.AST]:
+    for feature_version in (7, 6):
+        try:
+            return ast3.parse(src, feature_version=feature_version)
+        except SyntaxError:
+            continue
+
+    return ast27.parse(src)
+
+
 def assert_equivalent(src: str, dst: str) -> None:
     """Raise AssertionError if `src` and `dst` aren't equivalent."""
 
-    import ast
-    import traceback
-
-    def _v(node: ast.AST, depth: int = 0) -> Iterator[str]:
+    def _v(node: Union[ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]:
         """Simple visitor generating strings to compare ASTs by content."""
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
         for field in sorted(node._fields):
+            # TypeIgnore has only one field 'lineno' which breaks this comparison
+            if isinstance(node, (ast3.TypeIgnore, ast27.TypeIgnore)):
+                break
+
+            # Ignore str kind which is case sensitive / and ignores unicode_literals
+            if isinstance(node, (ast3.Str, ast27.Str, ast3.Bytes)) and field == "kind":
+                continue
+
             try:
                 value = getattr(node, field)
             except AttributeError:
@@ -3190,10 +3514,19 @@ def assert_equivalent(src: str, dst: str) -> None:
 
             if isinstance(value, list):
                 for item in value:
-                    if isinstance(item, ast.AST):
+                    # Ignore nested tuples within del statements, because we may insert
+                    # parentheses and they change the AST.
+                    if (
+                        field == "targets"
+                        and isinstance(node, (ast3.Delete, ast27.Delete))
+                        and isinstance(item, (ast3.Tuple, ast27.Tuple))
+                    ):
+                        for item in item.elts:
+                            yield from _v(item, depth + 2)
+                    elif isinstance(item, (ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
-            elif isinstance(value, ast.AST):
+            elif isinstance(value, (ast3.AST, ast27.AST)):
                 yield from _v(value, depth + 2)
 
             else:
@@ -3202,22 +3535,20 @@ def assert_equivalent(src: str, dst: str) -> None:
         yield f"{'  ' * depth})  # /{node.__class__.__name__}"
 
     try:
-        src_ast = ast.parse(src)
+        src_ast = parse_ast(src)
     except Exception as exc:
-        major, minor = sys.version_info[:2]
         raise AssertionError(
-            f"cannot use --safe with this file; failed to parse source file "
-            f"with Python {major}.{minor}'s builtin AST. Re-run with --fast "
-            f"or stop using deprecated Python 2 syntax. AST error message: {exc}"
+            f"cannot use --safe with this file; failed to parse source file.  "
+            f"AST error message: {exc}"
         )
 
     try:
-        dst_ast = ast.parse(dst)
+        dst_ast = parse_ast(dst)
     except Exception as exc:
         log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
         raise AssertionError(
             f"INTERNAL ERROR: Black produced invalid code: {exc}. "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This invalid output might be helpful: {log}"
         ) from None
 
@@ -3228,16 +3559,14 @@ def assert_equivalent(src: str, dst: str) -> None:
         raise AssertionError(
             f"INTERNAL ERROR: Black produced code that is not equivalent to "
             f"the source.  "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
 
-def assert_stable(
-    src: str, dst: str, line_length: int, mode: FileMode = FileMode.AUTO_DETECT
-) -> None:
+def assert_stable(src: str, dst: str, mode: FileMode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
-    newdst = format_str(dst, line_length=line_length, mode=mode)
+    newdst = format_str(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
             diff(src, dst, "source", "first pass"),
@@ -3246,15 +3575,13 @@ def assert_stable(
         raise AssertionError(
             f"INTERNAL ERROR: Black produced different code on the second pass "
             f"of the formatter.  "
-            f"Please report a bug on https://github.com/ambv/black/issues.  "
+            f"Please report a bug on https://github.com/psf/black/issues.  "
             f"This diff might be helpful: {log}"
         ) from None
 
 
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
-    import tempfile
-
     with tempfile.NamedTemporaryFile(
         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
     ) as f:
@@ -3265,6 +3592,13 @@ def dump_to_file(*output: str) -> str:
     return f.name
 
 
+@contextmanager
+def nullcontext() -> Iterator[None]:
+    """Return context manager that does nothing.
+    Similar to `nullcontext` from python 3.7"""
+    yield
+
+
 def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
@@ -3283,11 +3617,15 @@ def cancel(tasks: Iterable[asyncio.Task]) -> None:
         task.cancel()
 
 
-def shutdown(loop: BaseEventLoop) -> None:
+def shutdown(loop: asyncio.AbstractEventLoop) -> None:
     """Cancel all pending tasks on `loop`, wait for them, and close the loop."""
     try:
+        if sys.version_info[:2] >= (3, 7):
+            all_tasks = asyncio.all_tasks
+        else:
+            all_tasks = asyncio.Task.all_tasks
         # This part is borrowed from asyncio/runners.py in Python 3.7b2.
-        to_cancel = [task for task in asyncio.Task.all_tasks(loop) if not task.done()]
+        to_cancel = [task for task in all_tasks(loop) if not task.done()]
         if not to_cancel:
             return
 
@@ -3348,8 +3686,7 @@ def enumerate_with_length(
         if "\n" in leaf.value:
             return  # Multiline strings, we can't continue.
 
-        comment: Optional[Leaf]
-        for comment in line.comments_after(leaf, index):
+        for comment in line.comments_after(leaf):
             length += len(comment.value)
 
         yield index, leaf, length
@@ -3494,16 +3831,16 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
-def get_cache_file(line_length: int, mode: FileMode) -> Path:
-    return CACHE_DIR / f"cache.{line_length}.{mode.value}.pickle"
+def get_cache_file(mode: FileMode) -> Path:
+    return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
-def read_cache(line_length: int, mode: FileMode) -> Cache:
+def read_cache(mode: FileMode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     if not cache_file.exists():
         return {}
 
@@ -3538,17 +3875,15 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     return todo, done
 
 
-def write_cache(
-    cache: Cache, sources: Iterable[Path], line_length: int, mode: FileMode
-) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
     """Update the cache file."""
-    cache_file = get_cache_file(line_length, mode)
+    cache_file = get_cache_file(mode)
     try:
-        if not CACHE_DIR.exists():
-            CACHE_DIR.mkdir(parents=True)
+        CACHE_DIR.mkdir(parents=True, exist_ok=True)
         new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
-        with cache_file.open("wb") as fobj:
-            pickle.dump(new_cache, fobj, protocol=pickle.HIGHEST_PROTOCOL)
+        with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
+            pickle.dump(new_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
+        os.replace(f.name, cache_file)
     except OSError:
         pass
 
@@ -3575,6 +3910,11 @@ def patch_click() -> None:
             module._verify_python3_env = lambda: None
 
 
-if __name__ == "__main__":
+def patched_main() -> None:
+    freeze_support()
     patch_click()
     main()
+
+
+if __name__ == "__main__":
+    patched_main()