X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/ae56983a5f36d69e2a8cfa762f255a800039b0df..b1d060101626aa1c332f52e4bdf0ae5e4cc07990:/src/black/__init__.py diff --git a/src/black/__init__.py b/src/black/__init__.py index 51384fb..29fb244 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1,4 +1,6 @@ import asyncio +from json.decoder import JSONDecodeError +import json from concurrent.futures import Executor, ThreadPoolExecutor, ProcessPoolExecutor from contextlib import contextmanager from datetime import datetime @@ -18,6 +20,7 @@ from typing import ( Generator, Iterator, List, + MutableMapping, Optional, Pattern, Set, @@ -39,13 +42,21 @@ from black.mode import Mode, TargetVersion from black.mode import Feature, supports_feature, VERSION_TO_FEATURES from black.cache import read_cache, write_cache, get_cache_info, filter_cached, Cache from black.concurrency import cancel, shutdown, maybe_install_uvloop -from black.output import dump_to_file, diff, color_diff, out, err -from black.report import Report, Changed +from black.output import dump_to_file, ipynb_diff, diff, color_diff, out, err +from black.report import Report, Changed, NothingChanged from black.files import find_project_root, find_pyproject_toml, parse_pyproject_toml from black.files import gen_python_files, get_gitignore, normalize_path_maybe_ignore from black.files import wrap_stream_for_windows from black.parsing import InvalidInput # noqa F401 from black.parsing import lib2to3_parse, parse_ast, stringify_ast +from black.handle_ipynb_magics import ( + mask_cell, + unmask_cell, + remove_trailing_semicolon, + put_trailing_semicolon_back, + TRANSFORMED_MAGICS, + jupyter_dependencies_are_installed, +) # lib2to3 fork @@ -60,10 +71,6 @@ Encoding = str NewLine = str -class NothingChanged(UserWarning): - """Raised when reformatted code is the same as source.""" - - class WriteBack(Enum): NO = 0 YES = 1 @@ -196,6 +203,14 @@ def validate_regex( " when piping source on standard input)." ), ) +@click.option( + "--ipynb", + is_flag=True, + help=( + "Format all input files like Jupyter Notebooks regardless of file extension " + "(useful when piping source on standard input)." + ), +) @click.option( "-S", "--skip-string-normalization", @@ -355,6 +370,7 @@ def main( color: bool, fast: bool, pyi: bool, + ipynb: bool, skip_string_normalization: bool, skip_magic_trailing_comma: bool, experimental_string_processing: bool, @@ -380,6 +396,9 @@ def main( f" the running version `{__version__}`!" ) ctx.exit(1) + if ipynb and pyi: + err("Cannot pass both `pyi` and `ipynb` flags!") + ctx.exit(1) write_back = WriteBack.from_configuration(check=check, diff=diff, color=color) if target_version: @@ -391,6 +410,7 @@ def main( target_versions=versions, line_length=line_length, is_pyi=pyi, + is_ipynb=ipynb, string_normalization=not skip_string_normalization, magic_trailing_comma=not skip_magic_trailing_comma, experimental_string_processing=experimental_string_processing, @@ -504,6 +524,11 @@ def get_sources( if is_stdin: p = Path(f"{STDIN_PLACEHOLDER}{str(p)}") + if p.suffix == ".ipynb" and not jupyter_dependencies_are_installed( + verbose=verbose, quiet=quiet + ): + continue + sources.add(p) elif p.is_dir(): sources.update( @@ -516,6 +541,8 @@ def get_sources( force_exclude, report, gitignore, + verbose=verbose, + quiet=quiet, ) ) elif s == "-": @@ -585,6 +612,8 @@ def reformat_one( if is_stdin: if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode): changed = Changed.YES else: @@ -733,6 +762,8 @@ def format_file_in_place( """ if src.suffix == ".pyi": mode = replace(mode, is_pyi=True) + elif src.suffix == ".ipynb": + mode = replace(mode, is_ipynb=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -741,6 +772,8 @@ def format_file_in_place( dst_contents = format_file_contents(src_contents, fast=fast, mode=mode) except NothingChanged: return False + except JSONDecodeError: + raise ValueError(f"File '{src}' cannot be parsed as valid Jupyter notebook.") if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: @@ -749,7 +782,10 @@ def format_file_in_place( now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" - diff_contents = diff(src_contents, dst_contents, src_name, dst_name) + if mode.is_ipynb: + diff_contents = ipynb_diff(src_contents, dst_contents, src_name, dst_name) + else: + diff_contents = diff(src_contents, dst_contents, src_name, dst_name) if write_back == WriteBack.COLOR_DIFF: diff_contents = color_diff(diff_contents) @@ -819,6 +855,29 @@ def format_stdin_to_stdout( f.detach() +def check_stability_and_equivalence( + src_contents: str, dst_contents: str, *, mode: Mode +) -> None: + """Perform stability and equivalence checks. + + Raise AssertionError if source and destination contents are not + equivalent, or if a second pass of the formatter would format the + content differently. + """ + assert_equivalent(src_contents, dst_contents) + + # Forced second pass to work around optional trailing commas (becoming + # forced trailing commas on pass 2) interacting differently with optional + # parentheses. Admittedly ugly. + dst_contents_pass2 = format_str(dst_contents, mode=mode) + if dst_contents != dst_contents_pass2: + dst_contents = dst_contents_pass2 + assert_equivalent(src_contents, dst_contents, pass_num=2) + assert_stable(src_contents, dst_contents, mode=mode) + # Note: no need to explicitly call `assert_stable` if `dst_contents` was + # the same as `dst_contents_pass2`. + + def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Reformat contents of a file and return new contents. @@ -829,26 +888,116 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo if not src_contents.strip(): raise NothingChanged - dst_contents = format_str(src_contents, mode=mode) + if mode.is_ipynb: + dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode) + else: + dst_contents = format_str(src_contents, mode=mode) if src_contents == dst_contents: raise NothingChanged - if not fast: - assert_equivalent(src_contents, dst_contents) - - # Forced second pass to work around optional trailing commas (becoming - # forced trailing commas on pass 2) interacting differently with optional - # parentheses. Admittedly ugly. - dst_contents_pass2 = format_str(dst_contents, mode=mode) - if dst_contents != dst_contents_pass2: - dst_contents = dst_contents_pass2 - assert_equivalent(src_contents, dst_contents, pass_num=2) - assert_stable(src_contents, dst_contents, mode=mode) - # Note: no need to explicitly call `assert_stable` if `dst_contents` was - # the same as `dst_contents_pass2`. + if not fast and not mode.is_ipynb: + # Jupyter notebooks will already have been checked above. + check_stability_and_equivalence(src_contents, dst_contents, mode=mode) return dst_contents +def validate_cell(src: str) -> None: + """Check that cell does not already contain TransformerManager transformations. + + If a cell contains ``!ls``, then it'll be transformed to + ``get_ipython().system('ls')``. However, if the cell originally contained + ``get_ipython().system('ls')``, then it would get transformed in the same way: + + >>> TransformerManager().transform_cell("get_ipython().system('ls')") + "get_ipython().system('ls')\n" + >>> TransformerManager().transform_cell("!ls") + "get_ipython().system('ls')\n" + + Due to the impossibility of safely roundtripping in such situations, cells + containing transformed magics will be ignored. + """ + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): + raise NothingChanged + + +def format_cell(src: str, *, fast: bool, mode: Mode) -> str: + """Format code in given cell of Jupyter notebook. + + General idea is: + + - if cell has trailing semicolon, remove it; + - if cell has IPython magics, mask them; + - format cell; + - reinstate IPython magics; + - reinstate trailing semicolon (if originally present); + - strip trailing newlines. + + Cells with syntax errors will not be processed, as they + could potentially be automagics or multi-line magics, which + are currently not supported. + """ + validate_cell(src) + src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon( + src + ) + try: + masked_src, replacements = mask_cell(src_without_trailing_semicolon) + except SyntaxError: + raise NothingChanged + masked_dst = format_str(masked_src, mode=mode) + if not fast: + check_stability_and_equivalence(masked_src, masked_dst, mode=mode) + dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements) + dst = put_trailing_semicolon_back( + dst_without_trailing_semicolon, has_trailing_semicolon + ) + dst = dst.rstrip("\n") + if dst == src: + raise NothingChanged + return dst + + +def validate_metadata(nb: MutableMapping[str, Any]) -> None: + """If notebook is marked as non-Python, don't format it. + + All notebook metadata fields are optional, see + https://nbformat.readthedocs.io/en/latest/format_description.html. So + if a notebook has empty metadata, we will try to parse it anyway. + """ + language = nb.get("metadata", {}).get("language_info", {}).get("name", None) + if language is not None and language != "python": + raise NothingChanged + + +def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: + """Format Jupyter notebook. + + Operate cell-by-cell, only on code cells, only for Python notebooks. + If the ``.ipynb`` originally had a trailing newline, it'll be preseved. + """ + trailing_newline = src_contents[-1] == "\n" + modified = False + nb = json.loads(src_contents) + validate_metadata(nb) + for cell in nb["cells"]: + if cell.get("cell_type", None) == "code": + try: + src = "".join(cell["source"]) + dst = format_cell(src, fast=fast, mode=mode) + except NothingChanged: + pass + else: + cell["source"] = dst.splitlines(keepends=True) + modified = True + if modified: + dst_contents = json.dumps(nb, indent=1, ensure_ascii=False) + if trailing_newline: + dst_contents = dst_contents + "\n" + return dst_contents + else: + raise NothingChanged + + def format_str(src_contents: str, *, mode: Mode) -> FileContent: """Reformat a string and return new contents.