]> git.madduck.net Git - etc/vim.git/blob - src/black/handle_ipynb_magics.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

make black[jupyter] installation cross-shell (#3394)
[etc/vim.git] / src / black / handle_ipynb_magics.py
1 """Functions to process IPython magics with."""
2
3 import ast
4 import collections
5 import dataclasses
6 import secrets
7 import sys
8 from functools import lru_cache
9 from typing import Dict, List, Optional, Tuple
10
11 if sys.version_info >= (3, 10):
12     from typing import TypeGuard
13 else:
14     from typing_extensions import TypeGuard
15
16 from black.output import out
17 from black.report import NothingChanged
18
19 TRANSFORMED_MAGICS = frozenset(
20     (
21         "get_ipython().run_cell_magic",
22         "get_ipython().system",
23         "get_ipython().getoutput",
24         "get_ipython().run_line_magic",
25     )
26 )
27 TOKENS_TO_IGNORE = frozenset(
28     (
29         "ENDMARKER",
30         "NL",
31         "NEWLINE",
32         "COMMENT",
33         "DEDENT",
34         "UNIMPORTANT_WS",
35         "ESCAPED_NL",
36     )
37 )
38 PYTHON_CELL_MAGICS = frozenset(
39     (
40         "capture",
41         "prun",
42         "pypy",
43         "python",
44         "python3",
45         "time",
46         "timeit",
47     )
48 )
49 TOKEN_HEX = secrets.token_hex
50
51
52 @dataclasses.dataclass(frozen=True)
53 class Replacement:
54     mask: str
55     src: str
56
57
58 @lru_cache()
59 def jupyter_dependencies_are_installed(*, verbose: bool, quiet: bool) -> bool:
60     try:
61         import IPython  # noqa:F401
62         import tokenize_rt  # noqa:F401
63     except ModuleNotFoundError:
64         if verbose or not quiet:
65             msg = (
66                 "Skipping .ipynb files as Jupyter dependencies are not installed.\n"
67                 'You can fix this by running ``pip install "black[jupyter]"``'
68             )
69             out(msg)
70         return False
71     else:
72         return True
73
74
75 def remove_trailing_semicolon(src: str) -> Tuple[str, bool]:
76     """Remove trailing semicolon from Jupyter notebook cell.
77
78     For example,
79
80         fig, ax = plt.subplots()
81         ax.plot(x_data, y_data);  # plot data
82
83     would become
84
85         fig, ax = plt.subplots()
86         ax.plot(x_data, y_data)  # plot data
87
88     Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
89     ``tokenize_rt`` so that round-tripping works fine.
90     """
91     from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src
92
93     tokens = src_to_tokens(src)
94     trailing_semicolon = False
95     for idx, token in reversed_enumerate(tokens):
96         if token.name in TOKENS_TO_IGNORE:
97             continue
98         if token.name == "OP" and token.src == ";":
99             del tokens[idx]
100             trailing_semicolon = True
101         break
102     if not trailing_semicolon:
103         return src, False
104     return tokens_to_src(tokens), True
105
106
107 def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str:
108     """Put trailing semicolon back if cell originally had it.
109
110     Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
111     ``tokenize_rt`` so that round-tripping works fine.
112     """
113     if not has_trailing_semicolon:
114         return src
115     from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src
116
117     tokens = src_to_tokens(src)
118     for idx, token in reversed_enumerate(tokens):
119         if token.name in TOKENS_TO_IGNORE:
120             continue
121         tokens[idx] = token._replace(src=token.src + ";")
122         break
123     else:  # pragma: nocover
124         raise AssertionError(
125             "INTERNAL ERROR: Was not able to reinstate trailing semicolon. "
126             "Please report a bug on https://github.com/psf/black/issues.  "
127         ) from None
128     return str(tokens_to_src(tokens))
129
130
131 def mask_cell(src: str) -> Tuple[str, List[Replacement]]:
132     """Mask IPython magics so content becomes parseable Python code.
133
134     For example,
135
136         %matplotlib inline
137         'foo'
138
139     becomes
140
141         "25716f358c32750e"
142         'foo'
143
144     The replacements are returned, along with the transformed code.
145     """
146     replacements: List[Replacement] = []
147     try:
148         ast.parse(src)
149     except SyntaxError:
150         # Might have IPython magics, will process below.
151         pass
152     else:
153         # Syntax is fine, nothing to mask, early return.
154         return src, replacements
155
156     from IPython.core.inputtransformer2 import TransformerManager
157
158     transformer_manager = TransformerManager()
159     transformed = transformer_manager.transform_cell(src)
160     transformed, cell_magic_replacements = replace_cell_magics(transformed)
161     replacements += cell_magic_replacements
162     transformed = transformer_manager.transform_cell(transformed)
163     transformed, magic_replacements = replace_magics(transformed)
164     if len(transformed.splitlines()) != len(src.splitlines()):
165         # Multi-line magic, not supported.
166         raise NothingChanged
167     replacements += magic_replacements
168     return transformed, replacements
169
170
171 def get_token(src: str, magic: str) -> str:
172     """Return randomly generated token to mask IPython magic with.
173
174     For example, if 'magic' was `%matplotlib inline`, then a possible
175     token to mask it with would be `"43fdd17f7e5ddc83"`. The token
176     will be the same length as the magic, and we make sure that it was
177     not already present anywhere else in the cell.
178     """
179     assert magic
180     nbytes = max(len(magic) // 2 - 1, 1)
181     token = TOKEN_HEX(nbytes)
182     counter = 0
183     while token in src:
184         token = TOKEN_HEX(nbytes)
185         counter += 1
186         if counter > 100:
187             raise AssertionError(
188                 "INTERNAL ERROR: Black was not able to replace IPython magic. "
189                 "Please report a bug on https://github.com/psf/black/issues.  "
190                 f"The magic might be helpful: {magic}"
191             ) from None
192     if len(token) + 2 < len(magic):
193         token = f"{token}."
194     return f'"{token}"'
195
196
197 def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]:
198     """Replace cell magic with token.
199
200     Note that 'src' will already have been processed by IPython's
201     TransformerManager().transform_cell.
202
203     Example,
204
205         get_ipython().run_cell_magic('t', '-n1', 'ls =!ls\\n')
206
207     becomes
208
209         "a794."
210         ls =!ls
211
212     The replacement, along with the transformed code, is returned.
213     """
214     replacements: List[Replacement] = []
215
216     tree = ast.parse(src)
217
218     cell_magic_finder = CellMagicFinder()
219     cell_magic_finder.visit(tree)
220     if cell_magic_finder.cell_magic is None:
221         return src, replacements
222     header = cell_magic_finder.cell_magic.header
223     mask = get_token(src, header)
224     replacements.append(Replacement(mask=mask, src=header))
225     return f"{mask}\n{cell_magic_finder.cell_magic.body}", replacements
226
227
228 def replace_magics(src: str) -> Tuple[str, List[Replacement]]:
229     """Replace magics within body of cell.
230
231     Note that 'src' will already have been processed by IPython's
232     TransformerManager().transform_cell.
233
234     Example, this
235
236         get_ipython().run_line_magic('matplotlib', 'inline')
237         'foo'
238
239     becomes
240
241         "5e67db56d490fd39"
242         'foo'
243
244     The replacement, along with the transformed code, are returned.
245     """
246     replacements = []
247     magic_finder = MagicFinder()
248     magic_finder.visit(ast.parse(src))
249     new_srcs = []
250     for i, line in enumerate(src.splitlines(), start=1):
251         if i in magic_finder.magics:
252             offsets_and_magics = magic_finder.magics[i]
253             if len(offsets_and_magics) != 1:  # pragma: nocover
254                 raise AssertionError(
255                     f"Expecting one magic per line, got: {offsets_and_magics}\n"
256                     "Please report a bug on https://github.com/psf/black/issues."
257                 )
258             col_offset, magic = (
259                 offsets_and_magics[0].col_offset,
260                 offsets_and_magics[0].magic,
261             )
262             mask = get_token(src, magic)
263             replacements.append(Replacement(mask=mask, src=magic))
264             line = line[:col_offset] + mask
265         new_srcs.append(line)
266     return "\n".join(new_srcs), replacements
267
268
269 def unmask_cell(src: str, replacements: List[Replacement]) -> str:
270     """Remove replacements from cell.
271
272     For example
273
274         "9b20"
275         foo = bar
276
277     becomes
278
279         %%time
280         foo = bar
281     """
282     for replacement in replacements:
283         src = src.replace(replacement.mask, replacement.src)
284     return src
285
286
287 def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]:
288     """Check if attribute is IPython magic.
289
290     Note that the source of the abstract syntax tree
291     will already have been processed by IPython's
292     TransformerManager().transform_cell.
293     """
294     return (
295         isinstance(node, ast.Attribute)
296         and isinstance(node.value, ast.Call)
297         and isinstance(node.value.func, ast.Name)
298         and node.value.func.id == "get_ipython"
299     )
300
301
302 def _get_str_args(args: List[ast.expr]) -> List[str]:
303     str_args = []
304     for arg in args:
305         assert isinstance(arg, ast.Str)
306         str_args.append(arg.s)
307     return str_args
308
309
310 @dataclasses.dataclass(frozen=True)
311 class CellMagic:
312     name: str
313     params: Optional[str]
314     body: str
315
316     @property
317     def header(self) -> str:
318         if self.params:
319             return f"%%{self.name} {self.params}"
320         return f"%%{self.name}"
321
322
323 # ast.NodeVisitor + dataclass = breakage under mypyc.
324 class CellMagicFinder(ast.NodeVisitor):
325     """Find cell magics.
326
327     Note that the source of the abstract syntax tree
328     will already have been processed by IPython's
329     TransformerManager().transform_cell.
330
331     For example,
332
333         %%time\nfoo()
334
335     would have been transformed to
336
337         get_ipython().run_cell_magic('time', '', 'foo()\\n')
338
339     and we look for instances of the latter.
340     """
341
342     def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
343         self.cell_magic = cell_magic
344
345     def visit_Expr(self, node: ast.Expr) -> None:
346         """Find cell magic, extract header and body."""
347         if (
348             isinstance(node.value, ast.Call)
349             and _is_ipython_magic(node.value.func)
350             and node.value.func.attr == "run_cell_magic"
351         ):
352             args = _get_str_args(node.value.args)
353             self.cell_magic = CellMagic(name=args[0], params=args[1], body=args[2])
354         self.generic_visit(node)
355
356
357 @dataclasses.dataclass(frozen=True)
358 class OffsetAndMagic:
359     col_offset: int
360     magic: str
361
362
363 # Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
364 # as mypyc will generate broken code.
365 class MagicFinder(ast.NodeVisitor):
366     """Visit cell to look for get_ipython calls.
367
368     Note that the source of the abstract syntax tree
369     will already have been processed by IPython's
370     TransformerManager().transform_cell.
371
372     For example,
373
374         %matplotlib inline
375
376     would have been transformed to
377
378         get_ipython().run_line_magic('matplotlib', 'inline')
379
380     and we look for instances of the latter (and likewise for other
381     types of magics).
382     """
383
384     def __init__(self) -> None:
385         self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)
386
387     def visit_Assign(self, node: ast.Assign) -> None:
388         """Look for system assign magics.
389
390         For example,
391
392             black_version = !black --version
393             env = %env var
394
395         would have been (respectively) transformed to
396
397             black_version = get_ipython().getoutput('black --version')
398             env = get_ipython().run_line_magic('env', 'var')
399
400         and we look for instances of any of the latter.
401         """
402         if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):
403             args = _get_str_args(node.value.args)
404             if node.value.func.attr == "getoutput":
405                 src = f"!{args[0]}"
406             elif node.value.func.attr == "run_line_magic":
407                 src = f"%{args[0]}"
408                 if args[1]:
409                     src += f" {args[1]}"
410             else:
411                 raise AssertionError(
412                     f"Unexpected IPython magic {node.value.func.attr!r} found. "
413                     "Please report a bug on https://github.com/psf/black/issues."
414                 ) from None
415             self.magics[node.value.lineno].append(
416                 OffsetAndMagic(node.value.col_offset, src)
417             )
418         self.generic_visit(node)
419
420     def visit_Expr(self, node: ast.Expr) -> None:
421         """Look for magics in body of cell.
422
423         For examples,
424
425             !ls
426             !!ls
427             ?ls
428             ??ls
429
430         would (respectively) get transformed to
431
432             get_ipython().system('ls')
433             get_ipython().getoutput('ls')
434             get_ipython().run_line_magic('pinfo', 'ls')
435             get_ipython().run_line_magic('pinfo2', 'ls')
436
437         and we look for instances of any of the latter.
438         """
439         if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):
440             args = _get_str_args(node.value.args)
441             if node.value.func.attr == "run_line_magic":
442                 if args[0] == "pinfo":
443                     src = f"?{args[1]}"
444                 elif args[0] == "pinfo2":
445                     src = f"??{args[1]}"
446                 else:
447                     src = f"%{args[0]}"
448                     if args[1]:
449                         src += f" {args[1]}"
450             elif node.value.func.attr == "system":
451                 src = f"!{args[0]}"
452             elif node.value.func.attr == "getoutput":
453                 src = f"!!{args[0]}"
454             else:
455                 raise NothingChanged  # unsupported magic.
456             self.magics[node.value.lineno].append(
457                 OffsetAndMagic(node.value.col_offset, src)
458             )
459         self.generic_visit(node)