]> git.madduck.net Git - etc/vim.git/blob - src/black/handle_ipynb_magics.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Refactor Jupyter magic handling (#2545)
[etc/vim.git] / src / black / handle_ipynb_magics.py
1 """Functions to process IPython magics with."""
2
3 from functools import lru_cache
4 import dataclasses
5 import ast
6 from typing import Dict, List, Tuple, Optional
7
8 import secrets
9 import sys
10 import collections
11
12 if sys.version_info >= (3, 10):
13     from typing import TypeGuard
14 else:
15     from typing_extensions import TypeGuard
16
17 from black.report import NothingChanged
18 from black.output import out
19
20
21 TRANSFORMED_MAGICS = frozenset(
22     (
23         "get_ipython().run_cell_magic",
24         "get_ipython().system",
25         "get_ipython().getoutput",
26         "get_ipython().run_line_magic",
27     )
28 )
29 TOKENS_TO_IGNORE = frozenset(
30     (
31         "ENDMARKER",
32         "NL",
33         "NEWLINE",
34         "COMMENT",
35         "DEDENT",
36         "UNIMPORTANT_WS",
37         "ESCAPED_NL",
38     )
39 )
40 NON_PYTHON_CELL_MAGICS = frozenset(
41     (
42         "bash",
43         "html",
44         "javascript",
45         "js",
46         "latex",
47         "markdown",
48         "perl",
49         "ruby",
50         "script",
51         "sh",
52         "svg",
53         "writefile",
54     )
55 )
56 TOKEN_HEX = secrets.token_hex
57
58
59 @dataclasses.dataclass(frozen=True)
60 class Replacement:
61     mask: str
62     src: str
63
64
65 @lru_cache()
66 def jupyter_dependencies_are_installed(*, verbose: bool, quiet: bool) -> bool:
67     try:
68         import IPython  # noqa:F401
69         import tokenize_rt  # noqa:F401
70     except ModuleNotFoundError:
71         if verbose or not quiet:
72             msg = (
73                 "Skipping .ipynb files as Jupyter dependencies are not installed.\n"
74                 "You can fix this by running ``pip install black[jupyter]``"
75             )
76             out(msg)
77         return False
78     else:
79         return True
80
81
82 def remove_trailing_semicolon(src: str) -> Tuple[str, bool]:
83     """Remove trailing semicolon from Jupyter notebook cell.
84
85     For example,
86
87         fig, ax = plt.subplots()
88         ax.plot(x_data, y_data);  # plot data
89
90     would become
91
92         fig, ax = plt.subplots()
93         ax.plot(x_data, y_data)  # plot data
94
95     Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
96     ``tokenize_rt`` so that round-tripping works fine.
97     """
98     from tokenize_rt import (
99         src_to_tokens,
100         tokens_to_src,
101         reversed_enumerate,
102     )
103
104     tokens = src_to_tokens(src)
105     trailing_semicolon = False
106     for idx, token in reversed_enumerate(tokens):
107         if token.name in TOKENS_TO_IGNORE:
108             continue
109         if token.name == "OP" and token.src == ";":
110             del tokens[idx]
111             trailing_semicolon = True
112         break
113     if not trailing_semicolon:
114         return src, False
115     return tokens_to_src(tokens), True
116
117
118 def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str:
119     """Put trailing semicolon back if cell originally had it.
120
121     Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
122     ``tokenize_rt`` so that round-tripping works fine.
123     """
124     if not has_trailing_semicolon:
125         return src
126     from tokenize_rt import src_to_tokens, tokens_to_src, reversed_enumerate
127
128     tokens = src_to_tokens(src)
129     for idx, token in reversed_enumerate(tokens):
130         if token.name in TOKENS_TO_IGNORE:
131             continue
132         tokens[idx] = token._replace(src=token.src + ";")
133         break
134     else:  # pragma: nocover
135         raise AssertionError(
136             "INTERNAL ERROR: Was not able to reinstate trailing semicolon. "
137             "Please report a bug on https://github.com/psf/black/issues.  "
138         ) from None
139     return str(tokens_to_src(tokens))
140
141
142 def mask_cell(src: str) -> Tuple[str, List[Replacement]]:
143     """Mask IPython magics so content becomes parseable Python code.
144
145     For example,
146
147         %matplotlib inline
148         'foo'
149
150     becomes
151
152         "25716f358c32750e"
153         'foo'
154
155     The replacements are returned, along with the transformed code.
156     """
157     replacements: List[Replacement] = []
158     try:
159         ast.parse(src)
160     except SyntaxError:
161         # Might have IPython magics, will process below.
162         pass
163     else:
164         # Syntax is fine, nothing to mask, early return.
165         return src, replacements
166
167     from IPython.core.inputtransformer2 import TransformerManager
168
169     transformer_manager = TransformerManager()
170     transformed = transformer_manager.transform_cell(src)
171     transformed, cell_magic_replacements = replace_cell_magics(transformed)
172     replacements += cell_magic_replacements
173     transformed = transformer_manager.transform_cell(transformed)
174     transformed, magic_replacements = replace_magics(transformed)
175     if len(transformed.splitlines()) != len(src.splitlines()):
176         # Multi-line magic, not supported.
177         raise NothingChanged
178     replacements += magic_replacements
179     return transformed, replacements
180
181
182 def get_token(src: str, magic: str) -> str:
183     """Return randomly generated token to mask IPython magic with.
184
185     For example, if 'magic' was `%matplotlib inline`, then a possible
186     token to mask it with would be `"43fdd17f7e5ddc83"`. The token
187     will be the same length as the magic, and we make sure that it was
188     not already present anywhere else in the cell.
189     """
190     assert magic
191     nbytes = max(len(magic) // 2 - 1, 1)
192     token = TOKEN_HEX(nbytes)
193     counter = 0
194     while token in src:
195         token = TOKEN_HEX(nbytes)
196         counter += 1
197         if counter > 100:
198             raise AssertionError(
199                 "INTERNAL ERROR: Black was not able to replace IPython magic. "
200                 "Please report a bug on https://github.com/psf/black/issues.  "
201                 f"The magic might be helpful: {magic}"
202             ) from None
203     if len(token) + 2 < len(magic):
204         token = f"{token}."
205     return f'"{token}"'
206
207
208 def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]:
209     """Replace cell magic with token.
210
211     Note that 'src' will already have been processed by IPython's
212     TransformerManager().transform_cell.
213
214     Example,
215
216         get_ipython().run_cell_magic('t', '-n1', 'ls =!ls\\n')
217
218     becomes
219
220         "a794."
221         ls =!ls
222
223     The replacement, along with the transformed code, is returned.
224     """
225     replacements: List[Replacement] = []
226
227     tree = ast.parse(src)
228
229     cell_magic_finder = CellMagicFinder()
230     cell_magic_finder.visit(tree)
231     if cell_magic_finder.cell_magic is None:
232         return src, replacements
233     if cell_magic_finder.cell_magic.name in NON_PYTHON_CELL_MAGICS:
234         raise NothingChanged
235     header = cell_magic_finder.cell_magic.header
236     mask = get_token(src, header)
237     replacements.append(Replacement(mask=mask, src=header))
238     return f"{mask}\n{cell_magic_finder.cell_magic.body}", replacements
239
240
241 def replace_magics(src: str) -> Tuple[str, List[Replacement]]:
242     """Replace magics within body of cell.
243
244     Note that 'src' will already have been processed by IPython's
245     TransformerManager().transform_cell.
246
247     Example, this
248
249         get_ipython().run_line_magic('matplotlib', 'inline')
250         'foo'
251
252     becomes
253
254         "5e67db56d490fd39"
255         'foo'
256
257     The replacement, along with the transformed code, are returned.
258     """
259     replacements = []
260     magic_finder = MagicFinder()
261     magic_finder.visit(ast.parse(src))
262     new_srcs = []
263     for i, line in enumerate(src.splitlines(), start=1):
264         if i in magic_finder.magics:
265             offsets_and_magics = magic_finder.magics[i]
266             if len(offsets_and_magics) != 1:  # pragma: nocover
267                 raise AssertionError(
268                     f"Expecting one magic per line, got: {offsets_and_magics}\n"
269                     "Please report a bug on https://github.com/psf/black/issues."
270                 )
271             col_offset, magic = (
272                 offsets_and_magics[0].col_offset,
273                 offsets_and_magics[0].magic,
274             )
275             mask = get_token(src, magic)
276             replacements.append(Replacement(mask=mask, src=magic))
277             line = line[:col_offset] + mask
278         new_srcs.append(line)
279     return "\n".join(new_srcs), replacements
280
281
282 def unmask_cell(src: str, replacements: List[Replacement]) -> str:
283     """Remove replacements from cell.
284
285     For example
286
287         "9b20"
288         foo = bar
289
290     becomes
291
292         %%time
293         foo = bar
294     """
295     for replacement in replacements:
296         src = src.replace(replacement.mask, replacement.src)
297     return src
298
299
300 def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]:
301     """Check if attribute is IPython magic.
302
303     Note that the source of the abstract syntax tree
304     will already have been processed by IPython's
305     TransformerManager().transform_cell.
306     """
307     return (
308         isinstance(node, ast.Attribute)
309         and isinstance(node.value, ast.Call)
310         and isinstance(node.value.func, ast.Name)
311         and node.value.func.id == "get_ipython"
312     )
313
314
315 def _get_str_args(args: List[ast.expr]) -> List[str]:
316     str_args = []
317     for arg in args:
318         assert isinstance(arg, ast.Str)
319         str_args.append(arg.s)
320     return str_args
321
322
323 @dataclasses.dataclass(frozen=True)
324 class CellMagic:
325     name: str
326     params: Optional[str]
327     body: str
328
329     @property
330     def header(self) -> str:
331         if self.params:
332             return f"%%{self.name} {self.params}"
333         return f"%%{self.name}"
334
335
336 @dataclasses.dataclass
337 class CellMagicFinder(ast.NodeVisitor):
338     """Find cell magics.
339
340     Note that the source of the abstract syntax tree
341     will already have been processed by IPython's
342     TransformerManager().transform_cell.
343
344     For example,
345
346         %%time\nfoo()
347
348     would have been transformed to
349
350         get_ipython().run_cell_magic('time', '', 'foo()\\n')
351
352     and we look for instances of the latter.
353     """
354
355     cell_magic: Optional[CellMagic] = None
356
357     def visit_Expr(self, node: ast.Expr) -> None:
358         """Find cell magic, extract header and body."""
359         if (
360             isinstance(node.value, ast.Call)
361             and _is_ipython_magic(node.value.func)
362             and node.value.func.attr == "run_cell_magic"
363         ):
364             args = _get_str_args(node.value.args)
365             self.cell_magic = CellMagic(name=args[0], params=args[1], body=args[2])
366         self.generic_visit(node)
367
368
369 @dataclasses.dataclass(frozen=True)
370 class OffsetAndMagic:
371     col_offset: int
372     magic: str
373
374
375 @dataclasses.dataclass
376 class MagicFinder(ast.NodeVisitor):
377     """Visit cell to look for get_ipython calls.
378
379     Note that the source of the abstract syntax tree
380     will already have been processed by IPython's
381     TransformerManager().transform_cell.
382
383     For example,
384
385         %matplotlib inline
386
387     would have been transformed to
388
389         get_ipython().run_line_magic('matplotlib', 'inline')
390
391     and we look for instances of the latter (and likewise for other
392     types of magics).
393     """
394
395     magics: Dict[int, List[OffsetAndMagic]] = dataclasses.field(
396         default_factory=lambda: collections.defaultdict(list)
397     )
398
399     def visit_Assign(self, node: ast.Assign) -> None:
400         """Look for system assign magics.
401
402         For example,
403
404             black_version = !black --version
405
406         would have been transformed to
407
408             black_version = get_ipython().getoutput('black --version')
409
410         and we look for instances of the latter.
411         """
412         if (
413             isinstance(node.value, ast.Call)
414             and _is_ipython_magic(node.value.func)
415             and node.value.func.attr == "getoutput"
416         ):
417             (arg,) = _get_str_args(node.value.args)
418             src = f"!{arg}"
419             self.magics[node.value.lineno].append(
420                 OffsetAndMagic(node.value.col_offset, src)
421             )
422         self.generic_visit(node)
423
424     def visit_Expr(self, node: ast.Expr) -> None:
425         """Look for magics in body of cell.
426
427         For examples,
428
429             !ls
430             !!ls
431             ?ls
432             ??ls
433
434         would (respectively) get transformed to
435
436             get_ipython().system('ls')
437             get_ipython().getoutput('ls')
438             get_ipython().run_line_magic('pinfo', 'ls')
439             get_ipython().run_line_magic('pinfo2', 'ls')
440
441         and we look for instances of any of the latter.
442         """
443         if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):
444             args = _get_str_args(node.value.args)
445             if node.value.func.attr == "run_line_magic":
446                 if args[0] == "pinfo":
447                     src = f"?{args[1]}"
448                 elif args[0] == "pinfo2":
449                     src = f"??{args[1]}"
450                 else:
451                     src = f"%{args[0]}"
452                     if args[1]:
453                         assert src is not None
454                         src += f" {args[1]}"
455             elif node.value.func.attr == "system":
456                 src = f"!{args[0]}"
457             elif node.value.func.attr == "getoutput":
458                 src = f"!!{args[0]}"
459             else:
460                 raise NothingChanged  # unsupported magic.
461             self.magics[node.value.lineno].append(
462                 OffsetAndMagic(node.value.col_offset, src)
463             )
464         self.generic_visit(node)