X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/adce126949623da2f310e6ea3f9f83c37972582a..d29303c9884e1ef715851d69acc5d54f84441720:/black.py?ds=sidebyside

diff --git a/black.py b/black.py
index b3699d5..b7cacf7 100644
--- a/black.py
+++ b/black.py
@@ -37,9 +37,11 @@ from typing import (
     Union,
     cast,
 )
+from typing_extensions import Final
+from mypy_extensions import mypyc_attr
 
 from appdirs import user_cache_dir
-from attr import dataclass, evolve, Factory
+from dataclasses import dataclass, field, replace
 import click
 import toml
 from typed_ast import ast3, ast27
@@ -185,7 +187,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
 
 @dataclass
 class FileMode:
-    target_versions: Set[TargetVersion] = Factory(set)
+    target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
     is_pyi: bool = False
@@ -211,6 +213,23 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b
     return all(feature in VERSION_TO_FEATURES[version] for version in target_versions)
 
 
+def find_pyproject_toml(path_search_start: str) -> Optional[str]:
+    """Find the absolute filepath to a pyproject.toml if it exists"""
+    path_project_root = find_project_root(path_search_start)
+    path_pyproject_toml = path_project_root / "pyproject.toml"
+    return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None
+
+
+def parse_pyproject_toml(path_config: str) -> Dict[str, Any]:
+    """Parse a pyproject toml file, pulling out relevant parts for Black
+
+    If parsing fails, will raise a toml.TomlDecodeError
+    """
+    pyproject_toml = toml.load(path_config)
+    config = pyproject_toml.get("tool", {}).get("black", {})
+    return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
+
+
 def read_pyproject_toml(
     ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None]
 ) -> Optional[str]:
@@ -221,16 +240,12 @@ def read_pyproject_toml(
     """
     assert not isinstance(value, (int, bool)), "Invalid parameter type passed"
     if not value:
-        root = find_project_root(ctx.params.get("src", ()))
-        path = root / "pyproject.toml"
-        if path.is_file():
-            value = str(path)
-        else:
+        value = find_pyproject_toml(ctx.params.get("src", ()))
+        if value is None:
             return None
 
     try:
-        pyproject_toml = toml.load(value)
-        config = pyproject_toml.get("tool", {}).get("black", {})
+        config = parse_pyproject_toml(value)
     except (toml.TomlDecodeError, OSError) as e:
         raise click.FileError(
             filename=value, hint=f"Error reading configuration file: {e}"
@@ -241,12 +256,21 @@ def read_pyproject_toml(
 
     if ctx.default_map is None:
         ctx.default_map = {}
-    ctx.default_map.update(  # type: ignore  # bad types in .pyi
-        {k.replace("--", "").replace("-", "_"): v for k, v in config.items()}
-    )
+    ctx.default_map.update(config)  # type: ignore  # bad types in .pyi
     return value
 
 
+def target_version_option_callback(
+    c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
+) -> List[TargetVersion]:
+    """Compute the target versions from a --target-version flag.
+
+    This is its own function because mypy couldn't infer the type correctly
+    when it was a lambda, causing mypyc trouble.
+    """
+    return [TargetVersion[val.upper()] for val in v]
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -261,7 +285,7 @@ def read_pyproject_toml(
     "-t",
     "--target-version",
     type=click.Choice([v.name.lower() for v in TargetVersion]),
-    callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v],
+    callback=target_version_option_callback,
     multiple=True,
     help=(
         "Python versions that should be supported by Black's output. [default: "
@@ -388,7 +412,7 @@ def main(
     verbose: bool,
     include: str,
     exclude: str,
-    src: Tuple[str],
+    src: Tuple[str, ...],
     config: Optional[str],
 ) -> None:
     """The uncompromising code formatter."""
@@ -429,7 +453,7 @@ def main(
     except re.error:
         err(f"Invalid regular expression for exclude given: {exclude!r}")
         ctx.exit(2)
-    report = Report(check=check, quiet=quiet, verbose=verbose)
+    report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose)
     root = find_project_root(src)
     sources: Set[Path] = set()
     path_empty(src, quiet, verbose, ctx)
@@ -470,7 +494,9 @@ def main(
     ctx.exit(report.return_code)
 
 
-def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None:
+def path_empty(
+    src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context
+) -> None:
     """
     Exit if there is no `src` provided for formatting
     """
@@ -585,7 +611,7 @@ async def schedule_formatting(
         ): src
         for src in sorted(sources)
     }
-    pending: Iterable[asyncio.Future] = tasks.keys()
+    pending: Iterable["asyncio.Future[bool]"] = tasks.keys()
     try:
         loop.add_signal_handler(signal.SIGINT, cancel, pending)
         loop.add_signal_handler(signal.SIGTERM, cancel, pending)
@@ -629,7 +655,7 @@ def format_file_in_place(
     `mode` and `fast` options are passed to :func:`format_file_contents`.
     """
     if src.suffix == ".pyi":
-        mode = evolve(mode, is_pyi=True)
+        mode = replace(mode, is_pyi=True)
 
     then = datetime.utcfromtimestamp(src.stat().st_mtime)
     with open(src, "rb") as buf:
@@ -639,10 +665,10 @@ def format_file_in_place(
     except NothingChanged:
         return False
 
-    if write_back == write_back.YES:
+    if write_back == WriteBack.YES:
         with open(src, "w", encoding=encoding, newline=newline) as f:
             f.write(dst_contents)
-    elif write_back == write_back.DIFF:
+    elif write_back == WriteBack.DIFF:
         now = datetime.utcnow()
         src_name = f"{src}\t{then} +0000"
         dst_name = f"{src}\t{now} +0000"
@@ -745,11 +771,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent:
         if supports_feature(versions, feature)
     }
     for current_line in lines.visit(src_node):
-        for _ in range(after):
-            dst_contents.append(str(empty_line))
+        dst_contents.append(str(empty_line) * after)
         before, after = elt.maybe_empty_lines(current_line)
-        for _ in range(before):
-            dst_contents.append(str(empty_line))
+        dst_contents.append(str(empty_line) * before)
         for line in split_line(
             current_line, line_length=mode.line_length, features=split_line_features
         ):
@@ -865,8 +889,16 @@ class Visitor(Generic[T]):
         if node.type < 256:
             name = token.tok_name[node.type]
         else:
-            name = type_repr(node.type)
-        yield from getattr(self, f"visit_{name}", self.visit_default)(node)
+            name = str(type_repr(node.type))
+        # We explicitly branch on whether a visitor exists (instead of
+        # using self.visit_default as the default arg to getattr) in order
+        # to save needing to create a bound method object and so mypyc can
+        # generate a native call to visit_default.
+        visitf = getattr(self, f"visit_{name}", None)
+        if visitf:
+            yield from visitf(node)
+        else:
+            yield from self.visit_default(node)
 
     def visit_default(self, node: LN) -> Iterator[T]:
         """Default `visit_*()` implementation. Recurses to children of `node`."""
@@ -911,8 +943,8 @@ class DebugVisitor(Visitor[T]):
         list(v.visit(code))
 
 
-WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE}
-STATEMENT = {
+WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE}
+STATEMENT: Final = {
     syms.if_stmt,
     syms.while_stmt,
     syms.for_stmt,
@@ -922,10 +954,10 @@ STATEMENT = {
     syms.funcdef,
     syms.classdef,
 }
-STANDALONE_COMMENT = 153
+STANDALONE_COMMENT: Final = 153
 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT"
-LOGIC_OPERATORS = {"and", "or"}
-COMPARATORS = {
+LOGIC_OPERATORS: Final = {"and", "or"}
+COMPARATORS: Final = {
     token.LESS,
     token.GREATER,
     token.EQEQUAL,
@@ -933,7 +965,7 @@ COMPARATORS = {
     token.LESSEQUAL,
     token.GREATEREQUAL,
 }
-MATH_OPERATORS = {
+MATH_OPERATORS: Final = {
     token.VBAR,
     token.CIRCUMFLEX,
     token.AMPER,
@@ -949,23 +981,23 @@ MATH_OPERATORS = {
     token.TILDE,
     token.DOUBLESTAR,
 }
-STARS = {token.STAR, token.DOUBLESTAR}
-VARARGS_SPECIALS = STARS | {token.SLASH}
-VARARGS_PARENTS = {
+STARS: Final = {token.STAR, token.DOUBLESTAR}
+VARARGS_SPECIALS: Final = STARS | {token.SLASH}
+VARARGS_PARENTS: Final = {
     syms.arglist,
     syms.argument,  # double star in arglist
     syms.trailer,  # single argument to call
     syms.typedargslist,
     syms.varargslist,  # lambdas
 }
-UNPACKING_PARENTS = {
+UNPACKING_PARENTS: Final = {
     syms.atom,  # single element of a list or set literal
     syms.dictsetmaker,
     syms.listmaker,
     syms.testlist_gexp,
     syms.testlist_star_expr,
 }
-TEST_DESCENDANTS = {
+TEST_DESCENDANTS: Final = {
     syms.test,
     syms.lambdef,
     syms.or_test,
@@ -982,7 +1014,7 @@ TEST_DESCENDANTS = {
     syms.term,
     syms.power,
 }
-ASSIGNMENTS = {
+ASSIGNMENTS: Final = {
     "=",
     "+=",
     "-=",
@@ -998,13 +1030,13 @@ ASSIGNMENTS = {
     "**=",
     "//=",
 }
-COMPREHENSION_PRIORITY = 20
-COMMA_PRIORITY = 18
-TERNARY_PRIORITY = 16
-LOGIC_PRIORITY = 14
-STRING_PRIORITY = 12
-COMPARATOR_PRIORITY = 10
-MATH_PRIORITIES = {
+COMPREHENSION_PRIORITY: Final = 20
+COMMA_PRIORITY: Final = 18
+TERNARY_PRIORITY: Final = 16
+LOGIC_PRIORITY: Final = 14
+STRING_PRIORITY: Final = 12
+COMPARATOR_PRIORITY: Final = 10
+MATH_PRIORITIES: Final = {
     token.VBAR: 9,
     token.CIRCUMFLEX: 8,
     token.AMPER: 7,
@@ -1020,7 +1052,7 @@ MATH_PRIORITIES = {
     token.TILDE: 3,
     token.DOUBLESTAR: 2,
 }
-DOT_PRIORITY = 1
+DOT_PRIORITY: Final = 1
 
 
 @dataclass
@@ -1028,11 +1060,11 @@ class BracketTracker:
     """Keeps track of brackets on a line."""
 
     depth: int = 0
-    bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict)
-    delimiters: Dict[LeafID, Priority] = Factory(dict)
+    bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
+    delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
     previous: Optional[Leaf] = None
-    _for_loop_depths: List[int] = Factory(list)
-    _lambda_argument_depths: List[int] = Factory(list)
+    _for_loop_depths: List[int] = field(default_factory=list)
+    _lambda_argument_depths: List[int] = field(default_factory=list)
 
     def mark(self, leaf: Leaf) -> None:
         """Mark `leaf` with bracket-related metadata. Keep track of delimiters.
@@ -1160,9 +1192,10 @@ class Line:
     """Holds leaves and comments. Can be printed with `str(line)`."""
 
     depth: int = 0
-    leaves: List[Leaf] = Factory(list)
-    comments: Dict[LeafID, List[Leaf]] = Factory(dict)  # keys ordered like `leaves`
-    bracket_tracker: BracketTracker = Factory(BracketTracker)
+    leaves: List[Leaf] = field(default_factory=list)
+    # keys ordered like `leaves`
+    comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
+    bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
 
@@ -1250,6 +1283,7 @@ class Line:
         """
         if not self.leaves or len(self.leaves) < 4:
             return False
+
         # Look for and address a trailing colon.
         if self.leaves[-1].type == token.COLON:
             closer = self.leaves[-2]
@@ -1259,6 +1293,7 @@ class Line:
             close_index = -1
         if closer.type not in CLOSING_BRACKETS or self.inside_brackets:
             return False
+
         if closer.type == token.RPAR:
             # Tuples require an extra check, because if there's only
             # one element in the tuple removing the comma unmakes the
@@ -1273,9 +1308,11 @@ class Line:
             for _open_index, leaf in enumerate(self.leaves):
                 if leaf is opener:
                     break
+
             else:
                 # Couldn't find the matching opening paren, play it safe.
                 return False
+
             commas = 0
             comma_depth = self.leaves[close_index - 1].bracket_depth
             for leaf in self.leaves[_open_index + 1 : close_index]:
@@ -1285,16 +1322,20 @@ class Line:
                 # We haven't looked yet for the trailing comma because
                 # we might also have caught noop parens.
                 return self.leaves[close_index - 1].type == token.COMMA
+
             elif commas == 1:
                 return False  # it's either a one-tuple or didn't have a trailing comma
+
             if self.leaves[close_index - 1].type in CLOSING_BRACKETS:
                 close_index -= 1
                 closer = self.leaves[close_index]
                 if closer.type == token.RPAR:
                     # TODO: this is a gut feeling. Will we ever see this?
                     return False
+
         if self.leaves[close_index - 1].type != token.COMMA:
             return False
+
         return True
 
     @property
@@ -1344,9 +1385,9 @@ class Line:
     def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool:
         """If so, needs to be split before emitting."""
         for leaf in self.leaves:
-            if leaf.type == STANDALONE_COMMENT:
-                if leaf.bracket_depth <= depth_limit:
-                    return True
+            if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit:
+                return True
+
         return False
 
     def contains_uncollapsable_type_comments(self) -> bool:
@@ -1375,7 +1416,10 @@ class Line:
         for leaf_id, comments in self.comments.items():
             for comment in comments:
                 if is_type_comment(comment):
-                    if leaf_id not in ignored_ids or comment_seen:
+                    if comment_seen or (
+                        not is_type_comment(comment, " ignore")
+                        and leaf_id not in ignored_ids
+                    ):
                         return True
 
                 comment_seen = True
@@ -1414,16 +1458,13 @@ class Line:
         return False
 
     def contains_multiline_strings(self) -> bool:
-        for leaf in self.leaves:
-            if is_multiline_string(leaf):
-                return True
-
-        return False
+        return any(is_multiline_string(leaf) for leaf in self.leaves)
 
     def maybe_remove_trailing_comma(self, closing: Leaf) -> bool:
         """Remove trailing comma if there is one and it's safe."""
         if not (self.leaves and self.leaves[-1].type == token.COMMA):
             return False
+
         # We remove trailing commas only in the case of importing a
         # single name from a module.
         if not (
@@ -1488,6 +1529,7 @@ class Line:
                 comment.type = STANDALONE_COMMENT
                 comment.prefix = ""
                 return False
+
             last_leaf = self.leaves[-2]
         self.comments.setdefault(id(last_leaf), []).append(comment)
         return True
@@ -1555,7 +1597,7 @@ class EmptyLineTracker:
     is_pyi: bool = False
     previous_line: Optional[Line] = None
     previous_after: int = 0
-    previous_defs: List[int] = Factory(list)
+    previous_defs: List[int] = field(default_factory=list)
 
     def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]:
         """Return the number of extra empty lines before and after the `current_line`.
@@ -1669,7 +1711,7 @@ class LineGenerator(Visitor[Line]):
 
     is_pyi: bool = False
     normalize_strings: bool = True
-    current_line: Line = Factory(Line)
+    current_line: Line = field(default_factory=Line)
     remove_u_prefix: bool = False
 
     def line(self, indent: int = 0) -> Iterator[Line]:
@@ -1718,13 +1760,13 @@ class LineGenerator(Visitor[Line]):
                 self.current_line.append(node)
         yield from super().visit_default(node)
 
-    def visit_INDENT(self, node: Node) -> Iterator[Line]:
+    def visit_INDENT(self, node: Leaf) -> Iterator[Line]:
         """Increase indentation level, maybe yield a line."""
         # In blib2to3 INDENT never holds comments.
         yield from self.line(+1)
         yield from self.visit_default(node)
 
-    def visit_DEDENT(self, node: Node) -> Iterator[Line]:
+    def visit_DEDENT(self, node: Leaf) -> Iterator[Line]:
         """Decrease indentation level, maybe yield a line."""
         # The current line might still wait for trailing comments.  At DEDENT time
         # there won't be any (they would be prefixes on the preceding NEWLINE).
@@ -1834,7 +1876,7 @@ class LineGenerator(Visitor[Line]):
             node.insert_child(index, Node(syms.atom, [lpar, operand, rpar]))
         yield from self.visit_default(node)
 
-    def __attrs_post_init__(self) -> None:
+    def __post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
         v = self.visit_stmt
         Ø: Set[str] = set()
@@ -2452,7 +2494,7 @@ def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = head_leaves
-    matching_bracket = None
+    matching_bracket: Optional[Leaf] = None
     for leaf in line.leaves:
         if (
             current_leaves is body_leaves
@@ -2495,8 +2537,8 @@ def right_hand_split(
     body_leaves: List[Leaf] = []
     head_leaves: List[Leaf] = []
     current_leaves = tail_leaves
-    opening_bracket = None
-    closing_bracket = None
+    opening_bracket: Optional[Leaf] = None
+    closing_bracket: Optional[Leaf] = None
     for leaf in reversed(line.leaves):
         if current_leaves is body_leaves:
             if leaf is opening_bracket:
@@ -2619,11 +2661,11 @@ def bracket_split_build_line(
                 for i in range(len(leaves) - 1, -1, -1):
                     if leaves[i].type == STANDALONE_COMMENT:
                         continue
-                    elif leaves[i].type == token.COMMA:
-                        break
-                    else:
+
+                    if leaves[i].type != token.COMMA:
                         leaves.insert(i + 1, Leaf(token.COMMA, ","))
-                        break
+                    break
+
     # Populate the line
     for leaf in leaves:
         result.append(leaf, preformatted=True)
@@ -2800,7 +2842,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None:
     match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL)
     assert match is not None, f"failed to match string {leaf.value!r}"
     orig_prefix = match.group(1)
-    new_prefix = orig_prefix.lower()
+    new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u")
     if remove_u_prefix:
         new_prefix = new_prefix.replace("u", "")
     leaf.value = f"{new_prefix}{match.group(2)}"
@@ -2867,6 +2909,7 @@ def normalize_string_quotes(leaf: Leaf) -> None:
             if "\\" in str(m):
                 # Do not introduce backslashes in interpolated expressions
                 return
+
     if new_quote == '"""' and new_body[-1:] == '"':
         # edge case:
         new_body = new_body[:-1] + '\\"'
@@ -2939,7 +2982,6 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
         if pc.value in FMT_OFF:
             # This `node` has a prefix with `# fmt: off`, don't mess with parens.
             return
-
     check_lpar = False
     for index, child in enumerate(list(node.children)):
         # Add parentheses around long tuple unpacking in assignments.
@@ -2953,26 +2995,12 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
         if check_lpar:
             if is_walrus_assignment(child):
                 continue
-            if child.type == syms.atom:
-                # Determines if the underlying atom should be surrounded with
-                # invisible params - also makes parens invisible recursively
-                # within the atom and removes repeated invisible parens within
-                # the atom
-                should_surround_with_parens = maybe_make_parens_invisible_in_atom(
-                    child, parent=node
-                )
 
-                if should_surround_with_parens:
-                    lpar = Leaf(token.LPAR, "")
-                    rpar = Leaf(token.RPAR, "")
-                    index = child.remove() or 0
-                    node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+            if child.type == syms.atom:
+                if maybe_make_parens_invisible_in_atom(child, parent=node):
+                    wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
-                # wrap child in visible parentheses
-                lpar = Leaf(token.LPAR, "(")
-                rpar = Leaf(token.RPAR, ")")
-                child.remove()
-                node.insert_child(index, Node(syms.atom, [lpar, child, rpar]))
+                wrap_in_parentheses(node, child, visible=True)
             elif node.type == syms.import_from:
                 # "import from" nodes store parentheses directly as part of
                 # the statement
@@ -2987,15 +3015,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
                 break
 
             elif not (isinstance(child, Leaf) and is_multiline_string(child)):
-                # wrap child in invisible parentheses
-                lpar = Leaf(token.LPAR, "")
-                rpar = Leaf(token.RPAR, "")
-                index = child.remove() or 0
-                prefix = child.prefix
-                child.prefix = ""
-                new_child = Node(syms.atom, [lpar, child, rpar])
-                new_child.prefix = prefix
-                node.insert_child(index, new_child)
+                wrap_in_parentheses(node, child, visible=False)
 
         check_lpar = isinstance(child, Leaf) and child.value in parens_after
 
@@ -3039,7 +3059,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
                     # That happens when one of the `ignored_nodes` ended with a NEWLINE
                     # leaf (possibly followed by a DEDENT).
                     hidden_value = hidden_value[:-1]
-                first_idx = None
+                first_idx: Optional[int] = None
                 for ignored in ignored_nodes:
                     index = ignored.remove()
                     if first_idx is None:
@@ -3068,9 +3088,14 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
     """
     container: Optional[LN] = container_of(leaf)
     while container is not None and container.type != token.ENDMARKER:
+        is_fmt_on = False
         for comment in list_comments(container.prefix, is_endmarker=False):
             if comment.value in FMT_ON:
-                return
+                is_fmt_on = True
+            elif comment.value in FMT_OFF:
+                is_fmt_on = False
+        if is_fmt_on:
+            return
 
         yield container
 
@@ -3147,6 +3172,7 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     Parenthesis can be optional. Returns None otherwise"""
     if len(node.children) != 3:
         return None
+
     lpar, wrapped, rpar = node.children
     if not (lpar.type == token.LPAR and rpar.type == token.RPAR):
         return None
@@ -3154,6 +3180,24 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     return wrapped
 
 
+def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
+    """Wrap `child` in parentheses.
+
+    This replaces `child` with an atom holding the parentheses and the old
+    child.  That requires moving the prefix.
+
+    If `visible` is False, the leaves will be valueless (and thus invisible).
+    """
+    lpar = Leaf(token.LPAR, "(" if visible else "")
+    rpar = Leaf(token.RPAR, ")" if visible else "")
+    prefix = child.prefix
+    child.prefix = ""
+    index = child.remove() or 0
+    new_child = Node(syms.atom, [lpar, child, rpar])
+    new_child.prefix = prefix
+    parent.insert_child(index, new_child)
+
+
 def is_one_tuple(node: LN) -> bool:
     """Return True if `node` holds a tuple with one element, with or without parens."""
     if node.type == syms.atom:
@@ -3386,8 +3430,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     yield omit
 
     length = 4 * line.depth
-    opening_bracket = None
-    closing_bracket = None
+    opening_bracket: Optional[Leaf] = None
+    closing_bracket: Optional[Leaf] = None
     inner_brackets: Set[LeafID] = set()
     for index, leaf, leaf_length in enumerate_with_length(line, reversed=True):
         length += leaf_length
@@ -3431,19 +3475,23 @@ def get_future_imports(node: Node) -> Set[str]:
             if isinstance(child, Leaf):
                 if child.type == token.NAME:
                     yield child.value
+
             elif child.type == syms.import_as_name:
                 orig_name = child.children[0]
                 assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
                 assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
                 yield orig_name.value
+
             elif child.type == syms.import_as_names:
                 yield from get_imports_from_children(child.children)
+
             else:
                 raise AssertionError("Invalid syntax parsing imports")
 
     for child in node.children:
         if child.type != syms.simple_stmt:
             break
+
         first_child = child.children[0]
         if isinstance(first_child, Leaf):
             # Continue looking if we see a docstring; otherwise stop.
@@ -3453,15 +3501,18 @@ def get_future_imports(node: Node) -> Set[str]:
                 and child.children[1].type == token.NEWLINE
             ):
                 continue
-            else:
-                break
+
+            break
+
         elif first_child.type == syms.import_from:
             module_name = first_child.children[1]
             if not isinstance(module_name, Leaf) or module_name.value != "__future__":
                 break
+
             imports |= set(get_imports_from_children(first_child.children[3:]))
         else:
             break
+
     return imports
 
 
@@ -3469,10 +3520,11 @@ def get_future_imports(node: Node) -> Set[str]:
 def get_gitignore(root: Path) -> PathSpec:
     """ Return a PathSpec matching gitignore content if present."""
     gitignore = root / ".gitignore"
-    if not gitignore.is_file():
-        return PathSpec.from_lines("gitwildmatch", [])
-    else:
-        return PathSpec.from_lines("gitwildmatch", gitignore.open())
+    lines: List[str] = []
+    if gitignore.is_file():
+        with gitignore.open() as gf:
+            lines = gf.readlines()
+    return PathSpec.from_lines("gitwildmatch", lines)
 
 
 def gen_python_files_in_dir(
@@ -3503,6 +3555,7 @@ def gen_python_files_in_dir(
         except OSError as e:
             report.path_ignored(child, f"cannot be read because {e}")
             continue
+
         except ValueError:
             if child.is_symlink():
                 report.path_ignored(
@@ -3549,7 +3602,7 @@ def find_project_root(srcs: Iterable[str]) -> Path:
         # Append a fake file so `parents` below returns `common_base_dir`, too.
         common_base /= "fake-file"
     for directory in common_base.parents:
-        if (directory / ".git").is_dir():
+        if (directory / ".git").exists():
             return directory
 
         if (directory / ".hg").is_dir():
@@ -3566,6 +3619,7 @@ class Report:
     """Provides a reformatting counter. Can be rendered with `str(report)`."""
 
     check: bool = False
+    diff: bool = False
     quiet: bool = False
     verbose: bool = False
     change_count: int = 0
@@ -3575,7 +3629,7 @@ class Report:
     def done(self, src: Path, changed: Changed) -> None:
         """Increment the counter for successful reformatting. Write out a message."""
         if changed is Changed.YES:
-            reformatted = "would reformat" if self.check else "reformatted"
+            reformatted = "would reformat" if self.check or self.diff else "reformatted"
             if self.verbose or not self.quiet:
                 out(f"{reformatted} {src}")
             self.change_count += 1
@@ -3621,7 +3675,7 @@ class Report:
 
         Use `click.unstyle` to remove colors.
         """
-        if self.check:
+        if self.check or self.diff:
             reformatted = "would be reformatted"
             unchanged = "would be left unchanged"
             failed = "would fail to reformat"
@@ -3671,10 +3725,13 @@ def _fixup_ast_constants(
     """Map ast nodes deprecated in 3.8 to Constant."""
     if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)):
         return ast.Constant(value=node.s)
+
     if isinstance(node, (ast.Num, ast3.Num, ast27.Num)):
         return ast.Constant(value=node.n)
+
     if isinstance(node, (ast.NameConstant, ast3.NameConstant)):
         return ast.Constant(value=node.value)
+
     return node
 
 
@@ -3688,7 +3745,7 @@ def assert_equivalent(src: str, dst: str) -> None:
 
         yield f"{'  ' * depth}{node.__class__.__name__}("
 
-        for field in sorted(node._fields):
+        for field in sorted(node._fields):  # noqa: F402
             # TypeIgnore has only one field 'lineno' which breaks this comparison
             type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore)
             if sys.version_info >= (3, 8):
@@ -3714,6 +3771,7 @@ def assert_equivalent(src: str, dst: str) -> None:
                     ):
                         for item in item.elts:
                             yield from _v(item, depth + 2)
+
                     elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)):
                         yield from _v(item, depth + 2)
 
@@ -3771,6 +3829,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None:
         ) from None
 
 
+@mypyc_attr(patchable=True)
 def dump_to_file(*output: str) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
@@ -3803,7 +3862,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     )
 
 
-def cancel(tasks: Iterable[asyncio.Task]) -> None:
+def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
     """asyncio signal handler that cancels all `tasks` and reports to stderr."""
     err("Aborted!")
     for task in tasks: