]> git.madduck.net Git - etc/vim.git/blobdiff - black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update heredoc marker case to conform with vim patch 8.1.1723 (#1348)
[etc/vim.git] / black.py
index d837c89609f3e7237786b7da1657527382612c52..d9348a37b42d4c45a4bf4fd126f23c8c6626d848 100644 (file)
--- a/black.py
+++ b/black.py
@@ -37,9 +37,11 @@ from typing import (
     Union,
     cast,
 )
     Union,
     cast,
 )
+from typing_extensions import Final
+from mypy_extensions import mypyc_attr
 
 from appdirs import user_cache_dir
 
 from appdirs import user_cache_dir
-from attr import dataclass, evolve, Factory
+from dataclasses import dataclass, field, replace
 import click
 import toml
 from typed_ast import ast3, ast27
 import click
 import toml
 from typed_ast import ast3, ast27
@@ -184,8 +186,8 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
 
 
 @dataclass
 
 
 @dataclass
-class FileMode:
-    target_versions: Set[TargetVersion] = Factory(set)
+class Mode:
+    target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
     is_pyi: bool = False
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
     is_pyi: bool = False
@@ -207,10 +209,31 @@ class FileMode:
         return ".".join(parts)
 
 
         return ".".join(parts)
 
 
+# Legacy name, left for integrations.
+FileMode = Mode
+
+
 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
 def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool:
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+    """Find the absolute filepath to a pyproject.toml if it exists"""
+    path_project_root = find_project_root(path_search_start)
+    path_pyproject_toml = path_project_root / "pyproject.toml"
+    return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+    """Parse a pyproject toml file, pulling out relevant parts for Black
+
+    If parsing fails, will raise a toml.TomlDecodeError
+    """
+    pyproject_toml = toml.load(path_config)
+    config = pyproject_toml.get("tool", {}).get("black", {})
+    return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
 def read_pyproject_toml(
     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
 ) -> Optional[str]:
 def read_pyproject_toml(
     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
 ) -> Optional[str]:
@@ -221,16 +244,12 @@ def read_pyproject_toml(
     """
     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
     if not value:
     """
     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
     if not value:
-        root = find_project_root(ctx.params.get("src", ()))
-        path = root / "pyproject.toml"
-        if path.is_file():
-            value = str(path)
-        else:
+        value = find_pyproject_toml(ctx.params.get("src", ()))
+        if value is None:
             return None
 
     try:
             return None
 
     try:
-        pyproject_toml = toml.load(value)
-        config = pyproject_toml.get("tool", {}).get("black", {})
+        config = parse_pyproject_toml(value)
     except (toml.TomlDecodeError, OSError) as e:
         raise click.FileError(
             filename=value, hint=f"Error reading configuration file: {e}"
     except (toml.TomlDecodeError, OSError) as e:
         raise click.FileError(
             filename=value, hint=f"Error reading configuration file: {e}"
@@ -241,12 +260,21 @@ def read_pyproject_toml(
 
     if ctx.default_map is None:
         ctx.default_map = {}
 
     if ctx.default_map is None:
         ctx.default_map = {}
-    ctx.default_map.update(  # type: ignore  # bad types in .pyi
-        {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
-    )
+    ctx.default_map.update(config)  # type: ignore  # bad types in .pyi
     return value
 
 
     return value
 
 
+def target_version_option_callback(
+    c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
+) -> List[TargetVersion]:
+    """Compute the target versions from a --target-version flag.
+
+    This is its own function because mypy couldn't infer the type correctly
+    when it was a lambda, causing mypyc trouble.
+    """
+    return [TargetVersion[val.upper()] for val in v]
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -261,7 +289,7 @@ def read_pyproject_toml(
     "-t",
     "--target-version",
     type=click.Choice([v.name.lower() for v in TargetVersion]),
     "-t",
     "--target-version",
     type=click.Choice([v.name.lower() for v in TargetVersion]),
-    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    callback=target_version_option_callback,
     multiple=True,
     help=(
         "Python versions that should be supported by Black's output. [default: "
     multiple=True,
     help=(
         "Python versions that should be supported by Black's output. [default: "
@@ -388,14 +416,14 @@ def main(
     verbose: bool,
     include: str,
     exclude: str,
     verbose: bool,
     include: str,
     exclude: str,
-    src: Tuple[str],
+    src: Tuple[str, ...],
     config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     if target_version:
         if py36:
     config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
     write_back = WriteBack.from_configuration(check=check, diff=diff)
     if target_version:
         if py36:
-            err(f"Cannot use both --target-version and --py36")
+            err("Cannot use both --target-version and --py36")
             ctx.exit(2)
         else:
             versions = set(target_version)
             ctx.exit(2)
         else:
             versions = set(target_version)
@@ -408,7 +436,7 @@ def main(
     else:
         # We'll autodetect later.
         versions = set()
     else:
         # We'll autodetect later.
         versions = set()
-    mode = FileMode(
+    mode = Mode(
         target_versions=versions,
         line_length=line_length,
         is_pyi=pyi,
         target_versions=versions,
         line_length=line_length,
         is_pyi=pyi,
@@ -429,7 +457,7 @@ def main(
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
-    report = Report(check=check, quiet=quiet, verbose=verbose)
+    report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, quiet, verbose, ctx)
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, quiet, verbose, ctx)
@@ -470,7 +498,9 @@ def main(
     ctx.exit(report.return_code)
 
 
     ctx.exit(report.return_code)
 
 
-def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
+def path_empty(
+    src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context
+) -> None:
     """
     Exit if there is no `src` provided for formatting
     """
     """
     Exit if there is no `src` provided for formatting
     """
@@ -481,7 +511,7 @@ def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context)
 
 
 def reformat_one(
 
 
 def reformat_one(
-    src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report"
+    src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
 ) -> None:
     """Reformat a single file under `src` without spawning child processes.
 
@@ -514,11 +544,7 @@ def reformat_one(
 
 
 def reformat_many(
 
 
 def reformat_many(
-    sources: Set[Path],
-    fast: bool,
-    write_back: WriteBack,
-    mode: FileMode,
-    report: "Report",
+    sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
 ) -> None:
     """Reformat multiple files using a ProcessPoolExecutor."""
     loop = asyncio.get_event_loop()
 ) -> None:
     """Reformat multiple files using a ProcessPoolExecutor."""
     loop = asyncio.get_event_loop()
@@ -548,7 +574,7 @@ async def schedule_formatting(
     sources: Set[Path],
     fast: bool,
     write_back: WriteBack,
     sources: Set[Path],
     fast: bool,
     write_back: WriteBack,
-    mode: FileMode,
+    mode: Mode,
     report: "Report",
     loop: asyncio.AbstractEventLoop,
     executor: Executor,
     report: "Report",
     loop: asyncio.AbstractEventLoop,
     executor: Executor,
@@ -585,7 +611,7 @@ async def schedule_formatting(
         ): src
         for src in sorted(sources)
     }
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable[asyncio.Future] = tasks.keys()
+    pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -618,7 +644,7 @@ async def schedule_formatting(
 def format_file_in_place(
     src: Path,
     fast: bool,
 def format_file_in_place(
     src: Path,
     fast: bool,
-    mode: FileMode,
+    mode: Mode,
     write_back: WriteBack = WriteBack.NO,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
     write_back: WriteBack = WriteBack.NO,
     lock: Any = None,  # multiprocessing.Manager().Lock() is some crazy proxy
 ) -> bool:
@@ -629,7 +655,7 @@ def format_file_in_place(
     `mode` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
     `mode` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
-        mode = evolve(mode, is_pyi=True)
+        mode = replace(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
@@ -639,10 +665,10 @@ def format_file_in_place(
     except NothingChanged:
         return False
 
     except NothingChanged:
         return False
 
-    if write_back == write_back.YES:
+    if write_back == WriteBack.YES:
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
-    elif write_back == write_back.DIFF:
+    elif write_back == WriteBack.DIFF:
         now = datetime.utcnow()
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
         now = datetime.utcnow()
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
@@ -662,7 +688,7 @@ def format_file_in_place(
 
 
 def format_stdin_to_stdout(
 
 
 def format_stdin_to_stdout(
-    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode
+    fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
 ) -> bool:
     """Format file on stdin. Return True if changed.
 
@@ -694,9 +720,7 @@ def format_stdin_to_stdout(
         f.detach()
 
 
         f.detach()
 
 
-def format_file_contents(
-    src_contents: str, *, fast: bool, mode: FileMode
-) -> FileContent:
+def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
     """Reformat contents a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
@@ -716,11 +740,34 @@ def format_file_contents(
     return dst_contents
 
 
     return dst_contents
 
 
-def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
+def format_str(src_contents: str, *, mode: Mode) -> FileContent:
     """Reformat a string and return new contents.
 
     `mode` determines formatting options, such as how many characters per line are
     """Reformat a string and return new contents.
 
     `mode` determines formatting options, such as how many characters per line are
-    allowed.
+    allowed.  Example:
+
+    >>> import black
+    >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+    def f(arg: str = "") -> None:
+        ...
+
+    A more complex example:
+    >>> print(
+    ...   black.format_str(
+    ...     "def f(arg:str='')->None: hey",
+    ...     mode=black.Mode(
+    ...       target_versions={black.TargetVersion.PY36},
+    ...       line_length=10,
+    ...       string_normalization=False,
+    ...       is_pyi=False,
+    ...     ),
+    ...   ),
+    ... )
+    def f(
+        arg: str = '',
+    ) -> None:
+        hey
+
     """
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = []
     """
     src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
     dst_contents = []
@@ -745,11 +792,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
         if supports_feature(versions, feature)
     }
     for current_line in lines.visit(src_node):
         if supports_feature(versions, feature)
     }
     for current_line in lines.visit(src_node):
-        for _ in range(after):
-            dst_contents.append(str(empty_line))
+        dst_contents.append(str(empty_line) * after)
         before, after = elt.maybe_empty_lines(current_line)
         before, after = elt.maybe_empty_lines(current_line)
-        for _ in range(before):
-            dst_contents.append(str(empty_line))
+        dst_contents.append(str(empty_line) * before)
         for line in split_line(
             current_line, line_length=mode.line_length, features=split_line_features
         ):
         for line in split_line(
             current_line, line_length=mode.line_length, features=split_line_features
         ):
@@ -865,8 +910,16 @@ class Visitor(Generic[T]):
         if node.type < 256:
             name = token.tok_name[node.type]
         else:
         if node.type < 256:
             name = token.tok_name[node.type]
         else:
-            name = type_repr(node.type)
-        yield from getattr(self, f"visit_{name}", self.visit_default)(node)
+            name = str(type_repr(node.type))
+        # We explicitly branch on whether a visitor exists (instead of
+        # using self.visit_default as the default arg to getattr) in order
+        # to save needing to create a bound method object and so mypyc can
+        # generate a native call to visit_default.
+        visitf = getattr(self, f"visit_{name}", None)
+        if visitf:
+            yield from visitf(node)
+        else:
+            yield from self.visit_default(node)
 
     def visit_default(self, node: LN) -> Iterator[T]:
         """Default `visit_*()` implementation. Recurses to children of `node`."""
 
     def visit_default(self, node: LN) -> Iterator[T]:
         """Default `visit_*()` implementation. Recurses to children of `node`."""
@@ -911,8 +964,8 @@ class DebugVisitor(Visitor[T]):
         list(v.visit(code))
 
 
         list(v.visit(code))
 
 
-WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-STATEMENT = {
+WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
+STATEMENT: Final = {
     syms.if_stmt,
     syms.while_stmt,
     syms.for_stmt,
     syms.if_stmt,
     syms.while_stmt,
     syms.for_stmt,
@@ -922,10 +975,10 @@ STATEMENT = {
     syms.funcdef,
     syms.classdef,
 }
     syms.funcdef,
     syms.classdef,
 }
-STANDALONE_COMMENT = 153
+STANDALONE_COMMENT: Final = 153
 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
-LOGIC_OPERATORS = {"and", "or"}
-COMPARATORS = {
+LOGIC_OPERATORS: Final = {"and", "or"}
+COMPARATORS: Final = {
     token.LESS,
     token.GREATER,
     token.EQEQUAL,
     token.LESS,
     token.GREATER,
     token.EQEQUAL,
@@ -933,7 +986,7 @@ COMPARATORS = {
     token.LESSEQUAL,
     token.GREATEREQUAL,
 }
     token.LESSEQUAL,
     token.GREATEREQUAL,
 }
-MATH_OPERATORS = {
+MATH_OPERATORS: Final = {
     token.VBAR,
     token.CIRCUMFLEX,
     token.AMPER,
     token.VBAR,
     token.CIRCUMFLEX,
     token.AMPER,
@@ -949,23 +1002,23 @@ MATH_OPERATORS = {
     token.TILDE,
     token.DOUBLESTAR,
 }
     token.TILDE,
     token.DOUBLESTAR,
 }
-STARS = {token.STAR, token.DOUBLESTAR}
-VARARGS_SPECIALS = STARS | {token.SLASH}
-VARARGS_PARENTS = {
+STARS: Final = {token.STAR, token.DOUBLESTAR}
+VARARGS_SPECIALS: Final = STARS | {token.SLASH}
+VARARGS_PARENTS: Final = {
     syms.arglist,
     syms.argument,  # double star in arglist
     syms.trailer,  # single argument to call
     syms.typedargslist,
     syms.varargslist,  # lambdas
 }
     syms.arglist,
     syms.argument,  # double star in arglist
     syms.trailer,  # single argument to call
     syms.typedargslist,
     syms.varargslist,  # lambdas
 }
-UNPACKING_PARENTS = {
+UNPACKING_PARENTS: Final = {
     syms.atom,  # single element of a list or set literal
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
     syms.atom,  # single element of a list or set literal
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
-TEST_DESCENDANTS = {
+TEST_DESCENDANTS: Final = {
     syms.test,
     syms.lambdef,
     syms.or_test,
     syms.test,
     syms.lambdef,
     syms.or_test,
@@ -982,7 +1035,7 @@ TEST_DESCENDANTS = {
     syms.term,
     syms.power,
 }
     syms.term,
     syms.power,
 }
-ASSIGNMENTS = {
+ASSIGNMENTS: Final = {
     "=",
     "+=",
     "-=",
     "=",
     "+=",
     "-=",
@@ -998,13 +1051,13 @@ ASSIGNMENTS = {
     "**=",
     "//=",
 }
     "**=",
     "//=",
 }
-COMPREHENSION_PRIORITY = 20
-COMMA_PRIORITY = 18
-TERNARY_PRIORITY = 16
-LOGIC_PRIORITY = 14
-STRING_PRIORITY = 12
-COMPARATOR_PRIORITY = 10
-MATH_PRIORITIES = {
+COMPREHENSION_PRIORITY: Final = 20
+COMMA_PRIORITY: Final = 18
+TERNARY_PRIORITY: Final = 16
+LOGIC_PRIORITY: Final = 14
+STRING_PRIORITY: Final = 12
+COMPARATOR_PRIORITY: Final = 10
+MATH_PRIORITIES: Final = {
     token.VBAR: 9,
     token.CIRCUMFLEX: 8,
     token.AMPER: 7,
     token.VBAR: 9,
     token.CIRCUMFLEX: 8,
     token.AMPER: 7,
@@ -1020,7 +1073,7 @@ MATH_PRIORITIES = {
     token.TILDE: 3,
     token.DOUBLESTAR: 2,
 }
     token.TILDE: 3,
     token.DOUBLESTAR: 2,
 }
-DOT_PRIORITY = 1
+DOT_PRIORITY: Final = 1
 
 
 @dataclass
 
 
 @dataclass
@@ -1028,11 +1081,11 @@ class BracketTracker:
     """Keeps track of brackets on a line."""
 
     depth: int = 0
     """Keeps track of brackets on a line."""
 
     depth: int = 0
-    bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
-    delimiters: Dict[LeafID, Priority] = Factory(dict)
+    bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
+    delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
     previous: Optional[Leaf] = None
     previous: Optional[Leaf] = None
-    _for_loop_depths: List[int] = Factory(list)
-    _lambda_argument_depths: List[int] = Factory(list)
+    _for_loop_depths: List[int] = field(default_factory=list)
+    _lambda_argument_depths: List[int] = field(default_factory=list)
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -1160,9 +1213,10 @@ class Line:
     """Holds leaves and comments. Can be printed with `str(line)`."""
 
     depth: int = 0
     """Holds leaves and comments. Can be printed with `str(line)`."""
 
     depth: int = 0
-    leaves: List[Leaf] = Factory(list)
-    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
-    bracket_tracker: BracketTracker = Factory(BracketTracker)
+    leaves: List[Leaf] = field(default_factory=list)
+    # keys ordered like `leaves`
+    comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
+    bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
 
     inside_brackets: bool = False
     should_explode: bool = False
 
@@ -1383,7 +1437,10 @@ class Line:
         for leaf_id, comments in self.comments.items():
             for comment in comments:
                 if is_type_comment(comment):
         for leaf_id, comments in self.comments.items():
             for comment in comments:
                 if is_type_comment(comment):
-                    if leaf_id not in ignored_ids or comment_seen:
+                    if comment_seen or (
+                        not is_type_comment(comment, " ignore")
+                        and leaf_id not in ignored_ids
+                    ):
                         return True
 
                 comment_seen = True
                         return True
 
                 comment_seen = True
@@ -1422,11 +1479,7 @@ class Line:
         return False
 
     def contains_multiline_strings(self) -> bool:
         return False
 
     def contains_multiline_strings(self) -> bool:
-        for leaf in self.leaves:
-            if is_multiline_string(leaf):
-                return True
-
-        return False
+        return any(is_multiline_string(leaf) for leaf in self.leaves)
 
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
 
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
@@ -1565,7 +1618,7 @@ class EmptyLineTracker:
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
-    previous_defs: List[int] = Factory(list)
+    previous_defs: List[int] = field(default_factory=list)
 
     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
         """Return the number of extra empty lines before and after the `current_line`.
 
     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
         """Return the number of extra empty lines before and after the `current_line`.
@@ -1679,7 +1732,7 @@ class LineGenerator(Visitor[Line]):
 
     is_pyi: bool = False
     normalize_strings: bool = True
 
     is_pyi: bool = False
     normalize_strings: bool = True
-    current_line: Line = Factory(Line)
+    current_line: Line = field(default_factory=Line)
     remove_u_prefix: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
     remove_u_prefix: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
@@ -1728,13 +1781,13 @@ class LineGenerator(Visitor[Line]):
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
-    def visit_INDENT(self, node: Node) -> Iterator[Line]:
+    def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
         yield from self.line(+1)
         yield from self.visit_default(node)
 
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
         yield from self.line(+1)
         yield from self.visit_default(node)
 
-    def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+    def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
         """Decrease indentation level, maybe yield a line."""
         # The current line might still wait for trailing comments.  At DEDENT time
         # there won't be any (they would be prefixes on the preceding NEWLINE).
         """Decrease indentation level, maybe yield a line."""
         # The current line might still wait for trailing comments.  At DEDENT time
         # there won't be any (they would be prefixes on the preceding NEWLINE).
@@ -1844,7 +1897,7 @@ class LineGenerator(Visitor[Line]):
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
-    def __attrs_post_init__(self) -> None:
+    def __post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
         Ø: Set[str] = set()
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
         Ø: Set[str] = set()
@@ -2462,7 +2515,7 @@ def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = head_leaves
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = head_leaves
-    matching_bracket = None
+    matching_bracket: Optional[Leaf] = None
     for leaf in line.leaves:
         if (
             current_leaves is body_leaves
     for leaf in line.leaves:
         if (
             current_leaves is body_leaves
@@ -2505,8 +2558,8 @@ def right_hand_split(
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = tail_leaves
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = tail_leaves
-    opening_bracket = None
-    closing_bracket = None
+    opening_bracket: Optional[Leaf] = None
+    closing_bracket: Optional[Leaf] = None
     for leaf in reversed(line.leaves):
         if current_leaves is body_leaves:
             if leaf is opening_bracket:
     for leaf in reversed(line.leaves):
         if current_leaves is body_leaves:
             if leaf is opening_bracket:
@@ -2810,7 +2863,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
     assert match is not None, f"failed to match string {leaf.value!r}"
     orig_prefix = match.group(1)
     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
     assert match is not None, f"failed to match string {leaf.value!r}"
     orig_prefix = match.group(1)
-    new_prefix = orig_prefix.lower()
+    new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
     if remove_u_prefix:
         new_prefix = new_prefix.replace("u", "")
     leaf.value = f"{new_prefix}{match.group(2)}"
     if remove_u_prefix:
         new_prefix = new_prefix.replace("u", "")
     leaf.value = f"{new_prefix}{match.group(2)}"
@@ -2966,16 +3019,9 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
 
             if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
-                    lpar = Leaf(token.LPAR, "")
-                    rpar = Leaf(token.RPAR, "")
-                    index = child.remove() or 0
-                    node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                    wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
             elif is_one_tuple(child):
-                # wrap child in visible parentheses
-                lpar = Leaf(token.LPAR, "(")
-                rpar = Leaf(token.RPAR, ")")
-                child.remove()
-                node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                wrap_in_parentheses(node, child, visible=True)
             elif node.type == syms.import_from:
                 # "import from" nodes store parentheses directly as part of
                 # the statement
             elif node.type == syms.import_from:
                 # "import from" nodes store parentheses directly as part of
                 # the statement
@@ -2990,15 +3036,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 break
 
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
                 break
 
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
-                # wrap child in invisible parentheses
-                lpar = Leaf(token.LPAR, "")
-                rpar = Leaf(token.RPAR, "")
-                index = child.remove() or 0
-                prefix = child.prefix
-                child.prefix = ""
-                new_child = Node(syms.atom, [lpar, child, rpar])
-                new_child.prefix = prefix
-                node.insert_child(index, new_child)
+                wrap_in_parentheses(node, child, visible=False)
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
@@ -3042,7 +3080,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
                     # leaf (possibly followed by a DEDENT).
                     hidden_value = hidden_value[:-1]
                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
                     # leaf (possibly followed by a DEDENT).
                     hidden_value = hidden_value[:-1]
-                first_idx = None
+                first_idx: Optional[int] = None
                 for ignored in ignored_nodes:
                     index = ignored.remove()
                     if first_idx is None:
                 for ignored in ignored_nodes:
                     index = ignored.remove()
                     if first_idx is None:
@@ -3071,9 +3109,14 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
     """
     container: Optional[LN] = container_of(leaf)
     while container is not None and container.type != token.ENDMARKER:
     """
     container: Optional[LN] = container_of(leaf)
     while container is not None and container.type != token.ENDMARKER:
+        is_fmt_on = False
         for comment in list_comments(container.prefix, is_endmarker=False):
             if comment.value in FMT_ON:
         for comment in list_comments(container.prefix, is_endmarker=False):
             if comment.value in FMT_ON:
-                return
+                is_fmt_on = True
+            elif comment.value in FMT_OFF:
+                is_fmt_on = False
+        if is_fmt_on:
+            return
 
         yield container
 
 
         yield container
 
@@ -3158,6 +3201,24 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     return wrapped
 
 
     return wrapped
 
 
+def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
+    """Wrap `child` in parentheses.
+
+    This replaces `child` with an atom holding the parentheses and the old
+    child.  That requires moving the prefix.
+
+    If `visible` is False, the leaves will be valueless (and thus invisible).
+    """
+    lpar = Leaf(token.LPAR, "(" if visible else "")
+    rpar = Leaf(token.RPAR, ")" if visible else "")
+    prefix = child.prefix
+    child.prefix = ""
+    index = child.remove() or 0
+    new_child = Node(syms.atom, [lpar, child, rpar])
+    new_child.prefix = prefix
+    parent.insert_child(index, new_child)
+
+
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
@@ -3390,8 +3451,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     yield omit
 
     length = 4 * line.depth
     yield omit
 
     length = 4 * line.depth
-    opening_bracket = None
-    closing_bracket = None
+    opening_bracket: Optional[Leaf] = None
+    closing_bracket: Optional[Leaf] = None
     inner_brackets: Set[LeafID] = set()
     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
         length += leaf_length
     inner_brackets: Set[LeafID] = set()
     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
         length += leaf_length
@@ -3506,7 +3567,7 @@ def gen_python_files_in_dir(
     for child in path.iterdir():
         # First ignore files matching .gitignore
         if gitignore.match_file(child.as_posix()):
     for child in path.iterdir():
         # First ignore files matching .gitignore
         if gitignore.match_file(child.as_posix()):
-            report.path_ignored(child, f"matches the .gitignore file content")
+            report.path_ignored(child, "matches the .gitignore file content")
             continue
 
         # Then ignore with `exclude` option.
             continue
 
         # Then ignore with `exclude` option.
@@ -3530,7 +3591,7 @@ def gen_python_files_in_dir(
 
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
 
         exclude_match = exclude.search(normalized_path)
         if exclude_match and exclude_match.group(0):
-            report.path_ignored(child, f"matches the --exclude regular expression")
+            report.path_ignored(child, "matches the --exclude regular expression")
             continue
 
         if child.is_dir():
             continue
 
         if child.is_dir():
@@ -3562,7 +3623,7 @@ def find_project_root(srcs: Iterable[str]) -> Path:
         # Append a fake file so `parents` below returns `common_base_dir`, too.
         common_base /= "fake-file"
     for directory in common_base.parents:
         # Append a fake file so `parents` below returns `common_base_dir`, too.
         common_base /= "fake-file"
     for directory in common_base.parents:
-        if (directory / ".git").is_dir():
+        if (directory / ".git").exists():
             return directory
 
         if (directory / ".hg").is_dir():
             return directory
 
         if (directory / ".hg").is_dir():
@@ -3579,6 +3640,7 @@ class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
     check: bool = False
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
     check: bool = False
+    diff: bool = False
     quiet: bool = False
     verbose: bool = False
     change_count: int = 0
     quiet: bool = False
     verbose: bool = False
     change_count: int = 0
@@ -3588,7 +3650,7 @@ class Report:
     def done(self, src: Path, changed: Changed) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
     def done(self, src: Path, changed: Changed) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
-            reformatted = "would reformat" if self.check else "reformatted"
+            reformatted = "would reformat" if self.check or self.diff else "reformatted"
             if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
             if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
@@ -3634,7 +3696,7 @@ class Report:
 
         Use `click.unstyle` to remove colors.
         """
 
         Use `click.unstyle` to remove colors.
         """
-        if self.check:
+        if self.check or self.diff:
             reformatted = "would be reformatted"
             unchanged = "would be left unchanged"
             failed = "would fail to reformat"
             reformatted = "would be reformatted"
             unchanged = "would be left unchanged"
             failed = "would fail to reformat"
@@ -3704,7 +3766,7 @@ def assert_equivalent(src: str, dst: str) -> None:
 
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
 
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
-        for field in sorted(node._fields):
+        for field in sorted(node._fields):  # noqa: F402
             # TypeIgnore has only one field 'lineno' which breaks this comparison
             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
             if sys.version_info >= (3, 8):
             # TypeIgnore has only one field 'lineno' which breaks this comparison
             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
             if sys.version_info >= (3, 8):
@@ -3772,7 +3834,7 @@ def assert_equivalent(src: str, dst: str) -> None:
         ) from None
 
 
         ) from None
 
 
-def assert_stable(src: str, dst: str, mode: FileMode) -> None:
+def assert_stable(src: str, dst: str, mode: Mode) -> None:
     """Raise AssertionError if `dst` reformats differently the second time."""
     newdst = format_str(dst, mode=mode)
     if dst != newdst:
     """Raise AssertionError if `dst` reformats differently the second time."""
     newdst = format_str(dst, mode=mode)
     if dst != newdst:
@@ -3788,6 +3850,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None:
         ) from None
 
 
         ) from None
 
 
+@mypyc_attr(patchable=True)
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
@@ -3813,14 +3876,14 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
-    a_lines = [line + "\n" for line in a.split("\n")]
-    b_lines = [line + "\n" for line in b.split("\n")]
+    a_lines = [line + "\n" for line in a.splitlines()]
+    b_lines = [line + "\n" for line in b.splitlines()]
     return "".join(
         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
     )
 
 
     return "".join(
         difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
     )
 
 
-def cancel(tasks: Iterable[asyncio.Task]) -> None:
+def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
     err("Aborted!")
     for task in tasks:
     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
     err("Aborted!")
     for task in tasks:
@@ -4042,11 +4105,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool:
     return False
 
 
     return False
 
 
-def get_cache_file(mode: FileMode) -> Path:
+def get_cache_file(mode: Mode) -> Path:
     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
     return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"
 
 
-def read_cache(mode: FileMode) -> Cache:
+def read_cache(mode: Mode) -> Cache:
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
     """Read the cache if it exists and is well formed.
 
     If it is not well formed, the call to write_cache later should resolve the issue.
@@ -4086,7 +4149,7 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     return todo, done
 
 
     return todo, done
 
 
-def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None:
+def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
     """Update the cache file."""
     cache_file = get_cache_file(mode)
     try:
     """Update the cache file."""
     cache_file = get_cache_file(mode)
     try: