+def validate_cell(src: str, mode: Mode) -> None:
+ """Check that cell does not already contain TransformerManager transformations,
+ or non-Python cell magics, which might cause tokenizer_rt to break because of
+ indentations.
+
+ If a cell contains ``!ls``, then it'll be transformed to
+ ``get_ipython().system('ls')``. However, if the cell originally contained
+ ``get_ipython().system('ls')``, then it would get transformed in the same way:
+
+ >>> TransformerManager().transform_cell("get_ipython().system('ls')")
+ "get_ipython().system('ls')\n"
+ >>> TransformerManager().transform_cell("!ls")
+ "get_ipython().system('ls')\n"
+
+ Due to the impossibility of safely roundtripping in such situations, cells
+ containing transformed magics will be ignored.
+ """
+ if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS):
+ raise NothingChanged
+ if (
+ src[:2] == "%%"
+ and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics
+ ):
+ raise NothingChanged
+
+
+def format_cell(src: str, *, fast: bool, mode: Mode) -> str:
+ """Format code in given cell of Jupyter notebook.
+
+ General idea is:
+
+ - if cell has trailing semicolon, remove it;
+ - if cell has IPython magics, mask them;
+ - format cell;
+ - reinstate IPython magics;
+ - reinstate trailing semicolon (if originally present);
+ - strip trailing newlines.
+
+ Cells with syntax errors will not be processed, as they
+ could potentially be automagics or multi-line magics, which
+ are currently not supported.
+ """
+ validate_cell(src, mode)
+ src_without_trailing_semicolon, has_trailing_semicolon = remove_trailing_semicolon(
+ src
+ )
+ try:
+ masked_src, replacements = mask_cell(src_without_trailing_semicolon)
+ except SyntaxError:
+ raise NothingChanged from None
+ masked_dst = format_str(masked_src, mode=mode)
+ if not fast:
+ check_stability_and_equivalence(masked_src, masked_dst, mode=mode)
+ dst_without_trailing_semicolon = unmask_cell(masked_dst, replacements)
+ dst = put_trailing_semicolon_back(
+ dst_without_trailing_semicolon, has_trailing_semicolon
+ )
+ dst = dst.rstrip("\n")
+ if dst == src:
+ raise NothingChanged from None
+ return dst
+
+
+def validate_metadata(nb: MutableMapping[str, Any]) -> None:
+ """If notebook is marked as non-Python, don't format it.
+
+ All notebook metadata fields are optional, see
+ https://nbformat.readthedocs.io/en/latest/format_description.html. So
+ if a notebook has empty metadata, we will try to parse it anyway.
+ """
+ language = nb.get("metadata", {}).get("language_info", {}).get("name", None)
+ if language is not None and language != "python":
+ raise NothingChanged from None
+
+
+def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
+ """Format Jupyter notebook.
+
+ Operate cell-by-cell, only on code cells, only for Python notebooks.
+ If the ``.ipynb`` originally had a trailing newline, it'll be preserved.
+ """
+ trailing_newline = src_contents[-1] == "\n"
+ modified = False
+ nb = json.loads(src_contents)
+ validate_metadata(nb)
+ for cell in nb["cells"]:
+ if cell.get("cell_type", None) == "code":
+ try:
+ src = "".join(cell["source"])
+ dst = format_cell(src, fast=fast, mode=mode)
+ except NothingChanged:
+ pass
+ else:
+ cell["source"] = dst.splitlines(keepends=True)
+ modified = True
+ if modified:
+ dst_contents = json.dumps(nb, indent=1, ensure_ascii=False)
+ if trailing_newline:
+ dst_contents = dst_contents + "\n"
+ return dst_contents
+ else:
+ raise NothingChanged
+
+