##// END OF EJS Templates
Rename callbacks -> events (mostly), fire -> trigger
Thomas Kluyver -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,129 +1,129 b''
1 """Infrastructure for registering and firing callbacks.
1 """Infrastructure for registering and firing callbacks on application events.
2 2
3 3 Unlike :mod:`IPython.core.hooks`, which lets end users set single functions to
4 4 be called at specific times, or a collection of alternative methods to try,
5 5 callbacks are designed to be used by extension authors. A number of callbacks
6 6 can be registered for the same event without needing to be aware of one another.
7 7
8 8 The functions defined in this module are no-ops indicating the names of available
9 9 events and the arguments which will be passed to them.
10 10 """
11 11 from __future__ import print_function
12 12
13 class CallbackManager(object):
13 class EventManager(object):
14 14 """Manage a collection of events and a sequence of callbacks for each.
15 15
16 16 This is attached to :class:`~IPython.core.interactiveshell.InteractiveShell`
17 17 instances as a ``callbacks`` attribute.
18 18 """
19 def __init__(self, shell, available_callbacks):
19 def __init__(self, shell, available_events):
20 20 """Initialise the :class:`CallbackManager`.
21 21
22 22 Parameters
23 23 ----------
24 24 shell
25 25 The :class:`~IPython.core.interactiveshell.InteractiveShell` instance
26 26 available_callbacks
27 27 An iterable of names for callback events.
28 28 """
29 29 self.shell = shell
30 self.callbacks = {n:[] for n in available_callbacks}
30 self.callbacks = {n:[] for n in available_events}
31 31
32 def register(self, name, function):
32 def register(self, event, function):
33 33 """Register a new callback
34 34
35 35 Parameters
36 36 ----------
37 name : str
37 event : str
38 38 The event for which to register this callback.
39 39 function : callable
40 40 A function to be called on the given event. It should take the same
41 41 parameters as the appropriate callback prototype.
42 42
43 43 Raises
44 44 ------
45 45 TypeError
46 46 If ``function`` is not callable.
47 47 KeyError
48 If ``name`` is not one of the known callback events.
48 If ``event`` is not one of the known events.
49 49 """
50 50 if not callable(function):
51 51 raise TypeError('Need a callable, got %r' % function)
52 self.callbacks[name].append(function)
52 self.callbacks[event].append(function)
53 53
54 def unregister(self, name, function):
54 def unregister(self, event, function):
55 55 """Remove a callback from the given event."""
56 self.callbacks[name].remove(function)
56 self.callbacks[event].remove(function)
57 57
58 def reset(self, name):
58 def reset(self, event):
59 59 """Clear all callbacks for the given event."""
60 self.callbacks[name] = []
60 self.callbacks[event] = []
61 61
62 62 def reset_all(self):
63 63 """Clear all callbacks for all events."""
64 64 self.callbacks = {n:[] for n in self.callbacks}
65 65
66 def fire(self, name, *args, **kwargs):
67 """Call callbacks for the event ``name``.
66 def trigger(self, event, *args, **kwargs):
67 """Call callbacks for ``event``.
68 68
69 69 Any additional arguments are passed to all callbacks registered for this
70 70 event. Exceptions raised by callbacks are caught, and a message printed.
71 71 """
72 for func in self.callbacks[name]:
72 for func in self.callbacks[event]:
73 73 try:
74 74 func(*args, **kwargs)
75 75 except Exception:
76 print("Error in callback {} (for {}):".format(func, name))
76 print("Error in callback {} (for {}):".format(func, event))
77 77 self.shell.showtraceback()
78 78
79 79 # event_name -> prototype mapping
80 available_callbacks = {}
80 available_events = {}
81 81
82 82 def _collect(callback_proto):
83 available_callbacks[callback_proto.__name__] = callback_proto
83 available_events[callback_proto.__name__] = callback_proto
84 84 return callback_proto
85 85
86 86 # ------------------------------------------------------------------------------
87 87 # Callback prototypes
88 88 #
89 89 # No-op functions which describe the names of available events and the
90 90 # signatures of callbacks for those events.
91 91 # ------------------------------------------------------------------------------
92 92
93 93 @_collect
94 94 def pre_execute():
95 95 """Fires before code is executed in response to user/frontend action.
96 96
97 97 This includes comm and widget messages as well as user code cells."""
98 98 pass
99 99
100 100 @_collect
101 101 def pre_execute_explicit():
102 102 """Fires before user-entered code runs."""
103 103 pass
104 104
105 105 @_collect
106 106 def post_execute():
107 107 """Fires after code is executed in response to user/frontend action.
108 108
109 109 This includes comm and widget messages as well as user code cells."""
110 110 pass
111 111
112 112 @_collect
113 113 def post_execute_explicit():
114 114 """Fires after user-entered code runs."""
115 115 pass
116 116
117 117 @_collect
118 118 def shell_initialised(ip):
119 119 """Fires after initialisation of :class:`~IPython.core.interactiveshell.InteractiveShell`.
120 120
121 121 This is before extensions and startup scripts are loaded, so it can only be
122 122 set by subclassing.
123 123
124 124 Parameters
125 125 ----------
126 126 ip : :class:`~IPython.core.interactiveshell.InteractiveShell`
127 127 The newly initialised shell.
128 128 """
129 129 pass
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
@@ -1,46 +1,46 b''
1 1 import unittest
2 2 try: # Python 3.3 +
3 3 from unittest.mock import Mock
4 4 except ImportError:
5 5 from mock import Mock
6 6
7 from IPython.core import callbacks
7 from IPython.core import events
8 8 import IPython.testing.tools as tt
9 9
10 10 def ping_received():
11 11 pass
12 12
13 13 class CallbackTests(unittest.TestCase):
14 14 def setUp(self):
15 self.cbm = callbacks.CallbackManager(get_ipython(), {'ping_received': ping_received})
15 self.em = events.EventManager(get_ipython(), {'ping_received': ping_received})
16 16
17 17 def test_register_unregister(self):
18 18 cb = Mock()
19 19
20 self.cbm.register('ping_received', cb)
21 self.cbm.fire('ping_received')
20 self.em.register('ping_received', cb)
21 self.em.trigger('ping_received')
22 22 self.assertEqual(cb.call_count, 1)
23 23
24 self.cbm.unregister('ping_received', cb)
25 self.cbm.fire('ping_received')
24 self.em.unregister('ping_received', cb)
25 self.em.trigger('ping_received')
26 26 self.assertEqual(cb.call_count, 1)
27 27
28 28 def test_reset(self):
29 29 cb = Mock()
30 self.cbm.register('ping_received', cb)
31 self.cbm.reset('ping_received')
32 self.cbm.fire('ping_received')
30 self.em.register('ping_received', cb)
31 self.em.reset('ping_received')
32 self.em.trigger('ping_received')
33 33 assert not cb.called
34 34
35 35 def test_reset_all(self):
36 36 cb = Mock()
37 self.cbm.register('ping_received', cb)
38 self.cbm.reset_all()
39 self.cbm.fire('ping_received')
37 self.em.register('ping_received', cb)
38 self.em.reset_all()
39 self.em.trigger('ping_received')
40 40 assert not cb.called
41 41
42 42 def test_cb_error(self):
43 43 cb = Mock(side_effect=ValueError)
44 self.cbm.register('ping_received', cb)
44 self.em.register('ping_received', cb)
45 45 with tt.AssertPrints("Error in callback"):
46 self.cbm.fire('ping_received') No newline at end of file
46 self.em.trigger('ping_received') No newline at end of file
@@ -1,730 +1,730 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Tests for the key interactiveshell module.
3 3
4 4 Historically the main classes in interactiveshell have been under-tested. This
5 5 module should grow as many single-method tests as possible to trap many of the
6 6 recurring bugs we seem to encounter with high-level interaction.
7 7
8 8 Authors
9 9 -------
10 10 * Fernando Perez
11 11 """
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22 # stdlib
23 23 import ast
24 24 import os
25 25 import signal
26 26 import shutil
27 27 import sys
28 28 import tempfile
29 29 import unittest
30 30 try:
31 31 from unittest import mock
32 32 except ImportError:
33 33 import mock
34 34 from os.path import join
35 35
36 36 # third-party
37 37 import nose.tools as nt
38 38
39 39 # Our own
40 40 from IPython.core.inputtransformer import InputTransformer
41 41 from IPython.testing.decorators import skipif, skip_win32, onlyif_unicode_paths
42 42 from IPython.testing import tools as tt
43 43 from IPython.utils import io
44 44 from IPython.utils import py3compat
45 45 from IPython.utils.py3compat import unicode_type, PY3
46 46
47 47 if PY3:
48 48 from io import StringIO
49 49 else:
50 50 from StringIO import StringIO
51 51
52 52 #-----------------------------------------------------------------------------
53 53 # Globals
54 54 #-----------------------------------------------------------------------------
55 55 # This is used by every single test, no point repeating it ad nauseam
56 56 ip = get_ipython()
57 57
58 58 #-----------------------------------------------------------------------------
59 59 # Tests
60 60 #-----------------------------------------------------------------------------
61 61
62 62 class InteractiveShellTestCase(unittest.TestCase):
63 63 def test_naked_string_cells(self):
64 64 """Test that cells with only naked strings are fully executed"""
65 65 # First, single-line inputs
66 66 ip.run_cell('"a"\n')
67 67 self.assertEqual(ip.user_ns['_'], 'a')
68 68 # And also multi-line cells
69 69 ip.run_cell('"""a\nb"""\n')
70 70 self.assertEqual(ip.user_ns['_'], 'a\nb')
71 71
72 72 def test_run_empty_cell(self):
73 73 """Just make sure we don't get a horrible error with a blank
74 74 cell of input. Yes, I did overlook that."""
75 75 old_xc = ip.execution_count
76 76 ip.run_cell('')
77 77 self.assertEqual(ip.execution_count, old_xc)
78 78
79 79 def test_run_cell_multiline(self):
80 80 """Multi-block, multi-line cells must execute correctly.
81 81 """
82 82 src = '\n'.join(["x=1",
83 83 "y=2",
84 84 "if 1:",
85 85 " x += 1",
86 86 " y += 1",])
87 87 ip.run_cell(src)
88 88 self.assertEqual(ip.user_ns['x'], 2)
89 89 self.assertEqual(ip.user_ns['y'], 3)
90 90
91 91 def test_multiline_string_cells(self):
92 92 "Code sprinkled with multiline strings should execute (GH-306)"
93 93 ip.run_cell('tmp=0')
94 94 self.assertEqual(ip.user_ns['tmp'], 0)
95 95 ip.run_cell('tmp=1;"""a\nb"""\n')
96 96 self.assertEqual(ip.user_ns['tmp'], 1)
97 97
98 98 def test_dont_cache_with_semicolon(self):
99 99 "Ending a line with semicolon should not cache the returned object (GH-307)"
100 100 oldlen = len(ip.user_ns['Out'])
101 101 a = ip.run_cell('1;', store_history=True)
102 102 newlen = len(ip.user_ns['Out'])
103 103 self.assertEqual(oldlen, newlen)
104 104 #also test the default caching behavior
105 105 ip.run_cell('1', store_history=True)
106 106 newlen = len(ip.user_ns['Out'])
107 107 self.assertEqual(oldlen+1, newlen)
108 108
109 109 def test_In_variable(self):
110 110 "Verify that In variable grows with user input (GH-284)"
111 111 oldlen = len(ip.user_ns['In'])
112 112 ip.run_cell('1;', store_history=True)
113 113 newlen = len(ip.user_ns['In'])
114 114 self.assertEqual(oldlen+1, newlen)
115 115 self.assertEqual(ip.user_ns['In'][-1],'1;')
116 116
117 117 def test_magic_names_in_string(self):
118 118 ip.run_cell('a = """\n%exit\n"""')
119 119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
120 120
121 121 def test_trailing_newline(self):
122 122 """test that running !(command) does not raise a SyntaxError"""
123 123 ip.run_cell('!(true)\n', False)
124 124 ip.run_cell('!(true)\n\n\n', False)
125 125
126 126 def test_gh_597(self):
127 127 """Pretty-printing lists of objects with non-ascii reprs may cause
128 128 problems."""
129 129 class Spam(object):
130 130 def __repr__(self):
131 131 return "\xe9"*50
132 132 import IPython.core.formatters
133 133 f = IPython.core.formatters.PlainTextFormatter()
134 134 f([Spam(),Spam()])
135 135
136 136
137 137 def test_future_flags(self):
138 138 """Check that future flags are used for parsing code (gh-777)"""
139 139 ip.run_cell('from __future__ import print_function')
140 140 try:
141 141 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
142 142 assert 'prfunc_return_val' in ip.user_ns
143 143 finally:
144 144 # Reset compiler flags so we don't mess up other tests.
145 145 ip.compile.reset_compiler_flags()
146 146
147 147 def test_future_unicode(self):
148 148 """Check that unicode_literals is imported from __future__ (gh #786)"""
149 149 try:
150 150 ip.run_cell(u'byte_str = "a"')
151 151 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
152 152 ip.run_cell('from __future__ import unicode_literals')
153 153 ip.run_cell(u'unicode_str = "a"')
154 154 assert isinstance(ip.user_ns['unicode_str'], unicode_type) # strings literals are now unicode
155 155 finally:
156 156 # Reset compiler flags so we don't mess up other tests.
157 157 ip.compile.reset_compiler_flags()
158 158
159 159 def test_can_pickle(self):
160 160 "Can we pickle objects defined interactively (GH-29)"
161 161 ip = get_ipython()
162 162 ip.reset()
163 163 ip.run_cell(("class Mylist(list):\n"
164 164 " def __init__(self,x=[]):\n"
165 165 " list.__init__(self,x)"))
166 166 ip.run_cell("w=Mylist([1,2,3])")
167 167
168 168 from pickle import dumps
169 169
170 170 # We need to swap in our main module - this is only necessary
171 171 # inside the test framework, because IPython puts the interactive module
172 172 # in place (but the test framework undoes this).
173 173 _main = sys.modules['__main__']
174 174 sys.modules['__main__'] = ip.user_module
175 175 try:
176 176 res = dumps(ip.user_ns["w"])
177 177 finally:
178 178 sys.modules['__main__'] = _main
179 179 self.assertTrue(isinstance(res, bytes))
180 180
181 181 def test_global_ns(self):
182 182 "Code in functions must be able to access variables outside them."
183 183 ip = get_ipython()
184 184 ip.run_cell("a = 10")
185 185 ip.run_cell(("def f(x):\n"
186 186 " return x + a"))
187 187 ip.run_cell("b = f(12)")
188 188 self.assertEqual(ip.user_ns["b"], 22)
189 189
190 190 def test_bad_custom_tb(self):
191 191 """Check that InteractiveShell is protected from bad custom exception handlers"""
192 192 from IPython.utils import io
193 193 save_stderr = io.stderr
194 194 try:
195 195 # capture stderr
196 196 io.stderr = StringIO()
197 197 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
198 198 self.assertEqual(ip.custom_exceptions, (IOError,))
199 199 ip.run_cell(u'raise IOError("foo")')
200 200 self.assertEqual(ip.custom_exceptions, ())
201 201 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
202 202 finally:
203 203 io.stderr = save_stderr
204 204
205 205 def test_bad_custom_tb_return(self):
206 206 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
207 207 from IPython.utils import io
208 208 save_stderr = io.stderr
209 209 try:
210 210 # capture stderr
211 211 io.stderr = StringIO()
212 212 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
213 213 self.assertEqual(ip.custom_exceptions, (NameError,))
214 214 ip.run_cell(u'a=abracadabra')
215 215 self.assertEqual(ip.custom_exceptions, ())
216 216 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
217 217 finally:
218 218 io.stderr = save_stderr
219 219
220 220 def test_drop_by_id(self):
221 221 myvars = {"a":object(), "b":object(), "c": object()}
222 222 ip.push(myvars, interactive=False)
223 223 for name in myvars:
224 224 assert name in ip.user_ns, name
225 225 assert name in ip.user_ns_hidden, name
226 226 ip.user_ns['b'] = 12
227 227 ip.drop_by_id(myvars)
228 228 for name in ["a", "c"]:
229 229 assert name not in ip.user_ns, name
230 230 assert name not in ip.user_ns_hidden, name
231 231 assert ip.user_ns['b'] == 12
232 232 ip.reset()
233 233
234 234 def test_var_expand(self):
235 235 ip.user_ns['f'] = u'Ca\xf1o'
236 236 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
237 237 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
238 238 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
239 239 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
240 240
241 241 ip.user_ns['f'] = b'Ca\xc3\xb1o'
242 242 # This should not raise any exception:
243 243 ip.var_expand(u'echo $f')
244 244
245 245 def test_var_expand_local(self):
246 246 """Test local variable expansion in !system and %magic calls"""
247 247 # !system
248 248 ip.run_cell('def test():\n'
249 249 ' lvar = "ttt"\n'
250 250 ' ret = !echo {lvar}\n'
251 251 ' return ret[0]\n')
252 252 res = ip.user_ns['test']()
253 253 nt.assert_in('ttt', res)
254 254
255 255 # %magic
256 256 ip.run_cell('def makemacro():\n'
257 257 ' macroname = "macro_var_expand_locals"\n'
258 258 ' %macro {macroname} codestr\n')
259 259 ip.user_ns['codestr'] = "str(12)"
260 260 ip.run_cell('makemacro()')
261 261 nt.assert_in('macro_var_expand_locals', ip.user_ns)
262 262
263 263 def test_var_expand_self(self):
264 264 """Test variable expansion with the name 'self', which was failing.
265 265
266 266 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
267 267 """
268 268 ip.run_cell('class cTest:\n'
269 269 ' classvar="see me"\n'
270 270 ' def test(self):\n'
271 271 ' res = !echo Variable: {self.classvar}\n'
272 272 ' return res[0]\n')
273 273 nt.assert_in('see me', ip.user_ns['cTest']().test())
274 274
275 275 def test_bad_var_expand(self):
276 276 """var_expand on invalid formats shouldn't raise"""
277 277 # SyntaxError
278 278 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
279 279 # NameError
280 280 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
281 281 # ZeroDivisionError
282 282 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
283 283
284 284 def test_silent_postexec(self):
285 285 """run_cell(silent=True) doesn't invoke pre/post_execute_explicit callbacks"""
286 286 pre_explicit = mock.Mock()
287 287 pre_always = mock.Mock()
288 288 post_explicit = mock.Mock()
289 289 post_always = mock.Mock()
290 290
291 ip.callbacks.register('pre_execute_explicit', pre_explicit)
292 ip.callbacks.register('pre_execute', pre_always)
293 ip.callbacks.register('post_execute_explicit', post_explicit)
294 ip.callbacks.register('post_execute', post_always)
291 ip.events.register('pre_execute_explicit', pre_explicit)
292 ip.events.register('pre_execute', pre_always)
293 ip.events.register('post_execute_explicit', post_explicit)
294 ip.events.register('post_execute', post_always)
295 295
296 296 try:
297 297 ip.run_cell("1", silent=True)
298 298 assert pre_always.called
299 299 assert not pre_explicit.called
300 300 assert post_always.called
301 301 assert not post_explicit.called
302 302 # double-check that non-silent exec did what we expected
303 303 # silent to avoid
304 304 ip.run_cell("1")
305 305 assert pre_explicit.called
306 306 assert post_explicit.called
307 307 finally:
308 308 # remove post-exec
309 ip.callbacks.reset_all()
309 ip.events.reset_all()
310 310
311 311 def test_silent_noadvance(self):
312 312 """run_cell(silent=True) doesn't advance execution_count"""
313 313 ec = ip.execution_count
314 314 # silent should force store_history=False
315 315 ip.run_cell("1", store_history=True, silent=True)
316 316
317 317 self.assertEqual(ec, ip.execution_count)
318 318 # double-check that non-silent exec did what we expected
319 319 # silent to avoid
320 320 ip.run_cell("1", store_history=True)
321 321 self.assertEqual(ec+1, ip.execution_count)
322 322
323 323 def test_silent_nodisplayhook(self):
324 324 """run_cell(silent=True) doesn't trigger displayhook"""
325 325 d = dict(called=False)
326 326
327 327 trap = ip.display_trap
328 328 save_hook = trap.hook
329 329
330 330 def failing_hook(*args, **kwargs):
331 331 d['called'] = True
332 332
333 333 try:
334 334 trap.hook = failing_hook
335 335 ip.run_cell("1", silent=True)
336 336 self.assertFalse(d['called'])
337 337 # double-check that non-silent exec did what we expected
338 338 # silent to avoid
339 339 ip.run_cell("1")
340 340 self.assertTrue(d['called'])
341 341 finally:
342 342 trap.hook = save_hook
343 343
344 344 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
345 345 def test_print_softspace(self):
346 346 """Verify that softspace is handled correctly when executing multiple
347 347 statements.
348 348
349 349 In [1]: print 1; print 2
350 350 1
351 351 2
352 352
353 353 In [2]: print 1,; print 2
354 354 1 2
355 355 """
356 356
357 357 def test_ofind_line_magic(self):
358 358 from IPython.core.magic import register_line_magic
359 359
360 360 @register_line_magic
361 361 def lmagic(line):
362 362 "A line magic"
363 363
364 364 # Get info on line magic
365 365 lfind = ip._ofind('lmagic')
366 366 info = dict(found=True, isalias=False, ismagic=True,
367 367 namespace = 'IPython internal', obj= lmagic.__wrapped__,
368 368 parent = None)
369 369 nt.assert_equal(lfind, info)
370 370
371 371 def test_ofind_cell_magic(self):
372 372 from IPython.core.magic import register_cell_magic
373 373
374 374 @register_cell_magic
375 375 def cmagic(line, cell):
376 376 "A cell magic"
377 377
378 378 # Get info on cell magic
379 379 find = ip._ofind('cmagic')
380 380 info = dict(found=True, isalias=False, ismagic=True,
381 381 namespace = 'IPython internal', obj= cmagic.__wrapped__,
382 382 parent = None)
383 383 nt.assert_equal(find, info)
384 384
385 385 def test_custom_exception(self):
386 386 called = []
387 387 def my_handler(shell, etype, value, tb, tb_offset=None):
388 388 called.append(etype)
389 389 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
390 390
391 391 ip.set_custom_exc((ValueError,), my_handler)
392 392 try:
393 393 ip.run_cell("raise ValueError('test')")
394 394 # Check that this was called, and only once.
395 395 self.assertEqual(called, [ValueError])
396 396 finally:
397 397 # Reset the custom exception hook
398 398 ip.set_custom_exc((), None)
399 399
400 400 @skipif(sys.version_info[0] >= 3, "no differences with __future__ in py3")
401 401 def test_future_environment(self):
402 402 "Can we run code with & without the shell's __future__ imports?"
403 403 ip.run_cell("from __future__ import division")
404 404 ip.run_cell("a = 1/2", shell_futures=True)
405 405 self.assertEqual(ip.user_ns['a'], 0.5)
406 406 ip.run_cell("b = 1/2", shell_futures=False)
407 407 self.assertEqual(ip.user_ns['b'], 0)
408 408
409 409 ip.compile.reset_compiler_flags()
410 410 # This shouldn't leak to the shell's compiler
411 411 ip.run_cell("from __future__ import division \nc=1/2", shell_futures=False)
412 412 self.assertEqual(ip.user_ns['c'], 0.5)
413 413 ip.run_cell("d = 1/2", shell_futures=True)
414 414 self.assertEqual(ip.user_ns['d'], 0)
415 415
416 416
417 417 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
418 418
419 419 @onlyif_unicode_paths
420 420 def setUp(self):
421 421 self.BASETESTDIR = tempfile.mkdtemp()
422 422 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
423 423 os.mkdir(self.TESTDIR)
424 424 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
425 425 sfile.write("pass\n")
426 426 self.oldpath = py3compat.getcwd()
427 427 os.chdir(self.TESTDIR)
428 428 self.fname = u"Γ₯Àâtestscript.py"
429 429
430 430 def tearDown(self):
431 431 os.chdir(self.oldpath)
432 432 shutil.rmtree(self.BASETESTDIR)
433 433
434 434 @onlyif_unicode_paths
435 435 def test_1(self):
436 436 """Test safe_execfile with non-ascii path
437 437 """
438 438 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
439 439
440 440 class ExitCodeChecks(tt.TempFileMixin):
441 441 def test_exit_code_ok(self):
442 442 self.system('exit 0')
443 443 self.assertEqual(ip.user_ns['_exit_code'], 0)
444 444
445 445 def test_exit_code_error(self):
446 446 self.system('exit 1')
447 447 self.assertEqual(ip.user_ns['_exit_code'], 1)
448 448
449 449 @skipif(not hasattr(signal, 'SIGALRM'))
450 450 def test_exit_code_signal(self):
451 451 self.mktmp("import signal, time\n"
452 452 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
453 453 "time.sleep(1)\n")
454 454 self.system("%s %s" % (sys.executable, self.fname))
455 455 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
456 456
457 457 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
458 458 system = ip.system_raw
459 459
460 460 @onlyif_unicode_paths
461 461 def test_1(self):
462 462 """Test system_raw with non-ascii cmd
463 463 """
464 464 cmd = u'''python -c "'Γ₯Àâ'" '''
465 465 ip.system_raw(cmd)
466 466
467 467 # TODO: Exit codes are currently ignored on Windows.
468 468 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
469 469 system = ip.system_piped
470 470
471 471 @skip_win32
472 472 def test_exit_code_ok(self):
473 473 ExitCodeChecks.test_exit_code_ok(self)
474 474
475 475 @skip_win32
476 476 def test_exit_code_error(self):
477 477 ExitCodeChecks.test_exit_code_error(self)
478 478
479 479 @skip_win32
480 480 def test_exit_code_signal(self):
481 481 ExitCodeChecks.test_exit_code_signal(self)
482 482
483 483 class TestModules(unittest.TestCase, tt.TempFileMixin):
484 484 def test_extraneous_loads(self):
485 485 """Test we're not loading modules on startup that we shouldn't.
486 486 """
487 487 self.mktmp("import sys\n"
488 488 "print('numpy' in sys.modules)\n"
489 489 "print('IPython.parallel' in sys.modules)\n"
490 490 "print('IPython.kernel.zmq' in sys.modules)\n"
491 491 )
492 492 out = "False\nFalse\nFalse\n"
493 493 tt.ipexec_validate(self.fname, out)
494 494
495 495 class Negator(ast.NodeTransformer):
496 496 """Negates all number literals in an AST."""
497 497 def visit_Num(self, node):
498 498 node.n = -node.n
499 499 return node
500 500
501 501 class TestAstTransform(unittest.TestCase):
502 502 def setUp(self):
503 503 self.negator = Negator()
504 504 ip.ast_transformers.append(self.negator)
505 505
506 506 def tearDown(self):
507 507 ip.ast_transformers.remove(self.negator)
508 508
509 509 def test_run_cell(self):
510 510 with tt.AssertPrints('-34'):
511 511 ip.run_cell('print (12 + 22)')
512 512
513 513 # A named reference to a number shouldn't be transformed.
514 514 ip.user_ns['n'] = 55
515 515 with tt.AssertNotPrints('-55'):
516 516 ip.run_cell('print (n)')
517 517
518 518 def test_timeit(self):
519 519 called = set()
520 520 def f(x):
521 521 called.add(x)
522 522 ip.push({'f':f})
523 523
524 524 with tt.AssertPrints("best of "):
525 525 ip.run_line_magic("timeit", "-n1 f(1)")
526 526 self.assertEqual(called, set([-1]))
527 527 called.clear()
528 528
529 529 with tt.AssertPrints("best of "):
530 530 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
531 531 self.assertEqual(called, set([-2, -3]))
532 532
533 533 def test_time(self):
534 534 called = []
535 535 def f(x):
536 536 called.append(x)
537 537 ip.push({'f':f})
538 538
539 539 # Test with an expression
540 540 with tt.AssertPrints("Wall time: "):
541 541 ip.run_line_magic("time", "f(5+9)")
542 542 self.assertEqual(called, [-14])
543 543 called[:] = []
544 544
545 545 # Test with a statement (different code path)
546 546 with tt.AssertPrints("Wall time: "):
547 547 ip.run_line_magic("time", "a = f(-3 + -2)")
548 548 self.assertEqual(called, [5])
549 549
550 550 def test_macro(self):
551 551 ip.push({'a':10})
552 552 # The AST transformation makes this do a+=-1
553 553 ip.define_macro("amacro", "a+=1\nprint(a)")
554 554
555 555 with tt.AssertPrints("9"):
556 556 ip.run_cell("amacro")
557 557 with tt.AssertPrints("8"):
558 558 ip.run_cell("amacro")
559 559
560 560 class IntegerWrapper(ast.NodeTransformer):
561 561 """Wraps all integers in a call to Integer()"""
562 562 def visit_Num(self, node):
563 563 if isinstance(node.n, int):
564 564 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
565 565 args=[node], keywords=[])
566 566 return node
567 567
568 568 class TestAstTransform2(unittest.TestCase):
569 569 def setUp(self):
570 570 self.intwrapper = IntegerWrapper()
571 571 ip.ast_transformers.append(self.intwrapper)
572 572
573 573 self.calls = []
574 574 def Integer(*args):
575 575 self.calls.append(args)
576 576 return args
577 577 ip.push({"Integer": Integer})
578 578
579 579 def tearDown(self):
580 580 ip.ast_transformers.remove(self.intwrapper)
581 581 del ip.user_ns['Integer']
582 582
583 583 def test_run_cell(self):
584 584 ip.run_cell("n = 2")
585 585 self.assertEqual(self.calls, [(2,)])
586 586
587 587 # This shouldn't throw an error
588 588 ip.run_cell("o = 2.0")
589 589 self.assertEqual(ip.user_ns['o'], 2.0)
590 590
591 591 def test_timeit(self):
592 592 called = set()
593 593 def f(x):
594 594 called.add(x)
595 595 ip.push({'f':f})
596 596
597 597 with tt.AssertPrints("best of "):
598 598 ip.run_line_magic("timeit", "-n1 f(1)")
599 599 self.assertEqual(called, set([(1,)]))
600 600 called.clear()
601 601
602 602 with tt.AssertPrints("best of "):
603 603 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
604 604 self.assertEqual(called, set([(2,), (3,)]))
605 605
606 606 class ErrorTransformer(ast.NodeTransformer):
607 607 """Throws an error when it sees a number."""
608 608 def visit_Num(self):
609 609 raise ValueError("test")
610 610
611 611 class TestAstTransformError(unittest.TestCase):
612 612 def test_unregistering(self):
613 613 err_transformer = ErrorTransformer()
614 614 ip.ast_transformers.append(err_transformer)
615 615
616 616 with tt.AssertPrints("unregister", channel='stderr'):
617 617 ip.run_cell("1 + 2")
618 618
619 619 # This should have been removed.
620 620 nt.assert_not_in(err_transformer, ip.ast_transformers)
621 621
622 622 def test__IPYTHON__():
623 623 # This shouldn't raise a NameError, that's all
624 624 __IPYTHON__
625 625
626 626
627 627 class DummyRepr(object):
628 628 def __repr__(self):
629 629 return "DummyRepr"
630 630
631 631 def _repr_html_(self):
632 632 return "<b>dummy</b>"
633 633
634 634 def _repr_javascript_(self):
635 635 return "console.log('hi');", {'key': 'value'}
636 636
637 637
638 638 def test_user_variables():
639 639 # enable all formatters
640 640 ip.display_formatter.active_types = ip.display_formatter.format_types
641 641
642 642 ip.user_ns['dummy'] = d = DummyRepr()
643 643 keys = set(['dummy', 'doesnotexist'])
644 644 r = ip.user_variables(keys)
645 645
646 646 nt.assert_equal(keys, set(r.keys()))
647 647 dummy = r['dummy']
648 648 nt.assert_equal(set(['status', 'data', 'metadata']), set(dummy.keys()))
649 649 nt.assert_equal(dummy['status'], 'ok')
650 650 data = dummy['data']
651 651 metadata = dummy['metadata']
652 652 nt.assert_equal(data.get('text/html'), d._repr_html_())
653 653 js, jsmd = d._repr_javascript_()
654 654 nt.assert_equal(data.get('application/javascript'), js)
655 655 nt.assert_equal(metadata.get('application/javascript'), jsmd)
656 656
657 657 dne = r['doesnotexist']
658 658 nt.assert_equal(dne['status'], 'error')
659 659 nt.assert_equal(dne['ename'], 'KeyError')
660 660
661 661 # back to text only
662 662 ip.display_formatter.active_types = ['text/plain']
663 663
664 664 def test_user_expression():
665 665 # enable all formatters
666 666 ip.display_formatter.active_types = ip.display_formatter.format_types
667 667 query = {
668 668 'a' : '1 + 2',
669 669 'b' : '1/0',
670 670 }
671 671 r = ip.user_expressions(query)
672 672 import pprint
673 673 pprint.pprint(r)
674 674 nt.assert_equal(set(r.keys()), set(query.keys()))
675 675 a = r['a']
676 676 nt.assert_equal(set(['status', 'data', 'metadata']), set(a.keys()))
677 677 nt.assert_equal(a['status'], 'ok')
678 678 data = a['data']
679 679 metadata = a['metadata']
680 680 nt.assert_equal(data.get('text/plain'), '3')
681 681
682 682 b = r['b']
683 683 nt.assert_equal(b['status'], 'error')
684 684 nt.assert_equal(b['ename'], 'ZeroDivisionError')
685 685
686 686 # back to text only
687 687 ip.display_formatter.active_types = ['text/plain']
688 688
689 689
690 690
691 691
692 692
693 693 class TestSyntaxErrorTransformer(unittest.TestCase):
694 694 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
695 695
696 696 class SyntaxErrorTransformer(InputTransformer):
697 697
698 698 def push(self, line):
699 699 pos = line.find('syntaxerror')
700 700 if pos >= 0:
701 701 e = SyntaxError('input contains "syntaxerror"')
702 702 e.text = line
703 703 e.offset = pos + 1
704 704 raise e
705 705 return line
706 706
707 707 def reset(self):
708 708 pass
709 709
710 710 def setUp(self):
711 711 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
712 712 ip.input_splitter.python_line_transforms.append(self.transformer)
713 713 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
714 714
715 715 def tearDown(self):
716 716 ip.input_splitter.python_line_transforms.remove(self.transformer)
717 717 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
718 718
719 719 def test_syntaxerror_input_transformer(self):
720 720 with tt.AssertPrints('1234'):
721 721 ip.run_cell('1234')
722 722 with tt.AssertPrints('SyntaxError: invalid syntax'):
723 723 ip.run_cell('1 2 3') # plain python syntax error
724 724 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
725 725 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
726 726 with tt.AssertPrints('3456'):
727 727 ip.run_cell('3456')
728 728
729 729
730 730
@@ -1,505 +1,505 b''
1 1 """IPython extension to reload modules before executing user code.
2 2
3 3 ``autoreload`` reloads modules automatically before entering the execution of
4 4 code typed at the IPython prompt.
5 5
6 6 This makes for example the following workflow possible:
7 7
8 8 .. sourcecode:: ipython
9 9
10 10 In [1]: %load_ext autoreload
11 11
12 12 In [2]: %autoreload 2
13 13
14 14 In [3]: from foo import some_function
15 15
16 16 In [4]: some_function()
17 17 Out[4]: 42
18 18
19 19 In [5]: # open foo.py in an editor and change some_function to return 43
20 20
21 21 In [6]: some_function()
22 22 Out[6]: 43
23 23
24 24 The module was reloaded without reloading it explicitly, and the object
25 25 imported with ``from foo import ...`` was also updated.
26 26
27 27 Usage
28 28 =====
29 29
30 30 The following magic commands are provided:
31 31
32 32 ``%autoreload``
33 33
34 34 Reload all modules (except those excluded by ``%aimport``)
35 35 automatically now.
36 36
37 37 ``%autoreload 0``
38 38
39 39 Disable automatic reloading.
40 40
41 41 ``%autoreload 1``
42 42
43 43 Reload all modules imported with ``%aimport`` every time before
44 44 executing the Python code typed.
45 45
46 46 ``%autoreload 2``
47 47
48 48 Reload all modules (except those excluded by ``%aimport``) every
49 49 time before executing the Python code typed.
50 50
51 51 ``%aimport``
52 52
53 53 List modules which are to be automatically imported or not to be imported.
54 54
55 55 ``%aimport foo``
56 56
57 57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58 58
59 59 ``%aimport -foo``
60 60
61 61 Mark module 'foo' to not be autoreloaded.
62 62
63 63 Caveats
64 64 =======
65 65
66 66 Reloading Python modules in a reliable way is in general difficult,
67 67 and unexpected things may occur. ``%autoreload`` tries to work around
68 68 common pitfalls by replacing function code objects and parts of
69 69 classes previously in the module with new versions. This makes the
70 70 following things to work:
71 71
72 72 - Functions and classes imported via 'from xxx import foo' are upgraded
73 73 to new versions when 'xxx' is reloaded.
74 74
75 75 - Methods and properties of classes are upgraded on reload, so that
76 76 calling 'c.foo()' on an object 'c' created before the reload causes
77 77 the new code for 'foo' to be executed.
78 78
79 79 Some of the known remaining caveats are:
80 80
81 81 - Replacing code objects does not always succeed: changing a @property
82 82 in a class to an ordinary method or a method to a member variable
83 83 can cause problems (but in old objects only).
84 84
85 85 - Functions that are removed (eg. via monkey-patching) from a module
86 86 before it is reloaded are not upgraded.
87 87
88 88 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
89 89 """
90 90 from __future__ import print_function
91 91
92 92 skip_doctest = True
93 93
94 94 #-----------------------------------------------------------------------------
95 95 # Copyright (C) 2000 Thomas Heller
96 96 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
97 97 # Copyright (C) 2012 The IPython Development Team
98 98 #
99 99 # Distributed under the terms of the BSD License. The full license is in
100 100 # the file COPYING, distributed as part of this software.
101 101 #-----------------------------------------------------------------------------
102 102 #
103 103 # This IPython module is written by Pauli Virtanen, based on the autoreload
104 104 # code by Thomas Heller.
105 105
106 106 #-----------------------------------------------------------------------------
107 107 # Imports
108 108 #-----------------------------------------------------------------------------
109 109
110 110 import os
111 111 import sys
112 112 import traceback
113 113 import types
114 114 import weakref
115 115
116 116 try:
117 117 # Reload is not defined by default in Python3.
118 118 reload
119 119 except NameError:
120 120 from imp import reload
121 121
122 122 from IPython.utils import openpy
123 123 from IPython.utils.py3compat import PY3
124 124
125 125 #------------------------------------------------------------------------------
126 126 # Autoreload functionality
127 127 #------------------------------------------------------------------------------
128 128
129 129 class ModuleReloader(object):
130 130 enabled = False
131 131 """Whether this reloader is enabled"""
132 132
133 133 failed = {}
134 134 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
135 135
136 136 modules = {}
137 137 """Modules specially marked as autoreloadable."""
138 138
139 139 skip_modules = {}
140 140 """Modules specially marked as not autoreloadable."""
141 141
142 142 check_all = True
143 143 """Autoreload all modules, not just those listed in 'modules'"""
144 144
145 145 old_objects = {}
146 146 """(module-name, name) -> weakref, for replacing old code objects"""
147 147
148 148 def mark_module_skipped(self, module_name):
149 149 """Skip reloading the named module in the future"""
150 150 try:
151 151 del self.modules[module_name]
152 152 except KeyError:
153 153 pass
154 154 self.skip_modules[module_name] = True
155 155
156 156 def mark_module_reloadable(self, module_name):
157 157 """Reload the named module in the future (if it is imported)"""
158 158 try:
159 159 del self.skip_modules[module_name]
160 160 except KeyError:
161 161 pass
162 162 self.modules[module_name] = True
163 163
164 164 def aimport_module(self, module_name):
165 165 """Import a module, and mark it reloadable
166 166
167 167 Returns
168 168 -------
169 169 top_module : module
170 170 The imported module if it is top-level, or the top-level
171 171 top_name : module
172 172 Name of top_module
173 173
174 174 """
175 175 self.mark_module_reloadable(module_name)
176 176
177 177 __import__(module_name)
178 178 top_name = module_name.split('.')[0]
179 179 top_module = sys.modules[top_name]
180 180 return top_module, top_name
181 181
182 182 def check(self, check_all=False):
183 183 """Check whether some modules need to be reloaded."""
184 184
185 185 if not self.enabled and not check_all:
186 186 return
187 187
188 188 if check_all or self.check_all:
189 189 modules = list(sys.modules.keys())
190 190 else:
191 191 modules = list(self.modules.keys())
192 192
193 193 for modname in modules:
194 194 m = sys.modules.get(modname, None)
195 195
196 196 if modname in self.skip_modules:
197 197 continue
198 198
199 199 if not hasattr(m, '__file__'):
200 200 continue
201 201
202 202 if m.__name__ == '__main__':
203 203 # we cannot reload(__main__)
204 204 continue
205 205
206 206 filename = m.__file__
207 207 path, ext = os.path.splitext(filename)
208 208
209 209 if ext.lower() == '.py':
210 210 pyc_filename = openpy.cache_from_source(filename)
211 211 py_filename = filename
212 212 else:
213 213 pyc_filename = filename
214 214 try:
215 215 py_filename = openpy.source_from_cache(filename)
216 216 except ValueError:
217 217 continue
218 218
219 219 try:
220 220 pymtime = os.stat(py_filename).st_mtime
221 221 if pymtime <= os.stat(pyc_filename).st_mtime:
222 222 continue
223 223 if self.failed.get(py_filename, None) == pymtime:
224 224 continue
225 225 except OSError:
226 226 continue
227 227
228 228 try:
229 229 superreload(m, reload, self.old_objects)
230 230 if py_filename in self.failed:
231 231 del self.failed[py_filename]
232 232 except:
233 233 print("[autoreload of %s failed: %s]" % (
234 234 modname, traceback.format_exc(1)), file=sys.stderr)
235 235 self.failed[py_filename] = pymtime
236 236
237 237 #------------------------------------------------------------------------------
238 238 # superreload
239 239 #------------------------------------------------------------------------------
240 240
241 241 if PY3:
242 242 func_attrs = ['__code__', '__defaults__', '__doc__',
243 243 '__closure__', '__globals__', '__dict__']
244 244 else:
245 245 func_attrs = ['func_code', 'func_defaults', 'func_doc',
246 246 'func_closure', 'func_globals', 'func_dict']
247 247
248 248
249 249 def update_function(old, new):
250 250 """Upgrade the code object of a function"""
251 251 for name in func_attrs:
252 252 try:
253 253 setattr(old, name, getattr(new, name))
254 254 except (AttributeError, TypeError):
255 255 pass
256 256
257 257
258 258 def update_class(old, new):
259 259 """Replace stuff in the __dict__ of a class, and upgrade
260 260 method code objects"""
261 261 for key in list(old.__dict__.keys()):
262 262 old_obj = getattr(old, key)
263 263
264 264 try:
265 265 new_obj = getattr(new, key)
266 266 except AttributeError:
267 267 # obsolete attribute: remove it
268 268 try:
269 269 delattr(old, key)
270 270 except (AttributeError, TypeError):
271 271 pass
272 272 continue
273 273
274 274 if update_generic(old_obj, new_obj): continue
275 275
276 276 try:
277 277 setattr(old, key, getattr(new, key))
278 278 except (AttributeError, TypeError):
279 279 pass # skip non-writable attributes
280 280
281 281
282 282 def update_property(old, new):
283 283 """Replace get/set/del functions of a property"""
284 284 update_generic(old.fdel, new.fdel)
285 285 update_generic(old.fget, new.fget)
286 286 update_generic(old.fset, new.fset)
287 287
288 288
289 289 def isinstance2(a, b, typ):
290 290 return isinstance(a, typ) and isinstance(b, typ)
291 291
292 292
293 293 UPDATE_RULES = [
294 294 (lambda a, b: isinstance2(a, b, type),
295 295 update_class),
296 296 (lambda a, b: isinstance2(a, b, types.FunctionType),
297 297 update_function),
298 298 (lambda a, b: isinstance2(a, b, property),
299 299 update_property),
300 300 ]
301 301
302 302
303 303 if PY3:
304 304 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
305 305 lambda a, b: update_function(a.__func__, b.__func__)),
306 306 ])
307 307 else:
308 308 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.ClassType),
309 309 update_class),
310 310 (lambda a, b: isinstance2(a, b, types.MethodType),
311 311 lambda a, b: update_function(a.__func__, b.__func__)),
312 312 ])
313 313
314 314
315 315 def update_generic(a, b):
316 316 for type_check, update in UPDATE_RULES:
317 317 if type_check(a, b):
318 318 update(a, b)
319 319 return True
320 320 return False
321 321
322 322
323 323 class StrongRef(object):
324 324 def __init__(self, obj):
325 325 self.obj = obj
326 326 def __call__(self):
327 327 return self.obj
328 328
329 329
330 330 def superreload(module, reload=reload, old_objects={}):
331 331 """Enhanced version of the builtin reload function.
332 332
333 333 superreload remembers objects previously in the module, and
334 334
335 335 - upgrades the class dictionary of every old class in the module
336 336 - upgrades the code object of every old function and method
337 337 - clears the module's namespace before reloading
338 338
339 339 """
340 340
341 341 # collect old objects in the module
342 342 for name, obj in list(module.__dict__.items()):
343 343 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
344 344 continue
345 345 key = (module.__name__, name)
346 346 try:
347 347 old_objects.setdefault(key, []).append(weakref.ref(obj))
348 348 except TypeError:
349 349 # weakref doesn't work for all types;
350 350 # create strong references for 'important' cases
351 351 if not PY3 and isinstance(obj, types.ClassType):
352 352 old_objects.setdefault(key, []).append(StrongRef(obj))
353 353
354 354 # reload module
355 355 try:
356 356 # clear namespace first from old cruft
357 357 old_dict = module.__dict__.copy()
358 358 old_name = module.__name__
359 359 module.__dict__.clear()
360 360 module.__dict__['__name__'] = old_name
361 361 module.__dict__['__loader__'] = old_dict['__loader__']
362 362 except (TypeError, AttributeError, KeyError):
363 363 pass
364 364
365 365 try:
366 366 module = reload(module)
367 367 except:
368 368 # restore module dictionary on failed reload
369 369 module.__dict__.update(old_dict)
370 370 raise
371 371
372 372 # iterate over all objects and update functions & classes
373 373 for name, new_obj in list(module.__dict__.items()):
374 374 key = (module.__name__, name)
375 375 if key not in old_objects: continue
376 376
377 377 new_refs = []
378 378 for old_ref in old_objects[key]:
379 379 old_obj = old_ref()
380 380 if old_obj is None: continue
381 381 new_refs.append(old_ref)
382 382 update_generic(old_obj, new_obj)
383 383
384 384 if new_refs:
385 385 old_objects[key] = new_refs
386 386 else:
387 387 del old_objects[key]
388 388
389 389 return module
390 390
391 391 #------------------------------------------------------------------------------
392 392 # IPython connectivity
393 393 #------------------------------------------------------------------------------
394 394
395 395 from IPython.core.magic import Magics, magics_class, line_magic
396 396
397 397 @magics_class
398 398 class AutoreloadMagics(Magics):
399 399 def __init__(self, *a, **kw):
400 400 super(AutoreloadMagics, self).__init__(*a, **kw)
401 401 self._reloader = ModuleReloader()
402 402 self._reloader.check_all = False
403 403
404 404 @line_magic
405 405 def autoreload(self, parameter_s=''):
406 406 r"""%autoreload => Reload modules automatically
407 407
408 408 %autoreload
409 409 Reload all modules (except those excluded by %aimport) automatically
410 410 now.
411 411
412 412 %autoreload 0
413 413 Disable automatic reloading.
414 414
415 415 %autoreload 1
416 416 Reload all modules imported with %aimport every time before executing
417 417 the Python code typed.
418 418
419 419 %autoreload 2
420 420 Reload all modules (except those excluded by %aimport) every time
421 421 before executing the Python code typed.
422 422
423 423 Reloading Python modules in a reliable way is in general
424 424 difficult, and unexpected things may occur. %autoreload tries to
425 425 work around common pitfalls by replacing function code objects and
426 426 parts of classes previously in the module with new versions. This
427 427 makes the following things to work:
428 428
429 429 - Functions and classes imported via 'from xxx import foo' are upgraded
430 430 to new versions when 'xxx' is reloaded.
431 431
432 432 - Methods and properties of classes are upgraded on reload, so that
433 433 calling 'c.foo()' on an object 'c' created before the reload causes
434 434 the new code for 'foo' to be executed.
435 435
436 436 Some of the known remaining caveats are:
437 437
438 438 - Replacing code objects does not always succeed: changing a @property
439 439 in a class to an ordinary method or a method to a member variable
440 440 can cause problems (but in old objects only).
441 441
442 442 - Functions that are removed (eg. via monkey-patching) from a module
443 443 before it is reloaded are not upgraded.
444 444
445 445 - C extension modules cannot be reloaded, and so cannot be
446 446 autoreloaded.
447 447
448 448 """
449 449 if parameter_s == '':
450 450 self._reloader.check(True)
451 451 elif parameter_s == '0':
452 452 self._reloader.enabled = False
453 453 elif parameter_s == '1':
454 454 self._reloader.check_all = False
455 455 self._reloader.enabled = True
456 456 elif parameter_s == '2':
457 457 self._reloader.check_all = True
458 458 self._reloader.enabled = True
459 459
460 460 @line_magic
461 461 def aimport(self, parameter_s='', stream=None):
462 462 """%aimport => Import modules for automatic reloading.
463 463
464 464 %aimport
465 465 List modules to automatically import and not to import.
466 466
467 467 %aimport foo
468 468 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
469 469
470 470 %aimport -foo
471 471 Mark module 'foo' to not be autoreloaded for %autoreload 1
472 472 """
473 473 modname = parameter_s
474 474 if not modname:
475 475 to_reload = sorted(self._reloader.modules.keys())
476 476 to_skip = sorted(self._reloader.skip_modules.keys())
477 477 if stream is None:
478 478 stream = sys.stdout
479 479 if self._reloader.check_all:
480 480 stream.write("Modules to reload:\nall-except-skipped\n")
481 481 else:
482 482 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
483 483 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
484 484 elif modname.startswith('-'):
485 485 modname = modname[1:]
486 486 self._reloader.mark_module_skipped(modname)
487 487 else:
488 488 top_module, top_name = self._reloader.aimport_module(modname)
489 489
490 490 # Inject module to user namespace
491 491 self.shell.push({top_name: top_module})
492 492
493 493 def pre_execute_explicit(self):
494 494 if self._reloader.enabled:
495 495 try:
496 496 self._reloader.check()
497 497 except:
498 498 pass
499 499
500 500
501 501 def load_ipython_extension(ip):
502 502 """Load the extension in IPython."""
503 503 auto_reload = AutoreloadMagics(ip)
504 504 ip.register_magics(auto_reload)
505 ip.callbacks.register('pre_execute_explicit', auto_reload.pre_execute_explicit)
505 ip.events.register('pre_execute_explicit', auto_reload.pre_execute_explicit)
@@ -1,321 +1,321 b''
1 1 """Tests for autoreload extension.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (c) 2012 IPython Development Team.
5 5 #
6 6 # Distributed under the terms of the Modified BSD License.
7 7 #
8 8 # The full license is in the file COPYING.txt, distributed with this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 import os
16 16 import sys
17 17 import tempfile
18 18 import shutil
19 19 import random
20 20 import time
21 21
22 22 import nose.tools as nt
23 23 import IPython.testing.tools as tt
24 24
25 25 from IPython.extensions.autoreload import AutoreloadMagics
26 from IPython.core.callbacks import CallbackManager, pre_execute_explicit
26 from IPython.core.events import EventManager, pre_execute_explicit
27 27 from IPython.utils.py3compat import PY3
28 28
29 29 if PY3:
30 30 from io import StringIO
31 31 else:
32 32 from StringIO import StringIO
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # Test fixture
36 36 #-----------------------------------------------------------------------------
37 37
38 38 noop = lambda *a, **kw: None
39 39
40 40 class FakeShell(object):
41 41
42 42 def __init__(self):
43 43 self.ns = {}
44 self.callbacks = CallbackManager(self, {'pre_execute_explicit', pre_execute_explicit})
44 self.events = EventManager(self, {'pre_execute_explicit', pre_execute_explicit})
45 45 self.auto_magics = AutoreloadMagics(shell=self)
46 self.callbacks.register('pre_execute_explicit', self.auto_magics.pre_execute_explicit)
46 self.events.register('pre_execute_explicit', self.auto_magics.pre_execute_explicit)
47 47
48 48 register_magics = set_hook = noop
49 49
50 50 def run_code(self, code):
51 self.callbacks.fire('pre_execute_explicit')
51 self.events.trigger('pre_execute_explicit')
52 52 exec(code, self.ns)
53 53
54 54 def push(self, items):
55 55 self.ns.update(items)
56 56
57 57 def magic_autoreload(self, parameter):
58 58 self.auto_magics.autoreload(parameter)
59 59
60 60 def magic_aimport(self, parameter, stream=None):
61 61 self.auto_magics.aimport(parameter, stream=stream)
62 62
63 63
64 64 class Fixture(object):
65 65 """Fixture for creating test module files"""
66 66
67 67 test_dir = None
68 68 old_sys_path = None
69 69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
70 70
71 71 def setUp(self):
72 72 self.test_dir = tempfile.mkdtemp()
73 73 self.old_sys_path = list(sys.path)
74 74 sys.path.insert(0, self.test_dir)
75 75 self.shell = FakeShell()
76 76
77 77 def tearDown(self):
78 78 shutil.rmtree(self.test_dir)
79 79 sys.path = self.old_sys_path
80 80
81 81 self.test_dir = None
82 82 self.old_sys_path = None
83 83 self.shell = None
84 84
85 85 def get_module(self):
86 86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
87 87 if module_name in sys.modules:
88 88 del sys.modules[module_name]
89 89 file_name = os.path.join(self.test_dir, module_name + ".py")
90 90 return module_name, file_name
91 91
92 92 def write_file(self, filename, content):
93 93 """
94 94 Write a file, and force a timestamp difference of at least one second
95 95
96 96 Notes
97 97 -----
98 98 Python's .pyc files record the timestamp of their compilation
99 99 with a time resolution of one second.
100 100
101 101 Therefore, we need to force a timestamp difference between .py
102 102 and .pyc, without having the .py file be timestamped in the
103 103 future, and without changing the timestamp of the .pyc file
104 104 (because that is stored in the file). The only reliable way
105 105 to achieve this seems to be to sleep.
106 106 """
107 107
108 108 # Sleep one second + eps
109 109 time.sleep(1.05)
110 110
111 111 # Write
112 112 f = open(filename, 'w')
113 113 try:
114 114 f.write(content)
115 115 finally:
116 116 f.close()
117 117
118 118 def new_module(self, code):
119 119 mod_name, mod_fn = self.get_module()
120 120 f = open(mod_fn, 'w')
121 121 try:
122 122 f.write(code)
123 123 finally:
124 124 f.close()
125 125 return mod_name, mod_fn
126 126
127 127 #-----------------------------------------------------------------------------
128 128 # Test automatic reloading
129 129 #-----------------------------------------------------------------------------
130 130
131 131 class TestAutoreload(Fixture):
132 132 def _check_smoketest(self, use_aimport=True):
133 133 """
134 134 Functional test for the automatic reloader using either
135 135 '%autoreload 1' or '%autoreload 2'
136 136 """
137 137
138 138 mod_name, mod_fn = self.new_module("""
139 139 x = 9
140 140
141 141 z = 123 # this item will be deleted
142 142
143 143 def foo(y):
144 144 return y + 3
145 145
146 146 class Baz(object):
147 147 def __init__(self, x):
148 148 self.x = x
149 149 def bar(self, y):
150 150 return self.x + y
151 151 @property
152 152 def quux(self):
153 153 return 42
154 154 def zzz(self):
155 155 '''This method will be deleted below'''
156 156 return 99
157 157
158 158 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
159 159 def foo(self):
160 160 return 1
161 161 """)
162 162
163 163 #
164 164 # Import module, and mark for reloading
165 165 #
166 166 if use_aimport:
167 167 self.shell.magic_autoreload("1")
168 168 self.shell.magic_aimport(mod_name)
169 169 stream = StringIO()
170 170 self.shell.magic_aimport("", stream=stream)
171 171 nt.assert_true(("Modules to reload:\n%s" % mod_name) in
172 172 stream.getvalue())
173 173
174 174 nt.assert_raises(
175 175 ImportError,
176 176 self.shell.magic_aimport, "tmpmod_as318989e89ds")
177 177 else:
178 178 self.shell.magic_autoreload("2")
179 179 self.shell.run_code("import %s" % mod_name)
180 180 stream = StringIO()
181 181 self.shell.magic_aimport("", stream=stream)
182 182 nt.assert_true("Modules to reload:\nall-except-skipped" in
183 183 stream.getvalue())
184 184 nt.assert_in(mod_name, self.shell.ns)
185 185
186 186 mod = sys.modules[mod_name]
187 187
188 188 #
189 189 # Test module contents
190 190 #
191 191 old_foo = mod.foo
192 192 old_obj = mod.Baz(9)
193 193 old_obj2 = mod.Bar()
194 194
195 195 def check_module_contents():
196 196 nt.assert_equal(mod.x, 9)
197 197 nt.assert_equal(mod.z, 123)
198 198
199 199 nt.assert_equal(old_foo(0), 3)
200 200 nt.assert_equal(mod.foo(0), 3)
201 201
202 202 obj = mod.Baz(9)
203 203 nt.assert_equal(old_obj.bar(1), 10)
204 204 nt.assert_equal(obj.bar(1), 10)
205 205 nt.assert_equal(obj.quux, 42)
206 206 nt.assert_equal(obj.zzz(), 99)
207 207
208 208 obj2 = mod.Bar()
209 209 nt.assert_equal(old_obj2.foo(), 1)
210 210 nt.assert_equal(obj2.foo(), 1)
211 211
212 212 check_module_contents()
213 213
214 214 #
215 215 # Simulate a failed reload: no reload should occur and exactly
216 216 # one error message should be printed
217 217 #
218 218 self.write_file(mod_fn, """
219 219 a syntax error
220 220 """)
221 221
222 222 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
223 223 self.shell.run_code("pass") # trigger reload
224 224 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
225 225 self.shell.run_code("pass") # trigger another reload
226 226 check_module_contents()
227 227
228 228 #
229 229 # Rewrite module (this time reload should succeed)
230 230 #
231 231 self.write_file(mod_fn, """
232 232 x = 10
233 233
234 234 def foo(y):
235 235 return y + 4
236 236
237 237 class Baz(object):
238 238 def __init__(self, x):
239 239 self.x = x
240 240 def bar(self, y):
241 241 return self.x + y + 1
242 242 @property
243 243 def quux(self):
244 244 return 43
245 245
246 246 class Bar: # old-style class
247 247 def foo(self):
248 248 return 2
249 249 """)
250 250
251 251 def check_module_contents():
252 252 nt.assert_equal(mod.x, 10)
253 253 nt.assert_false(hasattr(mod, 'z'))
254 254
255 255 nt.assert_equal(old_foo(0), 4) # superreload magic!
256 256 nt.assert_equal(mod.foo(0), 4)
257 257
258 258 obj = mod.Baz(9)
259 259 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
260 260 nt.assert_equal(obj.bar(1), 11)
261 261
262 262 nt.assert_equal(old_obj.quux, 43)
263 263 nt.assert_equal(obj.quux, 43)
264 264
265 265 nt.assert_false(hasattr(old_obj, 'zzz'))
266 266 nt.assert_false(hasattr(obj, 'zzz'))
267 267
268 268 obj2 = mod.Bar()
269 269 nt.assert_equal(old_obj2.foo(), 2)
270 270 nt.assert_equal(obj2.foo(), 2)
271 271
272 272 self.shell.run_code("pass") # trigger reload
273 273 check_module_contents()
274 274
275 275 #
276 276 # Another failure case: deleted file (shouldn't reload)
277 277 #
278 278 os.unlink(mod_fn)
279 279
280 280 self.shell.run_code("pass") # trigger reload
281 281 check_module_contents()
282 282
283 283 #
284 284 # Disable autoreload and rewrite module: no reload should occur
285 285 #
286 286 if use_aimport:
287 287 self.shell.magic_aimport("-" + mod_name)
288 288 stream = StringIO()
289 289 self.shell.magic_aimport("", stream=stream)
290 290 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
291 291 stream.getvalue())
292 292
293 293 # This should succeed, although no such module exists
294 294 self.shell.magic_aimport("-tmpmod_as318989e89ds")
295 295 else:
296 296 self.shell.magic_autoreload("0")
297 297
298 298 self.write_file(mod_fn, """
299 299 x = -99
300 300 """)
301 301
302 302 self.shell.run_code("pass") # trigger reload
303 303 self.shell.run_code("pass")
304 304 check_module_contents()
305 305
306 306 #
307 307 # Re-enable autoreload: reload should now occur
308 308 #
309 309 if use_aimport:
310 310 self.shell.magic_aimport(mod_name)
311 311 else:
312 312 self.shell.magic_autoreload("")
313 313
314 314 self.shell.run_code("pass") # trigger reload
315 315 nt.assert_equal(mod.x, -99)
316 316
317 317 def test_smoketest_aimport(self):
318 318 self._check_smoketest(use_aimport=True)
319 319
320 320 def test_smoketest_autoreload(self):
321 321 self._check_smoketest(use_aimport=False)
@@ -1,142 +1,142 b''
1 1 """Base class for a Comm"""
2 2
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2013 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13
14 14 import uuid
15 15
16 16 from IPython.config import LoggingConfigurable
17 17 from IPython.core.getipython import get_ipython
18 18
19 19 from IPython.utils.traitlets import Instance, Unicode, Bytes, Bool, Dict, Any
20 20
21 21 #-----------------------------------------------------------------------------
22 22 # Code
23 23 #-----------------------------------------------------------------------------
24 24
25 25 class Comm(LoggingConfigurable):
26 26
27 27 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
28 28 def _shell_default(self):
29 29 return get_ipython()
30 30
31 31 iopub_socket = Any()
32 32 def _iopub_socket_default(self):
33 33 return self.shell.kernel.iopub_socket
34 34 session = Instance('IPython.kernel.zmq.session.Session')
35 35 def _session_default(self):
36 36 if self.shell is None:
37 37 return
38 38 return self.shell.kernel.session
39 39
40 40 target_name = Unicode('comm')
41 41
42 42 topic = Bytes()
43 43 def _topic_default(self):
44 44 return ('comm-%s' % self.comm_id).encode('ascii')
45 45
46 46 _open_data = Dict(help="data dict, if any, to be included in comm_open")
47 47 _close_data = Dict(help="data dict, if any, to be included in comm_close")
48 48
49 49 _msg_callback = Any()
50 50 _close_callback = Any()
51 51
52 52 _closed = Bool(False)
53 53 comm_id = Unicode()
54 54 def _comm_id_default(self):
55 55 return uuid.uuid4().hex
56 56
57 57 primary = Bool(True, help="Am I the primary or secondary Comm?")
58 58
59 59 def __init__(self, target_name='', data=None, **kwargs):
60 60 if target_name:
61 61 kwargs['target_name'] = target_name
62 62 super(Comm, self).__init__(**kwargs)
63 63 get_ipython().comm_manager.register_comm(self)
64 64 if self.primary:
65 65 # I am primary, open my peer.
66 66 self.open(data)
67 67
68 68 def _publish_msg(self, msg_type, data=None, metadata=None, **keys):
69 69 """Helper for sending a comm message on IOPub"""
70 70 data = {} if data is None else data
71 71 metadata = {} if metadata is None else metadata
72 72 self.session.send(self.iopub_socket, msg_type,
73 73 dict(data=data, comm_id=self.comm_id, **keys),
74 74 metadata=metadata,
75 75 parent=self.shell.get_parent(),
76 76 ident=self.topic,
77 77 )
78 78
79 79 def __del__(self):
80 80 """trigger close on gc"""
81 81 self.close()
82 82
83 83 # publishing messages
84 84
85 85 def open(self, data=None, metadata=None):
86 86 """Open the frontend-side version of this comm"""
87 87 if data is None:
88 88 data = self._open_data
89 89 self._publish_msg('comm_open', data, metadata, target_name=self.target_name)
90 90
91 91 def close(self, data=None, metadata=None):
92 92 """Close the frontend-side version of this comm"""
93 93 if self._closed:
94 94 # only close once
95 95 return
96 96 if data is None:
97 97 data = self._close_data
98 98 self._publish_msg('comm_close', data, metadata)
99 99 self._closed = True
100 100
101 101 def send(self, data=None, metadata=None):
102 102 """Send a message to the frontend-side version of this comm"""
103 103 self._publish_msg('comm_msg', data, metadata)
104 104
105 105 # registering callbacks
106 106
107 107 def on_close(self, callback):
108 108 """Register a callback for comm_close
109 109
110 110 Will be called with the `data` of the close message.
111 111
112 112 Call `on_close(None)` to disable an existing callback.
113 113 """
114 114 self._close_callback = callback
115 115
116 116 def on_msg(self, callback):
117 117 """Register a callback for comm_msg
118 118
119 119 Will be called with the `data` of any comm_msg messages.
120 120
121 121 Call `on_msg(None)` to disable an existing callback.
122 122 """
123 123 self._msg_callback = callback
124 124
125 125 # handling of incoming messages
126 126
127 127 def handle_close(self, msg):
128 128 """Handle a comm_close message"""
129 129 self.log.debug("handle_close[%s](%s)", self.comm_id, msg)
130 130 if self._close_callback:
131 131 self._close_callback(msg)
132 132
133 133 def handle_msg(self, msg):
134 134 """Handle a comm_msg message"""
135 135 self.log.debug("handle_msg[%s](%s)", self.comm_id, msg)
136 136 if self._msg_callback:
137 self.shell.callbacks.fire('pre_execute')
137 self.shell.events.trigger('pre_execute')
138 138 self._msg_callback(msg)
139 self.shell.callbacks.fire('post_execute')
139 self.shell.events.trigger('post_execute')
140 140
141 141
142 142 __all__ = ['Comm']
General Comments 0
You need to be logged in to leave comments. Login now