X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/811decd7f10fb2fb3ae343b9d9d0a3ae53b86a53..5446a92f0161e398de765bf9532d8c76c5652333:/src/black/__init__.py?ds=sidebyside

diff --git a/src/black/__init__.py b/src/black/__init__.py
index ffaafe2..a8f4f89 100644
--- a/src/black/__init__.py
+++ b/src/black/__init__.py
@@ -42,7 +42,6 @@ from typing import (
     cast,
     TYPE_CHECKING,
 )
-from typing_extensions import Final
 from mypy_extensions import mypyc_attr
 
 from appdirs import user_cache_dir
@@ -61,6 +60,11 @@ from blib2to3.pgen2.parse import ParseError
 
 from _black_version import version as __version__
 
+if sys.version_info < (3, 8):
+    from typing_extensions import Final
+else:
+    from typing import Final
+
 if TYPE_CHECKING:
     import colorama  # noqa: F401
 
@@ -68,6 +72,7 @@ DEFAULT_LINE_LENGTH = 88
 DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/"  # noqa: B950
 DEFAULT_INCLUDES = r"\.pyi?$"
 CACHE_DIR = Path(user_cache_dir("black", version=__version__))
+STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__"
 
 STRING_PREFIX_CHARS: Final = "furbFURB"  # All possible string prefix characters.
 
@@ -88,7 +93,7 @@ Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]]
 Timestamp = float
 FileSize = int
 CacheInfo = Tuple[Timestamp, FileSize]
-Cache = Dict[Path, CacheInfo]
+Cache = Dict[str, CacheInfo]
 out = partial(click.secho, bold=True, err=True)
 err = partial(click.secho, fg="red", err=True)
 
@@ -178,14 +183,12 @@ class TargetVersion(Enum):
     PY36 = 6
     PY37 = 7
     PY38 = 8
+    PY39 = 9
 
     def is_python2(self) -> bool:
         return self is TargetVersion.PY27
 
 
-PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38}
-
-
 class Feature(Enum):
     # All string literals are unicode
     UNICODE_LITERALS = 1
@@ -199,6 +202,7 @@ class Feature(Enum):
     ASYNC_KEYWORDS = 7
     ASSIGNMENT_EXPRESSIONS = 8
     POS_ONLY_ARGUMENTS = 9
+    RELAXED_DECORATORS = 10
     FORCE_OPTIONAL_PARENTHESES = 50
 
 
@@ -237,6 +241,17 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = {
         Feature.ASSIGNMENT_EXPRESSIONS,
         Feature.POS_ONLY_ARGUMENTS,
     },
+    TargetVersion.PY39: {
+        Feature.UNICODE_LITERALS,
+        Feature.F_STRINGS,
+        Feature.NUMERIC_UNDERSCORES,
+        Feature.TRAILING_COMMA_IN_CALL,
+        Feature.TRAILING_COMMA_IN_DEF,
+        Feature.ASYNC_KEYWORDS,
+        Feature.ASSIGNMENT_EXPRESSIONS,
+        Feature.RELAXED_DECORATORS,
+        Feature.POS_ONLY_ARGUMENTS,
+    },
 }
 
 
@@ -245,6 +260,7 @@ class Mode:
     target_versions: Set[TargetVersion] = field(default_factory=set)
     line_length: int = DEFAULT_LINE_LENGTH
     string_normalization: bool = True
+    magic_trailing_comma: bool = True
     experimental_string_processing: bool = False
     is_pyi: bool = False
 
@@ -347,6 +363,17 @@ def target_version_option_callback(
     return [TargetVersion[val.upper()] for val in v]
 
 
+def validate_regex(
+    ctx: click.Context,
+    param: click.Parameter,
+    value: Optional[str],
+) -> Optional[Pattern]:
+    try:
+        return re_compile_maybe_verbose(value) if value is not None else None
+    except re.error:
+        raise click.BadParameter("Not a valid regular expression")
+
+
 @click.command(context_settings=dict(help_option_names=["-h", "--help"]))
 @click.option("-c", "--code", type=str, help="Format the code passed in as a string.")
 @click.option(
@@ -382,6 +409,12 @@ def target_version_option_callback(
     is_flag=True,
     help="Don't normalize string quotes or prefixes.",
 )
+@click.option(
+    "-C",
+    "--skip-magic-trailing-comma",
+    is_flag=True,
+    help="Don't use trailing commas as a reason to split lines.",
+)
 @click.option(
     "--experimental-string-processing",
     is_flag=True,
@@ -419,6 +452,7 @@ def target_version_option_callback(
     "--include",
     type=str,
     default=DEFAULT_INCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " included on recursive searches.  An empty value means all files are included"
@@ -431,6 +465,7 @@ def target_version_option_callback(
     "--exclude",
     type=str,
     default=DEFAULT_EXCLUDES,
+    callback=validate_regex,
     help=(
         "A regular expression that matches files and directories that should be"
         " excluded on recursive searches.  An empty value means no paths are excluded."
@@ -439,12 +474,31 @@ def target_version_option_callback(
     ),
     show_default=True,
 )
+@click.option(
+    "--extend-exclude",
+    type=str,
+    callback=validate_regex,
+    help=(
+        "Like --exclude, but adds additional files and directories on top of the"
+        " excluded ones. (Useful if you simply want to add to the default)"
+    ),
+)
 @click.option(
     "--force-exclude",
     type=str,
+    callback=validate_regex,
     help=(
         "Like --exclude, but files and directories matching this regex will be "
-        "excluded even when they are passed explicitly as arguments"
+        "excluded even when they are passed explicitly as arguments."
+    ),
+)
+@click.option(
+    "--stdin-filename",
+    type=str,
+    help=(
+        "The name of the file when passing it through stdin. Useful to make "
+        "sure Black will respect --force-exclude option on some "
+        "editors that rely on using stdin."
     ),
 )
 @click.option(
@@ -462,7 +516,7 @@ def target_version_option_callback(
     is_flag=True,
     help=(
         "Also emit messages to stderr about files that were not changed or were ignored"
-        " due to --exclude=."
+        " due to exclusion patterns."
     ),
 )
 @click.version_option(version=__version__)
@@ -500,12 +554,15 @@ def main(
     fast: bool,
     pyi: bool,
     skip_string_normalization: bool,
+    skip_magic_trailing_comma: bool,
     experimental_string_processing: bool,
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    force_exclude: Optional[str],
+    include: Pattern,
+    exclude: Pattern,
+    extend_exclude: Optional[Pattern],
+    force_exclude: Optional[Pattern],
+    stdin_filename: Optional[str],
     src: Tuple[str, ...],
     config: Optional[str],
 ) -> None:
@@ -521,6 +578,7 @@ def main(
         line_length=line_length,
         is_pyi=pyi,
         string_normalization=not skip_string_normalization,
+        magic_trailing_comma=not skip_magic_trailing_comma,
         experimental_string_processing=experimental_string_processing,
     )
     if config and verbose:
@@ -536,8 +594,10 @@ def main(
         verbose=verbose,
         include=include,
         exclude=exclude,
+        extend_exclude=extend_exclude,
         force_exclude=force_exclude,
         report=report,
+        stdin_filename=stdin_filename,
     )
 
     path_empty(
@@ -573,29 +633,14 @@ def get_sources(
     src: Tuple[str, ...],
     quiet: bool,
     verbose: bool,
-    include: str,
-    exclude: str,
-    force_exclude: Optional[str],
+    include: Pattern[str],
+    exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
+    force_exclude: Optional[Pattern[str]],
     report: "Report",
+    stdin_filename: Optional[str],
 ) -> Set[Path]:
     """Compute the set of files to be formatted."""
-    try:
-        include_regex = re_compile_maybe_verbose(include)
-    except re.error:
-        err(f"Invalid regular expression for include given: {include!r}")
-        ctx.exit(2)
-    try:
-        exclude_regex = re_compile_maybe_verbose(exclude)
-    except re.error:
-        err(f"Invalid regular expression for exclude given: {exclude!r}")
-        ctx.exit(2)
-    try:
-        force_exclude_regex = (
-            re_compile_maybe_verbose(force_exclude) if force_exclude else None
-        )
-    except re.error:
-        err(f"Invalid regular expression for force_exclude given: {force_exclude!r}")
-        ctx.exit(2)
 
     root = find_project_root(src)
     sources: Set[Path] = set()
@@ -603,36 +648,46 @@ def get_sources(
     gitignore = get_gitignore(root)
 
     for s in src:
-        p = Path(s)
-        if p.is_dir():
-            sources.update(
-                gen_python_files(
-                    p.iterdir(),
-                    root,
-                    include_regex,
-                    exclude_regex,
-                    force_exclude_regex,
-                    report,
-                    gitignore,
-                )
-            )
-        elif s == "-":
-            sources.add(p)
-        elif p.is_file():
+        if s == "-" and stdin_filename:
+            p = Path(stdin_filename)
+            is_stdin = True
+        else:
+            p = Path(s)
+            is_stdin = False
+
+        if is_stdin or p.is_file():
             normalized_path = normalize_path_maybe_ignore(p, root, report)
             if normalized_path is None:
                 continue
 
             normalized_path = "/" + normalized_path
             # Hard-exclude any files that matches the `--force-exclude` regex.
-            if force_exclude_regex:
-                force_exclude_match = force_exclude_regex.search(normalized_path)
+            if force_exclude:
+                force_exclude_match = force_exclude.search(normalized_path)
             else:
                 force_exclude_match = None
             if force_exclude_match and force_exclude_match.group(0):
                 report.path_ignored(p, "matches the --force-exclude regular expression")
                 continue
 
+            if is_stdin:
+                p = Path(f"{STDIN_PLACEHOLDER}{str(p)}")
+
+            sources.add(p)
+        elif p.is_dir():
+            sources.update(
+                gen_python_files(
+                    p.iterdir(),
+                    root,
+                    include,
+                    exclude,
+                    extend_exclude,
+                    force_exclude,
+                    report,
+                    gitignore,
+                )
+            )
+        elif s == "-":
             sources.add(p)
         else:
             err(f"invalid path: {s}")
@@ -660,7 +715,18 @@ def reformat_one(
     """
     try:
         changed = Changed.NO
-        if not src.is_file() and str(src) == "-":
+
+        if str(src) == "-":
+            is_stdin = True
+        elif str(src).startswith(STDIN_PLACEHOLDER):
+            is_stdin = True
+            # Use the original name again in case we want to print something
+            # to the user
+            src = Path(str(src)[len(STDIN_PLACEHOLDER) :])
+        else:
+            is_stdin = False
+
+        if is_stdin:
             if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
                 changed = Changed.YES
         else:
@@ -668,7 +734,8 @@ def reformat_one(
             if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
                 cache = read_cache(mode)
                 res_src = src.resolve()
-                if res_src in cache and cache[res_src] == get_cache_info(res_src):
+                res_src_s = str(res_src)
+                if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
                     changed = Changed.CACHED
             if changed is not Changed.CACHED and format_file_in_place(
                 src, fast=fast, write_back=write_back, mode=mode
@@ -694,7 +761,7 @@ def reformat_many(
     worker_count = os.cpu_count()
     if sys.platform == "win32":
         # Work around https://bugs.python.org/issue26903
-        worker_count = min(worker_count, 61)
+        worker_count = min(worker_count, 60)
     try:
         executor = ProcessPoolExecutor(max_workers=worker_count)
     except (ImportError, OSError):
@@ -826,7 +893,7 @@ def format_file_in_place(
         dst_name = f"{src}\t{now} +0000"
         diff_contents = diff(src_contents, dst_contents, src_name, dst_name)
 
-        if write_back == write_back.COLOR_DIFF:
+        if write_back == WriteBack.COLOR_DIFF:
             diff_contents = color_diff(diff_contents)
 
         with lock or nullcontext():
@@ -849,9 +916,9 @@ def color_diff(contents: str) -> str:
     for i, line in enumerate(lines):
         if line.startswith("+++") or line.startswith("---"):
             line = "\033[1;37m" + line + "\033[0m"  # bold white, reset
-        if line.startswith("@@"):
+        elif line.startswith("@@"):
             line = "\033[36m" + line + "\033[0m"  # cyan, reset
-        if line.startswith("+"):
+        elif line.startswith("+"):
             line = "\033[32m" + line + "\033[0m"  # green, reset
         elif line.startswith("-"):
             line = "\033[31m" + line + "\033[0m"  # red, reset
@@ -861,30 +928,22 @@ def color_diff(contents: str) -> str:
 
 def wrap_stream_for_windows(
     f: io.TextIOWrapper,
-) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]:
+) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]:
     """
-    Wrap the stream in colorama's wrap_stream so colors are shown on Windows.
+    Wrap stream with colorama's wrap_stream so colors are shown on Windows.
 
-    If `colorama` is not found, then no change is made. If `colorama` does
-    exist, then it handles the logic to determine whether or not to change
-    things.
+    If `colorama` is unavailable, the original stream is returned unmodified.
+    Otherwise, the `wrap_stream()` function determines whether the stream needs
+    to be wrapped for a Windows environment and will accordingly either return
+    an `AnsiToWin32` wrapper or the original stream.
     """
     try:
-        from colorama import initialise
-
-        # We set `strip=False` so that we can don't have to modify
-        # test_express_diff_with_color.
-        f = initialise.wrap_stream(
-            f, convert=None, strip=False, autoreset=False, wrap=True
-        )
-
-        # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object
-        # which does not have a `detach()` method. So we fake one.
-        f.detach = lambda *args, **kwargs: None  # type: ignore
+        from colorama.initialise import wrap_stream
     except ImportError:
-        pass
-
-    return f
+        return f
+    else:
+        # Set `strip=False` to avoid needing to modify test_express_diff_with_color.
+        return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True)
 
 
 def format_stdin_to_stdout(
@@ -951,7 +1010,7 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
     allowed.  Example:
 
     >>> import black
-    >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode()))
+    >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode()))
     def f(arg: str = "") -> None:
         ...
 
@@ -983,13 +1042,12 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent:
         versions = detect_target_versions(src_node)
     normalize_fmt_off(src_node)
     lines = LineGenerator(
+        mode=mode,
         remove_u_prefix="unicode_literals" in future_imports
         or supports_feature(versions, Feature.UNICODE_LITERALS),
-        is_pyi=mode.is_pyi,
-        normalize_strings=mode.string_normalization,
     )
     elt = EmptyLineTracker(is_pyi=mode.is_pyi)
-    empty_line = Line()
+    empty_line = Line(mode=mode)
     after = 0
     split_line_features = {
         feature
@@ -1425,13 +1483,15 @@ class BracketTracker:
 class Line:
     """Holds leaves and comments. Can be printed with `str(line)`."""
 
+    mode: Mode
     depth: int = 0
     leaves: List[Leaf] = field(default_factory=list)
     # keys ordered like `leaves`
     comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict)
     bracket_tracker: BracketTracker = field(default_factory=BracketTracker)
     inside_brackets: bool = False
-    should_explode: bool = False
+    should_split_rhs: bool = False
+    magic_trailing_comma: Optional[Leaf] = None
 
     def append(self, leaf: Leaf, preformatted: bool = False) -> None:
         """Add a new `leaf` to the end of the line.
@@ -1457,8 +1517,11 @@ class Line:
             )
         if self.inside_brackets or not preformatted:
             self.bracket_tracker.mark(leaf)
-            if self.maybe_should_explode(leaf):
-                self.should_explode = True
+            if self.mode.magic_trailing_comma:
+                if self.has_magic_trailing_comma(leaf):
+                    self.magic_trailing_comma = leaf
+            elif self.has_magic_trailing_comma(leaf, ensure_removable=True):
+                self.remove_trailing_comma()
         if not self.append_comment(leaf):
             self.leaves.append(leaf)
 
@@ -1634,10 +1697,14 @@ class Line:
     def contains_multiline_strings(self) -> bool:
         return any(is_multiline_string(leaf) for leaf in self.leaves)
 
-    def maybe_should_explode(self, closing: Leaf) -> bool:
-        """Return True if this line should explode (always be split), that is when:
-        - there's a trailing comma here; and
-        - it's not a one-tuple.
+    def has_magic_trailing_comma(
+        self, closing: Leaf, ensure_removable: bool = False
+    ) -> bool:
+        """Return True if we have a magic trailing comma, that is when:
+        - there's a trailing comma here
+        - it's not a one-tuple
+        Additionally, if ensure_removable:
+        - it's not from square bracket indexing
         """
         if not (
             closing.type in CLOSING_BRACKETS
@@ -1646,9 +1713,15 @@ class Line:
         ):
             return False
 
-        if closing.type in {token.RBRACE, token.RSQB}:
+        if closing.type == token.RBRACE:
             return True
 
+        if closing.type == token.RSQB:
+            if not ensure_removable:
+                return True
+            comma = self.leaves[-1]
+            return bool(comma.parent and comma.parent.type == syms.listmaker)
+
         if self.is_import:
             return True
 
@@ -1726,9 +1799,11 @@ class Line:
 
     def clone(self) -> "Line":
         return Line(
+            mode=self.mode,
             depth=self.depth,
             inside_brackets=self.inside_brackets,
-            should_explode=self.should_explode,
+            should_split_rhs=self.should_split_rhs,
+            magic_trailing_comma=self.magic_trailing_comma,
         )
 
     def __str__(self) -> str:
@@ -1884,10 +1959,9 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
 
-    is_pyi: bool = False
-    normalize_strings: bool = True
-    current_line: Line = field(default_factory=Line)
+    mode: Mode
     remove_u_prefix: bool = False
+    current_line: Line = field(init=False)
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -1902,7 +1976,7 @@ class LineGenerator(Visitor[Line]):
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
         complete_line = self.current_line
-        self.current_line = Line(depth=complete_line.depth + indent)
+        self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
         yield complete_line
 
     def visit_default(self, node: LN) -> Iterator[Line]:
@@ -1926,7 +2000,7 @@ class LineGenerator(Visitor[Line]):
                     yield from self.line()
 
             normalize_prefix(node, inside_brackets=any_open_brackets)
-            if self.normalize_strings and node.type == token.STRING:
+            if self.mode.string_normalization and node.type == token.STRING:
                 normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix)
                 normalize_string_quotes(node)
             if node.type == token.NUMBER:
@@ -1978,16 +2052,18 @@ class LineGenerator(Visitor[Line]):
 
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
-        if self.is_pyi and is_stub_suite(node):
+        if self.mode.is_pyi and is_stub_suite(node):
             yield from self.visit(node.children[2])
         else:
             yield from self.visit_default(node)
 
     def visit_simple_stmt(self, node: Node) -> Iterator[Line]:
         """Visit a statement without nested statements."""
+        if first_child_is_arith(node):
+            wrap_in_parentheses(node, node.children[0], visible=False)
         is_suite_like = node.parent and node.parent.type in STATEMENT
         if is_suite_like:
-            if self.is_pyi and is_stub_body(node):
+            if self.mode.is_pyi and is_stub_body(node):
                 yield from self.visit_default(node)
             else:
                 yield from self.line(+1)
@@ -1995,7 +2071,11 @@ class LineGenerator(Visitor[Line]):
                 yield from self.line(-1)
 
         else:
-            if not self.is_pyi or not node.parent or not is_stub_suite(node.parent):
+            if (
+                not self.mode.is_pyi
+                or not node.parent
+                or not is_stub_suite(node.parent)
+            ):
                 yield from self.line()
             yield from self.visit_default(node)
 
@@ -2071,6 +2151,8 @@ class LineGenerator(Visitor[Line]):
 
     def __post_init__(self) -> None:
         """You are in a twisty little maze of passages."""
+        self.current_line = Line(mode=self.mode)
+
         v = self.visit_stmt
         Ø: Set[str] = set()
         self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","})
@@ -2184,6 +2266,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str:  # noqa: C901
         ):
             # Python 2 print chevron
             return NO
+        elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator:
+            # no space in decorators
+            return NO
 
     elif prev.type in OPENING_BRACKETS:
         return NO
@@ -2510,6 +2595,8 @@ def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Pr
 
 
 FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"}
+FMT_SKIP = {"# fmt: skip", "# fmt:skip"}
+FMT_PASS = {*FMT_OFF, *FMT_SKIP}
 FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"}
 
 
@@ -2564,7 +2651,7 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]:
     consumed = 0
     nlines = 0
     ignored_lines = 0
-    for index, line in enumerate(prefix.split("\n")):
+    for index, line in enumerate(re.split("\r?\n", prefix)):
         consumed += len(line) + 1  # adding the length of the split '\n'
         line = line.lstrip()
         if not line:
@@ -2637,7 +2724,8 @@ def transform_line(
     transformers: List[Transformer]
     if (
         not line.contains_uncollapsable_type_comments()
-        and not line.should_explode
+        and not line.should_split_rhs
+        and not line.magic_trailing_comma
         and (
             is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
             or line.contains_unsplittable_type_ignore()
@@ -3233,7 +3321,8 @@ class StringParenStripper(StringTransformer):
 
     Requirements:
         The line contains a string which is surrounded by parentheses and:
-            - The target string is NOT the only argument to a function call).
+            - The target string is NOT the only argument to a function call.
+            - The target string is NOT a "pointless" string.
             - If the target string contains a PERCENT, the brackets are not
               preceeded or followed by an operator with higher precedence than
               PERCENT.
@@ -3257,6 +3346,14 @@ class StringParenStripper(StringTransformer):
             if leaf.type != token.STRING:
                 continue
 
+            # If this is a "pointless" string...
+            if (
+                leaf.parent
+                and leaf.parent.parent
+                and leaf.parent.parent.type == syms.simple_stmt
+            ):
+                continue
+
             # Should be preceded by a non-empty LPAR...
             if (
                 not is_valid_index(idx - 1)
@@ -3585,7 +3682,8 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         MIN_SUBSTR_SIZE characters.
 
         The string will ONLY be split on spaces (i.e. each new substring should
-        start with a space).
+        start with a space). Note that the string will NOT be split on a space
+        which is escaped with a backslash.
 
         If the string is an f-string, it will NOT be split in the middle of an
         f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x
@@ -3605,13 +3703,14 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
     MIN_SUBSTR_SIZE = 6
     # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
     RE_FEXPR = r"""
-    (?<!\{)\{
+    (?<!\{) (?:\{\{)* \{ (?!\{)
         (?:
             [^\{\}]
             | \{\{
             | \}\}
+            | (?R)
         )+?
-    (?<!\})(?:\}\})*\}(?!\})
+    (?<!\}) \} (?:\}\})* (?!\})
     """
 
     def do_splitter_match(self, line: Line) -> TMatchResult:
@@ -3925,11 +4024,23 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
                 section of this classes' docstring would be be met by returning @i.
             """
             is_space = string[i] == " "
+
+            is_not_escaped = True
+            j = i - 1
+            while is_valid_index(j) and string[j] == "\\":
+                is_not_escaped = not is_not_escaped
+                j -= 1
+
             is_big_enough = (
                 len(string[i:]) >= self.MIN_SUBSTR_SIZE
                 and len(string[:i]) >= self.MIN_SUBSTR_SIZE
             )
-            return is_space and is_big_enough and not breaks_fstring_expression(i)
+            return (
+                is_space
+                and is_not_escaped
+                and is_big_enough
+                and not breaks_fstring_expression(i)
+            )
 
         # First, we check all indices BELOW @max_break_idx.
         break_idx = max_break_idx
@@ -4285,9 +4396,11 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         # `StringSplitter` will break it down further if necessary.
         string_value = LL[string_idx].value
         string_line = Line(
+            mode=line.mode,
             depth=line.depth + 1,
             inside_brackets=True,
-            should_explode=line.should_explode,
+            should_split_rhs=line.should_split_rhs,
+            magic_trailing_comma=line.magic_trailing_comma,
         )
         string_leaf = Leaf(token.STRING, string_value)
         insert_str_child(string_leaf)
@@ -4878,7 +4991,7 @@ def bracket_split_build_line(
     If `is_body` is True, the result line is one-indented inside brackets and as such
     has its first leaf's prefix normalized and a trailing comma added when expected.
     """
-    result = Line(depth=original.depth)
+    result = Line(mode=original.mode, depth=original.depth)
     if is_body:
         result.inside_brackets = True
         result.depth += 1
@@ -4908,8 +5021,8 @@ def bracket_split_build_line(
         result.append(leaf, preformatted=True)
         for comment_after in original.comments_after(leaf):
             result.append(comment_after, preformatted=True)
-    if is_body and should_split_body_explode(result, opening_bracket):
-        result.should_explode = True
+    if is_body and should_split_line(result, opening_bracket):
+        result.should_split_rhs = True
     return result
 
 
@@ -4950,7 +5063,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
             raise CannotSplit("Splitting a single attribute from its owner looks wrong")
 
-    current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+    current_line = Line(
+        mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+    )
     lowest_depth = sys.maxsize
     trailing_comma_safe = True
 
@@ -4962,7 +5077,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
         except ValueError:
             yield current_line
 
-            current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+            current_line = Line(
+                mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+            )
             current_line.append(leaf)
 
     for leaf in line.leaves:
@@ -4986,7 +5103,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
         if leaf_priority == delimiter_priority:
             yield current_line
 
-            current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+            current_line = Line(
+                mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+            )
     if current_line:
         if (
             trailing_comma_safe
@@ -5007,7 +5126,9 @@ def standalone_comment_split(
     if not line.contains_standalone_comments(0):
         raise CannotSplit("Line does not have any standalone comments")
 
-    current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+    current_line = Line(
+        mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+    )
 
     def append_to_line(leaf: Leaf) -> Iterator[Line]:
         """Append `leaf` to current line or to new line if appending impossible."""
@@ -5017,7 +5138,9 @@ def standalone_comment_split(
         except ValueError:
             yield current_line
 
-            current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
+            current_line = Line(
+                line.mode, depth=line.depth, inside_brackets=line.inside_brackets
+            )
             current_line.append(leaf)
 
     for leaf in line.leaves:
@@ -5173,31 +5296,52 @@ def normalize_numeric_literal(leaf: Leaf) -> None:
         # Leave octal and binary literals alone.
         pass
     elif text.startswith("0x"):
-        # Change hex literals to upper case.
-        before, after = text[:2], text[2:]
-        text = f"{before}{after.upper()}"
+        text = format_hex(text)
     elif "e" in text:
-        before, after = text.split("e")
-        sign = ""
-        if after.startswith("-"):
-            after = after[1:]
-            sign = "-"
-        elif after.startswith("+"):
-            after = after[1:]
-        before = format_float_or_int_string(before)
-        text = f"{before}e{sign}{after}"
+        text = format_scientific_notation(text)
     elif text.endswith(("j", "l")):
-        number = text[:-1]
-        suffix = text[-1]
-        # Capitalize in "2L" because "l" looks too similar to "1".
-        if suffix == "l":
-            suffix = "L"
-        text = f"{format_float_or_int_string(number)}{suffix}"
+        text = format_long_or_complex_number(text)
     else:
         text = format_float_or_int_string(text)
     leaf.value = text
 
 
+def format_hex(text: str) -> str:
+    """
+    Formats a hexadecimal string like "0x12b3"
+
+    Uses lowercase because of similarity between "B" and "8", which
+    can cause security issues.
+    see: https://github.com/psf/black/issues/1692
+    """
+
+    before, after = text[:2], text[2:]
+    return f"{before}{after.lower()}"
+
+
+def format_scientific_notation(text: str) -> str:
+    """Formats a numeric string utilizing scentific notation"""
+    before, after = text.split("e")
+    sign = ""
+    if after.startswith("-"):
+        after = after[1:]
+        sign = "-"
+    elif after.startswith("+"):
+        after = after[1:]
+    before = format_float_or_int_string(before)
+    return f"{before}e{sign}{after}"
+
+
+def format_long_or_complex_number(text: str) -> str:
+    """Formats a long or complex string like `10L` or `10j`"""
+    number = text[:-1]
+    suffix = text[-1]
+    # Capitalize in "2L" because "l" looks too similar to "1".
+    if suffix == "l":
+        suffix = "L"
+    return f"{format_float_or_int_string(number)}{suffix}"
+
+
 def format_float_or_int_string(text: str) -> str:
     """Formats a float string like "1.0"."""
     if "." not in text:
@@ -5236,10 +5380,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
             check_lpar = True
 
         if check_lpar:
-            if is_walrus_assignment(child):
-                pass
-
-            elif child.type == syms.atom:
+            if child.type == syms.atom:
                 if maybe_make_parens_invisible_in_atom(child, parent=node):
                     wrap_in_parentheses(node, child, visible=False)
             elif is_one_tuple(child):
@@ -5278,58 +5419,80 @@ def convert_one_fmt_off_pair(node: Node) -> bool:
     for leaf in node.leaves():
         previous_consumed = 0
         for comment in list_comments(leaf.prefix, is_endmarker=False):
-            if comment.value in FMT_OFF:
-                # We only want standalone comments. If there's no previous leaf or
-                # the previous leaf is indentation, it's a standalone comment in
-                # disguise.
-                if comment.type != STANDALONE_COMMENT:
-                    prev = preceding_leaf(leaf)
-                    if prev and prev.type not in WHITESPACE:
+            if comment.value not in FMT_PASS:
+                previous_consumed = comment.consumed
+                continue
+            # We only want standalone comments. If there's no previous leaf or
+            # the previous leaf is indentation, it's a standalone comment in
+            # disguise.
+            if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT:
+                prev = preceding_leaf(leaf)
+                if prev:
+                    if comment.value in FMT_OFF and prev.type not in WHITESPACE:
+                        continue
+                    if comment.value in FMT_SKIP and prev.type in WHITESPACE:
                         continue
 
-                ignored_nodes = list(generate_ignored_nodes(leaf))
-                if not ignored_nodes:
-                    continue
-
-                first = ignored_nodes[0]  # Can be a container node with the `leaf`.
-                parent = first.parent
-                prefix = first.prefix
-                first.prefix = prefix[comment.consumed :]
-                hidden_value = (
-                    comment.value + "\n" + "".join(str(n) for n in ignored_nodes)
-                )
-                if hidden_value.endswith("\n"):
-                    # That happens when one of the `ignored_nodes` ended with a NEWLINE
-                    # leaf (possibly followed by a DEDENT).
-                    hidden_value = hidden_value[:-1]
-                first_idx: Optional[int] = None
-                for ignored in ignored_nodes:
-                    index = ignored.remove()
-                    if first_idx is None:
-                        first_idx = index
-                assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
-                assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
-                parent.insert_child(
-                    first_idx,
-                    Leaf(
-                        STANDALONE_COMMENT,
-                        hidden_value,
-                        prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
-                    ),
-                )
-                return True
+            ignored_nodes = list(generate_ignored_nodes(leaf, comment))
+            if not ignored_nodes:
+                continue
 
-            previous_consumed = comment.consumed
+            first = ignored_nodes[0]  # Can be a container node with the `leaf`.
+            parent = first.parent
+            prefix = first.prefix
+            first.prefix = prefix[comment.consumed :]
+            hidden_value = "".join(str(n) for n in ignored_nodes)
+            if comment.value in FMT_OFF:
+                hidden_value = comment.value + "\n" + hidden_value
+            if comment.value in FMT_SKIP:
+                hidden_value += "  " + comment.value
+            if hidden_value.endswith("\n"):
+                # That happens when one of the `ignored_nodes` ended with a NEWLINE
+                # leaf (possibly followed by a DEDENT).
+                hidden_value = hidden_value[:-1]
+            first_idx: Optional[int] = None
+            for ignored in ignored_nodes:
+                index = ignored.remove()
+                if first_idx is None:
+                    first_idx = index
+            assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)"
+            assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)"
+            parent.insert_child(
+                first_idx,
+                Leaf(
+                    STANDALONE_COMMENT,
+                    hidden_value,
+                    prefix=prefix[:previous_consumed] + "\n" * comment.newlines,
+                ),
+            )
+            return True
 
     return False
 
 
-def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]:
+def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]:
     """Starting from the container of `leaf`, generate all leaves until `# fmt: on`.
 
+    If comment is skip, returns leaf only.
     Stops at the end of the block.
     """
     container: Optional[LN] = container_of(leaf)
+    if comment.value in FMT_SKIP:
+        prev_sibling = leaf.prev_sibling
+        if comment.value in leaf.prefix and prev_sibling is not None:
+            leaf.prefix = leaf.prefix.replace(comment.value, "")
+            siblings = [prev_sibling]
+            while (
+                "\n" not in prev_sibling.prefix
+                and prev_sibling.prev_sibling is not None
+            ):
+                prev_sibling = prev_sibling.prev_sibling
+                siblings.insert(0, prev_sibling)
+            for sibling in siblings:
+                yield sibling
+        elif leaf.parent is not None:
+            yield leaf.parent
+        return
     while container is not None and container.type != token.ENDMARKER:
         if is_fmt_on(container):
             return
@@ -5389,6 +5552,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     Returns whether the node should itself be wrapped in invisible parentheses.
 
     """
+
     if (
         node.type != syms.atom
         or is_empty_tuple(node)
@@ -5398,6 +5562,10 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
     ):
         return False
 
+    if is_walrus_assignment(node):
+        if parent.type in [syms.annassign, syms.expr_stmt]:
+            return False
+
     first = node.children[0]
     last = node.children[-1]
     if first.type == token.LPAR and last.type == token.RPAR:
@@ -5459,6 +5627,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]:
     return wrapped
 
 
+def first_child_is_arith(node: Node) -> bool:
+    """Whether first child is an arithmetic or a binary arithmetic expression"""
+    expr_types = {
+        syms.arith_expr,
+        syms.shift_expr,
+        syms.xor_expr,
+        syms.and_expr,
+    }
+    return bool(node.children and node.children[0].type in expr_types)
+
+
 def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None:
     """Wrap `child` in parentheses.
 
@@ -5499,6 +5678,49 @@ def is_walrus_assignment(node: LN) -> bool:
     return inner is not None and inner.type == syms.namedexpr_test
 
 
+def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool:
+    """Return True iff `node` is a trailer valid in a simple decorator"""
+    return node.type == syms.trailer and (
+        (
+            len(node.children) == 2
+            and node.children[0].type == token.DOT
+            and node.children[1].type == token.NAME
+        )
+        # last trailer can be arguments
+        or (
+            last
+            and len(node.children) == 3
+            and node.children[0].type == token.LPAR
+            # and node.children[1].type == syms.argument
+            and node.children[2].type == token.RPAR
+        )
+    )
+
+
+def is_simple_decorator_expression(node: LN) -> bool:
+    """Return True iff `node` could be a 'dotted name' decorator
+
+    This function takes the node of the 'namedexpr_test' of the new decorator
+    grammar and test if it would be valid under the old decorator grammar.
+
+    The old grammar was: decorator: @ dotted_name [arguments] NEWLINE
+    The new grammar is : decorator: @ namedexpr_test NEWLINE
+    """
+    if node.type == token.NAME:
+        return True
+    if node.type == syms.power:
+        if node.children:
+            return (
+                node.children[0].type == token.NAME
+                and all(map(is_simple_decorator_trailer, node.children[1:-1]))
+                and (
+                    len(node.children) < 2
+                    or is_simple_decorator_trailer(node.children[-1], last=True)
+                )
+            )
+    return False
+
+
 def is_yield(node: LN) -> bool:
     """Return True if `node` holds a `yield` or `yield from` expression."""
     if node.type == syms.yield_expr:
@@ -5617,7 +5839,7 @@ def ensure_visible(leaf: Leaf) -> None:
         leaf.value = ")"
 
 
-def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
+def should_split_line(line: Line, opening_bracket: Leaf) -> bool:
     """Should `line` be immediately split with `delimiter_split()` after RHS?"""
 
     if not (opening_bracket.parent and opening_bracket.value in "[{("):
@@ -5638,7 +5860,7 @@ def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool:
         return False
 
     return max_priority == COMMA_PRIORITY and (
-        trailing_comma
+        (line.mode.magic_trailing_comma and trailing_comma)
         # always explode imports
         or opening_bracket.parent.type in {syms.atom, syms.import_from}
     )
@@ -5684,6 +5906,8 @@ def get_features_used(node: Node) -> Set[Feature]:
     - underscores in numeric literals;
     - trailing commas after * or ** in function signatures and calls;
     - positional only arguments in function signatures and lambdas;
+    - assignment expression;
+    - relaxed decorator syntax;
     """
     features: Set[Feature] = set()
     for n in node.pre_order():
@@ -5703,6 +5927,12 @@ def get_features_used(node: Node) -> Set[Feature]:
         elif n.type == token.COLONEQUAL:
             features.add(Feature.ASSIGNMENT_EXPRESSIONS)
 
+        elif n.type == syms.decorator:
+            if len(n.children) > 1 and not is_simple_decorator_expression(
+                n.children[1]
+            ):
+                features.add(Feature.RELAXED_DECORATORS)
+
         elif (
             n.type in {syms.typedargslist, syms.arglist}
             and n.children
@@ -5745,7 +5975,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
     """
 
     omit: Set[LeafID] = set()
-    if not line.should_explode:
+    if not line.magic_trailing_comma:
         yield omit
 
     length = 4 * line.depth
@@ -5767,8 +5997,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
             elif leaf.type in CLOSING_BRACKETS:
                 prev = line.leaves[index - 1] if index > 0 else None
                 if (
-                    line.should_explode
-                    and prev
+                    prev
                     and prev.type == token.COMMA
                     and not is_one_tuple_between(
                         leaf.opening_bracket, leaf, line.leaves
@@ -5795,8 +6024,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf
                 yield omit
 
             if (
-                line.should_explode
-                and prev
+                prev
                 and prev.type == token.COMMA
                 and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves)
             ):
@@ -5894,17 +6122,27 @@ def normalize_path_maybe_ignore(
     return normalized_path
 
 
+def path_is_excluded(
+    normalized_path: str,
+    pattern: Optional[Pattern[str]],
+) -> bool:
+    match = pattern.search(normalized_path) if pattern else None
+    return bool(match and match.group(0))
+
+
 def gen_python_files(
     paths: Iterable[Path],
     root: Path,
     include: Optional[Pattern[str]],
     exclude: Pattern[str],
+    extend_exclude: Optional[Pattern[str]],
     force_exclude: Optional[Pattern[str]],
     report: "Report",
     gitignore: PathSpec,
 ) -> Iterator[Path]:
     """Generate all files under `path` whose paths are not excluded by the
-    `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex.
+    `exclude_regex`, `extend_exclude`, or `force_exclude` regexes,
+    but are included by the `include` regex.
 
     Symbolic links pointing outside of the `root` directory are ignored.
 
@@ -5921,20 +6159,22 @@ def gen_python_files(
             report.path_ignored(child, "matches the .gitignore file content")
             continue
 
-        # Then ignore with `--exclude` and `--force-exclude` options.
+        # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options.
         normalized_path = "/" + normalized_path
         if child.is_dir():
             normalized_path += "/"
 
-        exclude_match = exclude.search(normalized_path) if exclude else None
-        if exclude_match and exclude_match.group(0):
+        if path_is_excluded(normalized_path, exclude):
             report.path_ignored(child, "matches the --exclude regular expression")
             continue
 
-        force_exclude_match = (
-            force_exclude.search(normalized_path) if force_exclude else None
-        )
-        if force_exclude_match and force_exclude_match.group(0):
+        if path_is_excluded(normalized_path, extend_exclude):
+            report.path_ignored(
+                child, "matches the --extend-exclude regular expression"
+            )
+            continue
+
+        if path_is_excluded(normalized_path, force_exclude):
             report.path_ignored(child, "matches the --force-exclude regular expression")
             continue
 
@@ -5944,6 +6184,7 @@ def gen_python_files(
                 root,
                 include,
                 exclude,
+                extend_exclude,
                 force_exclude,
                 report,
                 gitignore,
@@ -6222,14 +6463,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None:
 
 
 @mypyc_attr(patchable=True)
-def dump_to_file(*output: str) -> str:
+def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str:
     """Dump `output` to a temporary file. Return path to the file."""
     with tempfile.NamedTemporaryFile(
         mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8"
     ) as f:
         for lines in output:
             f.write(lines)
-            if lines and lines[-1] != "\n":
+            if ensure_final_newline and lines and lines[-1] != "\n":
                 f.write("\n")
     return f.name
 
@@ -6247,11 +6488,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str:
     """Return a unified diff string between strings `a` and `b`."""
     import difflib
 
-    a_lines = [line + "\n" for line in a.splitlines()]
-    b_lines = [line + "\n" for line in b.splitlines()]
-    return "".join(
-        difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5)
-    )
+    a_lines = [line for line in a.splitlines(keepends=True)]
+    b_lines = [line for line in b.splitlines(keepends=True)]
+    diff_lines = []
+    for line in difflib.unified_diff(
+        a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5
+    ):
+        # Work around https://bugs.python.org/issue2142
+        # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html
+        if line[-1] == "\n":
+            diff_lines.append(line)
+        else:
+            diff_lines.append(line + "\n")
+            diff_lines.append("\\ No newline at end of file\n")
+    return "".join(diff_lines)
 
 
 def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None:
@@ -6428,7 +6678,7 @@ def can_omit_invisible_parens(
 
     penultimate = line.leaves[-2]
     last = line.leaves[-1]
-    if line.should_explode:
+    if line.magic_trailing_comma:
         try:
             penultimate, last = last_two_except(line.leaves, omit=omit_on_explode)
         except LookupError:
@@ -6455,7 +6705,7 @@ def can_omit_invisible_parens(
             # unnecessary.
             return True
 
-        if line.should_explode and penultimate.type == token.COMMA:
+        if line.magic_trailing_comma and penultimate.type == token.COMMA:
             # The rightmost non-omitted bracket pair is the one we want to explode on.
             return True
 
@@ -6605,8 +6855,8 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set
     """
     todo, done = set(), set()
     for src in sources:
-        src = src.resolve()
-        if cache.get(src) != get_cache_info(src):
+        res_src = src.resolve()
+        if cache.get(str(res_src)) != get_cache_info(res_src):
             todo.add(src)
         else:
             done.add(src)
@@ -6618,7 +6868,10 @@ def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
     cache_file = get_cache_file(mode)
     try:
         CACHE_DIR.mkdir(parents=True, exist_ok=True)
-        new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}}
+        new_cache = {
+            **cache,
+            **{str(src.resolve()): get_cache_info(src) for src in sources},
+        }
         with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
             pickle.dump(new_cache, f, protocol=4)
         os.replace(f.name, cache_file)
@@ -6674,13 +6927,33 @@ def is_docstring(leaf: Leaf) -> bool:
     return False
 
 
+def lines_with_leading_tabs_expanded(s: str) -> List[str]:
+    """
+    Splits string into lines and expands only leading tabs (following the normal
+    Python rules)
+    """
+    lines = []
+    for line in s.splitlines():
+        # Find the index of the first non-whitespace character after a string of
+        # whitespace that includes at least one tab
+        match = re.match(r"\s*\t+\s*(\S)", line)
+        if match:
+            first_non_whitespace_idx = match.start(1)
+
+            lines.append(
+                line[:first_non_whitespace_idx].expandtabs()
+                + line[first_non_whitespace_idx:]
+            )
+        else:
+            lines.append(line)
+    return lines
+
+
 def fix_docstring(docstring: str, prefix: str) -> str:
     # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
     if not docstring:
         return ""
-    # Convert tabs to spaces (following the normal Python rules)
-    # and split into a list of lines:
-    lines = docstring.expandtabs().splitlines()
+    lines = lines_with_leading_tabs_expanded(docstring)
     # Determine minimum indentation (first line doesn't count):
     indent = sys.maxsize
     for line in lines[1:]: