X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/58fd7885cd7ec3da248a1ce4d82a4e8d3dfd5af9..7f4b27541330e2ff35b1a010c9de0c5f618f9f4c:/black.py diff --git a/black.py b/black.py index f6f687b..3ab4bc7 100644 --- a/black.py +++ b/black.py @@ -1,6 +1,8 @@ import ast import asyncio -from concurrent.futures import Executor, ProcessPoolExecutor +from abc import ABC, abstractmethod +from collections import defaultdict +from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from contextlib import contextmanager from datetime import datetime from enum import Enum @@ -32,14 +34,19 @@ from typing import ( Pattern, Sequence, Set, + Sized, Tuple, + Type, TypeVar, Union, cast, + TYPE_CHECKING, ) +from typing_extensions import Final +from mypy_extensions import mypyc_attr from appdirs import user_cache_dir -from attr import dataclass, evolve, Factory +from dataclasses import dataclass, field, replace import click import toml from typed_ast import ast3, ast27 @@ -54,11 +61,16 @@ from blib2to3.pgen2.parse import ParseError from _black_version import version as __version__ +if TYPE_CHECKING: + import colorama # noqa: F401 + DEFAULT_LINE_LENGTH = 88 DEFAULT_EXCLUDES = r"/(\.eggs|\.git|\.hg|\.mypy_cache|\.nox|\.tox|\.venv|\.svn|_build|buck-out|build|dist)/" # noqa: B950 DEFAULT_INCLUDES = r"\.pyi?$" CACHE_DIR = Path(user_cache_dir("black", version=__version__)) +STRING_PREFIX_CHARS: Final = "furbFURB" # All possible string prefix characters. + # types FileContent = str @@ -66,11 +78,13 @@ Encoding = str NewLine = str Depth = int NodeType = int +ParserState = int LeafID = int +StringID = int Priority = int Index = int LN = Union[Leaf, Node] -SplitFunc = Callable[["Line", Collection["Feature"]], Iterator["Line"]] +Transformer = Callable[["Line", Collection["Feature"]], Iterator["Line"]] Timestamp = float FileSize = int CacheInfo = Tuple[Timestamp, FileSize] @@ -86,7 +100,11 @@ class NothingChanged(UserWarning): """Raised when reformatted code is the same as source.""" -class CannotSplit(Exception): +class CannotTransform(Exception): + """Base class for errors raised by Transformers.""" + + +class CannotSplit(CannotTransform): """A readable split that fits the allotted line length is impossible.""" @@ -94,17 +112,51 @@ class InvalidInput(ValueError): """Raised when input source code fails all parse attempts.""" +T = TypeVar("T") +E = TypeVar("E", bound=Exception) + + +class Ok(Generic[T]): + def __init__(self, value: T) -> None: + self._value = value + + def ok(self) -> T: + return self._value + + +class Err(Generic[E]): + def __init__(self, e: E) -> None: + self._e = e + + def err(self) -> E: + return self._e + + +# The 'Result' return type is used to implement an error-handling model heavily +# influenced by that used by the Rust programming language +# (see https://doc.rust-lang.org/book/ch09-00-error-handling.html). +Result = Union[Ok[T], Err[E]] +TResult = Result[T, CannotTransform] # (T)ransform Result +TMatchResult = TResult[Index] + + class WriteBack(Enum): NO = 0 YES = 1 DIFF = 2 CHECK = 3 + COLOR_DIFF = 4 @classmethod - def from_configuration(cls, *, check: bool, diff: bool) -> "WriteBack": + def from_configuration( + cls, *, check: bool, diff: bool, color: bool = False + ) -> "WriteBack": if check and not diff: return cls.CHECK + if diff and color: + return cls.COLOR_DIFF + return cls.DIFF if diff else cls.YES @@ -184,8 +236,8 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @dataclass -class FileMode: - target_versions: Set[TargetVersion] = Factory(set) +class Mode: + target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True is_pyi: bool = False @@ -207,30 +259,46 @@ class FileMode: return ".".join(parts) +# Legacy name, left for integrations. +FileMode = Mode + + def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool: return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( - ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] + ctx: click.Context, param: click.Parameter, value: Optional[str] ) -> Optional[str]: """Inject Black configuration from "pyproject.toml" into defaults in `ctx`. Returns the path to a successfully found and read configuration file, None otherwise. """ - assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -239,14 +307,32 @@ def read_pyproject_toml( if not config: return None - if ctx.default_map is None: - ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + target_version = config.get("target_version") + if target_version is not None and not isinstance(target_version, list): + raise click.BadOptionUsage( + "target-version", f"Config key target-version must be a list" + ) + + default_map: Dict[str, Any] = {} + if ctx.default_map: + default_map.update(ctx.default_map) + default_map.update(config) + + ctx.default_map = default_map return value +def target_version_option_callback( + c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] +) -> List[TargetVersion]: + """Compute the target versions from a --target-version flag. + + This is its own function because mypy couldn't infer the type correctly + when it was a lambda, causing mypyc trouble. + """ + return [TargetVersion[val.upper()] for val in v] + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -261,29 +347,19 @@ def read_pyproject_toml( "-t", "--target-version", type=click.Choice([v.name.lower() for v in TargetVersion]), - callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v], + callback=target_version_option_callback, multiple=True, help=( - "Python versions that should be supported by Black's output. [default: " - "per-file auto-detection]" - ), -) -@click.option( - "--py36", - is_flag=True, - help=( - "Allow using Python 3.6-only syntax on all input files. This will put " - "trailing commas in function signatures and calls also after *args and " - "**kwargs. Deprecated; use --target-version instead. " - "[default: per-file auto-detection]" + "Python versions that should be supported by Black's output. [default: per-file" + " auto-detection]" ), ) @click.option( "--pyi", is_flag=True, help=( - "Format all input files like typing stubs regardless of file extension " - "(useful when piping source on standard input)." + "Format all input files like typing stubs regardless of file extension (useful" + " when piping source on standard input)." ), ) @click.option( @@ -296,9 +372,9 @@ def read_pyproject_toml( "--check", is_flag=True, help=( - "Don't write the files back, just return the status. Return code 0 " - "means nothing would change. Return code 1 means some files would be " - "reformatted. Return code 123 means there was an internal error." + "Don't write the files back, just return the status. Return code 0 means" + " nothing would change. Return code 1 means some files would be reformatted." + " Return code 123 means there was an internal error." ), ) @click.option( @@ -306,6 +382,11 @@ def read_pyproject_toml( is_flag=True, help="Don't write the files back, just output a diff for each file on stdout.", ) +@click.option( + "--color/--no-color", + is_flag=True, + help="Show colored diff. Only applies when `--diff` is given.", +) @click.option( "--fast/--safe", is_flag=True, @@ -316,11 +397,10 @@ def read_pyproject_toml( type=str, default=DEFAULT_INCLUDES, help=( - "A regular expression that matches files and directories that should be " - "included on recursive searches. An empty value means all files are " - "included regardless of the name. Use forward slashes for directories on " - "all platforms (Windows, too). Exclusions are calculated first, inclusions " - "later." + "A regular expression that matches files and directories that should be" + " included on recursive searches. An empty value means all files are included" + " regardless of the name. Use forward slashes for directories on all platforms" + " (Windows, too). Exclusions are calculated first, inclusions later." ), show_default=True, ) @@ -329,20 +409,28 @@ def read_pyproject_toml( type=str, default=DEFAULT_EXCLUDES, help=( - "A regular expression that matches files and directories that should be " - "excluded on recursive searches. An empty value means no paths are excluded. " - "Use forward slashes for directories on all platforms (Windows, too). " - "Exclusions are calculated first, inclusions later." + "A regular expression that matches files and directories that should be" + " excluded on recursive searches. An empty value means no paths are excluded." + " Use forward slashes for directories on all platforms (Windows, too). " + " Exclusions are calculated first, inclusions later." ), show_default=True, ) +@click.option( + "--force-exclude", + type=str, + help=( + "Like --exclude, but files and directories matching this regex will be " + "excluded even when they are passed explicitly as arguments" + ), +) @click.option( "-q", "--quiet", is_flag=True, help=( - "Don't emit non-error messages to stderr. Errors are still emitted; " - "silence those with 2>/dev/null." + "Don't emit non-error messages to stderr. Errors are still emitted; silence" + " those with 2>/dev/null." ), ) @click.option( @@ -350,8 +438,8 @@ def read_pyproject_toml( "--verbose", is_flag=True, help=( - "Also emit messages to stderr about files that were not changed or were " - "ignored due to --exclude=." + "Also emit messages to stderr about files that were not changed or were ignored" + " due to --exclude=." ), ) @click.version_option(version=__version__) @@ -366,7 +454,12 @@ def read_pyproject_toml( @click.option( "--config", type=click.Path( - exists=False, file_okay=True, dir_okay=False, readable=True, allow_dash=False + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + allow_dash=False, + path_type=str, ), is_eager=True, callback=read_pyproject_toml, @@ -380,35 +473,26 @@ def main( target_version: List[TargetVersion], check: bool, diff: bool, + color: bool, fast: bool, pyi: bool, - py36: bool, skip_string_normalization: bool, quiet: bool, verbose: bool, include: str, exclude: str, - src: Tuple[str], + force_exclude: Optional[str], + src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" - write_back = WriteBack.from_configuration(check=check, diff=diff) + write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) if target_version: - if py36: - err(f"Cannot use both --target-version and --py36") - ctx.exit(2) - else: - versions = set(target_version) - elif py36: - err( - "--py36 is deprecated and will be removed in a future version. " - "Use --target-version py36 instead." - ) - versions = PY36_VERSIONS + versions = set(target_version) else: # We'll autodetect later. versions = set() - mode = FileMode( + mode = Mode( target_versions=versions, line_length=line_length, is_pyi=pyi, @@ -419,6 +503,57 @@ def main( if code is not None: print(format_str(code, mode=mode)) ctx.exit(0) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) + sources = get_sources( + ctx=ctx, + src=src, + quiet=quiet, + verbose=verbose, + include=include, + exclude=exclude, + force_exclude=force_exclude, + report=report, + ) + + path_empty( + sources, + "No Python files are present to be formatted. Nothing to do 😴", + quiet, + verbose, + ctx, + ) + + if len(sources) == 1: + reformat_one( + src=sources.pop(), + fast=fast, + write_back=write_back, + mode=mode, + report=report, + ) + else: + reformat_many( + sources=sources, fast=fast, write_back=write_back, mode=mode, report=report + ) + + if verbose or not quiet: + out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") + click.secho(str(report), err=True) + ctx.exit(report.return_code) + + +def get_sources( + *, + ctx: click.Context, + src: Tuple[str, ...], + quiet: bool, + verbose: bool, + include: str, + exclude: str, + force_exclude: Optional[str], + report: "Report", +) -> Set[Path]: + """Compute the set of files to be formatted.""" try: include_regex = re_compile_maybe_verbose(include) except re.error: @@ -429,59 +564,61 @@ def main( except re.error: err(f"Invalid regular expression for exclude given: {exclude!r}") ctx.exit(2) - report = Report(check=check, quiet=quiet, verbose=verbose) + try: + force_exclude_regex = ( + re_compile_maybe_verbose(force_exclude) if force_exclude else None + ) + except re.error: + err(f"Invalid regular expression for force_exclude given: {force_exclude!r}") + ctx.exit(2) + root = find_project_root(src) sources: Set[Path] = set() - path_empty(src, quiet, verbose, ctx) + path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx) + exclude_regexes = [exclude_regex] + if force_exclude_regex is not None: + exclude_regexes.append(force_exclude_regex) + for s in src: p = Path(s) if p.is_dir(): sources.update( - gen_python_files_in_dir( - p, root, include_regex, exclude_regex, report, get_gitignore(root) + gen_python_files( + p.iterdir(), + root, + include_regex, + exclude_regexes, + report, + get_gitignore(root), ) ) - elif p.is_file() or s == "-": - # if a file was explicitly given, we don't care about its extension + elif s == "-": sources.add(p) + elif p.is_file(): + sources.update( + gen_python_files( + [p], root, None, exclude_regexes, report, get_gitignore(root) + ) + ) else: err(f"invalid path: {s}") - if len(sources) == 0: - if verbose or not quiet: - out("No Python files are present to be formatted. Nothing to do 😴") - ctx.exit(0) - - if len(sources) == 1: - reformat_one( - src=sources.pop(), - fast=fast, - write_back=write_back, - mode=mode, - report=report, - ) - else: - reformat_many( - sources=sources, fast=fast, write_back=write_back, mode=mode, report=report - ) - - if verbose or not quiet: - out("Oh no! 💥 💔 💥" if report.return_code else "All done! ✨ 🍰 ✨") - click.secho(str(report), err=True) - ctx.exit(report.return_code) + return sources -def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None: +def path_empty( + src: Sized, msg: str, quiet: bool, verbose: bool, ctx: click.Context +) -> None: """ Exit if there is no `src` provided for formatting """ - if not src: + if len(src) == 0: if verbose or not quiet: - out("No Path provided. Nothing to do 😴") + out(msg) ctx.exit(0) def reformat_one( - src: Path, fast: bool, write_back: WriteBack, mode: FileMode, report: "Report" + src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat a single file under `src` without spawning child processes. @@ -514,19 +651,24 @@ def reformat_one( def reformat_many( - sources: Set[Path], - fast: bool, - write_back: WriteBack, - mode: FileMode, - report: "Report", + sources: Set[Path], fast: bool, write_back: WriteBack, mode: Mode, report: "Report" ) -> None: """Reformat multiple files using a ProcessPoolExecutor.""" + executor: Executor loop = asyncio.get_event_loop() worker_count = os.cpu_count() if sys.platform == "win32": # Work around https://bugs.python.org/issue26903 worker_count = min(worker_count, 61) - executor = ProcessPoolExecutor(max_workers=worker_count) + try: + executor = ProcessPoolExecutor(max_workers=worker_count) + except OSError: + # we arrive here if the underlying system does not support multi-processing + # like in AWS Lambda, in which case we gracefully fallback to + # a ThreadPollExecutor with just a single worker (more workers would not do us + # any good due to the Global Interpreter Lock) + executor = ThreadPoolExecutor(max_workers=1) + try: loop.run_until_complete( schedule_formatting( @@ -541,14 +683,15 @@ def reformat_many( ) finally: shutdown(loop) - executor.shutdown() + if executor is not None: + executor.shutdown() async def schedule_formatting( sources: Set[Path], fast: bool, write_back: WriteBack, - mode: FileMode, + mode: Mode, report: "Report", loop: asyncio.AbstractEventLoop, executor: Executor, @@ -585,7 +728,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable[asyncio.Future] = tasks.keys() + pending: Iterable["asyncio.Future[bool]"] = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -618,7 +761,7 @@ async def schedule_formatting( def format_file_in_place( src: Path, fast: bool, - mode: FileMode, + mode: Mode, write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy ) -> bool: @@ -629,7 +772,7 @@ def format_file_in_place( `mode` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": - mode = evolve(mode, is_pyi=True) + mode = replace(mode, is_pyi=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -639,15 +782,18 @@ def format_file_in_place( except NothingChanged: return False - if write_back == write_back.YES: + if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) - elif write_back == write_back.DIFF: + elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if write_back == write_back.COLOR_DIFF: + diff_contents = color_diff(diff_contents) + with lock or nullcontext(): f = io.TextIOWrapper( sys.stdout.buffer, @@ -655,14 +801,59 @@ def format_file_in_place( newline=newline, write_through=True, ) + f = wrap_stream_for_windows(f) f.write(diff_contents) f.detach() return True +def color_diff(contents: str) -> str: + """Inject the ANSI color codes to the diff.""" + lines = contents.split("\n") + for i, line in enumerate(lines): + if line.startswith("+++") or line.startswith("---"): + line = "\033[1;37m" + line + "\033[0m" # bold white, reset + if line.startswith("@@"): + line = "\033[36m" + line + "\033[0m" # cyan, reset + if line.startswith("+"): + line = "\033[32m" + line + "\033[0m" # green, reset + elif line.startswith("-"): + line = "\033[31m" + line + "\033[0m" # red, reset + lines[i] = line + return "\n".join(lines) + + +def wrap_stream_for_windows( + f: io.TextIOWrapper, +) -> Union[io.TextIOWrapper, "colorama.AnsiToWin32.AnsiToWin32"]: + """ + Wrap the stream in colorama's wrap_stream so colors are shown on Windows. + + If `colorama` is not found, then no change is made. If `colorama` does + exist, then it handles the logic to determine whether or not to change + things. + """ + try: + from colorama import initialise + + # We set `strip=False` so that we can don't have to modify + # test_express_diff_with_color. + f = initialise.wrap_stream( + f, convert=None, strip=False, autoreset=False, wrap=True + ) + + # wrap_stream returns a `colorama.AnsiToWin32.AnsiToWin32` object + # which does not have a `detach()` method. So we fake one. + f.detach = lambda *args, **kwargs: None # type: ignore + except ImportError: + pass + + return f + + def format_stdin_to_stdout( - fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: FileMode + fast: bool, *, write_back: WriteBack = WriteBack.NO, mode: Mode ) -> bool: """Format file on stdin. Return True if changed. @@ -686,17 +877,19 @@ def format_stdin_to_stdout( ) if write_back == WriteBack.YES: f.write(dst) - elif write_back == WriteBack.DIFF: + elif write_back in (WriteBack.DIFF, WriteBack.COLOR_DIFF): now = datetime.utcnow() src_name = f"STDIN\t{then} +0000" dst_name = f"STDOUT\t{now} +0000" - f.write(diff(src, dst, src_name, dst_name)) + d = diff(src, dst, src_name, dst_name) + if write_back == WriteBack.COLOR_DIFF: + d = color_diff(d) + f = wrap_stream_for_windows(f) + f.write(d) f.detach() -def format_file_contents( - src_contents: str, *, fast: bool, mode: FileMode -) -> FileContent: +def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents a file and return new contents. If `fast` is False, additionally confirm that the reformatted code is @@ -716,11 +909,34 @@ def format_file_contents( return dst_contents -def format_str(src_contents: str, *, mode: FileMode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are - allowed. + allowed. Example: + + >>> import black + >>> print(black.format_str("def f(arg:str='')->None:...", mode=Mode())) + def f(arg: str = "") -> None: + ... + + A more complex example: + >>> print( + ... black.format_str( + ... "def f(arg:str='')->None: hey", + ... mode=black.Mode( + ... target_versions={black.TargetVersion.PY36}, + ... line_length=10, + ... string_normalization=False, + ... is_pyi=False, + ... ), + ... ), + ... ) + def f( + arg: str = '', + ) -> None: + hey + """ src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] @@ -745,13 +961,14 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: if supports_feature(versions, feature) } for current_line in lines.visit(src_node): - for _ in range(after): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) - for _ in range(before): - dst_contents.append(str(empty_line)) - for line in split_line( - current_line, line_length=mode.line_length, features=split_line_features + dst_contents.append(str(empty_line) * before) + for line in transform_line( + current_line, + line_length=mode.line_length, + normalize_strings=mode.string_normalization, + features=split_line_features, ): dst_contents.append(str(line)) return "".join(dst_contents) @@ -787,7 +1004,8 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: # Python 2.7 pygram.python_grammar, ] - elif all(version.is_python2() for version in target_versions): + + if all(version.is_python2() for version in target_versions): # Python 2-only code, so try Python 2 grammars. return [ # Python 2.7 with future print_function import @@ -795,21 +1013,21 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: # Python 2.7 pygram.python_grammar, ] - else: - # Python 3-compatible code, so only try Python 3 grammar. - grammars = [] - # If we have to parse both, try to parse async as a keyword first - if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS): - # Python 3.7+ - grammars.append( - pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords # noqa: B950 - ) - if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS): - # Python 3.0-3.6 - grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement) - # At least one of the above branches must have been taken, because every Python - # version has exactly one of the two 'ASYNC_*' flags - return grammars + + # Python 3-compatible code, so only try Python 3 grammar. + grammars = [] + # If we have to parse both, try to parse async as a keyword first + if not supports_feature(target_versions, Feature.ASYNC_IDENTIFIERS): + # Python 3.7+ + grammars.append( + pygram.python_grammar_no_print_statement_no_exec_statement_async_keywords + ) + if not supports_feature(target_versions, Feature.ASYNC_KEYWORDS): + # Python 3.0-3.6 + grammars.append(pygram.python_grammar_no_print_statement_no_exec_statement) + # At least one of the above branches must have been taken, because every Python + # version has exactly one of the two 'ASYNC_*' flags + return grammars def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: @@ -845,9 +1063,6 @@ def lib2to3_unparse(node: Node) -> str: return code -T = TypeVar("T") - - class Visitor(Generic[T]): """Basic lib2to3 visitor that yields things of type `T` on `visit()`.""" @@ -864,8 +1079,16 @@ class Visitor(Generic[T]): if node.type < 256: name = token.tok_name[node.type] else: - name = type_repr(node.type) - yield from getattr(self, f"visit_{name}", self.visit_default)(node) + name = str(type_repr(node.type)) + # We explicitly branch on whether a visitor exists (instead of + # using self.visit_default as the default arg to getattr) in order + # to save needing to create a bound method object and so mypyc can + # generate a native call to visit_default. + visitf = getattr(self, f"visit_{name}", None) + if visitf: + yield from visitf(node) + else: + yield from self.visit_default(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -910,8 +1133,8 @@ class DebugVisitor(Visitor[T]): list(v.visit(code)) -WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -STATEMENT = { +WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE} +STATEMENT: Final = { syms.if_stmt, syms.while_stmt, syms.for_stmt, @@ -921,10 +1144,10 @@ STATEMENT = { syms.funcdef, syms.classdef, } -STANDALONE_COMMENT = 153 +STANDALONE_COMMENT: Final = 153 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT" -LOGIC_OPERATORS = {"and", "or"} -COMPARATORS = { +LOGIC_OPERATORS: Final = {"and", "or"} +COMPARATORS: Final = { token.LESS, token.GREATER, token.EQEQUAL, @@ -932,7 +1155,7 @@ COMPARATORS = { token.LESSEQUAL, token.GREATEREQUAL, } -MATH_OPERATORS = { +MATH_OPERATORS: Final = { token.VBAR, token.CIRCUMFLEX, token.AMPER, @@ -948,23 +1171,23 @@ MATH_OPERATORS = { token.TILDE, token.DOUBLESTAR, } -STARS = {token.STAR, token.DOUBLESTAR} -VARARGS_SPECIALS = STARS | {token.SLASH} -VARARGS_PARENTS = { +STARS: Final = {token.STAR, token.DOUBLESTAR} +VARARGS_SPECIALS: Final = STARS | {token.SLASH} +VARARGS_PARENTS: Final = { syms.arglist, syms.argument, # double star in arglist syms.trailer, # single argument to call syms.typedargslist, syms.varargslist, # lambdas } -UNPACKING_PARENTS = { +UNPACKING_PARENTS: Final = { syms.atom, # single element of a list or set literal syms.dictsetmaker, syms.listmaker, syms.testlist_gexp, syms.testlist_star_expr, } -TEST_DESCENDANTS = { +TEST_DESCENDANTS: Final = { syms.test, syms.lambdef, syms.or_test, @@ -981,7 +1204,7 @@ TEST_DESCENDANTS = { syms.term, syms.power, } -ASSIGNMENTS = { +ASSIGNMENTS: Final = { "=", "+=", "-=", @@ -997,13 +1220,13 @@ ASSIGNMENTS = { "**=", "//=", } -COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 18 -TERNARY_PRIORITY = 16 -LOGIC_PRIORITY = 14 -STRING_PRIORITY = 12 -COMPARATOR_PRIORITY = 10 -MATH_PRIORITIES = { +COMPREHENSION_PRIORITY: Final = 20 +COMMA_PRIORITY: Final = 18 +TERNARY_PRIORITY: Final = 16 +LOGIC_PRIORITY: Final = 14 +STRING_PRIORITY: Final = 12 +COMPARATOR_PRIORITY: Final = 10 +MATH_PRIORITIES: Final = { token.VBAR: 9, token.CIRCUMFLEX: 8, token.AMPER: 7, @@ -1019,7 +1242,7 @@ MATH_PRIORITIES = { token.TILDE: 3, token.DOUBLESTAR: 2, } -DOT_PRIORITY = 1 +DOT_PRIORITY: Final = 1 @dataclass @@ -1027,11 +1250,11 @@ class BracketTracker: """Keeps track of brackets on a line.""" depth: int = 0 - bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) - delimiters: Dict[LeafID, Priority] = Factory(dict) + bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) + delimiters: Dict[LeafID, Priority] = field(default_factory=dict) previous: Optional[Leaf] = None - _for_loop_depths: List[int] = Factory(list) - _lambda_argument_depths: List[int] = Factory(list) + _for_loop_depths: List[int] = field(default_factory=list) + _lambda_argument_depths: List[int] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -1159,9 +1382,10 @@ class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" depth: int = 0 - leaves: List[Leaf] = Factory(list) - comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves` - bracket_tracker: BracketTracker = Factory(BracketTracker) + leaves: List[Leaf] = field(default_factory=list) + # keys ordered like `leaves` + comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) + bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False should_explode: bool = False @@ -1249,6 +1473,7 @@ class Line: """ if not self.leaves or len(self.leaves) < 4: return False + # Look for and address a trailing colon. if self.leaves[-1].type == token.COLON: closer = self.leaves[-2] @@ -1258,6 +1483,7 @@ class Line: close_index = -1 if closer.type not in CLOSING_BRACKETS or self.inside_brackets: return False + if closer.type == token.RPAR: # Tuples require an extra check, because if there's only # one element in the tuple removing the comma unmakes the @@ -1272,9 +1498,11 @@ class Line: for _open_index, leaf in enumerate(self.leaves): if leaf is opener: break + else: # Couldn't find the matching opening paren, play it safe. return False + commas = 0 comma_depth = self.leaves[close_index - 1].bracket_depth for leaf in self.leaves[_open_index + 1 : close_index]: @@ -1284,16 +1512,20 @@ class Line: # We haven't looked yet for the trailing comma because # we might also have caught noop parens. return self.leaves[close_index - 1].type == token.COMMA + elif commas == 1: return False # it's either a one-tuple or didn't have a trailing comma + if self.leaves[close_index - 1].type in CLOSING_BRACKETS: close_index -= 1 closer = self.leaves[close_index] if closer.type == token.RPAR: # TODO: this is a gut feeling. Will we ever see this? return False + if self.leaves[close_index - 1].type != token.COMMA: return False + return True @property @@ -1343,9 +1575,9 @@ class Line: def contains_standalone_comments(self, depth_limit: int = sys.maxsize) -> bool: """If so, needs to be split before emitting.""" for leaf in self.leaves: - if leaf.type == STANDALONE_COMMENT: - if leaf.bracket_depth <= depth_limit: - return True + if leaf.type == STANDALONE_COMMENT and leaf.bracket_depth <= depth_limit: + return True + return False def contains_uncollapsable_type_comments(self) -> bool: @@ -1374,7 +1606,10 @@ class Line: for leaf_id, comments in self.comments.items(): for comment in comments: if is_type_comment(comment): - if leaf_id not in ignored_ids or comment_seen: + if comment_seen or ( + not is_type_comment(comment, " ignore") + and leaf_id not in ignored_ids + ): return True comment_seen = True @@ -1413,16 +1648,13 @@ class Line: return False def contains_multiline_strings(self) -> bool: - for leaf in self.leaves: - if is_multiline_string(leaf): - return True - - return False + return any(is_multiline_string(leaf) for leaf in self.leaves) def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" if not (self.leaves and self.leaves[-1].type == token.COMMA): return False + # We remove trailing commas only in the case of importing a # single name from a module. if not ( @@ -1487,6 +1719,7 @@ class Line: comment.type = STANDALONE_COMMENT comment.prefix = "" return False + last_leaf = self.leaves[-2] self.comments.setdefault(id(last_leaf), []).append(comment) return True @@ -1521,6 +1754,13 @@ class Line: n.type in TEST_DESCENDANTS for n in subscript_start.pre_order() ) + def clone(self) -> "Line": + return Line( + depth=self.depth, + inside_brackets=self.inside_brackets, + should_explode=self.should_explode, + ) + def __str__(self) -> str: """Render the line.""" if not self: @@ -1534,6 +1774,7 @@ class Line: res += str(leaf) for comment in itertools.chain.from_iterable(self.comments.values()): res += str(comment) + return res + "\n" def __bool__(self) -> bool: @@ -1554,7 +1795,7 @@ class EmptyLineTracker: is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 - previous_defs: List[int] = Factory(list) + previous_defs: List[int] = field(default_factory=list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: """Return the number of extra empty lines before and after the `current_line`. @@ -1668,7 +1909,7 @@ class LineGenerator(Visitor[Line]): is_pyi: bool = False normalize_strings: bool = True - current_line: Line = Factory(Line) + current_line: Line = field(default_factory=Line) remove_u_prefix: bool = False def line(self, indent: int = 0) -> Iterator[Line]: @@ -1717,26 +1958,13 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) - def visit_factor(self, node: Node) -> Iterator[Line]: - """Force parentheses between a unary op and a binary power: - - -2 ** 8 -> -(2 ** 8) - """ - child = node.children[1] - if child.type == syms.power and len(child.children) == 3: - lpar = Leaf(token.LPAR, "(") - rpar = Leaf(token.RPAR, ")") - index = child.remove() or 0 - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) - yield from self.visit_default(node) - - def visit_INDENT(self, node: Node) -> Iterator[Line]: + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) - def visit_DEDENT(self, node: Node) -> Iterator[Line]: + def visit_DEDENT(self, node: Leaf) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" # The current line might still wait for trailing comments. At DEDENT time # there won't be any (they would be prefixes on the preceding NEWLINE). @@ -1829,7 +2057,36 @@ class LineGenerator(Visitor[Line]): yield from self.line() yield from self.visit_default(leaf) - def __attrs_post_init__(self) -> None: + def visit_factor(self, node: Node) -> Iterator[Line]: + """Force parentheses between a unary op and a binary power: + + -2 ** 8 -> -(2 ** 8) + """ + _operator, operand = node.children + if ( + operand.type == syms.power + and len(operand.children) == 3 + and operand.children[1].type == token.DOUBLESTAR + ): + lpar = Leaf(token.LPAR, "(") + rpar = Leaf(token.RPAR, ")") + index = operand.remove() or 0 + node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) + yield from self.visit_default(node) + + def visit_STRING(self, leaf: Leaf) -> Iterator[Line]: + # Check if it's a docstring + if prev_siblings_are( + leaf.parent, [None, token.NEWLINE, token.INDENT, syms.simple_stmt] + ) and is_multiline_string(leaf): + prefix = " " * self.current_line.depth + docstring = fix_docstring(leaf.value[3:-3], prefix) + leaf.value = leaf.value[0:3] + docstring + leaf.value[-3:] + normalize_string_quotes(leaf) + + yield from self.visit_default(leaf) + + def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt Ø: Set[str] = set() @@ -2110,6 +2367,22 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None +def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool: + """Return if the `node` and its previous siblings match types against the provided + list of tokens; the provided `node`has its type matched against the last element in + the list. `None` can be used as the first element to declare that the start of the + list is anchored at the start of its parent's children.""" + if not tokens: + return True + if tokens[-1] is None: + return node is None + if not node: + return False + if node.type != tokens[-1]: + return False + return prev_siblings_are(node.prev_sibling, tokens[:-1]) + + def child_towards(ancestor: Node, descendant: LN) -> Optional[LN]: """Return the child of `ancestor` that contains `descendant`.""" node: Optional[LN] = descendant @@ -2354,18 +2627,15 @@ def make_comment(content: str) -> str: return "#" + content -def split_line( +def transform_line( line: Line, line_length: int, - inner: bool = False, + normalize_strings: bool, features: Collection[Feature] = (), ) -> Iterator[Line]: - """Split a `line` into potentially many lines. + """Transform a `line`, potentially splitting it into many lines. They should fit in the allotted `line_length` but might not be able to. - `inner` signifies that there were a pair of brackets somewhere around the - current `line`, possibly transitively. This means we can fallback to splitting - by delimiters if the LHS/RHS don't yield any results. `features` are syntactical features that may be used in the output. """ @@ -2373,8 +2643,18 @@ def split_line( yield line return - line_str = str(line).strip("\n") + line_str = line_to_string(line) + def init_st(ST: Type[StringTransformer]) -> StringTransformer: + """Initialize StringTransformer""" + return ST(line_length, normalize_strings) + + string_merge = init_st(StringMerger) + string_paren_strip = init_st(StringParenStripper) + string_split = init_st(StringSplitter) + string_paren_wrap = init_st(StringParenWrapper) + + transformers: List[Transformer] if ( not line.contains_uncollapsable_type_comments() and not line.should_explode @@ -2383,13 +2663,12 @@ def split_line( is_line_short_enough(line, line_length=line_length, line_str=line_str) or line.contains_unsplittable_type_ignore() ) + and not (line.contains_standalone_comments() and line.inside_brackets) ): - yield line - return - - split_funcs: List[SplitFunc] - if line.is_def: - split_funcs = [left_hand_split] + # Only apply basic string preprocessing, since lines shouldn't be split here. + transformers = [string_merge, string_paren_strip] + elif line.is_def: + transformers = [left_hand_split] else: def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]: @@ -2402,30 +2681,51 @@ def split_line( # All splits failed, best effort split with no omits. # This mostly happens to multiline strings that are by definition # reported as not fitting a single line. - yield from right_hand_split(line, line_length, features=features) + # line_length=1 here was historically a bug that somehow became a feature. + # See #762 and #781 for the full story. + yield from right_hand_split(line, line_length=1, features=features) if line.inside_brackets: - split_funcs = [delimiter_split, standalone_comment_split, rhs] + transformers = [ + string_merge, + string_paren_strip, + delimiter_split, + standalone_comment_split, + string_split, + string_paren_wrap, + rhs, + ] else: - split_funcs = [rhs] - for split_func in split_funcs: + transformers = [ + string_merge, + string_paren_strip, + string_split, + string_paren_wrap, + rhs, + ] + + for transform in transformers: # We are accumulating lines in `result` because we might want to abort # mission and return the original line in the end, or attempt a different # split altogether. result: List[Line] = [] try: - for l in split_func(line, features): + for l in transform(line, features): if str(l).strip("\n") == line_str: - raise CannotSplit("Split function returned an unchanged result") + raise CannotTransform( + "Line transformer returned an unchanged result" + ) result.extend( - split_line( - l, line_length=line_length, inner=True, features=features + transform_line( + l, + line_length=line_length, + normalize_strings=normalize_strings, + features=features, ) ) - except CannotSplit: + except CannotTransform: continue - else: yield from result break @@ -2434,77 +2734,2008 @@ def split_line( yield line -def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: - """Split line into many lines, starting with the first matching bracket pair. - - Note: this usually looks weird, only use this for function definitions. - Prefer RHS otherwise. This is why this function is not symmetrical with - :func:`right_hand_split` which also handles optional parentheses. +@dataclass # type: ignore +class StringTransformer(ABC): """ - tail_leaves: List[Leaf] = [] - body_leaves: List[Leaf] = [] - head_leaves: List[Leaf] = [] - current_leaves = head_leaves - matching_bracket = None - for leaf in line.leaves: - if ( - current_leaves is body_leaves - and leaf.type in CLOSING_BRACKETS - and leaf.opening_bracket is matching_bracket - ): - current_leaves = tail_leaves if body_leaves else head_leaves - current_leaves.append(leaf) - if current_leaves is head_leaves: - if leaf.type in OPENING_BRACKETS: - matching_bracket = leaf - current_leaves = body_leaves - if not matching_bracket: - raise CannotSplit("No brackets found") + An implementation of the Transformer protocol that relies on its + subclasses overriding the template methods `do_match(...)` and + `do_transform(...)`. - head = bracket_split_build_line(head_leaves, line, matching_bracket) - body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True) - tail = bracket_split_build_line(tail_leaves, line, matching_bracket) - bracket_split_succeeded_or_raise(head, body, tail) - for result in (head, body, tail): - if result: - yield result + This Transformer works exclusively on strings (for example, by merging + or splitting them). + The following sections can be found among the docstrings of each concrete + StringTransformer subclass. -def right_hand_split( - line: Line, - line_length: int, - features: Collection[Feature] = (), - omit: Collection[LeafID] = (), -) -> Iterator[Line]: - """Split line into many lines, starting with the last matching bracket pair. + Requirements: + Which requirements must be met of the given Line for this + StringTransformer to be applied? - If the split was by optional parentheses, attempt splitting without them, too. - `omit` is a collection of closing bracket IDs that shouldn't be considered for - this split. + Transformations: + If the given Line meets all of the above requirments, which string + transformations can you expect to be applied to it by this + StringTransformer? - Note: running this function modifies `bracket_depth` on the leaves of `line`. + Collaborations: + What contractual agreements does this StringTransformer have with other + StringTransfomers? Such collaborations should be eliminated/minimized + as much as possible. """ - tail_leaves: List[Leaf] = [] - body_leaves: List[Leaf] = [] - head_leaves: List[Leaf] = [] - current_leaves = tail_leaves - opening_bracket = None - closing_bracket = None - for leaf in reversed(line.leaves): - if current_leaves is body_leaves: - if leaf is opening_bracket: - current_leaves = head_leaves if body_leaves else tail_leaves - current_leaves.append(leaf) - if current_leaves is tail_leaves: - if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit: - opening_bracket = leaf.opening_bracket - closing_bracket = leaf - current_leaves = body_leaves - if not (opening_bracket and closing_bracket and head_leaves): - # If there is no opening or closing_bracket that means the split failed and - # all content is in the tail. Otherwise, if `head_leaves` are empty, it means - # the matching `opening_bracket` wasn't available on `line` anymore. - raise CannotSplit("No brackets found") + + line_length: int + normalize_strings: bool + + @abstractmethod + def do_match(self, line: Line) -> TMatchResult: + """ + Returns: + * Ok(string_idx) such that `line.leaves[string_idx]` is our target + string, if a match was able to be made. + OR + * Err(CannotTransform), if a match was not able to be made. + """ + + @abstractmethod + def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + """ + Yields: + * Ok(new_line) where new_line is the new transformed line. + OR + * Err(CannotTransform) if the transformation failed for some reason. The + `do_match(...)` template method should usually be used to reject + the form of the given Line, but in some cases it is difficult to + know whether or not a Line meets the StringTransformer's + requirements until the transformation is already midway. + + Side Effects: + This method should NOT mutate @line directly, but it MAY mutate the + Line's underlying Node structure. (WARNING: If the underlying Node + structure IS altered, then this method should NOT be allowed to + yield an CannotTransform after that point.) + """ + + def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]: + """ + StringTransformer instances have a call signature that mirrors that of + the Transformer type. + + Raises: + CannotTransform(...) if the concrete StringTransformer class is unable + to transform @line. + """ + # Optimization to avoid calling `self.do_match(...)` when the line does + # not contain any string. + if not any(leaf.type == token.STRING for leaf in line.leaves): + raise CannotTransform("There are no strings in this line.") + + match_result = self.do_match(line) + + if isinstance(match_result, Err): + cant_transform = match_result.err() + raise CannotTransform( + f"The string transformer {self.__class__.__name__} does not recognize" + " this line as one that it can transform." + ) from cant_transform + + string_idx = match_result.ok() + + for line_result in self.do_transform(line, string_idx): + if isinstance(line_result, Err): + cant_transform = line_result.err() + raise CannotTransform( + "StringTransformer failed while attempting to transform string." + ) from cant_transform + line = line_result.ok() + yield line + + +@dataclass +class CustomSplit: + """A custom (i.e. manual) string split. + + A single CustomSplit instance represents a single substring. + + Examples: + Consider the following string: + ``` + "Hi there friend." + " This is a custom" + f" string {split}." + ``` + + This string will correspond to the following three CustomSplit instances: + ``` + CustomSplit(False, 16) + CustomSplit(False, 17) + CustomSplit(True, 16) + ``` + """ + + has_prefix: bool + break_idx: int + + +class CustomSplitMapMixin: + """ + This mixin class is used to map merged strings to a sequence of + CustomSplits, which will then be used to re-split the strings iff none of + the resultant substrings go over the configured max line length. + """ + + _Key = Tuple[StringID, str] + _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple) + + @staticmethod + def _get_key(string: str) -> "CustomSplitMapMixin._Key": + """ + Returns: + A unique identifier that is used internally to map @string to a + group of custom splits. + """ + return (id(string), string) + + def add_custom_splits( + self, string: str, custom_splits: Iterable[CustomSplit] + ) -> None: + """Custom Split Map Setter Method + + Side Effects: + Adds a mapping from @string to the custom splits @custom_splits. + """ + key = self._get_key(string) + self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits) + + def pop_custom_splits(self, string: str) -> List[CustomSplit]: + """Custom Split Map Getter Method + + Returns: + * A list of the custom splits that are mapped to @string, if any + exist. + OR + * [], otherwise. + + Side Effects: + Deletes the mapping between @string and its associated custom + splits (which are returned to the caller). + """ + key = self._get_key(string) + + custom_splits = self._CUSTOM_SPLIT_MAP[key] + del self._CUSTOM_SPLIT_MAP[key] + + return list(custom_splits) + + def has_custom_splits(self, string: str) -> bool: + """ + Returns: + True iff @string is associated with a set of custom splits. + """ + key = self._get_key(string) + return key in self._CUSTOM_SPLIT_MAP + + +class StringMerger(CustomSplitMapMixin, StringTransformer): + """StringTransformer that merges strings together. + + Requirements: + (A) The line contains adjacent strings such that at most one substring + has inline comments AND none of those inline comments are pragmas AND + the set of all substring prefixes is either of length 1 or equal to + {"", "f"} AND none of the substrings are raw strings (i.e. are prefixed + with 'r'). + OR + (B) The line contains a string which uses line continuation backslashes. + + Transformations: + Depending on which of the two requirements above where met, either: + + (A) The string group associated with the target string is merged. + OR + (B) All line-continuation backslashes are removed from the target string. + + Collaborations: + StringMerger provides custom split information to StringSplitter. + """ + + def do_match(self, line: Line) -> TMatchResult: + LL = line.leaves + + is_valid_index = is_valid_index_factory(LL) + + for (i, leaf) in enumerate(LL): + if ( + leaf.type == token.STRING + and is_valid_index(i + 1) + and LL[i + 1].type == token.STRING + ): + return Ok(i) + + if leaf.type == token.STRING and "\\\n" in leaf.value: + return Ok(i) + + return TErr("This line has no strings that need merging.") + + def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + new_line = line + rblc_result = self.__remove_backslash_line_continuation_chars( + new_line, string_idx + ) + if isinstance(rblc_result, Ok): + new_line = rblc_result.ok() + + msg_result = self.__merge_string_group(new_line, string_idx) + if isinstance(msg_result, Ok): + new_line = msg_result.ok() + + if isinstance(rblc_result, Err) and isinstance(msg_result, Err): + msg_cant_transform = msg_result.err() + rblc_cant_transform = rblc_result.err() + cant_transform = CannotTransform( + "StringMerger failed to merge any strings in this line." + ) + + # Chain the errors together using `__cause__`. + msg_cant_transform.__cause__ = rblc_cant_transform + cant_transform.__cause__ = msg_cant_transform + + yield Err(cant_transform) + else: + yield Ok(new_line) + + @staticmethod + def __remove_backslash_line_continuation_chars( + line: Line, string_idx: int + ) -> TResult[Line]: + """ + Merge strings that were split across multiple lines using + line-continuation backslashes. + + Returns: + Ok(new_line), if @line contains backslash line-continuation + characters. + OR + Err(CannotTransform), otherwise. + """ + LL = line.leaves + + string_leaf = LL[string_idx] + if not ( + string_leaf.type == token.STRING + and "\\\n" in string_leaf.value + and not has_triple_quotes(string_leaf.value) + ): + return TErr( + f"String leaf {string_leaf} does not contain any backslash line" + " continuation characters." + ) + + new_line = line.clone() + new_line.comments = line.comments + append_leaves(new_line, line, LL) + + new_string_leaf = new_line.leaves[string_idx] + new_string_leaf.value = new_string_leaf.value.replace("\\\n", "") + + return Ok(new_line) + + def __merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]: + """ + Merges string group (i.e. set of adjacent strings) where the first + string in the group is `line.leaves[string_idx]`. + + Returns: + Ok(new_line), if ALL of the validation checks found in + __validate_msg(...) pass. + OR + Err(CannotTransform), otherwise. + """ + LL = line.leaves + + is_valid_index = is_valid_index_factory(LL) + + vresult = self.__validate_msg(line, string_idx) + if isinstance(vresult, Err): + return vresult + + # If the string group is wrapped inside an Atom node, we must make sure + # to later replace that Atom with our new (merged) string leaf. + atom_node = LL[string_idx].parent + + # We will place BREAK_MARK in between every two substrings that we + # merge. We will then later go through our final result and use the + # various instances of BREAK_MARK we find to add the right values to + # the custom split map. + BREAK_MARK = "@@@@@ BLACK BREAKPOINT MARKER @@@@@" + + QUOTE = LL[string_idx].value[-1] + + def make_naked(string: str, string_prefix: str) -> str: + """Strip @string (i.e. make it a "naked" string) + + Pre-conditions: + * assert_is_leaf_string(@string) + + Returns: + A string that is identical to @string except that + @string_prefix has been stripped, the surrounding QUOTE + characters have been removed, and any remaining QUOTE + characters have been escaped. + """ + assert_is_leaf_string(string) + + RE_EVEN_BACKSLASHES = r"(?:(?= 0 + ), "Logic error while filling the custom string breakpoint cache." + + temp_string = temp_string[mark_idx + len(BREAK_MARK) :] + breakpoint_idx = mark_idx + (len(prefix) if has_prefix else 0) + 1 + custom_splits.append(CustomSplit(has_prefix, breakpoint_idx)) + + string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, "")) + + if atom_node is not None: + replace_child(atom_node, string_leaf) + + # Build the final line ('new_line') that this method will later return. + new_line = line.clone() + for (i, leaf) in enumerate(LL): + if i == string_idx: + new_line.append(string_leaf) + + if string_idx <= i < string_idx + num_of_strings: + for comment_leaf in line.comments_after(LL[i]): + new_line.append(comment_leaf, preformatted=True) + continue + + append_leaves(new_line, line, [leaf]) + + self.add_custom_splits(string_leaf.value, custom_splits) + return Ok(new_line) + + @staticmethod + def __validate_msg(line: Line, string_idx: int) -> TResult[None]: + """Validate (M)erge (S)tring (G)roup + + Transform-time string validation logic for __merge_string_group(...). + + Returns: + * Ok(None), if ALL validation checks (listed below) pass. + OR + * Err(CannotTransform), if any of the following are true: + - The target string is not in a string group (i.e. it has no + adjacent strings). + - The string group has more than one inline comment. + - The string group has an inline comment that appears to be a pragma. + - The set of all string prefixes in the string group is of + length greater than one and is not equal to {"", "f"}. + - The string group consists of raw strings. + """ + num_of_inline_string_comments = 0 + set_of_prefixes = set() + num_of_strings = 0 + for leaf in line.leaves[string_idx:]: + if leaf.type != token.STRING: + # If the string group is trailed by a comma, we count the + # comments trailing the comma to be one of the string group's + # comments. + if leaf.type == token.COMMA and id(leaf) in line.comments: + num_of_inline_string_comments += 1 + break + + if has_triple_quotes(leaf.value): + return TErr("StringMerger does NOT merge multiline strings.") + + num_of_strings += 1 + prefix = get_string_prefix(leaf.value) + if "r" in prefix: + return TErr("StringMerger does NOT merge raw strings.") + + set_of_prefixes.add(prefix) + + if id(leaf) in line.comments: + num_of_inline_string_comments += 1 + if contains_pragma_comment(line.comments[id(leaf)]): + return TErr("Cannot merge strings which have pragma comments.") + + if num_of_strings < 2: + return TErr( + f"Not enough strings to merge (num_of_strings={num_of_strings})." + ) + + if num_of_inline_string_comments > 1: + return TErr( + f"Too many inline string comments ({num_of_inline_string_comments})." + ) + + if len(set_of_prefixes) > 1 and set_of_prefixes != {"", "f"}: + return TErr(f"Too many different prefixes ({set_of_prefixes}).") + + return Ok(None) + + +class StringParenStripper(StringTransformer): + """StringTransformer that strips surrounding parentheses from strings. + + Requirements: + The line contains a string which is surrounded by parentheses and: + - The target string is NOT the only argument to a function call). + - The RPAR is NOT followed by an attribute access (i.e. a dot). + + Transformations: + The parentheses mentioned in the 'Requirements' section are stripped. + + Collaborations: + StringParenStripper has its own inherent usefulness, but it is also + relied on to clean up the parentheses created by StringParenWrapper (in + the event that they are no longer needed). + """ + + def do_match(self, line: Line) -> TMatchResult: + LL = line.leaves + + is_valid_index = is_valid_index_factory(LL) + + for (idx, leaf) in enumerate(LL): + # Should be a string... + if leaf.type != token.STRING: + continue + + # Should be preceded by a non-empty LPAR... + if ( + not is_valid_index(idx - 1) + or LL[idx - 1].type != token.LPAR + or is_empty_lpar(LL[idx - 1]) + ): + continue + + # That LPAR should NOT be preceded by a function name or a closing + # bracket (which could be a function which returns a function or a + # list/dictionary that contains a function)... + if is_valid_index(idx - 2) and ( + LL[idx - 2].type == token.NAME or LL[idx - 2].type in CLOSING_BRACKETS + ): + continue + + string_idx = idx + + # Skip the string trailer, if one exists. + string_parser = StringParser() + next_idx = string_parser.parse(LL, string_idx) + + # Should be followed by a non-empty RPAR... + if ( + is_valid_index(next_idx) + and LL[next_idx].type == token.RPAR + and not is_empty_rpar(LL[next_idx]) + ): + # That RPAR should NOT be followed by a '.' symbol. + if is_valid_index(next_idx + 1) and LL[next_idx + 1].type == token.DOT: + continue + + return Ok(string_idx) + + return TErr("This line has no strings wrapped in parens.") + + def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + LL = line.leaves + + string_parser = StringParser() + rpar_idx = string_parser.parse(LL, string_idx) + + for leaf in (LL[string_idx - 1], LL[rpar_idx]): + if line.comments_after(leaf): + yield TErr( + "Will not strip parentheses which have comments attached to them." + ) + + new_line = line.clone() + new_line.comments = line.comments.copy() + + append_leaves(new_line, line, LL[: string_idx - 1]) + + string_leaf = Leaf(token.STRING, LL[string_idx].value) + LL[string_idx - 1].remove() + replace_child(LL[string_idx], string_leaf) + new_line.append(string_leaf) + + append_leaves( + new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :], + ) + + LL[rpar_idx].remove() + + yield Ok(new_line) + + +class BaseStringSplitter(StringTransformer): + """ + Abstract class for StringTransformers which transform a Line's strings by splitting + them or placing them on their own lines where necessary to avoid going over + the configured line length. + + Requirements: + * The target string value is responsible for the line going over the + line length limit. It follows that after all of black's other line + split methods have been exhausted, this line (or one of the resulting + lines after all line splits are performed) would still be over the + line_length limit unless we split this string. + AND + * The target string is NOT a "pointless" string (i.e. a string that has + no parent or siblings). + AND + * The target string is not followed by an inline comment that appears + to be a pragma. + AND + * The target string is not a multiline (i.e. triple-quote) string. + """ + + @abstractmethod + def do_splitter_match(self, line: Line) -> TMatchResult: + """ + BaseStringSplitter asks its clients to override this method instead of + `StringTransformer.do_match(...)`. + + Follows the same protocol as `StringTransformer.do_match(...)`. + + Refer to `help(StringTransformer.do_match)` for more information. + """ + + def do_match(self, line: Line) -> TMatchResult: + match_result = self.do_splitter_match(line) + if isinstance(match_result, Err): + return match_result + + string_idx = match_result.ok() + vresult = self.__validate(line, string_idx) + if isinstance(vresult, Err): + return vresult + + return match_result + + def __validate(self, line: Line, string_idx: int) -> TResult[None]: + """ + Checks that @line meets all of the requirements listed in this classes' + docstring. Refer to `help(BaseStringSplitter)` for a detailed + description of those requirements. + + Returns: + * Ok(None), if ALL of the requirements are met. + OR + * Err(CannotTransform), if ANY of the requirements are NOT met. + """ + LL = line.leaves + + string_leaf = LL[string_idx] + + max_string_length = self.__get_max_string_length(line, string_idx) + if len(string_leaf.value) <= max_string_length: + return TErr( + "The string itself is not what is causing this line to be too long." + ) + + if not string_leaf.parent or [L.type for L in string_leaf.parent.children] == [ + token.STRING, + token.NEWLINE, + ]: + return TErr( + f"This string ({string_leaf.value}) appears to be pointless (i.e. has" + " no parent)." + ) + + if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment( + line.comments[id(line.leaves[string_idx])] + ): + return TErr( + "Line appears to end with an inline pragma comment. Splitting the line" + " could modify the pragma's behavior." + ) + + if has_triple_quotes(string_leaf.value): + return TErr("We cannot split multiline strings.") + + return Ok(None) + + def __get_max_string_length(self, line: Line, string_idx: int) -> int: + """ + Calculates the max string length used when attempting to determine + whether or not the target string is responsible for causing the line to + go over the line length limit. + + WARNING: This method is tightly coupled to both StringSplitter and + (especially) StringParenWrapper. There is probably a better way to + accomplish what is being done here. + + Returns: + max_string_length: such that `line.leaves[string_idx].value > + max_string_length` implies that the target string IS responsible + for causing this line to exceed the line length limit. + """ + LL = line.leaves + + is_valid_index = is_valid_index_factory(LL) + + # We use the shorthand "WMA4" in comments to abbreviate "We must + # account for". When giving examples, we use STRING to mean some/any + # valid string. + # + # Finally, we use the following convenience variables: + # + # P: The leaf that is before the target string leaf. + # N: The leaf that is after the target string leaf. + # NN: The leaf that is after N. + + # WMA4 the whitespace at the beginning of the line. + offset = line.depth * 4 + + if is_valid_index(string_idx - 1): + p_idx = string_idx - 1 + if ( + LL[string_idx - 1].type == token.LPAR + and LL[string_idx - 1].value == "" + and string_idx >= 2 + ): + # If the previous leaf is an empty LPAR placeholder, we should skip it. + p_idx -= 1 + + P = LL[p_idx] + if P.type == token.PLUS: + # WMA4 a space and a '+' character (e.g. `+ STRING`). + offset += 2 + + if P.type == token.COMMA: + # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`]. + offset += 3 + + if P.type in [token.COLON, token.EQUAL, token.NAME]: + # This conditional branch is meant to handle dictionary keys, + # variable assignments, 'return STRING' statement lines, and + # 'else STRING' ternary expression lines. + + # WMA4 a single space. + offset += 1 + + # WMA4 the lengths of any leaves that came before that space. + for leaf in LL[: p_idx + 1]: + offset += len(str(leaf)) + + if is_valid_index(string_idx + 1): + N = LL[string_idx + 1] + if N.type == token.RPAR and N.value == "" and len(LL) > string_idx + 2: + # If the next leaf is an empty RPAR placeholder, we should skip it. + N = LL[string_idx + 2] + + if N.type == token.COMMA: + # WMA4 a single comma at the end of the string (e.g `STRING,`). + offset += 1 + + if is_valid_index(string_idx + 2): + NN = LL[string_idx + 2] + + if N.type == token.DOT and NN.type == token.NAME: + # This conditional branch is meant to handle method calls invoked + # off of a string literal up to and including the LPAR character. + + # WMA4 the '.' character. + offset += 1 + + if ( + is_valid_index(string_idx + 3) + and LL[string_idx + 3].type == token.LPAR + ): + # WMA4 the left parenthesis character. + offset += 1 + + # WMA4 the length of the method's name. + offset += len(NN.value) + + has_comments = False + for comment_leaf in line.comments_after(LL[string_idx]): + if not has_comments: + has_comments = True + # WMA4 two spaces before the '#' character. + offset += 2 + + # WMA4 the length of the inline comment. + offset += len(comment_leaf.value) + + max_string_length = self.line_length - offset + return max_string_length + + +class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): + """ + StringTransformer that splits "atom" strings (i.e. strings which exist on + lines by themselves). + + Requirements: + * The line consists ONLY of a single string (with the exception of a + '+' symbol which MAY exist at the start of the line), MAYBE a string + trailer, and MAYBE a trailing comma. + AND + * All of the requirements listed in BaseStringSplitter's docstring. + + Transformations: + The string mentioned in the 'Requirements' section is split into as + many substrings as necessary to adhere to the configured line length. + + In the final set of substrings, no substring should be smaller than + MIN_SUBSTR_SIZE characters. + + The string will ONLY be split on spaces (i.e. each new substring should + start with a space). + + If the string is an f-string, it will NOT be split in the middle of an + f-expression (e.g. in f"FooBar: {foo() if x else bar()}", {foo() if x + else bar()} is an f-expression). + + If the string that is being split has an associated set of custom split + records and those custom splits will NOT result in any line going over + the configured line length, those custom splits are used. Otherwise the + string is split as late as possible (from left-to-right) while still + adhering to the transformation rules listed above. + + Collaborations: + StringSplitter relies on StringMerger to construct the appropriate + CustomSplit objects and add them to the custom split map. + """ + + MIN_SUBSTR_SIZE = 6 + # Matches an "f-expression" (e.g. {var}) that might be found in an f-string. + RE_FEXPR = r""" + (? TMatchResult: + LL = line.leaves + + is_valid_index = is_valid_index_factory(LL) + + idx = 0 + + # The first leaf MAY be a '+' symbol... + if is_valid_index(idx) and LL[idx].type == token.PLUS: + idx += 1 + + # The next/first leaf MAY be an empty LPAR... + if is_valid_index(idx) and is_empty_lpar(LL[idx]): + idx += 1 + + # The next/first leaf MUST be a string... + if not is_valid_index(idx) or LL[idx].type != token.STRING: + return TErr("Line does not start with a string.") + + string_idx = idx + + # Skip the string trailer, if one exists. + string_parser = StringParser() + idx = string_parser.parse(LL, string_idx) + + # That string MAY be followed by an empty RPAR... + if is_valid_index(idx) and is_empty_rpar(LL[idx]): + idx += 1 + + # That string / empty RPAR leaf MAY be followed by a comma... + if is_valid_index(idx) and LL[idx].type == token.COMMA: + idx += 1 + + # But no more leaves are allowed... + if is_valid_index(idx): + return TErr("This line does not end with a string.") + + return Ok(string_idx) + + def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + LL = line.leaves + + QUOTE = LL[string_idx].value[-1] + + is_valid_index = is_valid_index_factory(LL) + insert_str_child = insert_str_child_factory(LL[string_idx]) + + prefix = get_string_prefix(LL[string_idx].value) + + # We MAY choose to drop the 'f' prefix from substrings that don't + # contain any f-expressions, but ONLY if the original f-string + # containes at least one f-expression. Otherwise, we will alter the AST + # of the program. + drop_pointless_f_prefix = ("f" in prefix) and re.search( + self.RE_FEXPR, LL[string_idx].value, re.VERBOSE + ) + + first_string_line = True + starts_with_plus = LL[0].type == token.PLUS + + def line_needs_plus() -> bool: + return first_string_line and starts_with_plus + + def maybe_append_plus(new_line: Line) -> None: + """ + Side Effects: + If @line starts with a plus and this is the first line we are + constructing, this function appends a PLUS leaf to @new_line + and replaces the old PLUS leaf in the node structure. Otherwise + this function does nothing. + """ + if line_needs_plus(): + plus_leaf = Leaf(token.PLUS, "+") + replace_child(LL[0], plus_leaf) + new_line.append(plus_leaf) + + ends_with_comma = ( + is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA + ) + + def max_last_string() -> int: + """ + Returns: + The max allowed length of the string value used for the last + line we will construct. + """ + result = self.line_length + result -= line.depth * 4 + result -= 1 if ends_with_comma else 0 + result -= 2 if line_needs_plus() else 0 + return result + + # --- Calculate Max Break Index (for string value) + # We start with the line length limit + max_break_idx = self.line_length + # The last index of a string of length N is N-1. + max_break_idx -= 1 + # Leading whitespace is not present in the string value (e.g. Leaf.value). + max_break_idx -= line.depth * 4 + if max_break_idx < 0: + yield TErr( + f"Unable to split {LL[string_idx].value} at such high of a line depth:" + f" {line.depth}" + ) + return + + # Check if StringMerger registered any custom splits. + custom_splits = self.pop_custom_splits(LL[string_idx].value) + # We use them ONLY if none of them would produce lines that exceed the + # line limit. + use_custom_breakpoints = bool( + custom_splits + and all(csplit.break_idx <= max_break_idx for csplit in custom_splits) + ) + + # Temporary storage for the remaining chunk of the string line that + # can't fit onto the line currently being constructed. + rest_value = LL[string_idx].value + + def more_splits_should_be_made() -> bool: + """ + Returns: + True iff `rest_value` (the remaining string value from the last + split), should be split again. + """ + if use_custom_breakpoints: + return len(custom_splits) > 1 + else: + return len(rest_value) > max_last_string() + + string_line_results: List[Ok[Line]] = [] + while more_splits_should_be_made(): + if use_custom_breakpoints: + # Custom User Split (manual) + csplit = custom_splits.pop(0) + break_idx = csplit.break_idx + else: + # Algorithmic Split (automatic) + max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx + maybe_break_idx = self.__get_break_idx(rest_value, max_bidx) + if maybe_break_idx is None: + # If we are unable to algorthmically determine a good split + # and this string has custom splits registered to it, we + # fall back to using them--which means we have to start + # over from the beginning. + if custom_splits: + rest_value = LL[string_idx].value + string_line_results = [] + first_string_line = True + use_custom_breakpoints = True + continue + + # Otherwise, we stop splitting here. + break + + break_idx = maybe_break_idx + + # --- Construct `next_value` + next_value = rest_value[:break_idx] + QUOTE + if ( + # Are we allowed to try to drop a pointless 'f' prefix? + drop_pointless_f_prefix + # If we are, will we be successful? + and next_value != self.__normalize_f_string(next_value, prefix) + ): + # If the current custom split did NOT originally use a prefix, + # then `csplit.break_idx` will be off by one after removing + # the 'f' prefix. + break_idx = ( + break_idx + 1 + if use_custom_breakpoints and not csplit.has_prefix + else break_idx + ) + next_value = rest_value[:break_idx] + QUOTE + next_value = self.__normalize_f_string(next_value, prefix) + + # --- Construct `next_leaf` + next_leaf = Leaf(token.STRING, next_value) + insert_str_child(next_leaf) + self.__maybe_normalize_string_quotes(next_leaf) + + # --- Construct `next_line` + next_line = line.clone() + maybe_append_plus(next_line) + next_line.append(next_leaf) + string_line_results.append(Ok(next_line)) + + rest_value = prefix + QUOTE + rest_value[break_idx:] + first_string_line = False + + yield from string_line_results + + if drop_pointless_f_prefix: + rest_value = self.__normalize_f_string(rest_value, prefix) + + rest_leaf = Leaf(token.STRING, rest_value) + insert_str_child(rest_leaf) + + # NOTE: I could not find a test case that verifies that the following + # line is actually necessary, but it seems to be. Otherwise we risk + # not normalizing the last substring, right? + self.__maybe_normalize_string_quotes(rest_leaf) + + last_line = line.clone() + maybe_append_plus(last_line) + + # If there are any leaves to the right of the target string... + if is_valid_index(string_idx + 1): + # We use `temp_value` here to determine how long the last line + # would be if we were to append all the leaves to the right of the + # target string to the last string line. + temp_value = rest_value + for leaf in LL[string_idx + 1 :]: + temp_value += str(leaf) + if leaf.type == token.LPAR: + break + + # Try to fit them all on the same line with the last substring... + if ( + len(temp_value) <= max_last_string() + or LL[string_idx + 1].type == token.COMMA + ): + last_line.append(rest_leaf) + append_leaves(last_line, line, LL[string_idx + 1 :]) + yield Ok(last_line) + # Otherwise, place the last substring on one line and everything + # else on a line below that... + else: + last_line.append(rest_leaf) + yield Ok(last_line) + + non_string_line = line.clone() + append_leaves(non_string_line, line, LL[string_idx + 1 :]) + yield Ok(non_string_line) + # Else the target string was the last leaf... + else: + last_line.append(rest_leaf) + last_line.comments = line.comments.copy() + yield Ok(last_line) + + def __get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]: + """ + This method contains the algorithm that StringSplitter uses to + determine which character to split each string at. + + Args: + @string: The substring that we are attempting to split. + @max_break_idx: The ideal break index. We will return this value if it + meets all the necessary conditions. In the likely event that it + doesn't we will try to find the closest index BELOW @max_break_idx + that does. If that fails, we will expand our search by also + considering all valid indices ABOVE @max_break_idx. + + Pre-Conditions: + * assert_is_leaf_string(@string) + * 0 <= @max_break_idx < len(@string) + + Returns: + break_idx, if an index is able to be found that meets all of the + conditions listed in the 'Transformations' section of this classes' + docstring. + OR + None, otherwise. + """ + is_valid_index = is_valid_index_factory(string) + + assert is_valid_index(max_break_idx) + assert_is_leaf_string(string) + + _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None + + def fexpr_slices() -> Iterator[Tuple[Index, Index]]: + """ + Yields: + All ranges of @string which, if @string were to be split there, + would result in the splitting of an f-expression (which is NOT + allowed). + """ + nonlocal _fexpr_slices + + if _fexpr_slices is None: + _fexpr_slices = [] + for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE): + _fexpr_slices.append(match.span()) + + yield from _fexpr_slices + + is_fstring = "f" in get_string_prefix(string) + + def breaks_fstring_expression(i: Index) -> bool: + """ + Returns: + True iff returning @i would result in the splitting of an + f-expression (which is NOT allowed). + """ + if not is_fstring: + return False + + for (start, end) in fexpr_slices(): + if start <= i < end: + return True + + return False + + def passes_all_checks(i: Index) -> bool: + """ + Returns: + True iff ALL of the conditions listed in the 'Transformations' + section of this classes' docstring would be be met by returning @i. + """ + is_space = string[i] == " " + is_big_enough = ( + len(string[i:]) >= self.MIN_SUBSTR_SIZE + and len(string[:i]) >= self.MIN_SUBSTR_SIZE + ) + return is_space and is_big_enough and not breaks_fstring_expression(i) + + # First, we check all indices BELOW @max_break_idx. + break_idx = max_break_idx + while is_valid_index(break_idx - 1) and not passes_all_checks(break_idx): + break_idx -= 1 + + if not passes_all_checks(break_idx): + # If that fails, we check all indices ABOVE @max_break_idx. + # + # If we are able to find a valid index here, the next line is going + # to be longer than the specified line length, but it's probably + # better than doing nothing at all. + break_idx = max_break_idx + 1 + while is_valid_index(break_idx + 1) and not passes_all_checks(break_idx): + break_idx += 1 + + if not is_valid_index(break_idx) or not passes_all_checks(break_idx): + return None + + return break_idx + + def __maybe_normalize_string_quotes(self, leaf: Leaf) -> None: + if self.normalize_strings: + normalize_string_quotes(leaf) + + def __normalize_f_string(self, string: str, prefix: str) -> str: + """ + Pre-Conditions: + * assert_is_leaf_string(@string) + + Returns: + * If @string is an f-string that contains no f-expressions, we + return a string identical to @string except that the 'f' prefix + has been stripped and all double braces (i.e. '{{' or '}}') have + been normalized (i.e. turned into '{' or '}'). + OR + * Otherwise, we return @string. + """ + assert_is_leaf_string(string) + + if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE): + new_prefix = prefix.replace("f", "") + + temp = string[len(prefix) :] + temp = re.sub(r"\{\{", "{", temp) + temp = re.sub(r"\}\}", "}", temp) + new_string = temp + + return f"{new_prefix}{new_string}" + else: + return string + + +class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): + """ + StringTransformer that splits non-"atom" strings (i.e. strings that do not + exist on lines by themselves). + + Requirements: + All of the requirements listed in BaseStringSplitter's docstring in + addition to the requirements listed below: + + * The line is a return/yield statement, which returns/yields a string. + OR + * The line is part of a ternary expression (e.g. `x = y if cond else + z`) such that the line starts with `else `, where is + some string. + OR + * The line is an assert statement, which ends with a string. + OR + * The line is an assignment statement (e.g. `x = ` or `x += + `) such that the variable is being assigned the value of some + string. + OR + * The line is a dictionary key assignment where some valid key is being + assigned the value of some string. + + Transformations: + The chosen string is wrapped in parentheses and then split at the LPAR. + + We then have one line which ends with an LPAR and another line that + starts with the chosen string. The latter line is then split again at + the RPAR. This results in the RPAR (and possibly a trailing comma) + being placed on its own line. + + NOTE: If any leaves exist to the right of the chosen string (except + for a trailing comma, which would be placed after the RPAR), those + leaves are placed inside the parentheses. In effect, the chosen + string is not necessarily being "wrapped" by parentheses. We can, + however, count on the LPAR being placed directly before the chosen + string. + + In other words, StringParenWrapper creates "atom" strings. These + can then be split again by StringSplitter, if necessary. + + Collaborations: + In the event that a string line split by StringParenWrapper is + changed such that it no longer needs to be given its own line, + StringParenWrapper relies on StringParenStripper to clean up the + parentheses it created. + """ + + def do_splitter_match(self, line: Line) -> TMatchResult: + LL = line.leaves + + string_idx = None + string_idx = string_idx or self._return_match(LL) + string_idx = string_idx or self._else_match(LL) + string_idx = string_idx or self._assert_match(LL) + string_idx = string_idx or self._assign_match(LL) + string_idx = string_idx or self._dict_match(LL) + + if string_idx is not None: + string_value = line.leaves[string_idx].value + # If the string has no spaces... + if " " not in string_value: + # And will still violate the line length limit when split... + max_string_length = self.line_length - ((line.depth + 1) * 4) + if len(string_value) > max_string_length: + # And has no associated custom splits... + if not self.has_custom_splits(string_value): + # Then we should NOT put this string on its own line. + return TErr( + "We do not wrap long strings in parentheses when the" + " resultant line would still be over the specified line" + " length and can't be split further by StringSplitter." + ) + return Ok(string_idx) + + return TErr("This line does not contain any non-atomic strings.") + + @staticmethod + def _return_match(LL: List[Leaf]) -> Optional[int]: + """ + Returns: + string_idx such that @LL[string_idx] is equal to our target (i.e. + matched) string, if this line matches the return/yield statement + requirements listed in the 'Requirements' section of this classes' + docstring. + OR + None, otherwise. + """ + # If this line is apart of a return/yield statement and the first leaf + # contains either the "return" or "yield" keywords... + if parent_type(LL[0]) in [syms.return_stmt, syms.yield_expr] and LL[ + 0 + ].value in ["return", "yield"]: + is_valid_index = is_valid_index_factory(LL) + + idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1 + # The next visible leaf MUST contain a string... + if is_valid_index(idx) and LL[idx].type == token.STRING: + return idx + + return None + + @staticmethod + def _else_match(LL: List[Leaf]) -> Optional[int]: + """ + Returns: + string_idx such that @LL[string_idx] is equal to our target (i.e. + matched) string, if this line matches the ternary expression + requirements listed in the 'Requirements' section of this classes' + docstring. + OR + None, otherwise. + """ + # If this line is apart of a ternary expression and the first leaf + # contains the "else" keyword... + if ( + parent_type(LL[0]) == syms.test + and LL[0].type == token.NAME + and LL[0].value == "else" + ): + is_valid_index = is_valid_index_factory(LL) + + idx = 2 if is_valid_index(1) and is_empty_par(LL[1]) else 1 + # The next visible leaf MUST contain a string... + if is_valid_index(idx) and LL[idx].type == token.STRING: + return idx + + return None + + @staticmethod + def _assert_match(LL: List[Leaf]) -> Optional[int]: + """ + Returns: + string_idx such that @LL[string_idx] is equal to our target (i.e. + matched) string, if this line matches the assert statement + requirements listed in the 'Requirements' section of this classes' + docstring. + OR + None, otherwise. + """ + # If this line is apart of an assert statement and the first leaf + # contains the "assert" keyword... + if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert": + is_valid_index = is_valid_index_factory(LL) + + for (i, leaf) in enumerate(LL): + # We MUST find a comma... + if leaf.type == token.COMMA: + idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 + + # That comma MUST be followed by a string... + if is_valid_index(idx) and LL[idx].type == token.STRING: + string_idx = idx + + # Skip the string trailer, if one exists. + string_parser = StringParser() + idx = string_parser.parse(LL, string_idx) + + # But no more leaves are allowed... + if not is_valid_index(idx): + return string_idx + + return None + + @staticmethod + def _assign_match(LL: List[Leaf]) -> Optional[int]: + """ + Returns: + string_idx such that @LL[string_idx] is equal to our target (i.e. + matched) string, if this line matches the assignment statement + requirements listed in the 'Requirements' section of this classes' + docstring. + OR + None, otherwise. + """ + # If this line is apart of an expression statement or is a function + # argument AND the first leaf contains a variable name... + if ( + parent_type(LL[0]) in [syms.expr_stmt, syms.argument, syms.power] + and LL[0].type == token.NAME + ): + is_valid_index = is_valid_index_factory(LL) + + for (i, leaf) in enumerate(LL): + # We MUST find either an '=' or '+=' symbol... + if leaf.type in [token.EQUAL, token.PLUSEQUAL]: + idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 + + # That symbol MUST be followed by a string... + if is_valid_index(idx) and LL[idx].type == token.STRING: + string_idx = idx + + # Skip the string trailer, if one exists. + string_parser = StringParser() + idx = string_parser.parse(LL, string_idx) + + # The next leaf MAY be a comma iff this line is apart + # of a function argument... + if ( + parent_type(LL[0]) == syms.argument + and is_valid_index(idx) + and LL[idx].type == token.COMMA + ): + idx += 1 + + # But no more leaves are allowed... + if not is_valid_index(idx): + return string_idx + + return None + + @staticmethod + def _dict_match(LL: List[Leaf]) -> Optional[int]: + """ + Returns: + string_idx such that @LL[string_idx] is equal to our target (i.e. + matched) string, if this line matches the dictionary key assignment + statement requirements listed in the 'Requirements' section of this + classes' docstring. + OR + None, otherwise. + """ + # If this line is apart of a dictionary key assignment... + if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]: + is_valid_index = is_valid_index_factory(LL) + + for (i, leaf) in enumerate(LL): + # We MUST find a colon... + if leaf.type == token.COLON: + idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 + + # That colon MUST be followed by a string... + if is_valid_index(idx) and LL[idx].type == token.STRING: + string_idx = idx + + # Skip the string trailer, if one exists. + string_parser = StringParser() + idx = string_parser.parse(LL, string_idx) + + # That string MAY be followed by a comma... + if is_valid_index(idx) and LL[idx].type == token.COMMA: + idx += 1 + + # But no more leaves are allowed... + if not is_valid_index(idx): + return string_idx + + return None + + def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + LL = line.leaves + + is_valid_index = is_valid_index_factory(LL) + insert_str_child = insert_str_child_factory(LL[string_idx]) + + comma_idx = len(LL) - 1 + ends_with_comma = False + if LL[comma_idx].type == token.COMMA: + ends_with_comma = True + + leaves_to_steal_comments_from = [LL[string_idx]] + if ends_with_comma: + leaves_to_steal_comments_from.append(LL[comma_idx]) + + # --- First Line + first_line = line.clone() + left_leaves = LL[:string_idx] + + # We have to remember to account for (possibly invisible) LPAR and RPAR + # leaves that already wrapped the target string. If these leaves do + # exist, we will replace them with our own LPAR and RPAR leaves. + old_parens_exist = False + if left_leaves and left_leaves[-1].type == token.LPAR: + old_parens_exist = True + leaves_to_steal_comments_from.append(left_leaves[-1]) + left_leaves.pop() + + append_leaves(first_line, line, left_leaves) + + lpar_leaf = Leaf(token.LPAR, "(") + if old_parens_exist: + replace_child(LL[string_idx - 1], lpar_leaf) + else: + insert_str_child(lpar_leaf) + first_line.append(lpar_leaf) + + # We throw inline comments that were originally to the right of the + # target string to the top line. They will now be shown to the right of + # the LPAR. + for leaf in leaves_to_steal_comments_from: + for comment_leaf in line.comments_after(leaf): + first_line.append(comment_leaf, preformatted=True) + + yield Ok(first_line) + + # --- Middle (String) Line + # We only need to yield one (possibly too long) string line, since the + # `StringSplitter` will break it down further if necessary. + string_value = LL[string_idx].value + string_line = Line( + depth=line.depth + 1, + inside_brackets=True, + should_explode=line.should_explode, + ) + string_leaf = Leaf(token.STRING, string_value) + insert_str_child(string_leaf) + string_line.append(string_leaf) + + old_rpar_leaf = None + if is_valid_index(string_idx + 1): + right_leaves = LL[string_idx + 1 :] + if ends_with_comma: + right_leaves.pop() + + if old_parens_exist: + assert ( + right_leaves and right_leaves[-1].type == token.RPAR + ), "Apparently, old parentheses do NOT exist?!" + old_rpar_leaf = right_leaves.pop() + + append_leaves(string_line, line, right_leaves) + + yield Ok(string_line) + + # --- Last Line + last_line = line.clone() + last_line.bracket_tracker = first_line.bracket_tracker + + new_rpar_leaf = Leaf(token.RPAR, ")") + if old_rpar_leaf is not None: + replace_child(old_rpar_leaf, new_rpar_leaf) + else: + insert_str_child(new_rpar_leaf) + last_line.append(new_rpar_leaf) + + # If the target string ended with a comma, we place this comma to the + # right of the RPAR on the last line. + if ends_with_comma: + comma_leaf = Leaf(token.COMMA, ",") + replace_child(LL[comma_idx], comma_leaf) + last_line.append(comma_leaf) + + yield Ok(last_line) + + +class StringParser: + """ + A state machine that aids in parsing a string's "trailer", which can be + either non-existant, an old-style formatting sequence (e.g. `% varX` or `% + (varX, varY)`), or a method-call / attribute access (e.g. `.format(varX, + varY)`). + + NOTE: A new StringParser object MUST be instantiated for each string + trailer we need to parse. + + Examples: + We shall assume that `line` equals the `Line` object that corresponds + to the following line of python code: + ``` + x = "Some {}.".format("String") + some_other_string + ``` + + Furthermore, we will assume that `string_idx` is some index such that: + ``` + assert line.leaves[string_idx].value == "Some {}." + ``` + + The following code snippet then holds: + ``` + string_parser = StringParser() + idx = string_parser.parse(line.leaves, string_idx) + assert line.leaves[idx].type == token.PLUS + ``` + """ + + DEFAULT_TOKEN = -1 + + # String Parser States + START = 1 + DOT = 2 + NAME = 3 + PERCENT = 4 + SINGLE_FMT_ARG = 5 + LPAR = 6 + RPAR = 7 + DONE = 8 + + # Lookup Table for Next State + _goto: Dict[Tuple[ParserState, NodeType], ParserState] = { + # A string trailer may start with '.' OR '%'. + (START, token.DOT): DOT, + (START, token.PERCENT): PERCENT, + (START, DEFAULT_TOKEN): DONE, + # A '.' MUST be followed by an attribute or method name. + (DOT, token.NAME): NAME, + # A method name MUST be followed by an '(', whereas an attribute name + # is the last symbol in the string trailer. + (NAME, token.LPAR): LPAR, + (NAME, DEFAULT_TOKEN): DONE, + # A '%' symbol can be followed by an '(' or a single argument (e.g. a + # string or variable name). + (PERCENT, token.LPAR): LPAR, + (PERCENT, DEFAULT_TOKEN): SINGLE_FMT_ARG, + # If a '%' symbol is followed by a single argument, that argument is + # the last leaf in the string trailer. + (SINGLE_FMT_ARG, DEFAULT_TOKEN): DONE, + # If present, a ')' symbol is the last symbol in a string trailer. + # (NOTE: LPARS and nested RPARS are not included in this lookup table, + # since they are treated as a special case by the parsing logic in this + # classes' implementation.) + (RPAR, DEFAULT_TOKEN): DONE, + } + + def __init__(self) -> None: + self._state = self.START + self._unmatched_lpars = 0 + + def parse(self, leaves: List[Leaf], string_idx: int) -> int: + """ + Pre-conditions: + * @leaves[@string_idx].type == token.STRING + + Returns: + The index directly after the last leaf which is apart of the string + trailer, if a "trailer" exists. + OR + @string_idx + 1, if no string "trailer" exists. + """ + assert leaves[string_idx].type == token.STRING + + idx = string_idx + 1 + while idx < len(leaves) and self._next_state(leaves[idx]): + idx += 1 + return idx + + def _next_state(self, leaf: Leaf) -> bool: + """ + Pre-conditions: + * On the first call to this function, @leaf MUST be the leaf that + was directly after the string leaf in question (e.g. if our target + string is `line.leaves[i]` then the first call to this method must + be `line.leaves[i + 1]`). + * On the next call to this function, the leaf paramater passed in + MUST be the leaf directly following @leaf. + + Returns: + True iff @leaf is apart of the string's trailer. + """ + # We ignore empty LPAR or RPAR leaves. + if is_empty_par(leaf): + return True + + next_token = leaf.type + if next_token == token.LPAR: + self._unmatched_lpars += 1 + + current_state = self._state + + # The LPAR parser state is a special case. We will return True until we + # find the matching RPAR token. + if current_state == self.LPAR: + if next_token == token.RPAR: + self._unmatched_lpars -= 1 + if self._unmatched_lpars == 0: + self._state = self.RPAR + # Otherwise, we use a lookup table to determine the next state. + else: + # If the lookup table matches the current state to the next + # token, we use the lookup table. + if (current_state, next_token) in self._goto: + self._state = self._goto[current_state, next_token] + else: + # Otherwise, we check if a the current state was assigned a + # default. + if (current_state, self.DEFAULT_TOKEN) in self._goto: + self._state = self._goto[current_state, self.DEFAULT_TOKEN] + # If no default has been assigned, then this parser has a logic + # error. + else: + raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!") + + if self._state == self.DONE: + return False + + return True + + +def TErr(err_msg: str) -> Err[CannotTransform]: + """(T)ransform Err + + Convenience function used when working with the TResult type. + """ + cant_transform = CannotTransform(err_msg) + return Err(cant_transform) + + +def contains_pragma_comment(comment_list: List[Leaf]) -> bool: + """ + Returns: + True iff one of the comments in @comment_list is a pragma used by one + of the more common static analysis tools for python (e.g. mypy, flake8, + pylint). + """ + for comment in comment_list: + if comment.value.startswith(("# type:", "# noqa", "# pylint:")): + return True + + return False + + +def insert_str_child_factory(string_leaf: Leaf) -> Callable[[LN], None]: + """ + Factory for a convenience function that is used to orphan @string_leaf + and then insert multiple new leaves into the same part of the node + structure that @string_leaf had originally occupied. + + Examples: + Let `string_leaf = Leaf(token.STRING, '"foo"')` and `N = + string_leaf.parent`. Assume the node `N` has the following + original structure: + + Node( + expr_stmt, [ + Leaf(NAME, 'x'), + Leaf(EQUAL, '='), + Leaf(STRING, '"foo"'), + ] + ) + + We then run the code snippet shown below. + ``` + insert_str_child = insert_str_child_factory(string_leaf) + + lpar = Leaf(token.LPAR, '(') + insert_str_child(lpar) + + bar = Leaf(token.STRING, '"bar"') + insert_str_child(bar) + + rpar = Leaf(token.RPAR, ')') + insert_str_child(rpar) + ``` + + After which point, it follows that `string_leaf.parent is None` and + the node `N` now has the following structure: + + Node( + expr_stmt, [ + Leaf(NAME, 'x'), + Leaf(EQUAL, '='), + Leaf(LPAR, '('), + Leaf(STRING, '"bar"'), + Leaf(RPAR, ')'), + ] + ) + """ + string_parent = string_leaf.parent + string_child_idx = string_leaf.remove() + + def insert_str_child(child: LN) -> None: + nonlocal string_child_idx + + assert string_parent is not None + assert string_child_idx is not None + + string_parent.insert_child(string_child_idx, child) + string_child_idx += 1 + + return insert_str_child + + +def has_triple_quotes(string: str) -> bool: + """ + Returns: + True iff @string starts with three quotation characters. + """ + raw_string = string.lstrip(STRING_PREFIX_CHARS) + return raw_string[:3] in {'"""', "'''"} + + +def parent_type(node: Optional[LN]) -> Optional[NodeType]: + """ + Returns: + @node.parent.type, if @node is not None and has a parent. + OR + None, otherwise. + """ + if node is None or node.parent is None: + return None + + return node.parent.type + + +def is_empty_par(leaf: Leaf) -> bool: + return is_empty_lpar(leaf) or is_empty_rpar(leaf) + + +def is_empty_lpar(leaf: Leaf) -> bool: + return leaf.type == token.LPAR and leaf.value == "" + + +def is_empty_rpar(leaf: Leaf) -> bool: + return leaf.type == token.RPAR and leaf.value == "" + + +def is_valid_index_factory(seq: Sequence[Any]) -> Callable[[int], bool]: + """ + Examples: + ``` + my_list = [1, 2, 3] + + is_valid_index = is_valid_index_factory(my_list) + + assert is_valid_index(0) + assert is_valid_index(2) + + assert not is_valid_index(3) + assert not is_valid_index(-1) + ``` + """ + + def is_valid_index(idx: int) -> bool: + """ + Returns: + True iff @idx is positive AND seq[@idx] does NOT raise an + IndexError. + """ + return 0 <= idx < len(seq) + + return is_valid_index + + +def line_to_string(line: Line) -> str: + """Returns the string representation of @line. + + WARNING: This is known to be computationally expensive. + """ + return str(line).strip("\n") + + +def append_leaves(new_line: Line, old_line: Line, leaves: List[Leaf]) -> None: + """ + Append leaves (taken from @old_line) to @new_line, making sure to fix the + underlying Node structure where appropriate. + + All of the leaves in @leaves are duplicated. The duplicates are then + appended to @new_line and used to replace their originals in the underlying + Node structure. Any comments attatched to the old leaves are reattached to + the new leaves. + + Pre-conditions: + set(@leaves) is a subset of set(@old_line.leaves). + """ + for old_leaf in leaves: + assert old_leaf in old_line.leaves + + new_leaf = Leaf(old_leaf.type, old_leaf.value) + replace_child(old_leaf, new_leaf) + new_line.append(new_leaf) + + for comment_leaf in old_line.comments_after(old_leaf): + new_line.append(comment_leaf, preformatted=True) + + +def replace_child(old_child: LN, new_child: LN) -> None: + """ + Side Effects: + * If @old_child.parent is set, replace @old_child with @new_child in + @old_child's underlying Node structure. + OR + * Otherwise, this function does nothing. + """ + parent = old_child.parent + if not parent: + return + + child_idx = old_child.remove() + if child_idx is not None: + parent.insert_child(child_idx, new_child) + + +def get_string_prefix(string: str) -> str: + """ + Pre-conditions: + * assert_is_leaf_string(@string) + + Returns: + @string's prefix (e.g. '', 'r', 'f', or 'rf'). + """ + assert_is_leaf_string(string) + + prefix = "" + prefix_idx = 0 + while string[prefix_idx] in STRING_PREFIX_CHARS: + prefix += string[prefix_idx].lower() + prefix_idx += 1 + + return prefix + + +def assert_is_leaf_string(string: str) -> None: + """ + Checks the pre-condition that @string has the format that you would expect + of `leaf.value` where `leaf` is some Leaf such that `leaf.type == + token.STRING`. A more precise description of the pre-conditions that are + checked are listed below. + + Pre-conditions: + * @string starts with either ', ", ', or " where + `set()` is some subset of `set(STRING_PREFIX_CHARS)`. + * @string ends with a quote character (' or "). + + Raises: + AssertionError(...) if the pre-conditions listed above are not + satisfied. + """ + dquote_idx = string.find('"') + squote_idx = string.find("'") + if -1 in [dquote_idx, squote_idx]: + quote_idx = max(dquote_idx, squote_idx) + else: + quote_idx = min(squote_idx, dquote_idx) + + assert ( + 0 <= quote_idx < len(string) - 1 + ), f"{string!r} is missing a starting quote character (' or \")." + assert string[-1] in ( + "'", + '"', + ), f"{string!r} is missing an ending quote character (' or \")." + assert set(string[:quote_idx]).issubset( + set(STRING_PREFIX_CHARS) + ), f"{set(string[:quote_idx])} is NOT a subset of {set(STRING_PREFIX_CHARS)}." + + +def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]: + """Split line into many lines, starting with the first matching bracket pair. + + Note: this usually looks weird, only use this for function definitions. + Prefer RHS otherwise. This is why this function is not symmetrical with + :func:`right_hand_split` which also handles optional parentheses. + """ + tail_leaves: List[Leaf] = [] + body_leaves: List[Leaf] = [] + head_leaves: List[Leaf] = [] + current_leaves = head_leaves + matching_bracket: Optional[Leaf] = None + for leaf in line.leaves: + if ( + current_leaves is body_leaves + and leaf.type in CLOSING_BRACKETS + and leaf.opening_bracket is matching_bracket + ): + current_leaves = tail_leaves if body_leaves else head_leaves + current_leaves.append(leaf) + if current_leaves is head_leaves: + if leaf.type in OPENING_BRACKETS: + matching_bracket = leaf + current_leaves = body_leaves + if not matching_bracket: + raise CannotSplit("No brackets found") + + head = bracket_split_build_line(head_leaves, line, matching_bracket) + body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True) + tail = bracket_split_build_line(tail_leaves, line, matching_bracket) + bracket_split_succeeded_or_raise(head, body, tail) + for result in (head, body, tail): + if result: + yield result + + +def right_hand_split( + line: Line, + line_length: int, + features: Collection[Feature] = (), + omit: Collection[LeafID] = (), +) -> Iterator[Line]: + """Split line into many lines, starting with the last matching bracket pair. + + If the split was by optional parentheses, attempt splitting without them, too. + `omit` is a collection of closing bracket IDs that shouldn't be considered for + this split. + + Note: running this function modifies `bracket_depth` on the leaves of `line`. + """ + tail_leaves: List[Leaf] = [] + body_leaves: List[Leaf] = [] + head_leaves: List[Leaf] = [] + current_leaves = tail_leaves + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None + for leaf in reversed(line.leaves): + if current_leaves is body_leaves: + if leaf is opening_bracket: + current_leaves = head_leaves if body_leaves else tail_leaves + current_leaves.append(leaf) + if current_leaves is tail_leaves: + if leaf.type in CLOSING_BRACKETS and id(leaf) not in omit: + opening_bracket = leaf.opening_bracket + closing_bracket = leaf + current_leaves = body_leaves + if not (opening_bracket and closing_bracket and head_leaves): + # If there is no opening or closing_bracket that means the split failed and + # all content is in the tail. Otherwise, if `head_leaves` are empty, it means + # the matching `opening_bracket` wasn't available on `line` anymore. + raise CannotSplit("No brackets found") tail_leaves.reverse() body_leaves.reverse() @@ -2546,10 +4777,10 @@ def right_hand_split( elif head.contains_multiline_strings() or tail.contains_multiline_strings(): raise CannotSplit( - "The current optional pair of parentheses is bound to fail to " - "satisfy the splitting algorithm because the head or the tail " - "contains multiline strings which by definition never fit one " - "line." + "The current optional pair of parentheses is bound to fail to" + " satisfy the splitting algorithm because the head or the tail" + " contains multiline strings which by definition never fit one" + " line." ) ensure_visible(opening_bracket) @@ -2580,8 +4811,8 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None elif tail_len < 3: raise CannotSplit( - f"Splitting brackets on an empty body to save " - f"{tail_len} characters is not worth it" + f"Splitting brackets on an empty body to save {tail_len} characters is" + " not worth it" ) @@ -2612,11 +4843,11 @@ def bracket_split_build_line( for i in range(len(leaves) - 1, -1, -1): if leaves[i].type == STANDALONE_COMMENT: continue - elif leaves[i].type == token.COMMA: - break - else: + + if leaves[i].type != token.COMMA: leaves.insert(i + 1, Leaf(token.COMMA, ",")) - break + break + # Populate the line for leaf in leaves: result.append(leaf, preformatted=True) @@ -2627,7 +4858,7 @@ def bracket_split_build_line( return result -def dont_increase_indentation(split_func: SplitFunc) -> SplitFunc: +def dont_increase_indentation(split_func: Transformer) -> Transformer: """Normalize prefix of the first leaf in every line returned by `split_func`. This is a decorator over relevant split functions. @@ -2790,10 +5021,10 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: Note: Mutates its argument. """ - match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) + match = re.match(r"^([" + STRING_PREFIX_CHARS + r"]*)(.*)$", leaf.value, re.DOTALL) assert match is not None, f"failed to match string {leaf.value!r}" orig_prefix = match.group(1) - new_prefix = orig_prefix.lower() + new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") if remove_u_prefix: new_prefix = new_prefix.replace("u", "") leaf.value = f"{new_prefix}{match.group(2)}" @@ -2807,7 +5038,7 @@ def normalize_string_quotes(leaf: Leaf) -> None: Note: Mutates its argument. """ - value = leaf.value.lstrip("furbFURB") + value = leaf.value.lstrip(STRING_PREFIX_CHARS) if value[:3] == '"""': return @@ -2860,6 +5091,7 @@ def normalize_string_quotes(leaf: Leaf) -> None: if "\\" in str(m): # Do not introduce backslashes in interpolated expressions return + if new_quote == '"""' and new_body[-1:] == '"': # edge case: new_body = new_body[:-1] + '\\"' @@ -2932,9 +5164,13 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: if pc.value in FMT_OFF: # This `node` has a prefix with `# fmt: off`, don't mess with parens. return - check_lpar = False for index, child in enumerate(list(node.children)): + # Fixes a bug where invisible parens are not properly stripped from + # assignment statements that contain type annotations. + if isinstance(child, Node) and child.type == syms.annassign: + normalize_invisible_parens(child, parens_after=parens_after) + # Add parentheses around long tuple unpacking in assignments. if ( index == 0 @@ -2946,26 +5182,12 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: if check_lpar: if is_walrus_assignment(child): continue - if child.type == syms.atom: - # Determines if the underlying atom should be surrounded with - # invisible params - also makes parens invisible recursively - # within the atom and removes repeated invisible parens within - # the atom - should_surround_with_parens = maybe_make_parens_invisible_in_atom( - child, parent=node - ) - if should_surround_with_parens: - lpar = Leaf(token.LPAR, "") - rpar = Leaf(token.RPAR, "") - index = child.remove() or 0 - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + if child.type == syms.atom: + if maybe_make_parens_invisible_in_atom(child, parent=node): + wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): - # wrap child in visible parentheses - lpar = Leaf(token.LPAR, "(") - rpar = Leaf(token.RPAR, ")") - child.remove() - node.insert_child(index, Node(syms.atom, [lpar, child, rpar])) + wrap_in_parentheses(node, child, visible=True) elif node.type == syms.import_from: # "import from" nodes store parentheses directly as part of # the statement @@ -2980,15 +5202,7 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None: break elif not (isinstance(child, Leaf) and is_multiline_string(child)): - # wrap child in invisible parentheses - lpar = Leaf(token.LPAR, "") - rpar = Leaf(token.RPAR, "") - index = child.remove() or 0 - prefix = child.prefix - child.prefix = "" - new_child = Node(syms.atom, [lpar, child, rpar]) - new_child.prefix = prefix - node.insert_child(index, new_child) + wrap_in_parentheses(node, child, visible=False) check_lpar = isinstance(child, Leaf) and child.value in parens_after @@ -3032,7 +5246,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool: # That happens when one of the `ignored_nodes` ended with a NEWLINE # leaf (possibly followed by a DEDENT). hidden_value = hidden_value[:-1] - first_idx = None + first_idx: Optional[int] = None for ignored in ignored_nodes: index = ignored.remove() if first_idx is None: @@ -3061,13 +5275,54 @@ def generate_ignored_nodes(leaf: Leaf) -> Iterator[LN]: """ container: Optional[LN] = container_of(leaf) while container is not None and container.type != token.ENDMARKER: - for comment in list_comments(container.prefix, is_endmarker=False): - if comment.value in FMT_ON: - return + if is_fmt_on(container): + return + + # fix for fmt: on in children + if contains_fmt_on_at_column(container, leaf.column): + for child in container.children: + if contains_fmt_on_at_column(child, leaf.column): + return + yield child + else: + yield container + container = container.next_sibling + + +def is_fmt_on(container: LN) -> bool: + """Determine whether formatting is switched on within a container. + Determined by whether the last `# fmt:` comment is `on` or `off`. + """ + fmt_on = False + for comment in list_comments(container.prefix, is_endmarker=False): + if comment.value in FMT_ON: + fmt_on = True + elif comment.value in FMT_OFF: + fmt_on = False + return fmt_on + + +def contains_fmt_on_at_column(container: LN, column: int) -> bool: + """Determine if children at a given column have formatting switched on.""" + for child in container.children: + if ( + isinstance(child, Node) + and first_leaf_column(child) == column + or isinstance(child, Leaf) + and child.column == column + ): + if is_fmt_on(child): + return True + + return False - yield container - container = container.next_sibling +def first_leaf_column(node: Node) -> Optional[int]: + """Returns the column of the first leaf child of a node.""" + for child in node.children: + if isinstance(child, Leaf): + return child.column + return None def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool: @@ -3140,6 +5395,7 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: Parenthesis can be optional. Returns None otherwise""" if len(node.children) != 3: return None + lpar, wrapped, rpar = node.children if not (lpar.type == token.LPAR and rpar.type == token.RPAR): return None @@ -3147,6 +5403,24 @@ def unwrap_singleton_parenthesis(node: LN) -> Optional[LN]: return wrapped +def wrap_in_parentheses(parent: Node, child: LN, *, visible: bool = True) -> None: + """Wrap `child` in parentheses. + + This replaces `child` with an atom holding the parentheses and the old + child. That requires moving the prefix. + + If `visible` is False, the leaves will be valueless (and thus invisible). + """ + lpar = Leaf(token.LPAR, "(" if visible else "") + rpar = Leaf(token.RPAR, ")" if visible else "") + prefix = child.prefix + child.prefix = "" + index = child.remove() or 0 + new_child = Node(syms.atom, [lpar, child, rpar]) + new_child.prefix = prefix + parent.insert_child(index, new_child) + + def is_one_tuple(node: LN) -> bool: """Return True if `node` holds a tuple with one element, with or without parens.""" if node.type == syms.atom: @@ -3215,8 +5489,7 @@ def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: def is_multiline_string(leaf: Leaf) -> bool: """Return True if `leaf` is a multiline string that actually spans many lines.""" - value = leaf.value.lstrip("furbFURB") - return value[:3] in {'"""', "'''"} and "\n" in value + return has_triple_quotes(leaf.value) and "\n" in leaf.value def is_stub_suite(node: Node) -> bool: @@ -3379,8 +5652,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit length = 4 * line.depth - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None inner_brackets: Set[LeafID] = set() for index, leaf, leaf_length in enumerate_with_length(line, reversed=True): length += leaf_length @@ -3424,19 +5697,23 @@ def get_future_imports(node: Node) -> Set[str]: if isinstance(child, Leaf): if child.type == token.NAME: yield child.value + elif child.type == syms.import_as_name: orig_name = child.children[0] assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports" assert orig_name.type == token.NAME, "Invalid syntax parsing imports" yield orig_name.value + elif child.type == syms.import_as_names: yield from get_imports_from_children(child.children) + else: raise AssertionError("Invalid syntax parsing imports") for child in node.children: if child.type != syms.simple_stmt: break + first_child = child.children[0] if isinstance(first_child, Leaf): # Continue looking if we see a docstring; otherwise stop. @@ -3446,15 +5723,18 @@ def get_future_imports(node: Node) -> Set[str]: and child.children[1].type == token.NEWLINE ): continue - else: - break + + break + elif first_child.type == syms.import_from: module_name = first_child.children[1] if not isinstance(module_name, Leaf) or module_name.value != "__future__": break + imports |= set(get_imports_from_children(first_child.children[3:])) else: break + return imports @@ -3462,17 +5742,18 @@ def get_future_imports(node: Node) -> Set[str]: def get_gitignore(root: Path) -> PathSpec: """ Return a PathSpec matching gitignore content if present.""" gitignore = root / ".gitignore" - if not gitignore.is_file(): - return PathSpec.from_lines("gitwildmatch", []) - else: - return PathSpec.from_lines("gitwildmatch", gitignore.open()) + lines: List[str] = [] + if gitignore.is_file(): + with gitignore.open() as gf: + lines = gf.readlines() + return PathSpec.from_lines("gitwildmatch", lines) -def gen_python_files_in_dir( - path: Path, +def gen_python_files( + paths: Iterable[Path], root: Path, - include: Pattern[str], - exclude: Pattern[str], + include: Optional[Pattern[str]], + exclude_regexes: Iterable[Pattern[str]], report: "Report", gitignore: PathSpec, ) -> Iterator[Path]: @@ -3484,15 +5765,10 @@ def gen_python_files_in_dir( `report` is where output about exclusions goes. """ assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}" - for child in path.iterdir(): - # First ignore files matching .gitignore - if gitignore.match_file(child.as_posix()): - report.path_ignored(child, f"matches the .gitignore file content") - continue - + for child in paths: # Then ignore with `exclude` option. try: - normalized_path = "/" + child.resolve().relative_to(root).as_posix() + normalized_path = child.resolve().relative_to(root).as_posix() except OSError as e: report.path_ignored(child, f"cannot be read because {e}") continue @@ -3505,21 +5781,32 @@ def gen_python_files_in_dir( raise + # First ignore files matching .gitignore + if gitignore.match_file(normalized_path): + report.path_ignored(child, "matches the .gitignore file content") + continue + + normalized_path = "/" + normalized_path if child.is_dir(): normalized_path += "/" - exclude_match = exclude.search(normalized_path) - if exclude_match and exclude_match.group(0): - report.path_ignored(child, f"matches the --exclude regular expression") + is_excluded = False + for exclude in exclude_regexes: + exclude_match = exclude.search(normalized_path) if exclude else None + if exclude_match and exclude_match.group(0): + report.path_ignored(child, "matches the --exclude regular expression") + is_excluded = True + break + if is_excluded: continue if child.is_dir(): - yield from gen_python_files_in_dir( - child, root, include, exclude, report, gitignore + yield from gen_python_files( + child.iterdir(), root, include, exclude_regexes, report, gitignore ) elif child.is_file(): - include_match = include.search(normalized_path) + include_match = include.search(normalized_path) if include else True if include_match: yield child @@ -3542,7 +5829,7 @@ def find_project_root(srcs: Iterable[str]) -> Path: # Append a fake file so `parents` below returns `common_base_dir`, too. common_base /= "fake-file" for directory in common_base.parents: - if (directory / ".git").is_dir(): + if (directory / ".git").exists(): return directory if (directory / ".hg").is_dir(): @@ -3559,6 +5846,7 @@ class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + diff: bool = False quiet: bool = False verbose: bool = False change_count: int = 0 @@ -3568,7 +5856,7 @@ class Report: def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed is Changed.YES: - reformatted = "would reformat" if self.check else "reformatted" + reformatted = "would reformat" if self.check or self.diff else "reformatted" if self.verbose or not self.quiet: out(f"{reformatted} {src}") self.change_count += 1 @@ -3614,7 +5902,7 @@ class Report: Use `click.unstyle` to remove colors. """ - if self.check: + if self.check or self.diff: reformatted = "would be reformatted" unchanged = "would be left unchanged" failed = "would fail to reformat" @@ -3662,70 +5950,84 @@ def _fixup_ast_constants( node: Union[ast.AST, ast3.AST, ast27.AST] ) -> Union[ast.AST, ast3.AST, ast27.AST]: """Map ast nodes deprecated in 3.8 to Constant.""" - # casts are required until this is released: - # https://github.com/python/typeshed/pull/3142 if isinstance(node, (ast.Str, ast3.Str, ast27.Str, ast.Bytes, ast3.Bytes)): - return cast(ast.AST, ast.Constant(value=node.s)) - elif isinstance(node, (ast.Num, ast3.Num, ast27.Num)): - return cast(ast.AST, ast.Constant(value=node.n)) - elif isinstance(node, (ast.NameConstant, ast3.NameConstant)): - return cast(ast.AST, ast.Constant(value=node.value)) + return ast.Constant(value=node.s) + + if isinstance(node, (ast.Num, ast3.Num, ast27.Num)): + return ast.Constant(value=node.n) + + if isinstance(node, (ast.NameConstant, ast3.NameConstant)): + return ast.Constant(value=node.value) + return node -def assert_equivalent(src: str, dst: str) -> None: - """Raise AssertionError if `src` and `dst` aren't equivalent.""" +def _stringify_ast( + node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0 +) -> Iterator[str]: + """Simple visitor generating strings to compare ASTs by content.""" - def _v(node: Union[ast.AST, ast3.AST, ast27.AST], depth: int = 0) -> Iterator[str]: - """Simple visitor generating strings to compare ASTs by content.""" + node = _fixup_ast_constants(node) - node = _fixup_ast_constants(node) + yield f"{' ' * depth}{node.__class__.__name__}(" - yield f"{' ' * depth}{node.__class__.__name__}(" + for field in sorted(node._fields): # noqa: F402 + # TypeIgnore has only one field 'lineno' which breaks this comparison + type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) + if sys.version_info >= (3, 8): + type_ignore_classes += (ast.TypeIgnore,) + if isinstance(node, type_ignore_classes): + break - for field in sorted(node._fields): - # TypeIgnore has only one field 'lineno' which breaks this comparison - type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) - if sys.version_info >= (3, 8): - type_ignore_classes += (ast.TypeIgnore,) - if isinstance(node, type_ignore_classes): - break + try: + value = getattr(node, field) + except AttributeError: + continue - try: - value = getattr(node, field) - except AttributeError: - continue + yield f"{' ' * (depth+1)}{field}=" - yield f"{' ' * (depth+1)}{field}=" + if isinstance(value, list): + for item in value: + # Ignore nested tuples within del statements, because we may insert + # parentheses and they change the AST. + if ( + field == "targets" + and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) + and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) + ): + for item in item.elts: + yield from _stringify_ast(item, depth + 2) - if isinstance(value, list): - for item in value: - # Ignore nested tuples within del statements, because we may insert - # parentheses and they change the AST. - if ( - field == "targets" - and isinstance(node, (ast.Delete, ast3.Delete, ast27.Delete)) - and isinstance(item, (ast.Tuple, ast3.Tuple, ast27.Tuple)) - ): - for item in item.elts: - yield from _v(item, depth + 2) - elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): - yield from _v(item, depth + 2) + elif isinstance(item, (ast.AST, ast3.AST, ast27.AST)): + yield from _stringify_ast(item, depth + 2) - elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): - yield from _v(value, depth + 2) + elif isinstance(value, (ast.AST, ast3.AST, ast27.AST)): + yield from _stringify_ast(value, depth + 2) + else: + # Constant strings may be indented across newlines, if they are + # docstrings; fold spaces after newlines when comparing + if ( + isinstance(node, ast.Constant) + and field == "value" + and isinstance(value, str) + ): + normalized = re.sub(r"\n[ \t]+", "\n ", value) else: - yield f"{' ' * (depth+2)}{value!r}, # {value.__class__.__name__}" + normalized = value + yield f"{' ' * (depth+2)}{normalized!r}, # {value.__class__.__name__}" - yield f"{' ' * depth}) # /{node.__class__.__name__}" + yield f"{' ' * depth}) # /{node.__class__.__name__}" + +def assert_equivalent(src: str, dst: str) -> None: + """Raise AssertionError if `src` and `dst` aren't equivalent.""" try: src_ast = parse_ast(src) except Exception as exc: raise AssertionError( - f"cannot use --safe with this file; failed to parse source file. " - f"AST error message: {exc}" + "cannot use --safe with this file; failed to parse source file. AST" + f" error message: {exc}" ) try: @@ -3733,24 +6035,23 @@ def assert_equivalent(src: str, dst: str) -> None: except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( - f"INTERNAL ERROR: Black produced invalid code: {exc}. " - f"Please report a bug on https://github.com/psf/black/issues. " - f"This invalid output might be helpful: {log}" + f"INTERNAL ERROR: Black produced invalid code: {exc}. Please report a bug" + " on https://github.com/psf/black/issues. This invalid output might be" + f" helpful: {log}" ) from None - src_ast_str = "\n".join(_v(src_ast)) - dst_ast_str = "\n".join(_v(dst_ast)) + src_ast_str = "\n".join(_stringify_ast(src_ast)) + dst_ast_str = "\n".join(_stringify_ast(dst_ast)) if src_ast_str != dst_ast_str: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( - f"INTERNAL ERROR: Black produced code that is not equivalent to " - f"the source. " - f"Please report a bug on https://github.com/psf/black/issues. " - f"This diff might be helpful: {log}" + "INTERNAL ERROR: Black produced code that is not equivalent to the" + " source. Please report a bug on https://github.com/psf/black/issues. " + f" This diff might be helpful: {log}" ) from None -def assert_stable(src: str, dst: str, mode: FileMode) -> None: +def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" newdst = format_str(dst, mode=mode) if dst != newdst: @@ -3759,13 +6060,13 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: diff(dst, newdst, "first pass", "second pass"), ) raise AssertionError( - f"INTERNAL ERROR: Black produced different code on the second pass " - f"of the formatter. " - f"Please report a bug on https://github.com/psf/black/issues. " - f"This diff might be helpful: {log}" + "INTERNAL ERROR: Black produced different code on the second pass of the" + " formatter. Please report a bug on https://github.com/psf/black/issues." + f" This diff might be helpful: {log}" ) from None +@mypyc_attr(patchable=True) def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( @@ -3791,14 +6092,14 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: """Return a unified diff string between strings `a` and `b`.""" import difflib - a_lines = [line + "\n" for line in a.split("\n")] - b_lines = [line + "\n" for line in b.split("\n")] + a_lines = [line + "\n" for line in a.splitlines()] + b_lines = [line + "\n" for line in b.splitlines()] return "".join( difflib.unified_diff(a_lines, b_lines, fromfile=a_name, tofile=b_name, n=5) ) -def cancel(tasks: Iterable[asyncio.Task]) -> None: +def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: @@ -3887,7 +6188,7 @@ def is_line_short_enough(line: Line, *, line_length: int, line_str: str = "") -> Uses the provided `line_str` rendering, if any, otherwise computes a new one. """ if not line_str: - line_str = str(line).strip("\n") + line_str = line_to_string(line) return ( len(line_str) <= line_length and "\n" not in line_str # multiline strings @@ -4020,11 +6321,11 @@ def can_omit_invisible_parens(line: Line, line_length: int) -> bool: return False -def get_cache_file(mode: FileMode) -> Path: +def get_cache_file(mode: Mode) -> Path: return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle" -def read_cache(mode: FileMode) -> Cache: +def read_cache(mode: Mode) -> Cache: """Read the cache if it exists and is well formed. If it is not well formed, the call to write_cache later should resolve the issue. @@ -4064,7 +6365,7 @@ def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set return todo, done -def write_cache(cache: Cache, sources: Iterable[Path], mode: FileMode) -> None: +def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None: """Update the cache file.""" cache_file = get_cache_file(mode) try: @@ -4105,5 +6406,32 @@ def patched_main() -> None: main() +def fix_docstring(docstring: str, prefix: str) -> str: + # https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation + if not docstring: + return "" + # Convert tabs to spaces (following the normal Python rules) + # and split into a list of lines: + lines = docstring.expandtabs().splitlines() + # Determine minimum indentation (first line doesn't count): + indent = sys.maxsize + for line in lines[1:]: + stripped = line.lstrip() + if stripped: + indent = min(indent, len(line) - len(stripped)) + # Remove indentation (first line is special): + trimmed = [lines[0].strip()] + if indent < sys.maxsize: + last_line_idx = len(lines) - 2 + for i, line in enumerate(lines[1:]): + stripped_line = line[indent:].rstrip() + if stripped_line or i == last_line_idx: + trimmed.append(prefix + stripped_line) + else: + trimmed.append("") + # Return a single string: + return "\n".join(trimmed) + + if __name__ == "__main__": patched_main()