X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/b1d060101626aa1c332f52e4bdf0ae5e4cc07990..aedb4ff7f061b321ea5804bc4fc4943c52c6a786:/src/black/handle_ipynb_magics.py diff --git a/src/black/handle_ipynb_magics.py b/src/black/handle_ipynb_magics.py index ad93c44..f10eaed 100644 --- a/src/black/handle_ipynb_magics.py +++ b/src/black/handle_ipynb_magics.py @@ -1,15 +1,19 @@ """Functions to process IPython magics with.""" + from functools import lru_cache import dataclasses import ast -from typing import Dict +from typing import Dict, List, Tuple, Optional import secrets -from typing import List, Tuple +import sys import collections -from typing import Optional -from typing_extensions import TypeGuard +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + from black.report import NothingChanged from black.output import out @@ -35,20 +39,21 @@ TOKENS_TO_IGNORE = frozenset( ) NON_PYTHON_CELL_MAGICS = frozenset( ( - "%%bash", - "%%html", - "%%javascript", - "%%js", - "%%latex", - "%%markdown", - "%%perl", - "%%ruby", - "%%script", - "%%sh", - "%%svg", - "%%writefile", + "bash", + "html", + "javascript", + "js", + "latex", + "markdown", + "perl", + "ruby", + "script", + "sh", + "svg", + "writefile", ) ) +TOKEN_HEX = secrets.token_hex @dataclasses.dataclass(frozen=True) @@ -184,10 +189,10 @@ def get_token(src: str, magic: str) -> str: """ assert magic nbytes = max(len(magic) // 2 - 1, 1) - token = secrets.token_hex(nbytes) + token = TOKEN_HEX(nbytes) counter = 0 - while token in src: # pragma: nocover - token = secrets.token_hex(nbytes) + while token in src: + token = TOKEN_HEX(nbytes) counter += 1 if counter > 100: raise AssertionError( @@ -225,10 +230,11 @@ def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]: cell_magic_finder.visit(tree) if cell_magic_finder.cell_magic is None: return src, replacements - if cell_magic_finder.cell_magic.header.split()[0] in NON_PYTHON_CELL_MAGICS: + if cell_magic_finder.cell_magic.name in NON_PYTHON_CELL_MAGICS: raise NothingChanged - mask = get_token(src, cell_magic_finder.cell_magic.header) - replacements.append(Replacement(mask=mask, src=cell_magic_finder.cell_magic.header)) + header = cell_magic_finder.cell_magic.header + mask = get_token(src, header) + replacements.append(Replacement(mask=mask, src=header)) return f"{mask}\n{cell_magic_finder.cell_magic.body}", replacements @@ -306,11 +312,26 @@ def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]: ) +def _get_str_args(args: List[ast.expr]) -> List[str]: + str_args = [] + for arg in args: + assert isinstance(arg, ast.Str) + str_args.append(arg.s) + return str_args + + @dataclasses.dataclass(frozen=True) class CellMagic: - header: str + name: str + params: Optional[str] body: str + @property + def header(self) -> str: + if self.params: + return f"%%{self.name} {self.params}" + return f"%%{self.name}" + @dataclasses.dataclass class CellMagicFinder(ast.NodeVisitor): @@ -340,14 +361,8 @@ class CellMagicFinder(ast.NodeVisitor): and _is_ipython_magic(node.value.func) and node.value.func.attr == "run_cell_magic" ): - args = [] - for arg in node.value.args: - assert isinstance(arg, ast.Str) - args.append(arg.s) - header = f"%%{args[0]}" - if args[1]: - header += f" {args[1]}" - self.cell_magic = CellMagic(header=header, body=args[2]) + args = _get_str_args(node.value.args) + self.cell_magic = CellMagic(name=args[0], params=args[1], body=args[2]) self.generic_visit(node) @@ -399,12 +414,8 @@ class MagicFinder(ast.NodeVisitor): and _is_ipython_magic(node.value.func) and node.value.func.attr == "getoutput" ): - args = [] - for arg in node.value.args: - assert isinstance(arg, ast.Str) - args.append(arg.s) - assert args - src = f"!{args[0]}" + (arg,) = _get_str_args(node.value.args) + src = f"!{arg}" self.magics[node.value.lineno].append( OffsetAndMagic(node.value.col_offset, src) ) @@ -430,11 +441,7 @@ class MagicFinder(ast.NodeVisitor): and we look for instances of any of the latter. """ if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func): - args = [] - for arg in node.value.args: - assert isinstance(arg, ast.Str) - args.append(arg.s) - assert args + args = _get_str_args(node.value.args) if node.value.func.attr == "run_line_magic": if args[0] == "pinfo": src = f"?{args[1]}"