X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/bc138d12630ec38cb1b11b2fb2ca8762ac2ae7a4..82a53999ea971895e2836df39be1ec9110ec102c:/src/black/__init__.py?ds=sidebyside diff --git a/src/black/__init__.py b/src/black/__init__.py index dce7a24..015ca41 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -42,14 +42,26 @@ from typing import ( cast, TYPE_CHECKING, ) -from typing_extensions import Final from mypy_extensions import mypyc_attr from appdirs import user_cache_dir from dataclasses import dataclass, field, replace import click import toml -from typed_ast import ast3, ast27 + +try: + from typed_ast import ast3, ast27 +except ImportError: + if sys.version_info < (3, 8): + print( + "The typed_ast package is not installed.\n" + "You can install it with `python3 -m pip install typed-ast`.", + file=sys.stderr, + ) + sys.exit(1) + else: + ast3 = ast27 = ast + from pathspec import PathSpec # lib2to3 fork @@ -61,13 +73,19 @@ from blib2to3.pgen2.parse import ParseError from _black_version import version as __version__ +if sys.version_info < (3, 8): + from typing_extensions import Final +else: + from typing import Final + if TYPE_CHECKING: import colorama # noqa: F401 DEFAULT_LINE_LENGTH = 88 -DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 +DEFAULT_EXCLUDES = r"/(\.direnv|\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) +STDIN_PLACEHOLDER = "__BLACK_STDIN_FILENAME__" STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters. @@ -88,7 +106,7 @@ Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]] Timestamp = float FileSize = int CacheInfo = Tuple[Timestamp, FileSize] -Cache = Dict[Path, CacheInfo] +Cache = Dict[str, CacheInfo] out = partial(click.secho, bold=True, err=True) err = partial(click.secho, fg="red", err=True) @@ -255,8 +273,9 @@ class Mode: target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True - experimental_string_processing: bool = False is_pyi: bool = False + magic_trailing_comma: bool = True + experimental_string_processing: bool = False def get_cache_key(self) -> str: if self.target_versions: @@ -271,6 +290,8 @@ class Mode: str(self.line_length), str(int(self.string_normalization)), str(int(self.is_pyi)), + str(int(self.magic_trailing_comma)), + str(int(self.experimental_string_processing)), ] return ".".join(parts) @@ -283,11 +304,15 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) -def find_pyproject_toml(path_search_start: Iterable[str]) -> Optional[str]: +def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]: """Find the absolute filepath to a pyproject.toml if it exists""" path_project_root = find_project_root(path_search_start) path_pyproject_toml = path_project_root / "pyproject.toml" - return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + if path_pyproject_toml.is_file(): + return str(path_pyproject_toml) + + path_user_pyproject_toml = find_user_pyproject_toml() + return str(path_user_pyproject_toml) if path_user_pyproject_toml.is_file() else None def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: @@ -357,6 +382,17 @@ def target_version_option_callback( return [TargetVersion[val.upper()] for val in v] +def validate_regex( + ctx: click.Context, + param: click.Parameter, + value: Optional[str], +) -> Optional[Pattern]: + try: + return re_compile_maybe_verbose(value) if value is not None else None + except re.error: + raise click.BadParameter("Not a valid regular expression") + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -392,6 +428,12 @@ def target_version_option_callback( is_flag=True, help="Don't normalize string quotes or prefixes.", ) +@click.option( + "-C", + "--skip-magic-trailing-comma", + is_flag=True, + help="Don't use trailing commas as a reason to split lines.", +) @click.option( "--experimental-string-processing", is_flag=True, @@ -405,8 +447,8 @@ def target_version_option_callback( "--check", is_flag=True, help=( - "Don't write the files back, just return the status. Return code 0 means" - " nothing would change. Return code 1 means some files would be reformatted." + "Don't write the files back, just return the status. Return code 0 means" + " nothing would change. Return code 1 means some files would be reformatted." " Return code 123 means there was an internal error." ), ) @@ -429,11 +471,12 @@ def target_version_option_callback( "--include", type=str, default=DEFAULT_INCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " included on recursive searches. An empty value means all files are included" - " regardless of the name. Use forward slashes for directories on all platforms" - " (Windows, too). Exclusions are calculated first, inclusions later." + " included on recursive searches. An empty value means all files are included" + " regardless of the name. Use forward slashes for directories on all platforms" + " (Windows, too). Exclusions are calculated first, inclusions later." ), show_default=True, ) @@ -441,22 +484,42 @@ def target_version_option_callback( "--exclude", type=str, default=DEFAULT_EXCLUDES, + callback=validate_regex, help=( "A regular expression that matches files and directories that should be" - " excluded on recursive searches. An empty value means no paths are excluded." - " Use forward slashes for directories on all platforms (Windows, too). " + " excluded on recursive searches. An empty value means no paths are excluded." + " Use forward slashes for directories on all platforms (Windows, too)." " Exclusions are calculated first, inclusions later." ), show_default=True, ) +@click.option( + "--extend-exclude", + type=str, + callback=validate_regex, + help=( + "Like --exclude, but adds additional files and directories on top of the" + " excluded ones. (Useful if you simply want to add to the default)" + ), +) @click.option( "--force-exclude", type=str, + callback=validate_regex, help=( "Like --exclude, but files and directories matching this regex will be " "excluded even when they are passed explicitly as arguments." ), ) +@click.option( + "--stdin-filename", + type=str, + help=( + "The name of the file when passing it through stdin. Useful to make " + "sure Black will respect --force-exclude option on some " + "editors that rely on using stdin." + ), +) @click.option( "-q", "--quiet", @@ -472,7 +535,7 @@ def target_version_option_callback( is_flag=True, help=( "Also emit messages to stderr about files that were not changed or were ignored" - " due to --exclude=." + " due to exclusion patterns." ), ) @click.version_option(version=__version__) @@ -510,12 +573,15 @@ def main( fast: bool, pyi: bool, skip_string_normalization: bool, + skip_magic_trailing_comma: bool, experimental_string_processing: bool, quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern, + exclude: Pattern, + extend_exclude: Optional[Pattern], + force_exclude: Optional[Pattern], + stdin_filename: Optional[str], src: Tuple[str, ...], config: Optional[str], ) -> None: @@ -531,6 +597,7 @@ def main( line_length=line_length, is_pyi=pyi, string_normalization=not skip_string_normalization, + magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, ) if config and verbose: @@ -546,8 +613,10 @@ def main( verbose=verbose, include=include, exclude=exclude, + extend_exclude=extend_exclude, force_exclude=force_exclude, report=report, + stdin_filename=stdin_filename, ) path_empty( @@ -583,29 +652,14 @@ def get_sources( src: Tuple[str, ...], quiet: bool, verbose: bool, - include: str, - exclude: str, - force_exclude: Optional[str], + include: Pattern[str], + exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], + force_exclude: Optional[Pattern[str]], report: "Report", + stdin_filename: Optional[str], ) -> Set[Path]: """Compute the set of files to be formatted.""" - try: - include_regex = re_compile_maybe_verbose(include) - except re.error: - err(f"Invalid regular expression for include given: {include!r}") - ctx.exit(2) - try: - exclude_regex = re_compile_maybe_verbose(exclude) - except re.error: - err(f"Invalid regular expression for exclude given: {exclude!r}") - ctx.exit(2) - try: - force_exclude_regex = ( - re_compile_maybe_verbose(force_exclude) if force_exclude else None - ) - except re.error: - err(f"Invalid regular expression for force_exclude given: {force_exclude!r}") - ctx.exit(2) root = find_project_root(src) sources: Set[Path] = set() @@ -613,36 +667,46 @@ def get_sources( gitignore = get_gitignore(root) for s in src: - p = Path(s) - if p.is_dir(): - sources.update( - gen_python_files( - p.iterdir(), - root, - include_regex, - exclude_regex, - force_exclude_regex, - report, - gitignore, - ) - ) - elif s == "-": - sources.add(p) - elif p.is_file(): + if s == "-" and stdin_filename: + p = Path(stdin_filename) + is_stdin = True + else: + p = Path(s) + is_stdin = False + + if is_stdin or p.is_file(): normalized_path = normalize_path_maybe_ignore(p, root, report) if normalized_path is None: continue normalized_path = "/" + normalized_path # Hard-exclude any files that matches the `--force-exclude` regex. - if force_exclude_regex: - force_exclude_match = force_exclude_regex.search(normalized_path) + if force_exclude: + force_exclude_match = force_exclude.search(normalized_path) else: force_exclude_match = None if force_exclude_match and force_exclude_match.group(0): report.path_ignored(p, "matches the --force-exclude regular expression") continue + if is_stdin: + p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") + + sources.add(p) + elif p.is_dir(): + sources.update( + gen_python_files( + p.iterdir(), + root, + include, + exclude, + extend_exclude, + force_exclude, + report, + gitignore, + ) + ) + elif s == "-": sources.add(p) else: err(f"invalid path: {s}") @@ -670,7 +734,18 @@ def reformat_one( """ try: changed = Changed.NO - if not src.is_file() and str(src) == "-": + + if str(src) == "-": + is_stdin = True + elif str(src).startswith(STDIN_PLACEHOLDER): + is_stdin = True + # Use the original name again in case we want to print something + # to the user + src = Path(str(src)[len(STDIN_PLACEHOLDER) :]) + else: + is_stdin = False + + if is_stdin: if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): changed = Changed.YES else: @@ -678,7 +753,8 @@ def reformat_one( if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF): cache = read_cache(mode) res_src = src.resolve() - if res_src in cache and cache[res_src] == get_cache_info(res_src): + res_src_s = str(res_src) + if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src): changed = Changed.CACHED if changed is not Changed.CACHED and format_file_in_place( src, fast=fast, write_back=write_back, mode=mode @@ -704,13 +780,13 @@ def reformat_many( worker_count = os.cpu_count() if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 - worker_count = min(worker_count, 61) + worker_count = min(worker_count, 60) try: executor = ProcessPoolExecutor(max_workers=worker_count) except (ImportError, OSError): # we arrive here if the underlying system does not support multi-processing # like in AWS Lambda or Termux, in which case we gracefully fallback to - # a ThreadPollExecutor with just a single worker (more workers would not do us + # a ThreadPoolExecutor with just a single worker (more workers would not do us # any good due to the Global Interpreter Lock) executor = ThreadPoolExecutor(max_workers=1) @@ -773,7 +849,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable["asyncio.Future[bool]"] = tasks.keys() + pending = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -836,7 +912,7 @@ def format_file_in_place( dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) - if write_back == write_back.COLOR_DIFF: + if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) with lock or nullcontext(): @@ -859,9 +935,9 @@ def color_diff(contents: str) -> str: for i, line in enumerate(lines): if line.startswith("+++") or line.startswith("---"): line = "\033[1;37m" + line + "\033[0m" # bold white, reset - if line.startswith("@@"): + elif line.startswith("@@"): line = "\033[36m" + line + "\033[0m" # cyan, reset - if line.startswith("+"): + elif line.startswith("+"): line = "\033[32m" + line + "\033[0m" # green, reset elif line.startswith("-"): line = "\033[31m" + line + "\033[0m" # red, reset @@ -871,30 +947,22 @@ def color_diff(contents: str) -> str: def wrap_stream_for_windows( f: io.TextIOWrapper, -) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]: +) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32"]: """ - Wrap the stream in colorama's wrap_stream so colors are shown on Windows. + Wrap stream with colorama's wrap_stream so colors are shown on Windows. - If `colorama` is not found, then no change is made. If `colorama` does - exist, then it handles the logic to determine whether or not to change - things. + If `colorama` is unavailable, the original stream is returned unmodified. + Otherwise, the `wrap_stream()` function determines whether the stream needs + to be wrapped for a Windows environment and will accordingly either return + an `AnsiToWin32` wrapper or the original stream. """ try: - from colorama import initialise - - # We set `strip=False` so that we can don't have to modify - # test_express_diff_with_color. - f = initialise.wrap_stream( - f, convert=None, strip=False, autoreset=False, wrap=True - ) - - # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object - # which does not have a `detach()` method. So we fake one. - f.detach = lambda *args, **kwargs: None # type: ignore + from colorama.initialise import wrap_stream except ImportError: - pass - - return f + return f + else: + # Set `strip=False` to avoid needing to modify test_express_diff_with_color. + return wrap_stream(f, convert=None, strip=False, autoreset=False, wrap=True) def format_stdin_to_stdout( @@ -950,7 +1018,17 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo if not fast: assert_equivalent(src_contents, dst_contents) - assert_stable(src_contents, dst_contents, mode=mode) + + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + dst_contents_pass2 = format_str(dst_contents, mode=mode) + if dst_contents != dst_contents_pass2: + dst_contents = dst_contents_pass2 + assert_equivalent(src_contents, dst_contents, pass_num=2) + assert_stable(src_contents, dst_contents, mode=mode) + # Note: no need to explicitly call `assert_stable` if `dst_contents` was + # the same as `dst_contents_pass2`. return dst_contents @@ -961,7 +1039,7 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: allowed. Example: >>> import black - >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + >>> print(black.format_str("def f(arg:str='')->None:...", mode=black.Mode())) def f(arg: str = "") -> None: ... @@ -993,13 +1071,12 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: versions = detect_target_versions(src_node) normalize_fmt_off(src_node) lines = LineGenerator( + mode=mode, remove_u_prefix="unicode_literals" in future_imports or supports_feature(versions, Feature.UNICODE_LITERALS), - is_pyi=mode.is_pyi, - normalize_strings=mode.string_normalization, ) elt = EmptyLineTracker(is_pyi=mode.is_pyi) - empty_line = Line() + empty_line = Line(mode=mode) after = 0 split_line_features = { feature @@ -1435,13 +1512,15 @@ class BracketTracker: class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" + mode: Mode depth: int = 0 leaves: List[Leaf] = field(default_factory=list) # keys ordered like `leaves` comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False - should_explode: bool = False + should_split_rhs: bool = False + magic_trailing_comma: Optional[Leaf] = None def append(self, leaf: Leaf, preformatted: bool = False) -> None: """Add a new `leaf` to the end of the line. @@ -1467,8 +1546,11 @@ class Line: ) if self.inside_brackets or not preformatted: self.bracket_tracker.mark(leaf) - if self.maybe_should_explode(leaf): - self.should_explode = True + if self.mode.magic_trailing_comma: + if self.has_magic_trailing_comma(leaf): + self.magic_trailing_comma = leaf + elif self.has_magic_trailing_comma(leaf, ensure_removable=True): + self.remove_trailing_comma() if not self.append_comment(leaf): self.leaves.append(leaf) @@ -1644,10 +1726,14 @@ class Line: def contains_multiline_strings(self) -> bool: return any(is_multiline_string(leaf) for leaf in self.leaves) - def maybe_should_explode(self, closing: Leaf) -> bool: - """Return True if this line should explode (always be split), that is when: - - there's a trailing comma here; and - - it's not a one-tuple. + def has_magic_trailing_comma( + self, closing: Leaf, ensure_removable: bool = False + ) -> bool: + """Return True if we have a magic trailing comma, that is when: + - there's a trailing comma here + - it's not a one-tuple + Additionally, if ensure_removable: + - it's not from square bracket indexing """ if not ( closing.type in CLOSING_BRACKETS @@ -1656,9 +1742,15 @@ class Line: ): return False - if closing.type in {token.RBRACE, token.RSQB}: + if closing.type == token.RBRACE: return True + if closing.type == token.RSQB: + if not ensure_removable: + return True + comma = self.leaves[-1] + return bool(comma.parent and comma.parent.type == syms.listmaker) + if self.is_import: return True @@ -1736,9 +1828,11 @@ class Line: def clone(self) -> "Line": return Line( + mode=self.mode, depth=self.depth, inside_brackets=self.inside_brackets, - should_explode=self.should_explode, + should_split_rhs=self.should_split_rhs, + magic_trailing_comma=self.magic_trailing_comma, ) def __str__(self) -> str: @@ -1894,10 +1988,9 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ - is_pyi: bool = False - normalize_strings: bool = True - current_line: Line = field(default_factory=Line) + mode: Mode remove_u_prefix: bool = False + current_line: Line = field(init=False) def line(self, indent: int = 0) -> Iterator[Line]: """Generate a line. @@ -1912,7 +2005,7 @@ class LineGenerator(Visitor[Line]): return # Line is empty, don't emit. Creating a new one unnecessary. complete_line = self.current_line - self.current_line = Line(depth=complete_line.depth + indent) + self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) yield complete_line def visit_default(self, node: LN) -> Iterator[Line]: @@ -1936,7 +2029,7 @@ class LineGenerator(Visitor[Line]): yield from self.line() normalize_prefix(node, inside_brackets=any_open_brackets) - if self.normalize_strings and node.type == token.STRING: + if self.mode.string_normalization and node.type == token.STRING: normalize_string_prefix(node, remove_u_prefix=self.remove_u_prefix) normalize_string_quotes(node) if node.type == token.NUMBER: @@ -1988,16 +2081,18 @@ class LineGenerator(Visitor[Line]): def visit_suite(self, node: Node) -> Iterator[Line]: """Visit a suite.""" - if self.is_pyi and is_stub_suite(node): + if self.mode.is_pyi and is_stub_suite(node): yield from self.visit(node.children[2]) else: yield from self.visit_default(node) def visit_simple_stmt(self, node: Node) -> Iterator[Line]: """Visit a statement without nested statements.""" + if first_child_is_arith(node): + wrap_in_parentheses(node, node.children[0], visible=False) is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - if self.is_pyi and is_stub_body(node): + if self.mode.is_pyi and is_stub_body(node): yield from self.visit_default(node) else: yield from self.line(+1) @@ -2005,7 +2100,11 @@ class LineGenerator(Visitor[Line]): yield from self.line(-1) else: - if not self.is_pyi or not node.parent or not is_stub_suite(node.parent): + if ( + not self.mode.is_pyi + or not node.parent + or not is_stub_suite(node.parent) + ): yield from self.line() yield from self.visit_default(node) @@ -2066,21 +2165,48 @@ class LineGenerator(Visitor[Line]): # We're ignoring docstrings with backslash newline escapes because changing # indentation of those changes the AST representation of the code. prefix = get_string_prefix(leaf.value) - lead_len = len(prefix) + 3 - tail_len = -3 - indent = " " * 4 * self.current_line.depth - docstring = fix_docstring(leaf.value[lead_len:tail_len], indent) + docstring = leaf.value[len(prefix) :] # Remove the prefix + quote_char = docstring[0] + # A natural way to remove the outer quotes is to do: + # docstring = docstring.strip(quote_char) + # but that breaks on """""x""" (which is '""x'). + # So we actually need to remove the first character and the next two + # characters but only if they are the same as the first. + quote_len = 1 if docstring[1] != quote_char else 3 + docstring = docstring[quote_len:-quote_len] + + if is_multiline_string(leaf): + indent = " " * 4 * self.current_line.depth + docstring = fix_docstring(docstring, indent) + else: + docstring = docstring.strip() + if docstring: - if leaf.value[lead_len - 1] == docstring[0]: + # Add some padding if the docstring starts / ends with a quote mark. + if docstring[0] == quote_char: docstring = " " + docstring - if leaf.value[tail_len + 1] == docstring[-1]: - docstring = docstring + " " - leaf.value = leaf.value[0:lead_len] + docstring + leaf.value[tail_len:] + if docstring[-1] == quote_char: + docstring += " " + if docstring[-1] == "\\": + backslash_count = len(docstring) - len(docstring.rstrip("\\")) + if backslash_count % 2: + # Odd number of tailing backslashes, add some padding to + # avoid escaping the closing string quote. + docstring += " " + else: + # Add some padding if the docstring is empty. + docstring = " " + + # We could enforce triple quotes at this point. + quote = quote_char * quote_len + leaf.value = prefix + quote + docstring + quote yield from self.visit_default(leaf) def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" + self.current_line = Line(mode=self.mode) + v = self.visit_stmt Ø: Set[str] = set() self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) @@ -2523,6 +2649,8 @@ def is_split_before_delimiter(leaf: Leaf, previous: Optional[Leaf] = None) -> Pr FMT_OFF = {"# fmt: off", "# fmt:off", "# yapf: disable"} +FMT_SKIP = {"# fmt: skip", "# fmt:skip"} +FMT_PASS = {*FMT_OFF, *FMT_SKIP} FMT_ON = {"# fmt: on", "# fmt:on", "# yapf: enable"} @@ -2577,7 +2705,7 @@ def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]: consumed = 0 nlines = 0 ignored_lines = 0 - for index, line in enumerate(prefix.split("\n")): + for index, line in enumerate(re.split("\r?\n", prefix)): consumed += len(line) + 1 # adding the length of the split '\n' line = line.lstrip() if not line: @@ -2618,6 +2746,13 @@ def make_comment(content: str) -> str: if content[0] == "#": content = content[1:] + NON_BREAKING_SPACE = " " + if ( + content + and content[0] == NON_BREAKING_SPACE + and not content.lstrip().startswith("type:") + ): + content = " " + content[1:] # Replace NBSP by a simple space if content and content[0] not in " !:#'%": content = " " + content return "#" + content @@ -2650,7 +2785,8 @@ def transform_line( transformers: List[Transformer] if ( not line.contains_uncollapsable_type_comments() - and not line.should_explode + and not line.should_split_rhs + and not line.magic_trailing_comma and ( is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() @@ -3246,7 +3382,8 @@ class StringParenStripper(StringTransformer): Requirements: The line contains a string which is surrounded by parentheses and: - - The target string is NOT the only argument to a function call). + - The target string is NOT the only argument to a function call. + - The target string is NOT a "pointless" string. - If the target string contains a PERCENT, the brackets are not preceeded or followed by an operator with higher precedence than PERCENT. @@ -3270,6 +3407,14 @@ class StringParenStripper(StringTransformer): if leaf.type != token.STRING: continue + # If this is a "pointless" string... + if ( + leaf.parent + and leaf.parent.parent + and leaf.parent.parent.type == syms.simple_stmt + ): + continue + # Should be preceded by a non-empty LPAR... if ( not is_valid_index(idx - 1) @@ -3598,7 +3743,8 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): MIN_SUBSTR_SIZE characters. The string will ONLY be split on spaces (i.e. each new substring should - start with a space). + start with a space). Note that the string will NOT be split on a space + which is escaped with a backslash. If the string is an f-string, it will NOT be split in the middle of an f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x @@ -3618,13 +3764,14 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): MIN_SUBSTR_SIZE = 6 # Matches an "f-expression" (e.g. {var}) that might be found in an f-string. RE_FEXPR = r""" - (? TMatchResult: @@ -3938,11 +4085,23 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): section of this classes' docstring would be be met by returning @i. """ is_space = string[i] == " " + + is_not_escaped = True + j = i - 1 + while is_valid_index(j) and string[j] == "\\": + is_not_escaped = not is_not_escaped + j -= 1 + is_big_enough = ( len(string[i:]) >= self.MIN_SUBSTR_SIZE and len(string[:i]) >= self.MIN_SUBSTR_SIZE ) - return is_space and is_big_enough and not breaks_fstring_expression(i) + return ( + is_space + and is_not_escaped + and is_big_enough + and not breaks_fstring_expression(i) + ) # First, we check all indices BELOW @max_break_idx. break_idx = max_break_idx @@ -4298,9 +4457,11 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): # `StringSplitter` will break it down further if necessary. string_value = LL[string_idx].value string_line = Line( + mode=line.mode, depth=line.depth + 1, inside_brackets=True, - should_explode=line.should_explode, + should_split_rhs=line.should_split_rhs, + magic_trailing_comma=line.magic_trailing_comma, ) string_leaf = Leaf(token.STRING, string_value) insert_str_child(string_leaf) @@ -4891,7 +5052,7 @@ def bracket_split_build_line( If `is_body` is True, the result line is one-indented inside brackets and as such has its first leaf's prefix normalized and a trailing comma added when expected. """ - result = Line(depth=original.depth) + result = Line(mode=original.mode, depth=original.depth) if is_body: result.inside_brackets = True result.depth += 1 @@ -4921,8 +5082,8 @@ def bracket_split_build_line( result.append(leaf, preformatted=True) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) - if is_body and should_split_body_explode(result, opening_bracket): - result.should_explode = True + if is_body and should_split_line(result, opening_bracket): + result.should_split_rhs = True return result @@ -4963,7 +5124,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ if bt.delimiter_count_with_priority(delimiter_priority) == 1: raise CannotSplit("Splitting a single attribute from its owner looks wrong") - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) lowest_depth = sys.maxsize trailing_comma_safe = True @@ -4975,7 +5138,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ except ValueError: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) current_line.append(leaf) for leaf in line.leaves: @@ -4999,7 +5164,9 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ if leaf_priority == delimiter_priority: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) if current_line: if ( trailing_comma_safe @@ -5020,7 +5187,9 @@ def standalone_comment_split( if not line.contains_standalone_comments(0): raise CannotSplit("Line does not have any standalone comments") - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) def append_to_line(leaf: Leaf) -> Iterator[Line]: """Append `leaf` to current line or to new line if appending impossible.""" @@ -5030,7 +5199,9 @@ def standalone_comment_split( except ValueError: yield current_line - current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets) + current_line = Line( + line.mode, depth=line.depth, inside_brackets=line.inside_brackets + ) current_line.append(leaf) for leaf in line.leaves: @@ -5186,31 +5357,47 @@ def normalize_numeric_literal(leaf: Leaf) -> None: # Leave octal and binary literals alone. pass elif text.startswith("0x"): - # Change hex literals to upper case. - before, after = text[:2], text[2:] - text = f"{before}{after.upper()}" + text = format_hex(text) elif "e" in text: - before, after = text.split("e") - sign = "" - if after.startswith("-"): - after = after[1:] - sign = "-" - elif after.startswith("+"): - after = after[1:] - before = format_float_or_int_string(before) - text = f"{before}e{sign}{after}" + text = format_scientific_notation(text) elif text.endswith(("j", "l")): - number = text[:-1] - suffix = text[-1] - # Capitalize in "2L" because "l" looks too similar to "1". - if suffix == "l": - suffix = "L" - text = f"{format_float_or_int_string(number)}{suffix}" + text = format_long_or_complex_number(text) else: text = format_float_or_int_string(text) leaf.value = text +def format_hex(text: str) -> str: + """ + Formats a hexadecimal string like "0x12B3" + """ + before, after = text[:2], text[2:] + return f"{before}{after.upper()}" + + +def format_scientific_notation(text: str) -> str: + """Formats a numeric string utilizing scentific notation""" + before, after = text.split("e") + sign = "" + if after.startswith("-"): + after = after[1:] + sign = "-" + elif after.startswith("+"): + after = after[1:] + before = format_float_or_int_string(before) + return f"{before}e{sign}{after}" + + +def format_long_or_complex_number(text: str) -> str: + """Formats a long or complex string like `10L` or `10j`""" + number = text[:-1] + suffix = text[-1] + # Capitalize in "2L" because "l" looks too similar to "1". + if suffix == "l": + suffix = "L" + return f"{format_float_or_int_string(number)}{suffix}" + + def format_float_or_int_string(text: str) -> str: """Formats a float string like "1.0".""" if "." not in text: @@ -5249,10 +5436,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: check_lpar = True if check_lpar: - if is_walrus_assignment(child): - pass - - elif child.type == syms.atom: + if child.type == syms.atom: if maybe_make_parens_invisible_in_atom(child, parent=node): wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): @@ -5291,58 +5475,80 @@ def convert_one_fmt_off_pair(node: Node) -> bool: for leaf in node.leaves(): previous_consumed = 0 for comment in list_comments(leaf.prefix, is_endmarker=False): - if comment.value in FMT_OFF: - # We only want standalone comments. If there's no previous leaf or - # the previous leaf is indentation, it's a standalone comment in - # disguise. - if comment.type != STANDALONE_COMMENT: - prev = preceding_leaf(leaf) - if prev and prev.type not in WHITESPACE: + if comment.value not in FMT_PASS: + previous_consumed = comment.consumed + continue + # We only want standalone comments. If there's no previous leaf or + # the previous leaf is indentation, it's a standalone comment in + # disguise. + if comment.value in FMT_PASS and comment.type != STANDALONE_COMMENT: + prev = preceding_leaf(leaf) + if prev: + if comment.value in FMT_OFF and prev.type not in WHITESPACE: + continue + if comment.value in FMT_SKIP and prev.type in WHITESPACE: continue - ignored_nodes = list(generate_ignored_nodes(leaf)) - if not ignored_nodes: - continue - - first = ignored_nodes[0] # Can be a container node with the `leaf`. - parent = first.parent - prefix = first.prefix - first.prefix = prefix[comment.consumed :] - hidden_value = ( - comment.value + "\n" + "".join(str(n) for n in ignored_nodes) - ) - if hidden_value.endswith("\n"): - # That happens when one of the `ignored_nodes` ended with a NEWLINE - # leaf (possibly followed by a DEDENT). - hidden_value = hidden_value[:-1] - first_idx: Optional[int] = None - for ignored in ignored_nodes: - index = ignored.remove() - if first_idx is None: - first_idx = index - assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)" - assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)" - parent.insert_child( - first_idx, - Leaf( - STANDALONE_COMMENT, - hidden_value, - prefix=prefix[:previous_consumed] + "\n" * comment.newlines, - ), - ) - return True + ignored_nodes = list(generate_ignored_nodes(leaf, comment)) + if not ignored_nodes: + continue - previous_consumed = comment.consumed + first = ignored_nodes[0] # Can be a container node with the `leaf`. + parent = first.parent + prefix = first.prefix + first.prefix = prefix[comment.consumed :] + hidden_value = "".join(str(n) for n in ignored_nodes) + if comment.value in FMT_OFF: + hidden_value = comment.value + "\n" + hidden_value + if comment.value in FMT_SKIP: + hidden_value += " " + comment.value + if hidden_value.endswith("\n"): + # That happens when one of the `ignored_nodes` ended with a NEWLINE + # leaf (possibly followed by a DEDENT). + hidden_value = hidden_value[:-1] + first_idx: Optional[int] = None + for ignored in ignored_nodes: + index = ignored.remove() + if first_idx is None: + first_idx = index + assert parent is not None, "INTERNAL ERROR: fmt: on/off handling (1)" + assert first_idx is not None, "INTERNAL ERROR: fmt: on/off handling (2)" + parent.insert_child( + first_idx, + Leaf( + STANDALONE_COMMENT, + hidden_value, + prefix=prefix[:previous_consumed] + "\n" * comment.newlines, + ), + ) + return True return False -def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: +def generate_ignored_nodes(leaf: Leaf, comment: ProtoComment) -> Iterator[LN]: """Starting from the container of `leaf`, generate all leaves until `# fmt: on`. + If comment is skip, returns leaf only. Stops at the end of the block. """ container: Optional[LN] = container_of(leaf) + if comment.value in FMT_SKIP: + prev_sibling = leaf.prev_sibling + if comment.value in leaf.prefix and prev_sibling is not None: + leaf.prefix = leaf.prefix.replace(comment.value, "") + siblings = [prev_sibling] + while ( + "\n" not in prev_sibling.prefix + and prev_sibling.prev_sibling is not None + ): + prev_sibling = prev_sibling.prev_sibling + siblings.insert(0, prev_sibling) + for sibling in siblings: + yield sibling + elif leaf.parent is not None: + yield leaf.parent + return while container is not None and container.type != token.ENDMARKER: if is_fmt_on(container): return @@ -5402,6 +5608,7 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: Returns whether the node should itself be wrapped in invisible parentheses. """ + if ( node.type != syms.atom or is_empty_tuple(node) @@ -5411,6 +5618,18 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: ): return False + if is_walrus_assignment(node): + if parent.type in [ + syms.annassign, + syms.expr_stmt, + syms.assert_stmt, + syms.return_stmt, + # these ones aren't useful to end users, but they do please fuzzers + syms.for_stmt, + syms.del_stmt, + ]: + return False + first = node.children[0] last = node.children[-1] if first.type == token.LPAR and last.type == token.RPAR: @@ -5472,6 +5691,17 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: return wrapped +def first_child_is_arith(node: Node) -> bool: + """Whether first child is an arithmetic or a binary arithmetic expression""" + expr_types = { + syms.arith_expr, + syms.shift_expr, + syms.xor_expr, + syms.and_expr, + } + return bool(node.children and node.children[0].type in expr_types) + + def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: """Wrap `child` in parentheses. @@ -5673,7 +5903,7 @@ def ensure_visible(leaf: Leaf) -> None: leaf.value = ")" -def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: +def should_split_line(line: Line, opening_bracket: Leaf) -> bool: """Should `line` be immediately split with `delimiter_split()` after RHS?""" if not (opening_bracket.parent and opening_bracket.value in "[{("): @@ -5694,7 +5924,7 @@ def should_split_body_explode(line: Line, opening_bracket: Leaf) -> bool: return False return max_priority == COMMA_PRIORITY and ( - trailing_comma + (line.mode.magic_trailing_comma and trailing_comma) # always explode imports or opening_bracket.parent.type in {syms.atom, syms.import_from} ) @@ -5809,7 +6039,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf """ omit: Set[LeafID] = set() - if not line.should_explode: + if not line.magic_trailing_comma: yield omit length = 4 * line.depth @@ -5831,8 +6061,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf elif leaf.type in CLOSING_BRACKETS: prev = line.leaves[index - 1] if index > 0 else None if ( - line.should_explode - and prev + prev and prev.type == token.COMMA and not is_one_tuple_between( leaf.opening_bracket, leaf, line.leaves @@ -5859,8 +6088,7 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit if ( - line.should_explode - and prev + prev and prev.type == token.COMMA and not is_one_tuple_between(leaf.opening_bracket, leaf, line.leaves) ): @@ -5925,7 +6153,7 @@ def get_future_imports(node: Node) -> Set[str]: @lru_cache() def get_gitignore(root: Path) -> PathSpec: - """ Return a PathSpec matching gitignore content if present.""" + """Return a PathSpec matching gitignore content if present.""" gitignore = root / ".gitignore" lines: List[str] = [] if gitignore.is_file(): @@ -5958,17 +6186,27 @@ def normalize_path_maybe_ignore( return normalized_path +def path_is_excluded( + normalized_path: str, + pattern: Optional[Pattern[str]], +) -> bool: + match = pattern.search(normalized_path) if pattern else None + return bool(match and match.group(0)) + + def gen_python_files( paths: Iterable[Path], root: Path, include: Optional[Pattern[str]], exclude: Pattern[str], + extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], report: "Report", gitignore: PathSpec, ) -> Iterator[Path]: """Generate all files under `path` whose paths are not excluded by the - `exclude_regex` or `force_exclude` regexes, but are included by the `include` regex. + `exclude_regex`, `extend_exclude`, or `force_exclude` regexes, + but are included by the `include` regex. Symbolic links pointing outside of the `root` directory are ignored. @@ -5985,20 +6223,22 @@ def gen_python_files( report.path_ignored(child, "matches the .gitignore file content") continue - # Then ignore with `--exclude` and `--force-exclude` options. + # Then ignore with `--exclude` `--extend-exclude` and `--force-exclude` options. normalized_path = "/" + normalized_path if child.is_dir(): normalized_path += "/" - exclude_match = exclude.search(normalized_path) if exclude else None - if exclude_match and exclude_match.group(0): + if path_is_excluded(normalized_path, exclude): report.path_ignored(child, "matches the --exclude regular expression") continue - force_exclude_match = ( - force_exclude.search(normalized_path) if force_exclude else None - ) - if force_exclude_match and force_exclude_match.group(0): + if path_is_excluded(normalized_path, extend_exclude): + report.path_ignored( + child, "matches the --extend-exclude regular expression" + ) + continue + + if path_is_excluded(normalized_path, force_exclude): report.path_ignored(child, "matches the --force-exclude regular expression") continue @@ -6008,6 +6248,7 @@ def gen_python_files( root, include, exclude, + extend_exclude, force_exclude, report, gitignore, @@ -6020,7 +6261,7 @@ def gen_python_files( @lru_cache() -def find_project_root(srcs: Iterable[str]) -> Path: +def find_project_root(srcs: Tuple[str, ...]) -> Path: """Return a directory containing .git, .hg, or pyproject.toml. That directory will be a common parent of all files and directories @@ -6058,6 +6299,22 @@ def find_project_root(srcs: Iterable[str]) -> Path: return directory +@lru_cache() +def find_user_pyproject_toml() -> Path: + r"""Return the path to the top-level user configuration for black. + + This looks for ~\.black on Windows and ~/.config/black on Linux and other + Unix systems. + """ + if sys.platform == "win32": + # Windows + user_config_path = Path.home() / ".black" + else: + config_root = os.environ.get("XDG_CONFIG_HOME", "~/.config") + user_config_path = Path(config_root).expanduser() / "black" + return user_config_path.resolve() + + @dataclass class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" @@ -6159,7 +6416,12 @@ def parse_ast(src: str) -> Union[ast.AST, ast3.AST, ast27.AST]: return ast3.parse(src, filename, feature_version=feature_version) except SyntaxError: continue - + if ast27.__name__ == "ast": + raise SyntaxError( + "The requested source code has invalid Python 3 syntax.\n" + "If you are trying to format Python 2 files please reinstall Black" + " with the 'python2' extra: `python3 -m pip install black[python2]`." + ) return ast27.parse(src) @@ -6225,12 +6487,22 @@ def _stringify_ast( # Constant strings may be indented across newlines, if they are # docstrings; fold spaces after newlines when comparing. Similarly, # trailing and leading space may be removed. + # Note that when formatting Python 2 code, at least with Windows + # line-endings, docstrings can end up here as bytes instead of + # str so make sure that we handle both cases. if ( isinstance(node, ast.Constant) and field == "value" - and isinstance(value, str) + and isinstance(value, (str, bytes)) ): - normalized = re.sub(r" *\n[ \t]*", "\n", value).strip() + lineend = "\n" if isinstance(value, str) else b"\n" + # To normalize, we strip any leading and trailing space from + # each line... + stripped = [line.strip() for line in value.splitlines()] + normalized = lineend.join(stripped) # type: ignore[attr-defined] + # ...and remove any blank lines at the beginning and end of + # the whole string + normalized = normalized.strip() else: normalized = value yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}" @@ -6238,7 +6510,7 @@ def _stringify_ast( yield f"{' ' * depth}) # /{node.__class__.__name__}" -def assert_equivalent(src: str, dst: str) -> None: +def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" try: src_ast = parse_ast(src) @@ -6253,9 +6525,9 @@ def assert_equivalent(src: str, dst: str) -> None: except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( - f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug" - " on https://github.com/psf/black/issues. This invalid output might be" - f" helpful: {log}" + f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. " + "Please report a bug on https://github.com/psf/black/issues. " + f"This invalid output might be helpful: {log}" ) from None src_ast_str = "\n".join(_stringify_ast(src_ast)) @@ -6264,8 +6536,8 @@ def assert_equivalent(src: str, dst: str) -> None: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( "INTERNAL ERROR: Black produced code that is not equivalent to the" - " source. Please report a bug on https://github.com/psf/black/issues. " - f" This diff might be helpful: {log}" + f" source on pass {pass_num}. Please report a bug on " + f"https://github.com/psf/black/issues. This diff might be helpful: {log}" ) from None @@ -6286,14 +6558,14 @@ def assert_stable(src: str, dst: str, mode: Mode) -> None: @mypyc_attr(patchable=True) -def dump_to_file(*output: str) -> str: +def dump_to_file(*output: str, ensure_final_newline: bool = True) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( mode="w", prefix="blk_", suffix=".log", delete=False, encoding="utf8" ) as f: for lines in output: f.write(lines) - if lines and lines[-1] != "\n": + if ensure_final_newline and lines and lines[-1] != "\n": f.write("\n") return f.name @@ -6311,11 +6583,20 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.splitlines()] - b_lines = [line + "\n" for line in b.splitlines()] - return "".join( - difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) - ) + a_lines = [line for line in a.splitlines(keepends=True)] + b_lines = [line for line in b.splitlines(keepends=True)] + diff_lines = [] + for line in difflib.unified_diff( + a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5 + ): + # Work around https://bugs.python.org/issue2142 + # See https://www.gnu.org/software/diffutils/manual/html_node/Incomplete-Lines.html + if line[-1] == "\n": + diff_lines.append(line) + else: + diff_lines.append(line + "\n") + diff_lines.append("\\ No newline at end of file\n") + return "".join(diff_lines) def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: @@ -6492,7 +6773,7 @@ def can_omit_invisible_parens( penultimate = line.leaves[-2] last = line.leaves[-1] - if line.should_explode: + if line.magic_trailing_comma: try: penultimate, last = last_two_except(line.leaves, omit=omit_on_explode) except LookupError: @@ -6519,7 +6800,7 @@ def can_omit_invisible_parens( # unnecessary. return True - if line.should_explode and penultimate.type == token.COMMA: + if line.magic_trailing_comma and penultimate.type == token.COMMA: # The rightmost non-omitted bracket pair is the one we want to explode on. return True @@ -6669,8 +6950,8 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set """ todo, done = set(), set() for src in sources: - src = src.resolve() - if cache.get(src) != get_cache_info(src): + res_src = src.resolve() + if cache.get(str(res_src)) != get_cache_info(res_src): todo.add(src) else: done.add(src) @@ -6682,7 +6963,10 @@ def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: cache_file = get_cache_file(mode) try: CACHE_DIR.mkdir(parents=True, exist_ok=True) - new_cache = {**cache, **{src.resolve(): get_cache_info(src) for src in sources}} + new_cache = { + **cache, + **{str(src.resolve()): get_cache_info(src) for src in sources}, + } with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f: pickle.dump(new_cache, f, protocol=4) os.replace(f.name, cache_file) @@ -6719,11 +7003,6 @@ def patched_main() -> None: def is_docstring(leaf: Leaf) -> bool: - if not is_multiline_string(leaf): - # For the purposes of docstring re-indentation, we don't need to do anything - # with single-line docstrings. - return False - if prev_siblings_are( leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt] ): @@ -6738,13 +7017,33 @@ def is_docstring(leaf: Leaf) -> bool: return False +def lines_with_leading_tabs_expanded(s: str) -> List[str]: + """ + Splits string into lines and expands only leading tabs (following the normal + Python rules) + """ + lines = [] + for line in s.splitlines(): + # Find the index of the first non-whitespace character after a string of + # whitespace that includes at least one tab + match = re.match(r"\s*\t+\s*(\S)", line) + if match: + first_non_whitespace_idx = match.start(1) + + lines.append( + line[:first_non_whitespace_idx].expandtabs() + + line[first_non_whitespace_idx:] + ) + else: + lines.append(line) + return lines + + def fix_docstring(docstring: str, prefix: str) -> str: # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation if not docstring: return "" - # Convert tabs to spaces (following the normal Python rules) - # and split into a list of lines: - lines = docstring.expandtabs().splitlines() + lines = lines_with_leading_tabs_expanded(docstring) # Determine minimum indentation (first line doesn't count): indent = sys.maxsize for line in lines[1:]: