X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/8c22d232b56104376a12d1e68eaf216d04979830..b4a6bb08fa704facbf3397f95b3216e13c3c964a:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 405a010..8c28b6b 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -24,6 +24,7 @@ from typing import ( MutableMapping, Optional, Pattern, + Sequence, Set, Sized, Tuple, @@ -225,6 +226,16 @@ def validate_regex( "(useful when piping source on standard input)." ), ) +@click.option( + "--python-cell-magics", + multiple=True, + help=( + "When processing Jupyter Notebooks, add the given magic to the list" + f" of known python-magics ({', '.join(PYTHON_CELL_MAGICS)})." + " Useful for formatting cells with custom python magics." + ), + default=[], +) @click.option( "-S", "--skip-string-normalization", @@ -241,16 +252,13 @@ def validate_regex( "--experimental-string-processing", is_flag=True, hidden=True, - help=( - "Experimental option that performs more normalization on string literals." - " Currently disabled because it leads to some crashes." - ), + help="(DEPRECATED and now included in --preview) Normalize string literals.", ) @click.option( "--preview", is_flag=True, help=( - "Enable potentially disruptive style changes that will be added to Black's main" + "Enable potentially disruptive style changes that may be added to Black's main" " functionality in the next major release." ), ) @@ -283,7 +291,8 @@ def validate_regex( type=str, help=( "Require a specific version of Black to be running (useful for unifying results" - " across many environments e.g. with a pyproject.toml file)." + " across many environments e.g. with a pyproject.toml file). It can be" + " either a major version number or an exact version." ), ) @click.option( @@ -404,6 +413,7 @@ def main( fast: bool, pyi: bool, ipynb: bool, + python_cell_magics: Sequence[str], skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, @@ -422,6 +432,17 @@ def main( ) -> None: """The uncompromising code formatter.""" ctx.ensure_object(dict) + + if src and code is not None: + out( + main.get_usage(ctx) + + "\n\n'SRC' and 'code' cannot be passed simultaneously." + ) + ctx.exit(1) + if not src and code is None: + out(main.get_usage(ctx) + "\n\nOne of 'SRC' or 'code' is required.") + ctx.exit(1) + root, method = find_project_root(src) if code is None else (None, None) ctx.obj["root"] = root @@ -454,7 +475,11 @@ def main( out(f"Using configuration in '{config}'.", fg="blue") error_msg = "Oh no! 💥 💔 💥" - if required_version and required_version != __version__: + if ( + required_version + and required_version != __version__ + and required_version != __version__.split(".")[0] + ): err( f"{error_msg} The required version `{required_version}` does not match" f" the running version `{__version__}`!" @@ -479,6 +504,7 @@ def main( magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, preview=preview, + python_cell_magics=set(python_cell_magics), ) if code is not None: @@ -559,7 +585,6 @@ def get_sources( ) -> Set[Path]: """Compute the set of files to be formatted.""" sources: Set[Path] = set() - path_empty(src, "No Path provided. Nothing to do 😴", quiet, verbose, ctx) if exclude is None: exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) @@ -948,17 +973,7 @@ def check_stability_and_equivalence( content differently. """ assert_equivalent(src_contents, dst_contents) - - # Forced second pass to work around optional trailing commas (becoming - # forced trailing commas on pass 2) interacting differently with optional - # parentheses. Admittedly ugly. - dst_contents_pass2 = format_str(dst_contents, mode=mode) - if dst_contents != dst_contents_pass2: - dst_contents = dst_contents_pass2 - assert_equivalent(src_contents, dst_contents, pass_num=2) - assert_stable(src_contents, dst_contents, mode=mode) - # Note: no need to explicitly call `assert_stable` if `dst_contents` was - # the same as `dst_contents_pass2`. + assert_stable(src_contents, dst_contents, mode=mode) def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: @@ -984,7 +999,7 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo return dst_contents -def validate_cell(src: str) -> None: +def validate_cell(src: str, mode: Mode) -> None: """Check that cell does not already contain TransformerManager transformations, or non-Python cell magics, which might cause tokenizer_rt to break because of indentations. @@ -1003,7 +1018,10 @@ def validate_cell(src: str) -> None: """ if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): raise NothingChanged - if src[:2] == "%%" and src.split()[0][2:] not in PYTHON_CELL_MAGICS: + if ( + src[:2] == "%%" + and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics + ): raise NothingChanged @@ -1023,7 +1041,7 @@ def format_cell(src: str, *, fast: bool, mode: Mode) -> str: could potentially be automagics or multi-line magics, which are currently not supported. """ - validate_cell(src) + validate_cell(src, mode) src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( src ) @@ -1085,7 +1103,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon raise NothingChanged -def format_str(src_contents: str, *, mode: Mode) -> FileContent: +def format_str(src_contents: str, *, mode: Mode) -> str: """Reformat a string and return new contents. `mode` determines formatting options, such as how many characters per line are @@ -1115,6 +1133,16 @@ def format_str(src_contents: str, *, mode: Mode) -> FileContent: hey """ + dst_contents = _format_str_once(src_contents, mode=mode) + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + if src_contents != dst_contents: + return _format_str_once(dst_contents, mode=mode) + return dst_contents + + +def _format_str_once(src_contents: str, *, mode: Mode) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) dst_contents = [] future_imports = get_future_imports(src_node) @@ -1309,13 +1337,16 @@ def get_future_imports(node: Node) -> Set[str]: return imports -def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: +def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" try: src_ast = parse_ast(src) except Exception as exc: raise AssertionError( - f"cannot use --safe with this file; failed to parse source file: {exc}" + f"cannot use --safe with this file; failed to parse source file AST: " + f"{exc}\n" + f"This could be caused by running Black with an older Python version " + f"that does not support new syntax used in your source file." ) from exc try: @@ -1323,7 +1354,7 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise AssertionError( - f"INTERNAL ERROR: Black produced invalid code on pass {pass_num}: {exc}. " + f"INTERNAL ERROR: Black produced invalid code: {exc}. " "Please report a bug on https://github.com/psf/black/issues. " f"This invalid output might be helpful: {log}" ) from None @@ -1334,14 +1365,17 @@ def assert_equivalent(src: str, dst: str, *, pass_num: int = 1) -> None: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise AssertionError( "INTERNAL ERROR: Black produced code that is not equivalent to the" - f" source on pass {pass_num}. Please report a bug on " + f" source. Please report a bug on " f"https://github.com/psf/black/issues. This diff might be helpful: {log}" ) from None def assert_stable(src: str, dst: str, mode: Mode) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" - newdst = format_str(dst, mode=mode) + # We shouldn't call format_str() here, because that formats the string + # twice and may hide a bug where we bounce back and forth between two + # versions. + newdst = _format_str_once(dst, mode=mode) if dst != newdst: log = dump_to_file( str(mode),