X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/73cb6e7734370108742d992d4fe1fa2829f100fd..e1506769a428889bc66964edabf76476433c031a:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 7024c9d..769e693 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -968,17 +968,7 @@ def check_stability_and_equivalence( content differently. """ assert_equivalent(src_contents, dst_contents) - - # Forced second pass to work around optional trailing commas (becoming - # forced trailing commas on pass 2) interacting differently with optional - # parentheses. Admittedly ugly. - dst_contents_pass2 = format_str(dst_contents, mode=mode) - if dst_contents != dst_contents_pass2: - dst_contents = dst_contents_pass2 - assert_equivalent(src_contents, dst_contents, pass_num=2) - assert_stable(src_contents, dst_contents, mode=mode) - # Note: no need to explicitly call `assert_stable` if `dst_contents` was - # the same as `dst_contents_pass2`. + assert_stable(src_contents, dst_contents, mode=mode) def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: @@ -1108,7 +1098,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon raise NothingChanged -def format_str(src_contents: str, *, mode: Mode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> str: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are @@ -1138,6 +1128,16 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: hey """ + dst_contents = _format_str_once(src_contents, mode=mode) + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + if src_contents != dst_contents: + return _format_str_once(dst_contents, mode=mode) + return dst_contents + + +def _format_str_once(src_contents: str, *, mode: Mode) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] future_imports = get_future_imports(src_node) @@ -1367,7 +1367,10 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, mode=mode) + # We shouldn't call format_str() here, because that formats the string + # twice and may hide a bug where we bounce back and forth between two + # versions. + newdst = _format_str_once(dst, mode=mode) if dst != newdst: log = dump_to_file( str(mode),