]> git.madduck.net Git - etc/vim.git/blob - src/black/handle_ipynb_magics.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update link to VS Code formatting instructions (#3921)
[etc/vim.git] / src / black / handle_ipynb_magics.py
1 """Functions to process IPython magics with."""
2
3 import ast
4 import collections
5 import dataclasses
6 import secrets
7 import sys
8 from functools import lru_cache
9 from importlib.util import find_spec
10 from typing import Dict, List, Optional, Tuple
11
12 if sys.version_info >= (3, 10):
13     from typing import TypeGuard
14 else:
15     from typing_extensions import TypeGuard
16
17 from black.output import out
18 from black.report import NothingChanged
19
20 TRANSFORMED_MAGICS = frozenset(
21     (
22         "get_ipython().run_cell_magic",
23         "get_ipython().system",
24         "get_ipython().getoutput",
25         "get_ipython().run_line_magic",
26     )
27 )
28 TOKENS_TO_IGNORE = frozenset(
29     (
30         "ENDMARKER",
31         "NL",
32         "NEWLINE",
33         "COMMENT",
34         "DEDENT",
35         "UNIMPORTANT_WS",
36         "ESCAPED_NL",
37     )
38 )
39 PYTHON_CELL_MAGICS = frozenset(
40     (
41         "capture",
42         "prun",
43         "pypy",
44         "python",
45         "python3",
46         "time",
47         "timeit",
48     )
49 )
50 TOKEN_HEX = secrets.token_hex
51
52
53 @dataclasses.dataclass(frozen=True)
54 class Replacement:
55     mask: str
56     src: str
57
58
59 @lru_cache
60 def jupyter_dependencies_are_installed(*, warn: bool) -> bool:
61     installed = (
62         find_spec("tokenize_rt") is not None and find_spec("IPython") is not None
63     )
64     if not installed and warn:
65         msg = (
66             "Skipping .ipynb files as Jupyter dependencies are not installed.\n"
67             'You can fix this by running ``pip install "black[jupyter]"``'
68         )
69         out(msg)
70     return installed
71
72
73 def remove_trailing_semicolon(src: str) -> Tuple[str, bool]:
74     """Remove trailing semicolon from Jupyter notebook cell.
75
76     For example,
77
78         fig, ax = plt.subplots()
79         ax.plot(x_data, y_data);  # plot data
80
81     would become
82
83         fig, ax = plt.subplots()
84         ax.plot(x_data, y_data)  # plot data
85
86     Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
87     ``tokenize_rt`` so that round-tripping works fine.
88     """
89     from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src
90
91     tokens = src_to_tokens(src)
92     trailing_semicolon = False
93     for idx, token in reversed_enumerate(tokens):
94         if token.name in TOKENS_TO_IGNORE:
95             continue
96         if token.name == "OP" and token.src == ";":
97             del tokens[idx]
98             trailing_semicolon = True
99         break
100     if not trailing_semicolon:
101         return src, False
102     return tokens_to_src(tokens), True
103
104
105 def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str:
106     """Put trailing semicolon back if cell originally had it.
107
108     Mirrors the logic in `quiet` from `IPython.core.displayhook`, but uses
109     ``tokenize_rt`` so that round-tripping works fine.
110     """
111     if not has_trailing_semicolon:
112         return src
113     from tokenize_rt import reversed_enumerate, src_to_tokens, tokens_to_src
114
115     tokens = src_to_tokens(src)
116     for idx, token in reversed_enumerate(tokens):
117         if token.name in TOKENS_TO_IGNORE:
118             continue
119         tokens[idx] = token._replace(src=token.src + ";")
120         break
121     else:  # pragma: nocover
122         raise AssertionError(
123             "INTERNAL ERROR: Was not able to reinstate trailing semicolon. "
124             "Please report a bug on https://github.com/psf/black/issues.  "
125         ) from None
126     return str(tokens_to_src(tokens))
127
128
129 def mask_cell(src: str) -> Tuple[str, List[Replacement]]:
130     """Mask IPython magics so content becomes parseable Python code.
131
132     For example,
133
134         %matplotlib inline
135         'foo'
136
137     becomes
138
139         "25716f358c32750e"
140         'foo'
141
142     The replacements are returned, along with the transformed code.
143     """
144     replacements: List[Replacement] = []
145     try:
146         ast.parse(src)
147     except SyntaxError:
148         # Might have IPython magics, will process below.
149         pass
150     else:
151         # Syntax is fine, nothing to mask, early return.
152         return src, replacements
153
154     from IPython.core.inputtransformer2 import TransformerManager
155
156     transformer_manager = TransformerManager()
157     transformed = transformer_manager.transform_cell(src)
158     transformed, cell_magic_replacements = replace_cell_magics(transformed)
159     replacements += cell_magic_replacements
160     transformed = transformer_manager.transform_cell(transformed)
161     transformed, magic_replacements = replace_magics(transformed)
162     if len(transformed.splitlines()) != len(src.splitlines()):
163         # Multi-line magic, not supported.
164         raise NothingChanged
165     replacements += magic_replacements
166     return transformed, replacements
167
168
169 def get_token(src: str, magic: str) -> str:
170     """Return randomly generated token to mask IPython magic with.
171
172     For example, if 'magic' was `%matplotlib inline`, then a possible
173     token to mask it with would be `"43fdd17f7e5ddc83"`. The token
174     will be the same length as the magic, and we make sure that it was
175     not already present anywhere else in the cell.
176     """
177     assert magic
178     nbytes = max(len(magic) // 2 - 1, 1)
179     token = TOKEN_HEX(nbytes)
180     counter = 0
181     while token in src:
182         token = TOKEN_HEX(nbytes)
183         counter += 1
184         if counter > 100:
185             raise AssertionError(
186                 "INTERNAL ERROR: Black was not able to replace IPython magic. "
187                 "Please report a bug on https://github.com/psf/black/issues.  "
188                 f"The magic might be helpful: {magic}"
189             ) from None
190     if len(token) + 2 < len(magic):
191         token = f"{token}."
192     return f'"{token}"'
193
194
195 def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]:
196     """Replace cell magic with token.
197
198     Note that 'src' will already have been processed by IPython's
199     TransformerManager().transform_cell.
200
201     Example,
202
203         get_ipython().run_cell_magic('t', '-n1', 'ls =!ls\\n')
204
205     becomes
206
207         "a794."
208         ls =!ls
209
210     The replacement, along with the transformed code, is returned.
211     """
212     replacements: List[Replacement] = []
213
214     tree = ast.parse(src)
215
216     cell_magic_finder = CellMagicFinder()
217     cell_magic_finder.visit(tree)
218     if cell_magic_finder.cell_magic is None:
219         return src, replacements
220     header = cell_magic_finder.cell_magic.header
221     mask = get_token(src, header)
222     replacements.append(Replacement(mask=mask, src=header))
223     return f"{mask}\n{cell_magic_finder.cell_magic.body}", replacements
224
225
226 def replace_magics(src: str) -> Tuple[str, List[Replacement]]:
227     """Replace magics within body of cell.
228
229     Note that 'src' will already have been processed by IPython's
230     TransformerManager().transform_cell.
231
232     Example, this
233
234         get_ipython().run_line_magic('matplotlib', 'inline')
235         'foo'
236
237     becomes
238
239         "5e67db56d490fd39"
240         'foo'
241
242     The replacement, along with the transformed code, are returned.
243     """
244     replacements = []
245     magic_finder = MagicFinder()
246     magic_finder.visit(ast.parse(src))
247     new_srcs = []
248     for i, line in enumerate(src.splitlines(), start=1):
249         if i in magic_finder.magics:
250             offsets_and_magics = magic_finder.magics[i]
251             if len(offsets_and_magics) != 1:  # pragma: nocover
252                 raise AssertionError(
253                     f"Expecting one magic per line, got: {offsets_and_magics}\n"
254                     "Please report a bug on https://github.com/psf/black/issues."
255                 )
256             col_offset, magic = (
257                 offsets_and_magics[0].col_offset,
258                 offsets_and_magics[0].magic,
259             )
260             mask = get_token(src, magic)
261             replacements.append(Replacement(mask=mask, src=magic))
262             line = line[:col_offset] + mask
263         new_srcs.append(line)
264     return "\n".join(new_srcs), replacements
265
266
267 def unmask_cell(src: str, replacements: List[Replacement]) -> str:
268     """Remove replacements from cell.
269
270     For example
271
272         "9b20"
273         foo = bar
274
275     becomes
276
277         %%time
278         foo = bar
279     """
280     for replacement in replacements:
281         src = src.replace(replacement.mask, replacement.src)
282     return src
283
284
285 def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]:
286     """Check if attribute is IPython magic.
287
288     Note that the source of the abstract syntax tree
289     will already have been processed by IPython's
290     TransformerManager().transform_cell.
291     """
292     return (
293         isinstance(node, ast.Attribute)
294         and isinstance(node.value, ast.Call)
295         and isinstance(node.value.func, ast.Name)
296         and node.value.func.id == "get_ipython"
297     )
298
299
300 def _get_str_args(args: List[ast.expr]) -> List[str]:
301     str_args = []
302     for arg in args:
303         assert isinstance(arg, ast.Str)
304         str_args.append(arg.s)
305     return str_args
306
307
308 @dataclasses.dataclass(frozen=True)
309 class CellMagic:
310     name: str
311     params: Optional[str]
312     body: str
313
314     @property
315     def header(self) -> str:
316         if self.params:
317             return f"%%{self.name} {self.params}"
318         return f"%%{self.name}"
319
320
321 # ast.NodeVisitor + dataclass = breakage under mypyc.
322 class CellMagicFinder(ast.NodeVisitor):
323     """Find cell magics.
324
325     Note that the source of the abstract syntax tree
326     will already have been processed by IPython's
327     TransformerManager().transform_cell.
328
329     For example,
330
331         %%time\n
332         foo()
333
334     would have been transformed to
335
336         get_ipython().run_cell_magic('time', '', 'foo()\\n')
337
338     and we look for instances of the latter.
339     """
340
341     def __init__(self, cell_magic: Optional[CellMagic] = None) -> None:
342         self.cell_magic = cell_magic
343
344     def visit_Expr(self, node: ast.Expr) -> None:
345         """Find cell magic, extract header and body."""
346         if (
347             isinstance(node.value, ast.Call)
348             and _is_ipython_magic(node.value.func)
349             and node.value.func.attr == "run_cell_magic"
350         ):
351             args = _get_str_args(node.value.args)
352             self.cell_magic = CellMagic(name=args[0], params=args[1], body=args[2])
353         self.generic_visit(node)
354
355
356 @dataclasses.dataclass(frozen=True)
357 class OffsetAndMagic:
358     col_offset: int
359     magic: str
360
361
362 # Unsurprisingly, subclassing ast.NodeVisitor means we can't use dataclasses here
363 # as mypyc will generate broken code.
364 class MagicFinder(ast.NodeVisitor):
365     """Visit cell to look for get_ipython calls.
366
367     Note that the source of the abstract syntax tree
368     will already have been processed by IPython's
369     TransformerManager().transform_cell.
370
371     For example,
372
373         %matplotlib inline
374
375     would have been transformed to
376
377         get_ipython().run_line_magic('matplotlib', 'inline')
378
379     and we look for instances of the latter (and likewise for other
380     types of magics).
381     """
382
383     def __init__(self) -> None:
384         self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list)
385
386     def visit_Assign(self, node: ast.Assign) -> None:
387         """Look for system assign magics.
388
389         For example,
390
391             black_version = !black --version
392             env = %env var
393
394         would have been (respectively) transformed to
395
396             black_version = get_ipython().getoutput('black --version')
397             env = get_ipython().run_line_magic('env', 'var')
398
399         and we look for instances of any of the latter.
400         """
401         if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):
402             args = _get_str_args(node.value.args)
403             if node.value.func.attr == "getoutput":
404                 src = f"!{args[0]}"
405             elif node.value.func.attr == "run_line_magic":
406                 src = f"%{args[0]}"
407                 if args[1]:
408                     src += f" {args[1]}"
409             else:
410                 raise AssertionError(
411                     f"Unexpected IPython magic {node.value.func.attr!r} found. "
412                     "Please report a bug on https://github.com/psf/black/issues."
413                 ) from None
414             self.magics[node.value.lineno].append(
415                 OffsetAndMagic(node.value.col_offset, src)
416             )
417         self.generic_visit(node)
418
419     def visit_Expr(self, node: ast.Expr) -> None:
420         """Look for magics in body of cell.
421
422         For examples,
423
424             !ls
425             !!ls
426             ?ls
427             ??ls
428
429         would (respectively) get transformed to
430
431             get_ipython().system('ls')
432             get_ipython().getoutput('ls')
433             get_ipython().run_line_magic('pinfo', 'ls')
434             get_ipython().run_line_magic('pinfo2', 'ls')
435
436         and we look for instances of any of the latter.
437         """
438         if isinstance(node.value, ast.Call) and _is_ipython_magic(node.value.func):
439             args = _get_str_args(node.value.args)
440             if node.value.func.attr == "run_line_magic":
441                 if args[0] == "pinfo":
442                     src = f"?{args[1]}"
443                 elif args[0] == "pinfo2":
444                     src = f"??{args[1]}"
445                 else:
446                     src = f"%{args[0]}"
447                     if args[1]:
448                         src += f" {args[1]}"
449             elif node.value.func.attr == "system":
450                 src = f"!{args[0]}"
451             elif node.value.func.attr == "getoutput":
452                 src = f"!!{args[0]}"
453             else:
454                 raise NothingChanged  # unsupported magic.
455             self.magics[node.value.lineno].append(
456                 OffsetAndMagic(node.value.col_offset, src)
457             )
458         self.generic_visit(node)