]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/__init__.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Document some culprits with pre-commit (#1783)
[etc/vim.git] / src / black / __init__.py
index e37caa98a2c00386afd2acd52a4174e7d556d82c..24e9d4edaaaeb5fa8792e9f0c21f0e9e863d1047 100644 (file)
@@ -112,6 +112,10 @@ class InvalidInput(ValueError):
     """Raised when input source code fails all parse attempts."""
 
 
+class BracketMatchError(KeyError):
+    """Raised when an opening bracket is unable to be matched to a closing bracket."""
+
+
 T = TypeVar("T")
 E = TypeVar("E", bound=Exception)
 
@@ -174,14 +178,12 @@ class TargetVersion(Enum):
     PY36 = 6
     PY37 = 7
     PY38 = 8
+    PY39 = 9
 
     def is_python2(self) -> bool:
         return self is TargetVersion.PY27
 
 
-PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
-
-
 class Feature(Enum):
     # All string literals are unicode
     UNICODE_LITERALS = 1
@@ -195,6 +197,7 @@ class Feature(Enum):
     ASYNC_KEYWORDS = 7
     ASSIGNMENT_EXPRESSIONS = 8
     POS_ONLY_ARGUMENTS = 9
+    RELAXED_DECORATORS = 10
     FORCE_OPTIONAL_PARENTHESES = 50
 
 
@@ -233,6 +236,17 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.ASSIGNMENT_EXPRESSIONS,
         Feature.POS_ONLY_ARGUMENTS,
     },
+    TargetVersion.PY39: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
+        Feature.ASSIGNMENT_EXPRESSIONS,
+        Feature.RELAXED_DECORATORS,
+        Feature.POS_ONLY_ARGUMENTS,
+    },
 }
 
 
@@ -440,7 +454,7 @@ def target_version_option_callback(
     type=str,
     help=(
         "Like --exclude, but files and directories matching this regex will be "
-        "excluded even when they are passed explicitly as arguments"
+        "excluded even when they are passed explicitly as arguments."
     ),
 )
 @click.option(
@@ -641,10 +655,9 @@ def path_empty(
     """
     Exit if there is no `src` provided for formatting
     """
-    if len(src) == 0:
-        if verbose or not quiet:
-            out(msg)
-            ctx.exit(0)
+    if not src and (verbose or not quiet):
+        out(msg)
+        ctx.exit(0)
 
 
 def reformat_one(
@@ -662,7 +675,7 @@ def reformat_one(
                 changed = Changed.YES
         else:
             cache: Cache = {}
-            if write_back != WriteBack.DIFF:
+            if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
                 cache = read_cache(mode)
                 res_src = src.resolve()
                 if res_src in cache and cache[res_src] == get_cache_info(res_src):
@@ -736,7 +749,7 @@ async def schedule_formatting(
     :func:`format_file_in_place`.
     """
     cache: Cache = {}
-    if write_back != WriteBack.DIFF:
+    if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
         cache = read_cache(mode)
         sources, cached = filter_cached(cache, sources)
         for src in sorted(cached):
@@ -747,7 +760,7 @@ async def schedule_formatting(
     cancelled = []
     sources_to_cache = []
     lock = None
-    if write_back == WriteBack.DIFF:
+    if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
         # For diff output, we need locks to ensure we don't interleave output
         # from different processes.
         manager = Manager()
@@ -846,9 +859,9 @@ def color_diff(contents: str) -> str:
     for i, line in enumerate(lines):
         if line.startswith("+++") or line.startswith("---"):
             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
-        if line.startswith("@@"):
+        elif line.startswith("@@"):
             line = "\033[36m" + line + "\033[0m"  # cyan, reset
-        if line.startswith("+"):
+        elif line.startswith("+"):
             line = "\033[32m" + line + "\033[0m"  # green, reset
         elif line.startswith("-"):
             line = "\033[31m" + line + "\033[0m"  # red, reset
@@ -858,30 +871,22 @@ def color_diff(contents: str) -> str:
 
 def wrap_stream_for_windows(
     f: io.TextIOWrapper,
-) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
+) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
     """
-    Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
+    Wrap stream with colorama's wrap_stream so colors are shown on Windows.
 
-    If `colorama` is not found, then no change is made. If `colorama` does
-    exist, then it handles the logic to determine whether or not to change
-    things.
+    If `colorama` is unavailable, the original stream is returned unmodified.
+    Otherwise, the `wrap_stream()` function determines whether the stream needs
+    to be wrapped for a Windows environment and will accordingly either return
+    an `AnsiToWin32` wrapper or the original stream.
     """
     try:
-        from colorama import initialise
-
-        # We set `strip=False` so that we can don't have to modify
-        # test_express_diff_with_color.
-        f = initialise.wrap_stream(
-            f, convert=None, strip=False, autoreset=False, wrap=True
-        )
-
-        # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
-        # which does not have a `detach()` method. So we fake one.
-        f.detach = lambda *args, **kwargs: None  # type: ignore
+        from colorama.initialise import wrap_stream
     except ImportError:
-        pass
-
-    return f
+        return f
+    else:
+        # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
+        return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
 
 
 def format_stdin_to_stdout(
@@ -922,13 +927,13 @@ def format_stdin_to_stdout(
 
 
 def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
-    """Reformat contents a file and return new contents.
+    """Reformat contents of a file and return new contents.
 
     If `fast` is False, additionally confirm that the reformatted code is
     valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it.
     `mode` is passed to :func:`format_str`.
     """
-    if src_contents.strip() == "":
+    if not src_contents.strip():
         raise NothingChanged
 
     dst_contents = format_str(src_contents, mode=mode)
@@ -1062,7 +1067,7 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]:
 
 def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node:
     """Given a string with source, return the lib2to3 Node."""
-    if src_txt[-1:] != "\n":
+    if not src_txt.endswith("\n"):
         src_txt += "\n"
 
     for grammar in get_grammars(set(target_versions)):
@@ -1309,7 +1314,13 @@ class BracketTracker:
         self.maybe_decrement_after_lambda_arguments(leaf)
         if leaf.type in CLOSING_BRACKETS:
             self.depth -= 1
-            opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
+            try:
+                opening_bracket = self.bracket_match.pop((self.depth, leaf.type))
+            except KeyError as e:
+                raise BracketMatchError(
+                    "Unable to match a closing bracket to the following opening"
+                    f" bracket: {leaf}"
+                ) from e
             leaf.opening_bracket = opening_bracket
             if not leaf.value:
                 self.invisible.append(leaf)
@@ -1627,14 +1638,13 @@ class Line:
 
     def maybe_should_explode(self, closing: Leaf) -> bool:
         """Return True if this line should explode (always be split), that is when:
-        - there's a pre-existing trailing comma here; and
+        - there's a trailing comma here; and
         - it's not a one-tuple.
         """
         if not (
             closing.type in CLOSING_BRACKETS
             and self.leaves
             and self.leaves[-1].type == token.COMMA
-            and not self.leaves[-1].was_checked  # pre-existing
         ):
             return False
 
@@ -1826,6 +1836,10 @@ class EmptyLineTracker:
             return 0, 0
 
         if self.previous_line.is_decorator:
+            if self.is_pyi and current_line.is_stub_class:
+                # Insert an empty line after a decorated stub class
+                return 0, 1
+
             return 0, 0
 
         if self.previous_line.depth < current_line.depth and (
@@ -1849,8 +1863,11 @@ class EmptyLineTracker:
                     newlines = 0
                 else:
                     newlines = 1
-            elif current_line.is_def and not self.previous_line.is_def:
-                # Blank line between a block of functions and a block of non-functions
+            elif (
+                current_line.is_def or current_line.is_decorator
+            ) and not self.previous_line.is_def:
+                # Blank line between a block of functions (maybe with preceding
+                # decorators) and a block of non-functions
                 newlines = 1
             else:
                 newlines = 0
@@ -2037,14 +2054,20 @@ class LineGenerator(Visitor[Line]):
         yield from self.visit_default(node)
 
     def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:
-        # Check if it's a docstring
-        if prev_siblings_are(
-            leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
-        ) and is_multiline_string(leaf):
-            prefix = "    " * self.current_line.depth
-            docstring = fix_docstring(leaf.value[3:-3], prefix)
-            leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:]
-            normalize_string_quotes(leaf)
+        if is_docstring(leaf) and "\\\n" not in leaf.value:
+            # We're ignoring docstrings with backslash newline escapes because changing
+            # indentation of those changes the AST representation of the code.
+            prefix = get_string_prefix(leaf.value)
+            lead_len = len(prefix) + 3
+            tail_len = -3
+            indent = " " * 4 * self.current_line.depth
+            docstring = fix_docstring(leaf.value[lead_len:tail_len], indent)
+            if docstring:
+                if leaf.value[lead_len - 1] == docstring[0]:
+                    docstring = " " + docstring
+                if leaf.value[tail_len + 1] == docstring[-1]:
+                    docstring = docstring + " "
+            leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:]
 
         yield from self.visit_default(leaf)
 
@@ -2163,6 +2186,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
         ):
             # Python 2 print chevron
             return NO
+        elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
+            # no space in decorators
+            return NO
 
     elif prev.type in OPENING_BRACKETS:
         return NO
@@ -2654,7 +2680,7 @@ def transform_line(
             # All splits failed, best effort split with no omits.
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line, as well as lines that contain
-            # pre-existing trailing commas (those have to be exploded).
+            # trailing commas (those have to be exploded).
             yield from right_hand_split(
                 line, line_length=mode.line_length, features=features
             )
@@ -2664,9 +2690,9 @@ def transform_line(
                 transformers = [
                     string_merge,
                     string_paren_strip,
+                    string_split,
                     delimiter_split,
                     standalone_comment_split,
-                    string_split,
                     string_paren_wrap,
                     rhs,
                 ]
@@ -2885,11 +2911,8 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
     """StringTransformer that merges strings together.
 
     Requirements:
-        (A) The line contains adjacent strings such that at most one substring
-        has inline comments AND none of those inline comments are pragmas AND
-        the set of all substring prefixes is either of length 1 or equal to
-        {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed
-        with 'r').
+        (A) The line contains adjacent strings such that ALL of the validation checks
+        listed in StringMerger.__validate_msg(...)'s docstring pass.
             OR
         (B) The line contains a string which uses line continuation backslashes.
 
@@ -3138,6 +3161,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
             * Ok(None), if ALL validation checks (listed below) pass.
                 OR
             * Err(CannotTransform), if any of the following are true:
+                - The target string group does not contain ANY stand-alone comments.
                 - The target string is not in a string group (i.e. it has no
                   adjacent strings).
                 - The string group has more than one inline comment.
@@ -3146,6 +3170,26 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
                   length greater than one and is not equal to {"", "f"}.
                 - The string group consists of raw strings.
         """
+        # We first check for "inner" stand-alone comments (i.e. stand-alone
+        # comments that have a string leaf before them AND after them).
+        for inc in [1, -1]:
+            i = string_idx
+            found_sa_comment = False
+            is_valid_index = is_valid_index_factory(line.leaves)
+            while is_valid_index(i) and line.leaves[i].type in [
+                token.STRING,
+                STANDALONE_COMMENT,
+            ]:
+                if line.leaves[i].type == STANDALONE_COMMENT:
+                    found_sa_comment = True
+                elif found_sa_comment:
+                    return TErr(
+                        "StringMerger does NOT merge string groups which contain "
+                        "stand-alone comments."
+                    )
+
+                i += inc
+
         num_of_inline_string_comments = 0
         set_of_prefixes = set()
         num_of_strings = 0
@@ -3302,10 +3346,17 @@ class StringParenStripper(StringTransformer):
                 yield TErr(
                     "Will not strip parentheses which have comments attached to them."
                 )
+                return
 
         new_line = line.clone()
         new_line.comments = line.comments.copy()
-        append_leaves(new_line, line, LL[: string_idx - 1])
+        try:
+            append_leaves(new_line, line, LL[: string_idx - 1])
+        except BracketMatchError:
+            # HACK: I believe there is currently a bug somewhere in
+            # right_hand_split() that is causing brackets to not be tracked
+            # properly by a shared BracketTracker.
+            append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True)
 
         string_leaf = Leaf(token.STRING, LL[string_idx].value)
         LL[string_idx - 1].remove()
@@ -3468,9 +3519,12 @@ class BaseStringSplitter(StringTransformer):
                 # WMA4 a single space.
                 offset += 1
 
-                # WMA4 the lengths of any leaves that came before that space.
-                for leaf in LL[: p_idx + 1]:
+                # WMA4 the lengths of any leaves that came before that space,
+                # but after any closing bracket before that space.
+                for leaf in reversed(LL[: p_idx + 1]):
                     offset += len(str(leaf))
+                    if leaf.type in CLOSING_BRACKETS:
+                        break
 
         if is_valid_index(string_idx + 1):
             N = LL[string_idx + 1]
@@ -3986,12 +4040,13 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
     def do_splitter_match(self, line: Line) -> TMatchResult:
         LL = line.leaves
 
-        string_idx = None
-        string_idx = string_idx or self._return_match(LL)
-        string_idx = string_idx or self._else_match(LL)
-        string_idx = string_idx or self._assert_match(LL)
-        string_idx = string_idx or self._assign_match(LL)
-        string_idx = string_idx or self._dict_match(LL)
+        string_idx = (
+            self._return_match(LL)
+            or self._else_match(LL)
+            or self._assert_match(LL)
+            or self._assign_match(LL)
+            or self._dict_match(LL)
+        )
 
         if string_idx is not None:
             string_value = line.leaves[string_idx].value
@@ -4190,7 +4245,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         is_valid_index = is_valid_index_factory(LL)
         insert_str_child = insert_str_child_factory(LL[string_idx])
 
-        comma_idx = len(LL) - 1
+        comma_idx = -1
         ends_with_comma = False
         if LL[comma_idx].type == token.COMMA:
             ends_with_comma = True
@@ -4575,7 +4630,9 @@ def line_to_string(line: Line) -> str:
     return str(line).strip("\n")
 
 
-def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
+def append_leaves(
+    new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False
+) -> None:
     """
     Append leaves (taken from @old_line) to @new_line, making sure to fix the
     underlying Node structure where appropriate.
@@ -4591,7 +4648,7 @@ def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None:
     for old_leaf in leaves:
         new_leaf = Leaf(old_leaf.type, old_leaf.value)
         replace_child(old_leaf, new_leaf)
-        new_line.append(new_leaf)
+        new_line.append(new_leaf, preformatted=preformatted)
 
         for comment_leaf in old_line.comments_after(old_leaf):
             new_line.append(comment_leaf, preformatted=True)
@@ -4848,7 +4905,6 @@ def bracket_split_build_line(
 
                     if leaves[i].type != token.COMMA:
                         new_comma = Leaf(token.COMMA, ",")
-                        new_comma.was_checked = True
                         leaves.insert(i + 1, new_comma)
                     break
 
@@ -4944,7 +5000,6 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
             and current_line.leaves[-1].type != STANDALONE_COMMENT
         ):
             new_comma = Leaf(token.COMMA, ",")
-            new_comma.was_checked = True
             current_line.append(new_comma)
         yield current_line
 
@@ -5187,9 +5242,9 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
 
         if check_lpar:
             if is_walrus_assignment(child):
-                continue
+                pass
 
-            if child.type == syms.atom:
+            elif child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
@@ -5449,6 +5504,49 @@ def is_walrus_assignment(node: LN) -> bool:
     return inner is not None and inner.type == syms.namedexpr_test
 
 
+def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
+    """Return True iff `node` is a trailer valid in a simple decorator"""
+    return node.type == syms.trailer and (
+        (
+            len(node.children) == 2
+            and node.children[0].type == token.DOT
+            and node.children[1].type == token.NAME
+        )
+        # last trailer can be arguments
+        or (
+            last
+            and len(node.children) == 3
+            and node.children[0].type == token.LPAR
+            # and node.children[1].type == syms.argument
+            and node.children[2].type == token.RPAR
+        )
+    )
+
+
+def is_simple_decorator_expression(node: LN) -> bool:
+    """Return True iff `node` could be a 'dotted name' decorator
+
+    This function takes the node of the 'namedexpr_test' of the new decorator
+    grammar and test if it would be valid under the old decorator grammar.
+
+    The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
+    The new grammar is : decorator: @ namedexpr_test NEWLINE
+    """
+    if node.type == token.NAME:
+        return True
+    if node.type == syms.power:
+        if node.children:
+            return (
+                node.children[0].type == token.NAME
+                and all(map(is_simple_decorator_trailer, node.children[1:-1]))
+                and (
+                    len(node.children) < 2
+                    or is_simple_decorator_trailer(node.children[-1], last=True)
+                )
+            )
+    return False
+
+
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
@@ -5577,20 +5675,20 @@ def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
     # than one of them (we're excluding the trailing comma and if the delimiter priority
     # is still commas, that means there's more).
     exclude = set()
-    pre_existing_trailing_comma = False
+    trailing_comma = False
     try:
         last_leaf = line.leaves[-1]
         if last_leaf.type == token.COMMA:
-            pre_existing_trailing_comma = not last_leaf.was_checked
+            trailing_comma = True
             exclude.add(id(last_leaf))
         max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude)
     except (IndexError, ValueError):
         return False
 
     return max_priority == COMMA_PRIORITY and (
+        trailing_comma
         # always explode imports
-        opening_bracket.parent.type in {syms.atom, syms.import_from}
-        or pre_existing_trailing_comma
+        or opening_bracket.parent.type in {syms.atom, syms.import_from}
     )
 
 
@@ -5634,6 +5732,8 @@ def get_features_used(node: Node) -> Set[Feature]:
     - underscores in numeric literals;
     - trailing commas after * or ** in function signatures and calls;
     - positional only arguments in function signatures and lambdas;
+    - assignment expression;
+    - relaxed decorator syntax;
     """
     features: Set[Feature] = set()
     for n in node.pre_order():
@@ -5653,6 +5753,12 @@ def get_features_used(node: Node) -> Set[Feature]:
         elif n.type == token.COLONEQUAL:
             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
 
+        elif n.type == syms.decorator:
+            if len(n.children) > 1 and not is_simple_decorator_expression(
+                n.children[1]
+            ):
+                features.add(Feature.RELAXED_DECORATORS)
+
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
@@ -5720,12 +5826,11 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
                     line.should_explode
                     and prev
                     and prev.type == token.COMMA
-                    and not prev.was_checked
                     and not is_one_tuple_between(
                         leaf.opening_bracket, leaf, line.leaves
                     )
                 ):
-                    # Never omit bracket pairs with pre-existing trailing commas.
+                    # Never omit bracket pairs with trailing commas.
                     # We need to explode on those.
                     break
 
@@ -5749,10 +5854,9 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
                 line.should_explode
                 and prev
                 and prev.type == token.COMMA
-                and not prev.was_checked
                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
             ):
-                # Never omit bracket pairs with pre-existing trailing commas.
+                # Never omit bracket pairs with trailing commas.
                 # We need to explode on those.
                 break
 
@@ -5830,7 +5934,8 @@ def normalize_path_maybe_ignore(
     `report` is where "path ignored" output goes.
     """
     try:
-        normalized_path = path.resolve().relative_to(root).as_posix()
+        abspath = path if path.is_absolute() else Path.cwd() / path
+        normalized_path = abspath.resolve().relative_to(root).as_posix()
     except OSError as e:
         report.path_ignored(path, f"cannot be read because {e}")
         return None
@@ -6161,6 +6266,7 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None:
     newdst = format_str(dst, mode=mode)
     if dst != newdst:
         log = dump_to_file(
+            str(mode),
             diff(src, dst, "source", "first pass"),
             diff(dst, newdst, "first pass", "second pass"),
         )
@@ -6405,11 +6511,7 @@ def can_omit_invisible_parens(
             # unnecessary.
             return True
 
-        if (
-            line.should_explode
-            and penultimate.type == token.COMMA
-            and not penultimate.was_checked
-        ):
+        if line.should_explode and penultimate.type == token.COMMA:
             # The rightmost non-omitted bracket pair is the one we want to explode on.
             return True
 
@@ -6608,6 +6710,26 @@ def patched_main() -> None:
     main()
 
 
+def is_docstring(leaf: Leaf) -> bool:
+    if not is_multiline_string(leaf):
+        # For the purposes of docstring re-indentation, we don't need to do anything
+        # with single-line docstrings.
+        return False
+
+    if prev_siblings_are(
+        leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt]
+    ):
+        return True
+
+    # Multiline docstring on the same line as the `def`.
+    if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]):
+        # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python
+        # grammar. We're safe to return True without further checks.
+        return True
+
+    return False
+
+
 def fix_docstring(docstring: str, prefix: str) -> str:
     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
     if not docstring:
@@ -6631,7 +6753,6 @@ def fix_docstring(docstring: str, prefix: str) -> str:
                 trimmed.append(prefix + stripped_line)
             else:
                 trimmed.append("")
-    # Return a single string:
     return "\n".join(trimmed)