##// END OF EJS Templates
Fix/mpl integration (#14128)...
Matthias Bussonnier -
r28375:175d52c8 merge
parent child Browse files
Show More
@@ -1,425 +1,433 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Pylab (matplotlib) support utilities."""
2 """Pylab (matplotlib) support utilities."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 from io import BytesIO
7 from io import BytesIO
8 from binascii import b2a_base64
8 from binascii import b2a_base64
9 from functools import partial
9 from functools import partial
10 import warnings
10 import warnings
11
11
12 from IPython.core.display import _pngxy
12 from IPython.core.display import _pngxy
13 from IPython.utils.decorators import flag_calls
13 from IPython.utils.decorators import flag_calls
14
14
15 # If user specifies a GUI, that dictates the backend, otherwise we read the
15 # If user specifies a GUI, that dictates the backend, otherwise we read the
16 # user's mpl default from the mpl rc structure
16 # user's mpl default from the mpl rc structure
17 backends = {
17 backends = {
18 "tk": "TkAgg",
18 "tk": "TkAgg",
19 "gtk": "GTKAgg",
19 "gtk": "GTKAgg",
20 "gtk3": "GTK3Agg",
20 "gtk3": "GTK3Agg",
21 "gtk4": "GTK4Agg",
21 "gtk4": "GTK4Agg",
22 "wx": "WXAgg",
22 "wx": "WXAgg",
23 "qt4": "Qt4Agg",
23 "qt4": "Qt4Agg",
24 "qt5": "Qt5Agg",
24 "qt5": "Qt5Agg",
25 "qt6": "QtAgg",
25 "qt6": "QtAgg",
26 "qt": "Qt5Agg",
26 "qt": "QtAgg",
27 "osx": "MacOSX",
27 "osx": "MacOSX",
28 "nbagg": "nbAgg",
28 "nbagg": "nbAgg",
29 "webagg": "WebAgg",
29 "webagg": "WebAgg",
30 "notebook": "nbAgg",
30 "notebook": "nbAgg",
31 "agg": "agg",
31 "agg": "agg",
32 "svg": "svg",
32 "svg": "svg",
33 "pdf": "pdf",
33 "pdf": "pdf",
34 "ps": "ps",
34 "ps": "ps",
35 "inline": "module://matplotlib_inline.backend_inline",
35 "inline": "module://matplotlib_inline.backend_inline",
36 "ipympl": "module://ipympl.backend_nbagg",
36 "ipympl": "module://ipympl.backend_nbagg",
37 "widget": "module://ipympl.backend_nbagg",
37 "widget": "module://ipympl.backend_nbagg",
38 }
38 }
39
39
40 # We also need a reverse backends2guis mapping that will properly choose which
40 # We also need a reverse backends2guis mapping that will properly choose which
41 # GUI support to activate based on the desired matplotlib backend. For the
41 # GUI support to activate based on the desired matplotlib backend. For the
42 # most part it's just a reverse of the above dict, but we also need to add a
42 # most part it's just a reverse of the above dict, but we also need to add a
43 # few others that map to the same GUI manually:
43 # few others that map to the same GUI manually:
44 backend2gui = dict(zip(backends.values(), backends.keys()))
44 backend2gui = dict(zip(backends.values(), backends.keys()))
45 # In the reverse mapping, there are a few extra valid matplotlib backends that
45 # In the reverse mapping, there are a few extra valid matplotlib backends that
46 # map to the same GUI support
46 # map to the same GUI support
47 backend2gui["GTK"] = backend2gui["GTKCairo"] = "gtk"
47 backend2gui["GTK"] = backend2gui["GTKCairo"] = "gtk"
48 backend2gui["GTK3Cairo"] = "gtk3"
48 backend2gui["GTK3Cairo"] = "gtk3"
49 backend2gui["GTK4Cairo"] = "gtk4"
49 backend2gui["GTK4Cairo"] = "gtk4"
50 backend2gui["WX"] = "wx"
50 backend2gui["WX"] = "wx"
51 backend2gui["CocoaAgg"] = "osx"
51 backend2gui["CocoaAgg"] = "osx"
52 # There needs to be a hysteresis here as the new QtAgg Matplotlib backend
52 # There needs to be a hysteresis here as the new QtAgg Matplotlib backend
53 # supports either Qt5 or Qt6 and the IPython qt event loop support Qt4, Qt5,
53 # supports either Qt5 or Qt6 and the IPython qt event loop support Qt4, Qt5,
54 # and Qt6.
54 # and Qt6.
55 backend2gui["QtAgg"] = "qt"
55 backend2gui["QtAgg"] = "qt"
56 backend2gui["Qt4Agg"] = "qt"
56 backend2gui["Qt4Agg"] = "qt4"
57 backend2gui["Qt5Agg"] = "qt"
57 backend2gui["Qt5Agg"] = "qt5"
58
58
59 # And some backends that don't need GUI integration
59 # And some backends that don't need GUI integration
60 del backend2gui["nbAgg"]
60 del backend2gui["nbAgg"]
61 del backend2gui["agg"]
61 del backend2gui["agg"]
62 del backend2gui["svg"]
62 del backend2gui["svg"]
63 del backend2gui["pdf"]
63 del backend2gui["pdf"]
64 del backend2gui["ps"]
64 del backend2gui["ps"]
65 del backend2gui["module://matplotlib_inline.backend_inline"]
65 del backend2gui["module://matplotlib_inline.backend_inline"]
66 del backend2gui["module://ipympl.backend_nbagg"]
66 del backend2gui["module://ipympl.backend_nbagg"]
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Matplotlib utilities
69 # Matplotlib utilities
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72
72
73 def getfigs(*fig_nums):
73 def getfigs(*fig_nums):
74 """Get a list of matplotlib figures by figure numbers.
74 """Get a list of matplotlib figures by figure numbers.
75
75
76 If no arguments are given, all available figures are returned. If the
76 If no arguments are given, all available figures are returned. If the
77 argument list contains references to invalid figures, a warning is printed
77 argument list contains references to invalid figures, a warning is printed
78 but the function continues pasting further figures.
78 but the function continues pasting further figures.
79
79
80 Parameters
80 Parameters
81 ----------
81 ----------
82 figs : tuple
82 figs : tuple
83 A tuple of ints giving the figure numbers of the figures to return.
83 A tuple of ints giving the figure numbers of the figures to return.
84 """
84 """
85 from matplotlib._pylab_helpers import Gcf
85 from matplotlib._pylab_helpers import Gcf
86 if not fig_nums:
86 if not fig_nums:
87 fig_managers = Gcf.get_all_fig_managers()
87 fig_managers = Gcf.get_all_fig_managers()
88 return [fm.canvas.figure for fm in fig_managers]
88 return [fm.canvas.figure for fm in fig_managers]
89 else:
89 else:
90 figs = []
90 figs = []
91 for num in fig_nums:
91 for num in fig_nums:
92 f = Gcf.figs.get(num)
92 f = Gcf.figs.get(num)
93 if f is None:
93 if f is None:
94 print('Warning: figure %s not available.' % num)
94 print('Warning: figure %s not available.' % num)
95 else:
95 else:
96 figs.append(f.canvas.figure)
96 figs.append(f.canvas.figure)
97 return figs
97 return figs
98
98
99
99
100 def figsize(sizex, sizey):
100 def figsize(sizex, sizey):
101 """Set the default figure size to be [sizex, sizey].
101 """Set the default figure size to be [sizex, sizey].
102
102
103 This is just an easy to remember, convenience wrapper that sets::
103 This is just an easy to remember, convenience wrapper that sets::
104
104
105 matplotlib.rcParams['figure.figsize'] = [sizex, sizey]
105 matplotlib.rcParams['figure.figsize'] = [sizex, sizey]
106 """
106 """
107 import matplotlib
107 import matplotlib
108 matplotlib.rcParams['figure.figsize'] = [sizex, sizey]
108 matplotlib.rcParams['figure.figsize'] = [sizex, sizey]
109
109
110
110
111 def print_figure(fig, fmt="png", bbox_inches="tight", base64=False, **kwargs):
111 def print_figure(fig, fmt="png", bbox_inches="tight", base64=False, **kwargs):
112 """Print a figure to an image, and return the resulting file data
112 """Print a figure to an image, and return the resulting file data
113
113
114 Returned data will be bytes unless ``fmt='svg'``,
114 Returned data will be bytes unless ``fmt='svg'``,
115 in which case it will be unicode.
115 in which case it will be unicode.
116
116
117 Any keyword args are passed to fig.canvas.print_figure,
117 Any keyword args are passed to fig.canvas.print_figure,
118 such as ``quality`` or ``bbox_inches``.
118 such as ``quality`` or ``bbox_inches``.
119
119
120 If `base64` is True, return base64-encoded str instead of raw bytes
120 If `base64` is True, return base64-encoded str instead of raw bytes
121 for binary-encoded image formats
121 for binary-encoded image formats
122
122
123 .. versionadded:: 7.29
123 .. versionadded:: 7.29
124 base64 argument
124 base64 argument
125 """
125 """
126 # When there's an empty figure, we shouldn't return anything, otherwise we
126 # When there's an empty figure, we shouldn't return anything, otherwise we
127 # get big blank areas in the qt console.
127 # get big blank areas in the qt console.
128 if not fig.axes and not fig.lines:
128 if not fig.axes and not fig.lines:
129 return
129 return
130
130
131 dpi = fig.dpi
131 dpi = fig.dpi
132 if fmt == 'retina':
132 if fmt == 'retina':
133 dpi = dpi * 2
133 dpi = dpi * 2
134 fmt = 'png'
134 fmt = 'png'
135
135
136 # build keyword args
136 # build keyword args
137 kw = {
137 kw = {
138 "format":fmt,
138 "format":fmt,
139 "facecolor":fig.get_facecolor(),
139 "facecolor":fig.get_facecolor(),
140 "edgecolor":fig.get_edgecolor(),
140 "edgecolor":fig.get_edgecolor(),
141 "dpi":dpi,
141 "dpi":dpi,
142 "bbox_inches":bbox_inches,
142 "bbox_inches":bbox_inches,
143 }
143 }
144 # **kwargs get higher priority
144 # **kwargs get higher priority
145 kw.update(kwargs)
145 kw.update(kwargs)
146
146
147 bytes_io = BytesIO()
147 bytes_io = BytesIO()
148 if fig.canvas is None:
148 if fig.canvas is None:
149 from matplotlib.backend_bases import FigureCanvasBase
149 from matplotlib.backend_bases import FigureCanvasBase
150 FigureCanvasBase(fig)
150 FigureCanvasBase(fig)
151
151
152 fig.canvas.print_figure(bytes_io, **kw)
152 fig.canvas.print_figure(bytes_io, **kw)
153 data = bytes_io.getvalue()
153 data = bytes_io.getvalue()
154 if fmt == 'svg':
154 if fmt == 'svg':
155 data = data.decode('utf-8')
155 data = data.decode('utf-8')
156 elif base64:
156 elif base64:
157 data = b2a_base64(data, newline=False).decode("ascii")
157 data = b2a_base64(data, newline=False).decode("ascii")
158 return data
158 return data
159
159
160 def retina_figure(fig, base64=False, **kwargs):
160 def retina_figure(fig, base64=False, **kwargs):
161 """format a figure as a pixel-doubled (retina) PNG
161 """format a figure as a pixel-doubled (retina) PNG
162
162
163 If `base64` is True, return base64-encoded str instead of raw bytes
163 If `base64` is True, return base64-encoded str instead of raw bytes
164 for binary-encoded image formats
164 for binary-encoded image formats
165
165
166 .. versionadded:: 7.29
166 .. versionadded:: 7.29
167 base64 argument
167 base64 argument
168 """
168 """
169 pngdata = print_figure(fig, fmt="retina", base64=False, **kwargs)
169 pngdata = print_figure(fig, fmt="retina", base64=False, **kwargs)
170 # Make sure that retina_figure acts just like print_figure and returns
170 # Make sure that retina_figure acts just like print_figure and returns
171 # None when the figure is empty.
171 # None when the figure is empty.
172 if pngdata is None:
172 if pngdata is None:
173 return
173 return
174 w, h = _pngxy(pngdata)
174 w, h = _pngxy(pngdata)
175 metadata = {"width": w//2, "height":h//2}
175 metadata = {"width": w//2, "height":h//2}
176 if base64:
176 if base64:
177 pngdata = b2a_base64(pngdata, newline=False).decode("ascii")
177 pngdata = b2a_base64(pngdata, newline=False).decode("ascii")
178 return pngdata, metadata
178 return pngdata, metadata
179
179
180
180
181 # We need a little factory function here to create the closure where
181 # We need a little factory function here to create the closure where
182 # safe_execfile can live.
182 # safe_execfile can live.
183 def mpl_runner(safe_execfile):
183 def mpl_runner(safe_execfile):
184 """Factory to return a matplotlib-enabled runner for %run.
184 """Factory to return a matplotlib-enabled runner for %run.
185
185
186 Parameters
186 Parameters
187 ----------
187 ----------
188 safe_execfile : function
188 safe_execfile : function
189 This must be a function with the same interface as the
189 This must be a function with the same interface as the
190 :meth:`safe_execfile` method of IPython.
190 :meth:`safe_execfile` method of IPython.
191
191
192 Returns
192 Returns
193 -------
193 -------
194 A function suitable for use as the ``runner`` argument of the %run magic
194 A function suitable for use as the ``runner`` argument of the %run magic
195 function.
195 function.
196 """
196 """
197
197
198 def mpl_execfile(fname,*where,**kw):
198 def mpl_execfile(fname,*where,**kw):
199 """matplotlib-aware wrapper around safe_execfile.
199 """matplotlib-aware wrapper around safe_execfile.
200
200
201 Its interface is identical to that of the :func:`execfile` builtin.
201 Its interface is identical to that of the :func:`execfile` builtin.
202
202
203 This is ultimately a call to execfile(), but wrapped in safeties to
203 This is ultimately a call to execfile(), but wrapped in safeties to
204 properly handle interactive rendering."""
204 properly handle interactive rendering."""
205
205
206 import matplotlib
206 import matplotlib
207 import matplotlib.pyplot as plt
207 import matplotlib.pyplot as plt
208
208
209 #print '*** Matplotlib runner ***' # dbg
209 #print '*** Matplotlib runner ***' # dbg
210 # turn off rendering until end of script
210 # turn off rendering until end of script
211 is_interactive = matplotlib.rcParams['interactive']
211 with matplotlib.rc_context({"interactive": False}):
212 matplotlib.interactive(False)
212 safe_execfile(fname, *where, **kw)
213 safe_execfile(fname,*where,**kw)
213
214 matplotlib.interactive(is_interactive)
214 if matplotlib.is_interactive():
215 plt.show()
216
215 # make rendering call now, if the user tried to do it
217 # make rendering call now, if the user tried to do it
216 if plt.draw_if_interactive.called:
218 if plt.draw_if_interactive.called:
217 plt.draw()
219 plt.draw()
218 plt.draw_if_interactive.called = False
220 plt.draw_if_interactive.called = False
219
221
220 # re-draw everything that is stale
222 # re-draw everything that is stale
221 try:
223 try:
222 da = plt.draw_all
224 da = plt.draw_all
223 except AttributeError:
225 except AttributeError:
224 pass
226 pass
225 else:
227 else:
226 da()
228 da()
227
229
228 return mpl_execfile
230 return mpl_execfile
229
231
230
232
231 def _reshow_nbagg_figure(fig):
233 def _reshow_nbagg_figure(fig):
232 """reshow an nbagg figure"""
234 """reshow an nbagg figure"""
233 try:
235 try:
234 reshow = fig.canvas.manager.reshow
236 reshow = fig.canvas.manager.reshow
235 except AttributeError as e:
237 except AttributeError as e:
236 raise NotImplementedError() from e
238 raise NotImplementedError() from e
237 else:
239 else:
238 reshow()
240 reshow()
239
241
240
242
241 def select_figure_formats(shell, formats, **kwargs):
243 def select_figure_formats(shell, formats, **kwargs):
242 """Select figure formats for the inline backend.
244 """Select figure formats for the inline backend.
243
245
244 Parameters
246 Parameters
245 ----------
247 ----------
246 shell : InteractiveShell
248 shell : InteractiveShell
247 The main IPython instance.
249 The main IPython instance.
248 formats : str or set
250 formats : str or set
249 One or a set of figure formats to enable: 'png', 'retina', 'jpeg', 'svg', 'pdf'.
251 One or a set of figure formats to enable: 'png', 'retina', 'jpeg', 'svg', 'pdf'.
250 **kwargs : any
252 **kwargs : any
251 Extra keyword arguments to be passed to fig.canvas.print_figure.
253 Extra keyword arguments to be passed to fig.canvas.print_figure.
252 """
254 """
253 import matplotlib
255 import matplotlib
254 from matplotlib.figure import Figure
256 from matplotlib.figure import Figure
255
257
256 svg_formatter = shell.display_formatter.formatters['image/svg+xml']
258 svg_formatter = shell.display_formatter.formatters['image/svg+xml']
257 png_formatter = shell.display_formatter.formatters['image/png']
259 png_formatter = shell.display_formatter.formatters['image/png']
258 jpg_formatter = shell.display_formatter.formatters['image/jpeg']
260 jpg_formatter = shell.display_formatter.formatters['image/jpeg']
259 pdf_formatter = shell.display_formatter.formatters['application/pdf']
261 pdf_formatter = shell.display_formatter.formatters['application/pdf']
260
262
261 if isinstance(formats, str):
263 if isinstance(formats, str):
262 formats = {formats}
264 formats = {formats}
263 # cast in case of list / tuple
265 # cast in case of list / tuple
264 formats = set(formats)
266 formats = set(formats)
265
267
266 [ f.pop(Figure, None) for f in shell.display_formatter.formatters.values() ]
268 [ f.pop(Figure, None) for f in shell.display_formatter.formatters.values() ]
267 mplbackend = matplotlib.get_backend().lower()
269 mplbackend = matplotlib.get_backend().lower()
268 if mplbackend == 'nbagg' or mplbackend == 'module://ipympl.backend_nbagg':
270 if mplbackend == 'nbagg' or mplbackend == 'module://ipympl.backend_nbagg':
269 formatter = shell.display_formatter.ipython_display_formatter
271 formatter = shell.display_formatter.ipython_display_formatter
270 formatter.for_type(Figure, _reshow_nbagg_figure)
272 formatter.for_type(Figure, _reshow_nbagg_figure)
271
273
272 supported = {'png', 'png2x', 'retina', 'jpg', 'jpeg', 'svg', 'pdf'}
274 supported = {'png', 'png2x', 'retina', 'jpg', 'jpeg', 'svg', 'pdf'}
273 bad = formats.difference(supported)
275 bad = formats.difference(supported)
274 if bad:
276 if bad:
275 bs = "%s" % ','.join([repr(f) for f in bad])
277 bs = "%s" % ','.join([repr(f) for f in bad])
276 gs = "%s" % ','.join([repr(f) for f in supported])
278 gs = "%s" % ','.join([repr(f) for f in supported])
277 raise ValueError("supported formats are: %s not %s" % (gs, bs))
279 raise ValueError("supported formats are: %s not %s" % (gs, bs))
278
280
279 if "png" in formats:
281 if "png" in formats:
280 png_formatter.for_type(
282 png_formatter.for_type(
281 Figure, partial(print_figure, fmt="png", base64=True, **kwargs)
283 Figure, partial(print_figure, fmt="png", base64=True, **kwargs)
282 )
284 )
283 if "retina" in formats or "png2x" in formats:
285 if "retina" in formats or "png2x" in formats:
284 png_formatter.for_type(Figure, partial(retina_figure, base64=True, **kwargs))
286 png_formatter.for_type(Figure, partial(retina_figure, base64=True, **kwargs))
285 if "jpg" in formats or "jpeg" in formats:
287 if "jpg" in formats or "jpeg" in formats:
286 jpg_formatter.for_type(
288 jpg_formatter.for_type(
287 Figure, partial(print_figure, fmt="jpg", base64=True, **kwargs)
289 Figure, partial(print_figure, fmt="jpg", base64=True, **kwargs)
288 )
290 )
289 if "svg" in formats:
291 if "svg" in formats:
290 svg_formatter.for_type(Figure, partial(print_figure, fmt="svg", **kwargs))
292 svg_formatter.for_type(Figure, partial(print_figure, fmt="svg", **kwargs))
291 if "pdf" in formats:
293 if "pdf" in formats:
292 pdf_formatter.for_type(
294 pdf_formatter.for_type(
293 Figure, partial(print_figure, fmt="pdf", base64=True, **kwargs)
295 Figure, partial(print_figure, fmt="pdf", base64=True, **kwargs)
294 )
296 )
295
297
296 #-----------------------------------------------------------------------------
298 #-----------------------------------------------------------------------------
297 # Code for initializing matplotlib and importing pylab
299 # Code for initializing matplotlib and importing pylab
298 #-----------------------------------------------------------------------------
300 #-----------------------------------------------------------------------------
299
301
300
302
301 def find_gui_and_backend(gui=None, gui_select=None):
303 def find_gui_and_backend(gui=None, gui_select=None):
302 """Given a gui string return the gui and mpl backend.
304 """Given a gui string return the gui and mpl backend.
303
305
304 Parameters
306 Parameters
305 ----------
307 ----------
306 gui : str
308 gui : str
307 Can be one of ('tk','gtk','wx','qt','qt4','inline','agg').
309 Can be one of ('tk','gtk','wx','qt','qt4','inline','agg').
308 gui_select : str
310 gui_select : str
309 Can be one of ('tk','gtk','wx','qt','qt4','inline').
311 Can be one of ('tk','gtk','wx','qt','qt4','inline').
310 This is any gui already selected by the shell.
312 This is any gui already selected by the shell.
311
313
312 Returns
314 Returns
313 -------
315 -------
314 A tuple of (gui, backend) where backend is one of ('TkAgg','GTKAgg',
316 A tuple of (gui, backend) where backend is one of ('TkAgg','GTKAgg',
315 'WXAgg','Qt4Agg','module://matplotlib_inline.backend_inline','agg').
317 'WXAgg','Qt4Agg','module://matplotlib_inline.backend_inline','agg').
316 """
318 """
317
319
318 import matplotlib
320 import matplotlib
319
321
322 has_unified_qt_backend = getattr(matplotlib, "__version_info__", (0, 0)) >= (3, 5)
323
324 backends_ = dict(backends)
325 if not has_unified_qt_backend:
326 backends_["qt"] = "qt5agg"
327
320 if gui and gui != 'auto':
328 if gui and gui != 'auto':
321 # select backend based on requested gui
329 # select backend based on requested gui
322 backend = backends[gui]
330 backend = backends_[gui]
323 if gui == 'agg':
331 if gui == 'agg':
324 gui = None
332 gui = None
325 else:
333 else:
326 # We need to read the backend from the original data structure, *not*
334 # We need to read the backend from the original data structure, *not*
327 # from mpl.rcParams, since a prior invocation of %matplotlib may have
335 # from mpl.rcParams, since a prior invocation of %matplotlib may have
328 # overwritten that.
336 # overwritten that.
329 # WARNING: this assumes matplotlib 1.1 or newer!!
337 # WARNING: this assumes matplotlib 1.1 or newer!!
330 backend = matplotlib.rcParamsOrig['backend']
338 backend = matplotlib.rcParamsOrig['backend']
331 # In this case, we need to find what the appropriate gui selection call
339 # In this case, we need to find what the appropriate gui selection call
332 # should be for IPython, so we can activate inputhook accordingly
340 # should be for IPython, so we can activate inputhook accordingly
333 gui = backend2gui.get(backend, None)
341 gui = backend2gui.get(backend, None)
334
342
335 # If we have already had a gui active, we need it and inline are the
343 # If we have already had a gui active, we need it and inline are the
336 # ones allowed.
344 # ones allowed.
337 if gui_select and gui != gui_select:
345 if gui_select and gui != gui_select:
338 gui = gui_select
346 gui = gui_select
339 backend = backends[gui]
347 backend = backends_[gui]
340
348
341 return gui, backend
349 return gui, backend
342
350
343
351
344 def activate_matplotlib(backend):
352 def activate_matplotlib(backend):
345 """Activate the given backend and set interactive to True."""
353 """Activate the given backend and set interactive to True."""
346
354
347 import matplotlib
355 import matplotlib
348 matplotlib.interactive(True)
356 matplotlib.interactive(True)
349
357
350 # Matplotlib had a bug where even switch_backend could not force
358 # Matplotlib had a bug where even switch_backend could not force
351 # the rcParam to update. This needs to be set *before* the module
359 # the rcParam to update. This needs to be set *before* the module
352 # magic of switch_backend().
360 # magic of switch_backend().
353 matplotlib.rcParams['backend'] = backend
361 matplotlib.rcParams['backend'] = backend
354
362
355 # Due to circular imports, pyplot may be only partially initialised
363 # Due to circular imports, pyplot may be only partially initialised
356 # when this function runs.
364 # when this function runs.
357 # So avoid needing matplotlib attribute-lookup to access pyplot.
365 # So avoid needing matplotlib attribute-lookup to access pyplot.
358 from matplotlib import pyplot as plt
366 from matplotlib import pyplot as plt
359
367
360 plt.switch_backend(backend)
368 plt.switch_backend(backend)
361
369
362 plt.show._needmain = False
370 plt.show._needmain = False
363 # We need to detect at runtime whether show() is called by the user.
371 # We need to detect at runtime whether show() is called by the user.
364 # For this, we wrap it into a decorator which adds a 'called' flag.
372 # For this, we wrap it into a decorator which adds a 'called' flag.
365 plt.draw_if_interactive = flag_calls(plt.draw_if_interactive)
373 plt.draw_if_interactive = flag_calls(plt.draw_if_interactive)
366
374
367
375
368 def import_pylab(user_ns, import_all=True):
376 def import_pylab(user_ns, import_all=True):
369 """Populate the namespace with pylab-related values.
377 """Populate the namespace with pylab-related values.
370
378
371 Imports matplotlib, pylab, numpy, and everything from pylab and numpy.
379 Imports matplotlib, pylab, numpy, and everything from pylab and numpy.
372
380
373 Also imports a few names from IPython (figsize, display, getfigs)
381 Also imports a few names from IPython (figsize, display, getfigs)
374
382
375 """
383 """
376
384
377 # Import numpy as np/pyplot as plt are conventions we're trying to
385 # Import numpy as np/pyplot as plt are conventions we're trying to
378 # somewhat standardize on. Making them available to users by default
386 # somewhat standardize on. Making them available to users by default
379 # will greatly help this.
387 # will greatly help this.
380 s = ("import numpy\n"
388 s = ("import numpy\n"
381 "import matplotlib\n"
389 "import matplotlib\n"
382 "from matplotlib import pylab, mlab, pyplot\n"
390 "from matplotlib import pylab, mlab, pyplot\n"
383 "np = numpy\n"
391 "np = numpy\n"
384 "plt = pyplot\n"
392 "plt = pyplot\n"
385 )
393 )
386 exec(s, user_ns)
394 exec(s, user_ns)
387
395
388 if import_all:
396 if import_all:
389 s = ("from matplotlib.pylab import *\n"
397 s = ("from matplotlib.pylab import *\n"
390 "from numpy import *\n")
398 "from numpy import *\n")
391 exec(s, user_ns)
399 exec(s, user_ns)
392
400
393 # IPython symbols to add
401 # IPython symbols to add
394 user_ns['figsize'] = figsize
402 user_ns['figsize'] = figsize
395 from IPython.display import display
403 from IPython.display import display
396 # Add display and getfigs to the user's namespace
404 # Add display and getfigs to the user's namespace
397 user_ns['display'] = display
405 user_ns['display'] = display
398 user_ns['getfigs'] = getfigs
406 user_ns['getfigs'] = getfigs
399
407
400
408
401 def configure_inline_support(shell, backend):
409 def configure_inline_support(shell, backend):
402 """
410 """
403 .. deprecated:: 7.23
411 .. deprecated:: 7.23
404
412
405 use `matplotlib_inline.backend_inline.configure_inline_support()`
413 use `matplotlib_inline.backend_inline.configure_inline_support()`
406
414
407 Configure an IPython shell object for matplotlib use.
415 Configure an IPython shell object for matplotlib use.
408
416
409 Parameters
417 Parameters
410 ----------
418 ----------
411 shell : InteractiveShell instance
419 shell : InteractiveShell instance
412 backend : matplotlib backend
420 backend : matplotlib backend
413 """
421 """
414 warnings.warn(
422 warnings.warn(
415 "`configure_inline_support` is deprecated since IPython 7.23, directly "
423 "`configure_inline_support` is deprecated since IPython 7.23, directly "
416 "use `matplotlib_inline.backend_inline.configure_inline_support()`",
424 "use `matplotlib_inline.backend_inline.configure_inline_support()`",
417 DeprecationWarning,
425 DeprecationWarning,
418 stacklevel=2,
426 stacklevel=2,
419 )
427 )
420
428
421 from matplotlib_inline.backend_inline import (
429 from matplotlib_inline.backend_inline import (
422 configure_inline_support as configure_inline_support_orig,
430 configure_inline_support as configure_inline_support_orig,
423 )
431 )
424
432
425 configure_inline_support_orig(shell, backend)
433 configure_inline_support_orig(shell, backend)
@@ -1,270 +1,270 b''
1 """Tests for pylab tools module.
1 """Tests for pylab tools module.
2 """
2 """
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7
7
8 from binascii import a2b_base64
8 from binascii import a2b_base64
9 from io import BytesIO
9 from io import BytesIO
10
10
11 import pytest
11 import pytest
12
12
13 matplotlib = pytest.importorskip("matplotlib")
13 matplotlib = pytest.importorskip("matplotlib")
14 matplotlib.use('Agg')
14 matplotlib.use('Agg')
15 from matplotlib.figure import Figure
15 from matplotlib.figure import Figure
16
16
17 from matplotlib import pyplot as plt
17 from matplotlib import pyplot as plt
18 from matplotlib_inline import backend_inline
18 from matplotlib_inline import backend_inline
19 import numpy as np
19 import numpy as np
20
20
21 from IPython.core.getipython import get_ipython
21 from IPython.core.getipython import get_ipython
22 from IPython.core.interactiveshell import InteractiveShell
22 from IPython.core.interactiveshell import InteractiveShell
23 from IPython.core.display import _PNG, _JPEG
23 from IPython.core.display import _PNG, _JPEG
24 from .. import pylabtools as pt
24 from .. import pylabtools as pt
25
25
26 from IPython.testing import decorators as dec
26 from IPython.testing import decorators as dec
27
27
28
28
29 def test_figure_to_svg():
29 def test_figure_to_svg():
30 # simple empty-figure test
30 # simple empty-figure test
31 fig = plt.figure()
31 fig = plt.figure()
32 assert pt.print_figure(fig, "svg") is None
32 assert pt.print_figure(fig, "svg") is None
33
33
34 plt.close('all')
34 plt.close('all')
35
35
36 # simple check for at least svg-looking output
36 # simple check for at least svg-looking output
37 fig = plt.figure()
37 fig = plt.figure()
38 ax = fig.add_subplot(1,1,1)
38 ax = fig.add_subplot(1,1,1)
39 ax.plot([1,2,3])
39 ax.plot([1,2,3])
40 plt.draw()
40 plt.draw()
41 svg = pt.print_figure(fig, "svg")[:100].lower()
41 svg = pt.print_figure(fig, "svg")[:100].lower()
42 assert "doctype svg" in svg
42 assert "doctype svg" in svg
43
43
44
44
45 def _check_pil_jpeg_bytes():
45 def _check_pil_jpeg_bytes():
46 """Skip if PIL can't write JPEGs to BytesIO objects"""
46 """Skip if PIL can't write JPEGs to BytesIO objects"""
47 # PIL's JPEG plugin can't write to BytesIO objects
47 # PIL's JPEG plugin can't write to BytesIO objects
48 # Pillow fixes this
48 # Pillow fixes this
49 from PIL import Image
49 from PIL import Image
50 buf = BytesIO()
50 buf = BytesIO()
51 img = Image.new("RGB", (4,4))
51 img = Image.new("RGB", (4,4))
52 try:
52 try:
53 img.save(buf, 'jpeg')
53 img.save(buf, 'jpeg')
54 except Exception as e:
54 except Exception as e:
55 ename = e.__class__.__name__
55 ename = e.__class__.__name__
56 raise pytest.skip("PIL can't write JPEG to BytesIO: %s: %s" % (ename, e)) from e
56 raise pytest.skip("PIL can't write JPEG to BytesIO: %s: %s" % (ename, e)) from e
57
57
58 @dec.skip_without("PIL.Image")
58 @dec.skip_without("PIL.Image")
59 def test_figure_to_jpeg():
59 def test_figure_to_jpeg():
60 _check_pil_jpeg_bytes()
60 _check_pil_jpeg_bytes()
61 # simple check for at least jpeg-looking output
61 # simple check for at least jpeg-looking output
62 fig = plt.figure()
62 fig = plt.figure()
63 ax = fig.add_subplot(1,1,1)
63 ax = fig.add_subplot(1,1,1)
64 ax.plot([1,2,3])
64 ax.plot([1,2,3])
65 plt.draw()
65 plt.draw()
66 jpeg = pt.print_figure(fig, 'jpeg', pil_kwargs={'optimize': 50})[:100].lower()
66 jpeg = pt.print_figure(fig, 'jpeg', pil_kwargs={'optimize': 50})[:100].lower()
67 assert jpeg.startswith(_JPEG)
67 assert jpeg.startswith(_JPEG)
68
68
69 def test_retina_figure():
69 def test_retina_figure():
70 # simple empty-figure test
70 # simple empty-figure test
71 fig = plt.figure()
71 fig = plt.figure()
72 assert pt.retina_figure(fig) == None
72 assert pt.retina_figure(fig) == None
73 plt.close('all')
73 plt.close('all')
74
74
75 fig = plt.figure()
75 fig = plt.figure()
76 ax = fig.add_subplot(1,1,1)
76 ax = fig.add_subplot(1,1,1)
77 ax.plot([1,2,3])
77 ax.plot([1,2,3])
78 plt.draw()
78 plt.draw()
79 png, md = pt.retina_figure(fig)
79 png, md = pt.retina_figure(fig)
80 assert png.startswith(_PNG)
80 assert png.startswith(_PNG)
81 assert "width" in md
81 assert "width" in md
82 assert "height" in md
82 assert "height" in md
83
83
84
84
85 _fmt_mime_map = {
85 _fmt_mime_map = {
86 'png': 'image/png',
86 'png': 'image/png',
87 'jpeg': 'image/jpeg',
87 'jpeg': 'image/jpeg',
88 'pdf': 'application/pdf',
88 'pdf': 'application/pdf',
89 'retina': 'image/png',
89 'retina': 'image/png',
90 'svg': 'image/svg+xml',
90 'svg': 'image/svg+xml',
91 }
91 }
92
92
93 def test_select_figure_formats_str():
93 def test_select_figure_formats_str():
94 ip = get_ipython()
94 ip = get_ipython()
95 for fmt, active_mime in _fmt_mime_map.items():
95 for fmt, active_mime in _fmt_mime_map.items():
96 pt.select_figure_formats(ip, fmt)
96 pt.select_figure_formats(ip, fmt)
97 for mime, f in ip.display_formatter.formatters.items():
97 for mime, f in ip.display_formatter.formatters.items():
98 if mime == active_mime:
98 if mime == active_mime:
99 assert Figure in f
99 assert Figure in f
100 else:
100 else:
101 assert Figure not in f
101 assert Figure not in f
102
102
103 def test_select_figure_formats_kwargs():
103 def test_select_figure_formats_kwargs():
104 ip = get_ipython()
104 ip = get_ipython()
105 kwargs = dict(bbox_inches="tight")
105 kwargs = dict(bbox_inches="tight")
106 pt.select_figure_formats(ip, "png", **kwargs)
106 pt.select_figure_formats(ip, "png", **kwargs)
107 formatter = ip.display_formatter.formatters["image/png"]
107 formatter = ip.display_formatter.formatters["image/png"]
108 f = formatter.lookup_by_type(Figure)
108 f = formatter.lookup_by_type(Figure)
109 cell = f.keywords
109 cell = f.keywords
110 expected = kwargs
110 expected = kwargs
111 expected["base64"] = True
111 expected["base64"] = True
112 expected["fmt"] = "png"
112 expected["fmt"] = "png"
113 assert cell == expected
113 assert cell == expected
114
114
115 # check that the formatter doesn't raise
115 # check that the formatter doesn't raise
116 fig = plt.figure()
116 fig = plt.figure()
117 ax = fig.add_subplot(1,1,1)
117 ax = fig.add_subplot(1,1,1)
118 ax.plot([1,2,3])
118 ax.plot([1,2,3])
119 plt.draw()
119 plt.draw()
120 formatter.enabled = True
120 formatter.enabled = True
121 png = formatter(fig)
121 png = formatter(fig)
122 assert isinstance(png, str)
122 assert isinstance(png, str)
123 png_bytes = a2b_base64(png)
123 png_bytes = a2b_base64(png)
124 assert png_bytes.startswith(_PNG)
124 assert png_bytes.startswith(_PNG)
125
125
126 def test_select_figure_formats_set():
126 def test_select_figure_formats_set():
127 ip = get_ipython()
127 ip = get_ipython()
128 for fmts in [
128 for fmts in [
129 {'png', 'svg'},
129 {'png', 'svg'},
130 ['png'],
130 ['png'],
131 ('jpeg', 'pdf', 'retina'),
131 ('jpeg', 'pdf', 'retina'),
132 {'svg'},
132 {'svg'},
133 ]:
133 ]:
134 active_mimes = {_fmt_mime_map[fmt] for fmt in fmts}
134 active_mimes = {_fmt_mime_map[fmt] for fmt in fmts}
135 pt.select_figure_formats(ip, fmts)
135 pt.select_figure_formats(ip, fmts)
136 for mime, f in ip.display_formatter.formatters.items():
136 for mime, f in ip.display_formatter.formatters.items():
137 if mime in active_mimes:
137 if mime in active_mimes:
138 assert Figure in f
138 assert Figure in f
139 else:
139 else:
140 assert Figure not in f
140 assert Figure not in f
141
141
142 def test_select_figure_formats_bad():
142 def test_select_figure_formats_bad():
143 ip = get_ipython()
143 ip = get_ipython()
144 with pytest.raises(ValueError):
144 with pytest.raises(ValueError):
145 pt.select_figure_formats(ip, 'foo')
145 pt.select_figure_formats(ip, 'foo')
146 with pytest.raises(ValueError):
146 with pytest.raises(ValueError):
147 pt.select_figure_formats(ip, {'png', 'foo'})
147 pt.select_figure_formats(ip, {'png', 'foo'})
148 with pytest.raises(ValueError):
148 with pytest.raises(ValueError):
149 pt.select_figure_formats(ip, ['retina', 'pdf', 'bar', 'bad'])
149 pt.select_figure_formats(ip, ['retina', 'pdf', 'bar', 'bad'])
150
150
151 def test_import_pylab():
151 def test_import_pylab():
152 ns = {}
152 ns = {}
153 pt.import_pylab(ns, import_all=False)
153 pt.import_pylab(ns, import_all=False)
154 assert "plt" in ns
154 assert "plt" in ns
155 assert ns["np"] == np
155 assert ns["np"] == np
156
156
157
157
158 class TestPylabSwitch(object):
158 class TestPylabSwitch(object):
159 class Shell(InteractiveShell):
159 class Shell(InteractiveShell):
160 def init_history(self):
160 def init_history(self):
161 """Sets up the command history, and starts regular autosaves."""
161 """Sets up the command history, and starts regular autosaves."""
162 self.config.HistoryManager.hist_file = ":memory:"
162 self.config.HistoryManager.hist_file = ":memory:"
163 super().init_history()
163 super().init_history()
164
164
165 def enable_gui(self, gui):
165 def enable_gui(self, gui):
166 pass
166 pass
167
167
168 def setup(self):
168 def setup(self):
169 import matplotlib
169 import matplotlib
170 def act_mpl(backend):
170 def act_mpl(backend):
171 matplotlib.rcParams['backend'] = backend
171 matplotlib.rcParams['backend'] = backend
172
172
173 # Save rcParams since they get modified
173 # Save rcParams since they get modified
174 self._saved_rcParams = matplotlib.rcParams
174 self._saved_rcParams = matplotlib.rcParams
175 self._saved_rcParamsOrig = matplotlib.rcParamsOrig
175 self._saved_rcParamsOrig = matplotlib.rcParamsOrig
176 matplotlib.rcParams = dict(backend='Qt4Agg')
176 matplotlib.rcParams = dict(backend="QtAgg")
177 matplotlib.rcParamsOrig = dict(backend='Qt4Agg')
177 matplotlib.rcParamsOrig = dict(backend="QtAgg")
178
178
179 # Mock out functions
179 # Mock out functions
180 self._save_am = pt.activate_matplotlib
180 self._save_am = pt.activate_matplotlib
181 pt.activate_matplotlib = act_mpl
181 pt.activate_matplotlib = act_mpl
182 self._save_ip = pt.import_pylab
182 self._save_ip = pt.import_pylab
183 pt.import_pylab = lambda *a,**kw:None
183 pt.import_pylab = lambda *a,**kw:None
184 self._save_cis = backend_inline.configure_inline_support
184 self._save_cis = backend_inline.configure_inline_support
185 backend_inline.configure_inline_support = lambda *a, **kw: None
185 backend_inline.configure_inline_support = lambda *a, **kw: None
186
186
187 def teardown(self):
187 def teardown(self):
188 pt.activate_matplotlib = self._save_am
188 pt.activate_matplotlib = self._save_am
189 pt.import_pylab = self._save_ip
189 pt.import_pylab = self._save_ip
190 backend_inline.configure_inline_support = self._save_cis
190 backend_inline.configure_inline_support = self._save_cis
191 import matplotlib
191 import matplotlib
192 matplotlib.rcParams = self._saved_rcParams
192 matplotlib.rcParams = self._saved_rcParams
193 matplotlib.rcParamsOrig = self._saved_rcParamsOrig
193 matplotlib.rcParamsOrig = self._saved_rcParamsOrig
194
194
195 def test_qt(self):
195 def test_qt(self):
196 s = self.Shell()
196 s = self.Shell()
197 gui, backend = s.enable_matplotlib(None)
197 gui, backend = s.enable_matplotlib(None)
198 assert gui == "qt"
198 assert gui == "qt"
199 assert s.pylab_gui_select == "qt"
199 assert s.pylab_gui_select == "qt"
200
200
201 gui, backend = s.enable_matplotlib("inline")
201 gui, backend = s.enable_matplotlib("inline")
202 assert gui == "inline"
202 assert gui == "inline"
203 assert s.pylab_gui_select == "qt"
203 assert s.pylab_gui_select == "qt"
204
204
205 gui, backend = s.enable_matplotlib("qt")
205 gui, backend = s.enable_matplotlib("qt")
206 assert gui == "qt"
206 assert gui == "qt"
207 assert s.pylab_gui_select == "qt"
207 assert s.pylab_gui_select == "qt"
208
208
209 gui, backend = s.enable_matplotlib("inline")
209 gui, backend = s.enable_matplotlib("inline")
210 assert gui == "inline"
210 assert gui == "inline"
211 assert s.pylab_gui_select == "qt"
211 assert s.pylab_gui_select == "qt"
212
212
213 gui, backend = s.enable_matplotlib()
213 gui, backend = s.enable_matplotlib()
214 assert gui == "qt"
214 assert gui == "qt"
215 assert s.pylab_gui_select == "qt"
215 assert s.pylab_gui_select == "qt"
216
216
217 def test_inline(self):
217 def test_inline(self):
218 s = self.Shell()
218 s = self.Shell()
219 gui, backend = s.enable_matplotlib("inline")
219 gui, backend = s.enable_matplotlib("inline")
220 assert gui == "inline"
220 assert gui == "inline"
221 assert s.pylab_gui_select == None
221 assert s.pylab_gui_select == None
222
222
223 gui, backend = s.enable_matplotlib("inline")
223 gui, backend = s.enable_matplotlib("inline")
224 assert gui == "inline"
224 assert gui == "inline"
225 assert s.pylab_gui_select == None
225 assert s.pylab_gui_select == None
226
226
227 gui, backend = s.enable_matplotlib("qt")
227 gui, backend = s.enable_matplotlib("qt")
228 assert gui == "qt"
228 assert gui == "qt"
229 assert s.pylab_gui_select == "qt"
229 assert s.pylab_gui_select == "qt"
230
230
231 def test_inline_twice(self):
231 def test_inline_twice(self):
232 "Using '%matplotlib inline' twice should not reset formatters"
232 "Using '%matplotlib inline' twice should not reset formatters"
233
233
234 ip = self.Shell()
234 ip = self.Shell()
235 gui, backend = ip.enable_matplotlib("inline")
235 gui, backend = ip.enable_matplotlib("inline")
236 assert gui == "inline"
236 assert gui == "inline"
237
237
238 fmts = {'png'}
238 fmts = {'png'}
239 active_mimes = {_fmt_mime_map[fmt] for fmt in fmts}
239 active_mimes = {_fmt_mime_map[fmt] for fmt in fmts}
240 pt.select_figure_formats(ip, fmts)
240 pt.select_figure_formats(ip, fmts)
241
241
242 gui, backend = ip.enable_matplotlib("inline")
242 gui, backend = ip.enable_matplotlib("inline")
243 assert gui == "inline"
243 assert gui == "inline"
244
244
245 for mime, f in ip.display_formatter.formatters.items():
245 for mime, f in ip.display_formatter.formatters.items():
246 if mime in active_mimes:
246 if mime in active_mimes:
247 assert Figure in f
247 assert Figure in f
248 else:
248 else:
249 assert Figure not in f
249 assert Figure not in f
250
250
251 def test_qt_gtk(self):
251 def test_qt_gtk(self):
252 s = self.Shell()
252 s = self.Shell()
253 gui, backend = s.enable_matplotlib("qt")
253 gui, backend = s.enable_matplotlib("qt")
254 assert gui == "qt"
254 assert gui == "qt"
255 assert s.pylab_gui_select == "qt"
255 assert s.pylab_gui_select == "qt"
256
256
257 gui, backend = s.enable_matplotlib("gtk")
257 gui, backend = s.enable_matplotlib("gtk")
258 assert gui == "qt"
258 assert gui == "qt"
259 assert s.pylab_gui_select == "qt"
259 assert s.pylab_gui_select == "qt"
260
260
261
261
262 def test_no_gui_backends():
262 def test_no_gui_backends():
263 for k in ['agg', 'svg', 'pdf', 'ps']:
263 for k in ['agg', 'svg', 'pdf', 'ps']:
264 assert k not in pt.backend2gui
264 assert k not in pt.backend2gui
265
265
266
266
267 def test_figure_no_canvas():
267 def test_figure_no_canvas():
268 fig = Figure()
268 fig = Figure()
269 fig.canvas = None
269 fig.canvas = None
270 pt.print_figure(fig)
270 pt.print_figure(fig)
General Comments 0
You need to be logged in to leave comments. Login now