X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/6ebdc5a644b5e53b83a99eb93d63c91d0751da16..5d978df8b54d4fb102142508353a04932a4b5b04:/src/black/__init__.py?ds=inline diff --git a/src/black/__init__.py b/src/black/__init__.py index 2b2d3d8..24e9d4e 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -65,7 +65,7 @@ if TYPE_CHECKING: import colorama # noqa: F401 DEFAULT_LINE_LENGTH = 88 -DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 +DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) @@ -112,6 +112,10 @@ class InvalidInput(ValueError): """Raised when input source code fails all parse attempts.""" +class BracketMatchError(KeyError): + """Raised when an opening bracket is unable to be matched to a closing bracket.""" + + T = TypeVar("T") E = TypeVar("E", bound=Exception) @@ -174,14 +178,12 @@ class TargetVersion(Enum): PY36 = 6 PY37 = 7 PY38 = 8 + PY39 = 9 def is_python2(self) -> bool: return self is TargetVersion.PY27 -PY36_VERSIONS = {TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38} - - class Feature(Enum): # All string literals are unicode UNICODE_LITERALS = 1 @@ -195,6 +197,8 @@ class Feature(Enum): ASYNC_KEYWORDS = 7 ASSIGNMENT_EXPRESSIONS = 8 POS_ONLY_ARGUMENTS = 9 + RELAXED_DECORATORS = 10 + FORCE_OPTIONAL_PARENTHESES = 50 VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @@ -232,6 +236,17 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { Feature.ASSIGNMENT_EXPRESSIONS, Feature.POS_ONLY_ARGUMENTS, }, + TargetVersion.PY39: { + Feature.UNICODE_LITERALS, + Feature.F_STRINGS, + Feature.NUMERIC_UNDERSCORES, + Feature.TRAILING_COMMA_IN_CALL, + Feature.TRAILING_COMMA_IN_DEF, + Feature.ASYNC_KEYWORDS, + Feature.ASSIGNMENT_EXPRESSIONS, + Feature.RELAXED_DECORATORS, + Feature.POS_ONLY_ARGUMENTS, + }, } @@ -240,6 +255,7 @@ class Mode: target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True + experimental_string_processing: bool = False is_pyi: bool = False def get_cache_key(self) -> str: @@ -376,6 +392,15 @@ def target_version_option_callback( is_flag=True, help="Don't normalize string quotes or prefixes.", ) +@click.option( + "--experimental-string-processing", + is_flag=True, + hidden=True, + help=( + "Experimental option that performs more normalization on string literals." + " Currently disabled because it leads to some crashes." + ), +) @click.option( "--check", is_flag=True, @@ -429,7 +454,7 @@ def target_version_option_callback( type=str, help=( "Like --exclude, but files and directories matching this regex will be " - "excluded even when they are passed explicitly as arguments" + "excluded even when they are passed explicitly as arguments." ), ) @click.option( @@ -471,7 +496,7 @@ def target_version_option_callback( ), is_eager=True, callback=read_pyproject_toml, - help="Read configuration from PATH.", + help="Read configuration from FILE path.", ) @click.pass_context def main( @@ -485,6 +510,7 @@ def main( fast: bool, pyi: bool, skip_string_normalization: bool, + experimental_string_processing: bool, quiet: bool, verbose: bool, include: str, @@ -505,6 +531,7 @@ def main( line_length=line_length, is_pyi=pyi, string_normalization=not skip_string_normalization, + experimental_string_processing=experimental_string_processing, ) if config and verbose: out(f"Using configuration from {config}.", bold=False, fg="blue") @@ -583,9 +610,7 @@ def get_sources( root = find_project_root(src) sources: Set[Path] = set() path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx) - exclude_regexes = [exclude_regex] - if force_exclude_regex is not None: - exclude_regexes.append(force_exclude_regex) + gitignore = get_gitignore(root) for s in src: p = Path(s) @@ -595,19 +620,30 @@ def get_sources( p.iterdir(), root, include_regex, - exclude_regexes, + exclude_regex, + force_exclude_regex, report, - get_gitignore(root), + gitignore, ) ) elif s == "-": sources.add(p) elif p.is_file(): - sources.update( - gen_python_files( - [p], root, None, exclude_regexes, report, get_gitignore(root) - ) - ) + normalized_path = normalize_path_maybe_ignore(p, root, report) + if normalized_path is None: + continue + + normalized_path = "/" + normalized_path + # Hard-exclude any files that matches the `--force-exclude` regex. + if force_exclude_regex: + force_exclude_match = force_exclude_regex.search(normalized_path) + else: + force_exclude_match = None + if force_exclude_match and force_exclude_match.group(0): + report.path_ignored(p, "matches the --force-exclude regular expression") + continue + + sources.add(p) else: err(f"invalid path: {s}") return sources @@ -619,10 +655,9 @@ def path_empty( """ Exit if there is no `src` provided for formatting """ - if len(src) == 0: - if verbose or not quiet: - out(msg) - ctx.exit(0) + if not src and (verbose or not quiet): + out(msg) + ctx.exit(0) def reformat_one( @@ -640,7 +675,7 @@ def reformat_one( changed = Changed.YES else: cache: Cache = {} - if write_back != WriteBack.DIFF: + if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): cache = read_cache(mode) res_src = src.resolve() if res_src in cache and cache[res_src] == get_cache_info(res_src): @@ -655,6 +690,8 @@ def reformat_one( write_cache(cache, [src], mode) report.done(src, changed) except Exception as exc: + if report.verbose: + traceback.print_exc() report.failed(src, str(exc)) @@ -712,7 +749,7 @@ async def schedule_formatting( :func:`format_file_in_place`. """ cache: Cache = {} - if write_back != WriteBack.DIFF: + if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): cache = read_cache(mode) sources, cached = filter_cached(cache, sources) for src in sorted(cached): @@ -723,7 +760,7 @@ async def schedule_formatting( cancelled = [] sources_to_cache = [] lock = None - if write_back == WriteBack.DIFF: + if write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): # For diff output, we need locks to ensure we don't interleave output # from different processes. manager = Manager() @@ -822,9 +859,9 @@ def color_diff(contents: str) -> str: for i, line in enumerate(lines): if line.startswith("+++") or line.startswith("---"): line = "\033[1;37m" + line + "\033[0m" # bold white, reset - if line.startswith("@@"): + elif line.startswith("@@"): line = "\033[36m" + line + "\033[0m" # cyan, reset - if line.startswith("+"): + elif line.startswith("+"): line = "\033[32m" + line + "\033[0m" # green, reset elif line.startswith("-"): line = "\033[31m" + line + "\033[0m" # red, reset @@ -834,30 +871,22 @@ def color_diff(contents: str) -> str: def wrap_stream_for_windows( f: io.TextIOWrapper, -) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]: +) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]: """ - Wrap the stream in colorama's wrap_stream so colors are shown on Windows. + Wrap stream with colorama's wrap_stream so colors are shown on Windows. - If `colorama` is not found, then no change is made. If `colorama` does - exist, then it handles the logic to determine whether or not to change - things. + If `colorama` is unavailable, the original stream is returned unmodified. + Otherwise, the `wrap_stream()` function determines whether the stream needs + to be wrapped for a Windows environment and will accordingly either return + an `AnsiToWin32` wrapper or the original stream. """ try: - from colorama import initialise - - # We set `strip=False` so that we can don't have to modify - # test_express_diff_with_color. - f = initialise.wrap_stream( - f, convert=None, strip=False, autoreset=False, wrap=True - ) - - # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object - # which does not have a `detach()` method. So we fake one. - f.detach = lambda *args, **kwargs: None # type: ignore + from colorama.initialise import wrap_stream except ImportError: - pass - - return f + return f + else: + # Set `strip=False` to avoid needing to modify test_express_diff_with_color. + return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True) def format_stdin_to_stdout( @@ -898,13 +927,13 @@ def format_stdin_to_stdout( def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: - """Reformat contents a file and return new contents. + """Reformat contents of a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is valid by calling :func:`assert_equivalent` and :func:`assert_stable` on it. `mode` is passed to :func:`format_str`. """ - if src_contents.strip() == "": + if not src_contents.strip(): raise NothingChanged dst_contents = format_str(src_contents, mode=mode) @@ -929,6 +958,7 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: ... A more complex example: + >>> print( ... black.format_str( ... "def f(arg:str='')->None: hey", @@ -973,10 +1003,7 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: before, after = elt.maybe_empty_lines(current_line) dst_contents.append(str(empty_line) * before) for line in transform_line( - current_line, - line_length=mode.line_length, - normalize_strings=mode.string_normalization, - features=split_line_features, + current_line, mode=mode, features=split_line_features ): dst_contents.append(str(line)) return "".join(dst_contents) @@ -1040,7 +1067,7 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: """Given a string with source, return the lib2to3 Node.""" - if src_txt[-1:] != "\n": + if not src_txt.endswith("\n"): src_txt += "\n" for grammar in get_grammars(set(target_versions)): @@ -1263,6 +1290,7 @@ class BracketTracker: previous: Optional[Leaf] = None _for_loop_depths: List[int] = field(default_factory=list) _lambda_argument_depths: List[int] = field(default_factory=list) + invisible: List[Leaf] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -1286,8 +1314,16 @@ class BracketTracker: self.maybe_decrement_after_lambda_arguments(leaf) if leaf.type in CLOSING_BRACKETS: self.depth -= 1 - opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) + try: + opening_bracket = self.bracket_match.pop((self.depth, leaf.type)) + except KeyError as e: + raise BracketMatchError( + "Unable to match a closing bracket to the following opening" + f" bracket: {leaf}" + ) from e leaf.opening_bracket = opening_bracket + if not leaf.value: + self.invisible.append(leaf) leaf.bracket_depth = self.depth if self.depth == 0: delim = is_split_before_delimiter(leaf, self.previous) @@ -1300,6 +1336,8 @@ class BracketTracker: if leaf.type in OPENING_BRACKETS: self.bracket_match[self.depth, BRACKET[leaf.type]] = leaf self.depth += 1 + if not leaf.value: + self.invisible.append(leaf) self.previous = leaf self.maybe_increment_lambda_arguments(leaf) self.maybe_increment_for_loop_variable(leaf) @@ -1421,7 +1459,8 @@ class Line: ) if self.inside_brackets or not preformatted: self.bracket_tracker.mark(leaf) - self.maybe_remove_trailing_comma(leaf) + if self.maybe_should_explode(leaf): + self.should_explode = True if not self.append_comment(leaf): self.leaves.append(leaf) @@ -1473,69 +1512,6 @@ class Line: Leaf(token.DOT, ".") for _ in range(3) ] - @property - def is_collection_with_optional_trailing_comma(self) -> bool: - """Is this line a collection literal with a trailing comma that's optional? - - Note that the trailing comma in a 1-tuple is not optional. - """ - if not self.leaves or len(self.leaves) < 4: - return False - - # Look for and address a trailing colon. - if self.leaves[-1].type == token.COLON: - closer = self.leaves[-2] - close_index = -2 - else: - closer = self.leaves[-1] - close_index = -1 - if closer.type not in CLOSING_BRACKETS or self.inside_brackets: - return False - - if closer.type == token.RPAR: - # Tuples require an extra check, because if there's only - # one element in the tuple removing the comma unmakes the - # tuple. - # - # We also check for parens before looking for the trailing - # comma because in some cases (eg assigning a dict - # literal) the literal gets wrapped in temporary parens - # during parsing. This case is covered by the - # collections.py test data. - opener = closer.opening_bracket - for _open_index, leaf in enumerate(self.leaves): - if leaf is opener: - break - - else: - # Couldn't find the matching opening paren, play it safe. - return False - - commas = 0 - comma_depth = self.leaves[close_index - 1].bracket_depth - for leaf in self.leaves[_open_index + 1 : close_index]: - if leaf.bracket_depth == comma_depth and leaf.type == token.COMMA: - commas += 1 - if commas > 1: - # We haven't looked yet for the trailing comma because - # we might also have caught noop parens. - return self.leaves[close_index - 1].type == token.COMMA - - elif commas == 1: - return False # it's either a one-tuple or didn't have a trailing comma - - if self.leaves[close_index - 1].type in CLOSING_BRACKETS: - close_index -= 1 - closer = self.leaves[close_index] - if closer.type == token.RPAR: - # TODO: this is a gut feeling. Will we ever see this? - return False - - if self.leaves[close_index - 1].type != token.COMMA: - return False - - return True - @property def is_def(self) -> bool: """Is this a function definition? (Also returns True for async defs.)""" @@ -1660,42 +1636,28 @@ class Line: def contains_multiline_strings(self) -> bool: return any(is_multiline_string(leaf) for leaf in self.leaves) - def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: - """Remove trailing comma if there is one and it's safe.""" - if not (self.leaves and self.leaves[-1].type == token.COMMA): - return False - - # We remove trailing commas only in the case of importing a - # single name from a module. + def maybe_should_explode(self, closing: Leaf) -> bool: + """Return True if this line should explode (always be split), that is when: + - there's a trailing comma here; and + - it's not a one-tuple. + """ if not ( - self.leaves - and self.is_import - and len(self.leaves) > 4 + closing.type in CLOSING_BRACKETS + and self.leaves and self.leaves[-1].type == token.COMMA - and closing.type in CLOSING_BRACKETS - and self.leaves[-4].type == token.NAME - and ( - # regular `from foo import bar,` - self.leaves[-4].value == "import" - # `from foo import (bar as baz,) - or ( - len(self.leaves) > 6 - and self.leaves[-6].value == "import" - and self.leaves[-3].value == "as" - ) - # `from foo import bar as baz,` - or ( - len(self.leaves) > 5 - and self.leaves[-5].value == "import" - and self.leaves[-3].value == "as" - ) - ) - and closing.type == token.RPAR ): return False - self.remove_trailing_comma() - return True + if closing.type in {token.RBRACE, token.RSQB}: + return True + + if self.is_import: + return True + + if not is_one_tuple_between(closing.opening_bracket, closing, self.leaves): + return True + + return False def append_comment(self, comment: Leaf) -> bool: """Add an inline or standalone comment to the line.""" @@ -1874,6 +1836,10 @@ class EmptyLineTracker: return 0, 0 if self.previous_line.is_decorator: + if self.is_pyi and current_line.is_stub_class: + # Insert an empty line after a decorated stub class + return 0, 1 + return 0, 0 if self.previous_line.depth < current_line.depth and ( @@ -1897,8 +1863,11 @@ class EmptyLineTracker: newlines = 0 else: newlines = 1 - elif current_line.is_def and not self.previous_line.is_def: - # Blank line between a block of functions and a block of non-functions + elif ( + current_line.is_def or current_line.is_decorator + ) and not self.previous_line.is_def: + # Blank line between a block of functions (maybe with preceding + # decorators) and a block of non-functions newlines = 1 else: newlines = 0 @@ -2085,14 +2054,20 @@ class LineGenerator(Visitor[Line]): yield from self.visit_default(node) def visit_STRING(self, leaf: Leaf) -> Iterator[Line]: - # Check if it's a docstring - if prev_siblings_are( - leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt] - ) and is_multiline_string(leaf): - prefix = " " * self.current_line.depth - docstring = fix_docstring(leaf.value[3:-3], prefix) - leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:] - normalize_string_quotes(leaf) + if is_docstring(leaf) and "\\\n" not in leaf.value: + # We're ignoring docstrings with backslash newline escapes because changing + # indentation of those changes the AST representation of the code. + prefix = get_string_prefix(leaf.value) + lead_len = len(prefix) + 3 + tail_len = -3 + indent = " " * 4 * self.current_line.depth + docstring = fix_docstring(leaf.value[lead_len:tail_len], indent) + if docstring: + if leaf.value[lead_len - 1] == docstring[0]: + docstring = " " + docstring + if leaf.value[tail_len + 1] == docstring[-1]: + docstring = docstring + " " + leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:] yield from self.visit_default(leaf) @@ -2211,6 +2186,9 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool) -> str: # noqa: C901 ): # Python 2 print chevron return NO + elif prevp.type == token.AT and p.parent and p.parent.type == syms.decorator: + # no space in decorators + return NO elif prev.type in OPENING_BRACKETS: return NO @@ -2638,10 +2616,7 @@ def make_comment(content: str) -> str: def transform_line( - line: Line, - line_length: int, - normalize_strings: bool, - features: Collection[Feature] = (), + line: Line, mode: Mode, features: Collection[Feature] = () ) -> Iterator[Line]: """Transform a `line`, potentially splitting it into many lines. @@ -2657,7 +2632,7 @@ def transform_line( def init_st(ST: Type[StringTransformer]) -> StringTransformer: """Initialize StringTransformer""" - return ST(line_length, normalize_strings) + return ST(mode.line_length, mode.string_normalization) string_merge = init_st(StringMerger) string_paren_strip = init_st(StringParenStripper) @@ -2668,72 +2643,79 @@ def transform_line( if ( not line.contains_uncollapsable_type_comments() and not line.should_explode - and not line.is_collection_with_optional_trailing_comma and ( - is_line_short_enough(line, line_length=line_length, line_str=line_str) + is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() ) - and not (line.contains_standalone_comments() and line.inside_brackets) + and not (line.inside_brackets and line.contains_standalone_comments()) ): # Only apply basic string preprocessing, since lines shouldn't be split here. - transformers = [string_merge, string_paren_strip] + if mode.experimental_string_processing: + transformers = [string_merge, string_paren_strip] + else: + transformers = [] elif line.is_def: transformers = [left_hand_split] else: def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]: - for omit in generate_trailers_to_omit(line, line_length): - lines = list(right_hand_split(line, line_length, features, omit=omit)) - if is_line_short_enough(lines[0], line_length=line_length): + """Wraps calls to `right_hand_split`. + + The calls increasingly `omit` right-hand trailers (bracket pairs with + content), meaning the trailers get glued together to split on another + bracket pair instead. + """ + for omit in generate_trailers_to_omit(line, mode.line_length): + lines = list( + right_hand_split(line, mode.line_length, features, omit=omit) + ) + # Note: this check is only able to figure out if the first line of the + # *current* transformation fits in the line length. This is true only + # for simple cases. All others require running more transforms via + # `transform_line()`. This check doesn't know if those would succeed. + if is_line_short_enough(lines[0], line_length=mode.line_length): yield from lines return # All splits failed, best effort split with no omits. # This mostly happens to multiline strings that are by definition - # reported as not fitting a single line. - # line_length=1 here was historically a bug that somehow became a feature. - # See #762 and #781 for the full story. - yield from right_hand_split(line, line_length=1, features=features) - - if line.inside_brackets: - transformers = [ - string_merge, - string_paren_strip, - delimiter_split, - standalone_comment_split, - string_split, - string_paren_wrap, - rhs, - ] + # reported as not fitting a single line, as well as lines that contain + # trailing commas (those have to be exploded). + yield from right_hand_split( + line, line_length=mode.line_length, features=features + ) + + if mode.experimental_string_processing: + if line.inside_brackets: + transformers = [ + string_merge, + string_paren_strip, + string_split, + delimiter_split, + standalone_comment_split, + string_paren_wrap, + rhs, + ] + else: + transformers = [ + string_merge, + string_paren_strip, + string_split, + string_paren_wrap, + rhs, + ] else: - transformers = [ - string_merge, - string_paren_strip, - string_split, - string_paren_wrap, - rhs, - ] + if line.inside_brackets: + transformers = [delimiter_split, standalone_comment_split, rhs] + else: + transformers = [rhs] for transform in transformers: # We are accumulating lines in `result` because we might want to abort # mission and return the original line in the end, or attempt a different # split altogether. - result: List[Line] = [] try: - for transformed_line in transform(line, features): - if str(transformed_line).strip("\n") == line_str: - raise CannotTransform( - "Line transformer returned an unchanged result" - ) - - result.extend( - transform_line( - transformed_line, - line_length=line_length, - normalize_strings=normalize_strings, - features=features, - ) - ) + result = run_transformer(line, transform, mode, features, line_str=line_str) except CannotTransform: continue else: @@ -2774,6 +2756,7 @@ class StringTransformer(ABC): line_length: int normalize_strings: bool + __name__ = "StringTransformer" @abstractmethod def do_match(self, line: Line) -> TMatchResult: @@ -2928,11 +2911,8 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): """StringTransformer that merges strings together. Requirements: - (A) The line contains adjacent strings such that at most one substring - has inline comments AND none of those inline comments are pragmas AND - the set of all substring prefixes is either of length 1 or equal to - {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed - with 'r'). + (A) The line contains adjacent strings such that ALL of the validation checks + listed in StringMerger.__validate_msg(...)'s docstring pass. OR (B) The line contains a string which uses line continuation backslashes. @@ -3020,7 +3000,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): ) new_line = line.clone() - new_line.comments = line.comments + new_line.comments = line.comments.copy() append_leaves(new_line, line, LL) new_string_leaf = new_line.leaves[string_idx] @@ -3181,6 +3161,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): * Ok(None), if ALL validation checks (listed below) pass. OR * Err(CannotTransform), if any of the following are true: + - The target string group does not contain ANY stand-alone comments. - The target string is not in a string group (i.e. it has no adjacent strings). - The string group has more than one inline comment. @@ -3189,6 +3170,26 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): length greater than one and is not equal to {"", "f"}. - The string group consists of raw strings. """ + # We first check for "inner" stand-alone comments (i.e. stand-alone + # comments that have a string leaf before them AND after them). + for inc in [1, -1]: + i = string_idx + found_sa_comment = False + is_valid_index = is_valid_index_factory(line.leaves) + while is_valid_index(i) and line.leaves[i].type in [ + token.STRING, + STANDALONE_COMMENT, + ]: + if line.leaves[i].type == STANDALONE_COMMENT: + found_sa_comment = True + elif found_sa_comment: + return TErr( + "StringMerger does NOT merge string groups which contain " + "stand-alone comments." + ) + + i += inc + num_of_inline_string_comments = 0 set_of_prefixes = set() num_of_strings = 0 @@ -3238,7 +3239,9 @@ class StringParenStripper(StringTransformer): Requirements: The line contains a string which is surrounded by parentheses and: - The target string is NOT the only argument to a function call). - - The RPAR is NOT followed by an attribute access (i.e. a dot). + - If the target string contains a PERCENT, the brackets are not + preceeded or followed by an operator with higher precedence than + PERCENT. Transformations: The parentheses mentioned in the 'Requirements' section are stripped. @@ -3281,14 +3284,51 @@ class StringParenStripper(StringTransformer): string_parser = StringParser() next_idx = string_parser.parse(LL, string_idx) + # if the leaves in the parsed string include a PERCENT, we need to + # make sure the initial LPAR is NOT preceded by an operator with + # higher or equal precedence to PERCENT + if is_valid_index(idx - 2): + # mypy can't quite follow unless we name this + before_lpar = LL[idx - 2] + if token.PERCENT in {leaf.type for leaf in LL[idx - 1 : next_idx]} and ( + ( + before_lpar.type + in { + token.STAR, + token.AT, + token.SLASH, + token.DOUBLESLASH, + token.PERCENT, + token.TILDE, + token.DOUBLESTAR, + token.AWAIT, + token.LSQB, + token.LPAR, + } + ) + or ( + # only unary PLUS/MINUS + before_lpar.parent + and before_lpar.parent.type == syms.factor + and (before_lpar.type in {token.PLUS, token.MINUS}) + ) + ): + continue + # Should be followed by a non-empty RPAR... if ( is_valid_index(next_idx) and LL[next_idx].type == token.RPAR and not is_empty_rpar(LL[next_idx]) ): - # That RPAR should NOT be followed by a '.' symbol. - if is_valid_index(next_idx + 1) and LL[next_idx + 1].type == token.DOT: + # That RPAR should NOT be followed by anything with higher + # precedence than PERCENT + if is_valid_index(next_idx + 1) and LL[next_idx + 1].type in { + token.DOUBLESTAR, + token.LSQB, + token.LPAR, + token.DOT, + }: continue return Ok(string_idx) @@ -3306,11 +3346,17 @@ class StringParenStripper(StringTransformer): yield TErr( "Will not strip parentheses which have comments attached to them." ) + return new_line = line.clone() new_line.comments = line.comments.copy() - - append_leaves(new_line, line, LL[: string_idx - 1]) + try: + append_leaves(new_line, line, LL[: string_idx - 1]) + except BracketMatchError: + # HACK: I believe there is currently a bug somewhere in + # right_hand_split() that is causing brackets to not be tracked + # properly by a shared BracketTracker. + append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True) string_leaf = Leaf(token.STRING, LL[string_idx].value) LL[string_idx - 1].remove() @@ -3318,7 +3364,7 @@ class StringParenStripper(StringTransformer): new_line.append(string_leaf) append_leaves( - new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :], + new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :] ) LL[rpar_idx].remove() @@ -3473,9 +3519,12 @@ class BaseStringSplitter(StringTransformer): # WMA4 a single space. offset += 1 - # WMA4 the lengths of any leaves that came before that space. - for leaf in LL[: p_idx + 1]: + # WMA4 the lengths of any leaves that came before that space, + # but after any closing bracket before that space. + for leaf in reversed(LL[: p_idx + 1]): offset += len(str(leaf)) + if leaf.type in CLOSING_BRACKETS: + break if is_valid_index(string_idx + 1): N = LL[string_idx + 1] @@ -3991,12 +4040,13 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): def do_splitter_match(self, line: Line) -> TMatchResult: LL = line.leaves - string_idx = None - string_idx = string_idx or self._return_match(LL) - string_idx = string_idx or self._else_match(LL) - string_idx = string_idx or self._assert_match(LL) - string_idx = string_idx or self._assign_match(LL) - string_idx = string_idx or self._dict_match(LL) + string_idx = ( + self._return_match(LL) + or self._else_match(LL) + or self._assert_match(LL) + or self._assign_match(LL) + or self._dict_match(LL) + ) if string_idx is not None: string_value = line.leaves[string_idx].value @@ -4195,7 +4245,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): is_valid_index = is_valid_index_factory(LL) insert_str_child = insert_str_child_factory(LL[string_idx]) - comma_idx = len(LL) - 1 + comma_idx = -1 ends_with_comma = False if LL[comma_idx].type == token.COMMA: ends_with_comma = True @@ -4580,7 +4630,9 @@ def line_to_string(line: Line) -> str: return str(line).strip("\n") -def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None: +def append_leaves( + new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False +) -> None: """ Append leaves (taken from @old_line) to @new_line, making sure to fix the underlying Node structure where appropriate. @@ -4594,11 +4646,9 @@ def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None: set(@leaves) is a subset of set(@old_line.leaves). """ for old_leaf in leaves: - assert old_leaf in old_line.leaves - new_leaf = Leaf(old_leaf.type, old_leaf.value) replace_child(old_leaf, new_leaf) - new_line.append(new_leaf) + new_line.append(new_leaf, preformatted=preformatted) for comment_leaf in old_line.comments_after(old_leaf): new_line.append(comment_leaf, preformatted=True) @@ -4755,8 +4805,7 @@ def right_hand_split( tail = bracket_split_build_line(tail_leaves, line, opening_bracket) bracket_split_succeeded_or_raise(head, body, tail) if ( - # the body shouldn't be exploded - not body.should_explode + Feature.FORCE_OPTIONAL_PARENTHESES not in features # the opening bracket is an optional paren and opening_bracket.type == token.LPAR and not opening_bracket.value @@ -4769,7 +4818,7 @@ def right_hand_split( # there are no standalone comments in the body and not body.contains_standalone_comments(0) # and we can actually remove the parens - and can_omit_invisible_parens(body, line_length) + and can_omit_invisible_parens(body, line_length, omit_on_explode=omit) ): omit = {id(closing_bracket), *omit} try: @@ -4855,7 +4904,8 @@ def bracket_split_build_line( continue if leaves[i].type != token.COMMA: - leaves.insert(i + 1, Leaf(token.COMMA, ",")) + new_comma = Leaf(token.COMMA, ",") + leaves.insert(i + 1, new_comma) break # Populate the line @@ -4863,8 +4913,8 @@ def bracket_split_build_line( result.append(leaf, preformatted=True) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) - if is_body: - result.should_explode = should_explode(result, opening_bracket) + if is_body and should_split_body_explode(result, opening_bracket): + result.should_explode = True return result @@ -4949,7 +4999,8 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ and current_line.leaves[-1].type != token.COMMA and current_line.leaves[-1].type != STANDALONE_COMMENT ): - current_line.append(Leaf(token.COMMA, ",")) + new_comma = Leaf(token.COMMA, ",") + current_line.append(new_comma) yield current_line @@ -5191,9 +5242,9 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: if check_lpar: if is_walrus_assignment(child): - continue + pass - if child.type == syms.atom: + elif child.type == syms.atom: if maybe_make_parens_invisible_in_atom(child, parent=node): wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): @@ -5453,6 +5504,49 @@ def is_walrus_assignment(node: LN) -> bool: return inner is not None and inner.type == syms.namedexpr_test +def is_simple_decorator_trailer(node: LN, last: bool = False) -> bool: + """Return True iff `node` is a trailer valid in a simple decorator""" + return node.type == syms.trailer and ( + ( + len(node.children) == 2 + and node.children[0].type == token.DOT + and node.children[1].type == token.NAME + ) + # last trailer can be arguments + or ( + last + and len(node.children) == 3 + and node.children[0].type == token.LPAR + # and node.children[1].type == syms.argument + and node.children[2].type == token.RPAR + ) + ) + + +def is_simple_decorator_expression(node: LN) -> bool: + """Return True iff `node` could be a 'dotted name' decorator + + This function takes the node of the 'namedexpr_test' of the new decorator + grammar and test if it would be valid under the old decorator grammar. + + The old grammar was: decorator: @ dotted_name [arguments] NEWLINE + The new grammar is : decorator: @ namedexpr_test NEWLINE + """ + if node.type == token.NAME: + return True + if node.type == syms.power: + if node.children: + return ( + node.children[0].type == token.NAME + and all(map(is_simple_decorator_trailer, node.children[1:-1])) + and ( + len(node.children) < 2 + or is_simple_decorator_trailer(node.children[-1], last=True) + ) + ) + return False + + def is_yield(node: LN) -> bool: """Return True if `node` holds a `yield` or `yield from` expression.""" if node.type == syms.yield_expr: @@ -5571,24 +5665,63 @@ def ensure_visible(leaf: Leaf) -> None: leaf.value = ")" -def should_explode(line: Line, opening_bracket: Leaf) -> bool: - """Should `line` immediately be split with `delimiter_split()` after RHS?""" +def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: + """Should `line` be immediately split with `delimiter_split()` after RHS?""" - if not ( - opening_bracket.parent - and opening_bracket.parent.type in {syms.atom, syms.import_from} - and opening_bracket.value in "[{(" - ): + if not (opening_bracket.parent and opening_bracket.value in "[{("): return False + # We're essentially checking if the body is delimited by commas and there's more + # than one of them (we're excluding the trailing comma and if the delimiter priority + # is still commas, that means there's more). + exclude = set() + trailing_comma = False try: last_leaf = line.leaves[-1] - exclude = {id(last_leaf)} if last_leaf.type == token.COMMA else set() + if last_leaf.type == token.COMMA: + trailing_comma = True + exclude.add(id(last_leaf)) max_priority = line.bracket_tracker.max_delimiter_priority(exclude=exclude) except (IndexError, ValueError): return False - return max_priority == COMMA_PRIORITY + return max_priority == COMMA_PRIORITY and ( + trailing_comma + # always explode imports + or opening_bracket.parent.type in {syms.atom, syms.import_from} + ) + + +def is_one_tuple_between(opening: Leaf, closing: Leaf, leaves: List[Leaf]) -> bool: + """Return True if content between `opening` and `closing` looks like a one-tuple.""" + if opening.type != token.LPAR and closing.type != token.RPAR: + return False + + depth = closing.bracket_depth + 1 + for _opening_index, leaf in enumerate(leaves): + if leaf is opening: + break + + else: + raise LookupError("Opening paren not found in `leaves`") + + commas = 0 + _opening_index += 1 + for leaf in leaves[_opening_index:]: + if leaf is closing: + break + + bracket_depth = leaf.bracket_depth + if bracket_depth == depth and leaf.type == token.COMMA: + commas += 1 + if leaf.parent and leaf.parent.type in { + syms.arglist, + syms.typedargslist, + }: + commas += 1 + break + + return commas < 2 def get_features_used(node: Node) -> Set[Feature]: @@ -5599,6 +5732,8 @@ def get_features_used(node: Node) -> Set[Feature]: - underscores in numeric literals; - trailing commas after * or ** in function signatures and calls; - positional only arguments in function signatures and lambdas; + - assignment expression; + - relaxed decorator syntax; """ features: Set[Feature] = set() for n in node.pre_order(): @@ -5618,6 +5753,12 @@ def get_features_used(node: Node) -> Set[Feature]: elif n.type == token.COLONEQUAL: features.add(Feature.ASSIGNMENT_EXPRESSIONS) + elif n.type == syms.decorator: + if len(n.children) > 1 and not is_simple_decorator_expression( + n.children[1] + ): + features.add(Feature.RELAXED_DECORATORS) + elif ( n.type in {syms.typedargslist, syms.arglist} and n.children @@ -5655,11 +5796,13 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf a preceding closing bracket fits in one line. Yielded sets are cumulative (contain results of previous yields, too). First - set is empty. + set is empty, unless the line should explode, in which case bracket pairs until + the one that needs to explode are omitted. """ omit: Set[LeafID] = set() - yield omit + if not line.should_explode: + yield omit length = 4 * line.depth opening_bracket: Optional[Leaf] = None @@ -5678,9 +5821,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf if leaf is opening_bracket: opening_bracket = None elif leaf.type in CLOSING_BRACKETS: + prev = line.leaves[index - 1] if index > 0 else None + if ( + line.should_explode + and prev + and prev.type == token.COMMA + and not is_one_tuple_between( + leaf.opening_bracket, leaf, line.leaves + ) + ): + # Never omit bracket pairs with trailing commas. + # We need to explode on those. + break + inner_brackets.add(id(leaf)) elif leaf.type in CLOSING_BRACKETS: - if index > 0 and line.leaves[index - 1].type in OPENING_BRACKETS: + prev = line.leaves[index - 1] if index > 0 else None + if prev and prev.type in OPENING_BRACKETS: # Empty brackets would fail a split so treat them as "inner" # brackets (e.g. only add them to the `omit` set if another # pair of brackets was good enough. @@ -5693,6 +5850,16 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf inner_brackets.clear() yield omit + if ( + line.should_explode + and prev + and prev.type == token.COMMA + and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves) + ): + # Never omit bracket pairs with trailing commas. + # We need to explode on those. + break + if leaf.value: opening_bracket = leaf.opening_bracket closing_bracket = leaf @@ -5759,16 +5926,41 @@ def get_gitignore(root: Path) -> PathSpec: return PathSpec.from_lines("gitwildmatch", lines) +def normalize_path_maybe_ignore( + path: Path, root: Path, report: "Report" +) -> Optional[str]: + """Normalize `path`. May return `None` if `path` was ignored. + + `report` is where "path ignored" output goes. + """ + try: + abspath = path if path.is_absolute() else Path.cwd() / path + normalized_path = abspath.resolve().relative_to(root).as_posix() + except OSError as e: + report.path_ignored(path, f"cannot be read because {e}") + return None + + except ValueError: + if path.is_symlink(): + report.path_ignored(path, f"is a symbolic link that points outside {root}") + return None + + raise + + return normalized_path + + def gen_python_files( paths: Iterable[Path], root: Path, include: Optional[Pattern[str]], - exclude_regexes: Iterable[Pattern[str]], + exclude: Pattern[str], + force_exclude: Optional[Pattern[str]], report: "Report", gitignore: PathSpec, ) -> Iterator[Path]: """Generate all files under `path` whose paths are not excluded by the - `exclude` regex, but are included by the `include` regex. + `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex. Symbolic links pointing outside of the `root` directory are ignored. @@ -5776,43 +5968,41 @@ def gen_python_files( """ assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}" for child in paths: - # Then ignore with `exclude` option. - try: - normalized_path = child.resolve().relative_to(root).as_posix() - except OSError as e: - report.path_ignored(child, f"cannot be read because {e}") + normalized_path = normalize_path_maybe_ignore(child, root, report) + if normalized_path is None: continue - except ValueError: - if child.is_symlink(): - report.path_ignored( - child, f"is a symbolic link that points outside {root}" - ) - continue - - raise # First ignore files matching .gitignore if gitignore.match_file(normalized_path): report.path_ignored(child, "matches the .gitignore file content") continue + # Then ignore with `--exclude` and `--force-exclude` options. normalized_path = "/" + normalized_path if child.is_dir(): normalized_path += "/" - is_excluded = False - for exclude in exclude_regexes: - exclude_match = exclude.search(normalized_path) if exclude else None - if exclude_match and exclude_match.group(0): - report.path_ignored(child, "matches the --exclude regular expression") - is_excluded = True - break - if is_excluded: + exclude_match = exclude.search(normalized_path) if exclude else None + if exclude_match and exclude_match.group(0): + report.path_ignored(child, "matches the --exclude regular expression") + continue + + force_exclude_match = ( + force_exclude.search(normalized_path) if force_exclude else None + ) + if force_exclude_match and force_exclude_match.group(0): + report.path_ignored(child, "matches the --force-exclude regular expression") continue if child.is_dir(): yield from gen_python_files( - child.iterdir(), root, include, exclude_regexes, report, gitignore + child.iterdir(), + root, + include, + exclude, + force_exclude, + report, + gitignore, ) elif child.is_file(): @@ -5825,8 +6015,8 @@ def gen_python_files( def find_project_root(srcs: Iterable[str]) -> Path: """Return a directory containing .git, .hg, or pyproject.toml. - That directory can be one of the directories passed in `srcs` or their - common parent. + That directory will be a common parent of all files and directories + passed in `srcs`. If no directory in the tree contains a marker that would specify it's the project root, the root of the file system is returned. @@ -5834,11 +6024,20 @@ def find_project_root(srcs: Iterable[str]) -> Path: if not srcs: return Path("/").resolve() - common_base = min(Path(src).resolve() for src in srcs) - if common_base.is_dir(): - # Append a fake file so `parents` below returns `common_base_dir`, too. - common_base /= "fake-file" - for directory in common_base.parents: + path_srcs = [Path(Path.cwd(), src).resolve() for src in srcs] + + # A list of lists of parents for each 'src'. 'src' is included as a + # "parent" of itself if it is a directory + src_parents = [ + list(path.parents) + ([path] if path.is_dir() else []) for path in path_srcs + ] + + common_base = max( + set.intersection(*(set(parents) for parents in src_parents)), + key=lambda path: path.parts, + ) + + for directory in (common_base, *common_base.parents): if (directory / ".git").exists(): return directory @@ -6023,7 +6222,7 @@ def _stringify_ast( and field == "value" and isinstance(value, str) ): - normalized = re.sub(r" *\n[ \t]+", "\n ", value).strip() + normalized = re.sub(r" *\n[ \t]*", "\n", value).strip() else: normalized = value yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}" @@ -6067,6 +6266,7 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None: newdst = format_str(dst, mode=mode) if dst != newdst: log = dump_to_file( + str(mode), diff(src, dst, "source", "first pass"), diff(dst, newdst, "first pass", "second pass"), ) @@ -6243,7 +6443,11 @@ def can_be_split(line: Line) -> bool: return True -def can_omit_invisible_parens(line: Line, line_length: int) -> bool: +def can_omit_invisible_parens( + line: Line, + line_length: int, + omit_on_explode: Collection[LeafID] = (), +) -> bool: """Does `line` have a shape safe to reformat without optional parens around it? Returns True for only a subset of potentially nice looking formattings but @@ -6266,37 +6470,27 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: assert len(line.leaves) >= 2, "Stranded delimiter" - first = line.leaves[0] - second = line.leaves[1] - penultimate = line.leaves[-2] - last = line.leaves[-1] - # With a single delimiter, omit if the expression starts or ends with # a bracket. + first = line.leaves[0] + second = line.leaves[1] if first.type in OPENING_BRACKETS and second.type not in CLOSING_BRACKETS: - remainder = False - length = 4 * line.depth - for _index, leaf, leaf_length in enumerate_with_length(line): - if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first: - remainder = True - if remainder: - length += leaf_length - if length > line_length: - break - - if leaf.type in OPENING_BRACKETS: - # There are brackets we can further split on. - remainder = False - - else: - # checked the entire string and line length wasn't exceeded - if len(line.leaves) == _index + 1: - return True + if _can_omit_opening_paren(line, first=first, line_length=line_length): + return True # Note: we are not returning False here because a line might have *both* # a leading opening bracket and a trailing closing bracket. If the # opening bracket doesn't match our rule, maybe the closing will. + penultimate = line.leaves[-2] + last = line.leaves[-1] + if line.should_explode: + try: + penultimate, last = last_two_except(line.leaves, omit=omit_on_explode) + except LookupError: + # Turns out we'd omit everything. We cannot skip the optional parentheses. + return False + if ( last.type == token.RPAR or last.type == token.RBRACE @@ -6317,21 +6511,120 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: # unnecessary. return True - length = 4 * line.depth - seen_other_brackets = False - for _index, leaf, leaf_length in enumerate_with_length(line): + if line.should_explode and penultimate.type == token.COMMA: + # The rightmost non-omitted bracket pair is the one we want to explode on. + return True + + if _can_omit_closing_paren(line, last=last, line_length=line_length): + return True + + return False + + +def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool: + """See `can_omit_invisible_parens`.""" + remainder = False + length = 4 * line.depth + _index = -1 + for _index, leaf, leaf_length in enumerate_with_length(line): + if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first: + remainder = True + if remainder: length += leaf_length - if leaf is last.opening_bracket: - if seen_other_brackets or length <= line_length: - return True + if length > line_length: + break - elif leaf.type in OPENING_BRACKETS: + if leaf.type in OPENING_BRACKETS: # There are brackets we can further split on. - seen_other_brackets = True + remainder = False + + else: + # checked the entire string and line length wasn't exceeded + if len(line.leaves) == _index + 1: + return True + + return False + + +def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool: + """See `can_omit_invisible_parens`.""" + length = 4 * line.depth + seen_other_brackets = False + for _index, leaf, leaf_length in enumerate_with_length(line): + length += leaf_length + if leaf is last.opening_bracket: + if seen_other_brackets or length <= line_length: + return True + + elif leaf.type in OPENING_BRACKETS: + # There are brackets we can further split on. + seen_other_brackets = True return False +def last_two_except(leaves: List[Leaf], omit: Collection[LeafID]) -> Tuple[Leaf, Leaf]: + """Return (penultimate, last) leaves skipping brackets in `omit` and contents.""" + stop_after = None + last = None + for leaf in reversed(leaves): + if stop_after: + if leaf is stop_after: + stop_after = None + continue + + if last: + return leaf, last + + if id(leaf) in omit: + stop_after = leaf.opening_bracket + else: + last = leaf + else: + raise LookupError("Last two leaves were also skipped") + + +def run_transformer( + line: Line, + transform: Transformer, + mode: Mode, + features: Collection[Feature], + *, + line_str: str = "", +) -> List[Line]: + if not line_str: + line_str = line_to_string(line) + result: List[Line] = [] + for transformed_line in transform(line, features): + if str(transformed_line).strip("\n") == line_str: + raise CannotTransform("Line transformer returned an unchanged result") + + result.extend(transform_line(transformed_line, mode=mode, features=features)) + + if not ( + transform.__name__ == "rhs" + and line.bracket_tracker.invisible + and not any(bracket.value for bracket in line.bracket_tracker.invisible) + and not line.contains_multiline_strings() + and not result[0].contains_uncollapsable_type_comments() + and not result[0].contains_unsplittable_type_ignore() + and not is_line_short_enough(result[0], line_length=mode.line_length) + ): + return result + + line_copy = line.clone() + append_leaves(line_copy, line, line.leaves) + features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES} + second_opinion = run_transformer( + line_copy, transform, mode, features_fop, line_str=line_str + ) + if all( + is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion + ): + result = second_opinion + return result + + def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" @@ -6417,6 +6710,26 @@ def patched_main() -> None: main() +def is_docstring(leaf: Leaf) -> bool: + if not is_multiline_string(leaf): + # For the purposes of docstring re-indentation, we don't need to do anything + # with single-line docstrings. + return False + + if prev_siblings_are( + leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt] + ): + return True + + # Multiline docstring on the same line as the `def`. + if prev_siblings_are(leaf.parent, [syms.parameters, token.COLON, syms.simple_stmt]): + # `syms.parameters` is only used in funcdefs and async_funcdefs in the Python + # grammar. We're safe to return True without further checks. + return True + + return False + + def fix_docstring(docstring: str, prefix: str) -> str: # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation if not docstring: @@ -6440,7 +6753,6 @@ def fix_docstring(docstring: str, prefix: str) -> str: trimmed.append(prefix + stripped_line) else: trimmed.append("") - # Return a single string: return "\n".join(trimmed)